From 358528b4953a47614dfc79f18e9f635753abd6c3 Mon Sep 17 00:00:00 2001 From: Mix Date: Wed, 4 Aug 2021 00:35:31 +0800 Subject: [PATCH] :alien: :sparkles: Add forward driver support for mirai-api-http adapter --- .../nonebot/adapters/mirai/__init__.py | 5 +- .../nonebot/adapters/mirai/bot.py | 54 +++-- .../nonebot/adapters/mirai/bot_ws.py | 202 ------------------ 3 files changed, 45 insertions(+), 216 deletions(-) delete mode 100644 packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py index 5adc7a16..68c35ca4 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py @@ -28,6 +28,9 @@ Mirai-API-HTTP 的适配器以 `AGPLv3许可`_ 单独开源 """ from .bot import Bot -from .bot_ws import WebsocketBot from .event import * from .message import MessageChain, MessageSegment +""" +``WebsocketBot``现在已经和``Bot``合并, 并已经被弃用, 请直接使用``Bot`` +""" +WebsocketBot = Bot \ No newline at end of file diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py index 4b10d446..de96a29e 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py @@ -1,15 +1,17 @@ +import json from datetime import datetime, timedelta from io import BytesIO from ipaddress import IPv4Address from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union import httpx +from loguru import logger -from nonebot.config import Config -from nonebot.typing import overrides from nonebot.adapters import Bot as BaseBot +from nonebot.config import Config +from nonebot.drivers import Driver, ReverseDriver, HTTPConnection, HTTPResponse, WebSocket, ForwardDriver, WebSocketSetup from nonebot.exception import ApiNotAvailable -from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket +from nonebot.typing import overrides from .config import Config as MiraiConfig from .event import Event, FriendMessage, GroupMessage, TempMessage @@ -152,15 +154,12 @@ class Bot(BaseBot): """ + _type = 'mirai' + @property @overrides(BaseBot) def type(self) -> str: - return "mirai" - - @property - def alive(self) -> bool: - assert isinstance(self.request, WebSocket) - return not self.request.closed + return self._type @property def api(self) -> SessionManager: @@ -190,21 +189,50 @@ class Bot(BaseBot): @classmethod @overrides(BaseBot) - def register(cls, driver: Driver, config: "Config"): + def register(cls, + driver: Driver, + config: "Config", + qq: Optional[int] = None): cls.mirai_config = MiraiConfig(**config.dict()) if (cls.mirai_config.auth_key and cls.mirai_config.host and cls.mirai_config.port) is None: - raise ApiNotAvailable('mirai') + raise ApiNotAvailable(cls._type) + super().register(driver, config) + if not isinstance(driver, ForwardDriver) and qq: + logger.warning( + f"Current driver {cls.config.driver} don't support forward connections" + ) + elif isinstance(driver, ForwardDriver) and qq: + + async def url_factory(): + assert cls.mirai_config.host and cls.mirai_config.port and cls.mirai_config.auth_key + session = await SessionManager.new( + qq, # type: ignore + host=cls.mirai_config.host, + port=cls.mirai_config.port, + auth_key=cls.mirai_config.auth_key) + return WebSocketSetup( + adapter=cls._type, + self_id=str(qq), + url=(f'ws://{cls.mirai_config.host}:{cls.mirai_config.port}' + f'/all?sessionKey={session.session_key}')) + + driver.setup_websocket(url_factory) + elif isinstance(driver, ReverseDriver): + logger.debug( + 'Param "qq" does not set for mirai adapter, use http post instead' + ) + @overrides(BaseBot) - async def handle_message(self, message: dict): + async def handle_message(self, message: bytes): Log.debug(f'received message {message}') try: await process_event( bot=self, event=Event.new({ - **message, + **json.loads(message), 'self_id': self.self_id, }), ) diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py deleted file mode 100644 index 29fc12bf..00000000 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py +++ /dev/null @@ -1,202 +0,0 @@ -import json -import asyncio -from dataclasses import dataclass -from ipaddress import IPv4Address -from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine - -import httpx -import websockets - -from nonebot.log import logger -from nonebot.config import Config -from nonebot.typing import overrides -from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket - -from .bot import SessionManager, Bot - -WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] -WebsocketHandler_T = TypeVar('WebsocketHandler_T', - bound=WebsocketHandlerFunction) - - -@dataclass -class WebSocket(BaseWebSocket): - websocket: websockets.WebSocketClientProtocol = None # type: ignore - - @classmethod - async def new(cls, *, host: IPv4Address, port: int, - session_key: str) -> "WebSocket": - listen_address = httpx.URL(f'ws://{host}:{port}/all', - params={'sessionKey': session_key}) - websocket = await websockets.connect(uri=str(listen_address)) - await (await websocket.ping()) - return cls("1.1", - listen_address.scheme, - listen_address.path, - listen_address.query, - websocket=websocket) - - @overrides(BaseWebSocket) - def __init__(self, - http_version: str, - scheme: str, - path: str, - query_string: bytes = b"", - headers: Dict[str, str] = None, - websocket: websockets.WebSocketClientProtocol = None): - self.event_handlers: Set[WebsocketHandlerFunction] = set() - self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore - super(WebSocket, self).__init__(http_version=http_version, - scheme=scheme, - path=path, - query_string=query_string, - headers=headers or {}) - - @property - @overrides(BaseWebSocket) - def closed(self) -> bool: - return self.websocket.closed - - @overrides(BaseWebSocket) - async def send(self, data: str): - return await self.websocket.send(data) - - @overrides(BaseWebSocket) - async def send_bytes(self, data: str): - return await self.websocket.send(data) - - @overrides(BaseWebSocket) - async def receive(self) -> str: - return await self.websocket.recv() # type: ignore - - @overrides(BaseWebSocket) - async def receive_bytes(self) -> bytes: - return await self.websocket.recv() # type: ignore - - async def _dispatcher(self): - while not self.closed: - try: - data = await self.receive() - except websockets.ConnectionClosedOK: - logger.debug(f'Websocket connection {self.websocket} closed') - break - except websockets.ConnectionClosedError: - logger.exception(f'Websocket connection {self.websocket} ' - 'connection closed abnormally:') - break - except json.JSONDecodeError as e: - logger.exception(f'Websocket client listened {self.websocket} ' - f'failed to decode data: {e}') - continue - asyncio.gather( - *map(lambda f: f(data), self.event_handlers), #type: ignore - return_exceptions=True) - - @overrides(BaseWebSocket) - async def accept(self): - asyncio.create_task(self._dispatcher()) - - @overrides(BaseWebSocket) - async def close(self): - await self.websocket.close() - - def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T: - self.event_handlers.add(callable) - return callable - - -class WebsocketBot(Bot): - """ - mirai-api-http 正向 Websocket 协议 Bot 适配。 - """ - - @property - @overrides(Bot) - def type(self) -> str: - return "mirai-ws" - - @property - def alive(self) -> bool: - assert isinstance(self.request, WebSocket) - return not self.request.closed - - @property - def api(self) -> SessionManager: - api = SessionManager.get(self_id=int(self.self_id), check_expire=False) - assert api is not None, 'SessionManager has not been initialized' - return api - - @classmethod - @overrides(Bot) - async def check_permission( - cls, driver: Driver, - request: HTTPConnection) -> Tuple[None, HTTPResponse]: - return None, HTTPResponse(501, b'Connection not implented') - - @classmethod - @overrides(Bot) - def register(cls, driver: Driver, config: "Config", qq: int): - """ - :说明: - - 注册该Adapter - - :参数: - - * ``driver: Driver``: 程序所使用的``Driver`` - * ``config: Config``: 程序配置对象 - * ``qq: int``: 要使用的Bot的QQ号 **注意: 在使用正向Websocket时必须指定该值!** - """ - super().register(driver, config) - cls.active = True - - async def _bot_connection(): - session: SessionManager = await SessionManager.new( - qq, - host=cls.mirai_config.host, # type: ignore - port=cls.mirai_config.port, # type: ignore - auth_key=cls.mirai_config.auth_key # type: ignore - ) - websocket = await WebSocket.new( - host=cls.mirai_config.host, # type: ignore - port=cls.mirai_config.port, # type: ignore - session_key=session.session_key) - bot = cls(self_id=str(qq), request=websocket) - websocket.handle(bot.handle_message) - await websocket.accept() - return bot - - async def _connection_ensure(): - self_id = str(qq) - if self_id not in driver._clients: - bot = await _bot_connection() - driver._bot_connect(bot) - else: - bot = driver._clients[self_id] - if not bot.alive: - driver._bot_disconnect(bot) - return - - @driver.on_startup - async def _startup(): - - async def _checker(): - while cls.active: - try: - await _connection_ensure() - except Exception as e: - logger.opt(colors=True).warning( - 'Failed to create mirai connection to ' - f'{qq}, reason: {e}. ' - 'Will retry after 3 seconds') - await asyncio.sleep(3) - - asyncio.create_task(_checker()) - - @driver.on_shutdown - async def _shutdown(): - cls.active = False - bot = driver._clients.pop(str(qq), None) - if bot is None: - return - await bot.websocket.close() #type:ignore