👽 Add forward driver support for mirai-api-http adapter

This commit is contained in:
Mix 2021-08-04 00:35:31 +08:00
parent cda1ad093f
commit 358528b495
3 changed files with 45 additions and 216 deletions

View File

@ -28,6 +28,9 @@ Mirai-API-HTTP 的适配器以 `AGPLv3许可`_ 单独开源
"""
from .bot import Bot
from .bot_ws import WebsocketBot
from .event import *
from .message import MessageChain, MessageSegment
"""
``WebsocketBot``现在已经和``Bot``合并, 并已经被弃用, 请直接使用``Bot``
"""
WebsocketBot = Bot

View File

@ -1,15 +1,17 @@
import json
from datetime import datetime, timedelta
from io import BytesIO
from ipaddress import IPv4Address
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
from loguru import logger
from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config
from nonebot.drivers import Driver, ReverseDriver, HTTPConnection, HTTPResponse, WebSocket, ForwardDriver, WebSocketSetup
from nonebot.exception import ApiNotAvailable
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket
from nonebot.typing import overrides
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
@ -152,15 +154,12 @@ class Bot(BaseBot):
"""
_type = 'mirai'
@property
@overrides(BaseBot)
def type(self) -> str:
return "mirai"
@property
def alive(self) -> bool:
assert isinstance(self.request, WebSocket)
return not self.request.closed
return self._type
@property
def api(self) -> SessionManager:
@ -190,21 +189,50 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
def register(cls, driver: Driver, config: "Config"):
def register(cls,
driver: Driver,
config: "Config",
qq: Optional[int] = None):
cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
raise ApiNotAvailable('mirai')
raise ApiNotAvailable(cls._type)
super().register(driver, config)
if not isinstance(driver, ForwardDriver) and qq:
logger.warning(
f"Current driver {cls.config.driver} don't support forward connections"
)
elif isinstance(driver, ForwardDriver) and qq:
async def url_factory():
assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key
session = await SessionManager.new(
qq, # type: ignore
host=cls.mirai_config.host,
port=cls.mirai_config.port,
auth_key=cls.mirai_config.auth_key)
return WebSocketSetup(
adapter=cls._type,
self_id=str(qq),
url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}'
f'/all?sessionKey={session.session_key}'))
driver.setup_websocket(url_factory)
elif isinstance(driver, ReverseDriver):
logger.debug(
'Param "qq" does not set for mirai adapter, use http post instead'
)
@overrides(BaseBot)
async def handle_message(self, message: dict):
async def handle_message(self, message: bytes):
Log.debug(f'received message {message}')
try:
await process_event(
bot=self,
event=Event.new({
**message,
**json.loads(message),
'self_id': self.self_id,
}),
)

View File

@ -1,202 +0,0 @@
import json
import asyncio
from dataclasses import dataclass
from ipaddress import IPv4Address
from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine
import httpx
import websockets
from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket
from .bot import SessionManager, Bot
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
@dataclass
class WebSocket(BaseWebSocket):
websocket: websockets.WebSocketClientProtocol = None # type: ignore
@classmethod
async def new(cls, *, host: IPv4Address, port: int,
session_key: str) -> "WebSocket":
listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping())
return cls("1.1",
listen_address.scheme,
listen_address.path,
listen_address.query,
websocket=websocket)
@overrides(BaseWebSocket)
def __init__(self,
http_version: str,
scheme: str,
path: str,
query_string: bytes = b"",
headers: Dict[str, str] = None,
websocket: websockets.WebSocketClientProtocol = None):
self.event_handlers: Set[WebsocketHandlerFunction] = set()
self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
super(WebSocket, self).__init__(http_version=http_version,
scheme=scheme,
path=path,
query_string=query_string,
headers=headers or {})
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return self.websocket.closed
@overrides(BaseWebSocket)
async def send(self, data: str):
return await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: str):
return await self.websocket.send(data)
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.recv() # type: ignore
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.recv() # type: ignore
async def _dispatcher(self):
while not self.closed:
try:
data = await self.receive()
except websockets.ConnectionClosedOK:
logger.debug(f'Websocket connection {self.websocket} closed')
break
except websockets.ConnectionClosedError:
logger.exception(f'Websocket connection {self.websocket} '
'connection closed abnormally:')
break
except json.JSONDecodeError as e:
logger.exception(f'Websocket client listened {self.websocket} '
f'failed to decode data: {e}')
continue
asyncio.gather(
*map(lambda f: f(data), self.event_handlers), #type: ignore
return_exceptions=True)
@overrides(BaseWebSocket)
async def accept(self):
asyncio.create_task(self._dispatcher())
@overrides(BaseWebSocket)
async def close(self):
await self.websocket.close()
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
self.event_handlers.add(callable)
return callable
class WebsocketBot(Bot):
"""
mirai-api-http 正向 Websocket 协议 Bot 适配
"""
@property
@overrides(Bot)
def type(self) -> str:
return "mirai-ws"
@property
def alive(self) -> bool:
assert isinstance(self.request, WebSocket)
return not self.request.closed
@property
def api(self) -> SessionManager:
api = SessionManager.get(self_id=int(self.self_id), check_expire=False)
assert api is not None, 'SessionManager has not been initialized'
return api
@classmethod
@overrides(Bot)
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[None, HTTPResponse]:
return None, HTTPResponse(501, b'Connection not implented')
@classmethod
@overrides(Bot)
def register(cls, driver: Driver, config: "Config", qq: int):
"""
:说明:
注册该Adapter
:参数:
* ``driver: Driver``: 程序所使用的``Driver``
* ``config: Config``: 程序配置对象
* ``qq: int``: 要使用的Bot的QQ号 **注意: 在使用正向Websocket时必须指定该值!**
"""
super().register(driver, config)
cls.active = True
async def _bot_connection():
session: SessionManager = await SessionManager.new(
qq,
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key # type: ignore
)
websocket = await WebSocket.new(
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
session_key=session.session_key)
bot = cls(self_id=str(qq), request=websocket)
websocket.handle(bot.handle_message)
await websocket.accept()
return bot
async def _connection_ensure():
self_id = str(qq)
if self_id not in driver._clients:
bot = await _bot_connection()
driver._bot_connect(bot)
else:
bot = driver._clients[self_id]
if not bot.alive:
driver._bot_disconnect(bot)
return
@driver.on_startup
async def _startup():
async def _checker():
while cls.active:
try:
await _connection_ensure()
except Exception as e:
logger.opt(colors=True).warning(
'Failed to create mirai connection to '
f'<y>{qq}</y>, reason: <r>{e}</r>. '
'Will retry after 3 seconds')
await asyncio.sleep(3)
asyncio.create_task(_checker())
@driver.on_shutdown
async def _shutdown():
cls.active = False
bot = driver._clients.pop(str(qq), None)
if bot is None:
return
await bot.websocket.close() #type:ignore