diff --git a/nonebot/adapters/mirai/__init__.py b/nonebot/adapters/mirai/__init__.py index 991f30fd..1107af38 100644 --- a/nonebot/adapters/mirai/__init__.py +++ b/nonebot/adapters/mirai/__init__.py @@ -1,3 +1,4 @@ from .bot import MiraiBot +from .bot_ws import MiraiWebsocketBot from .event import * from .message import MessageChain, MessageSegment diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 2414dca8..ebb9b768 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -1,19 +1,19 @@ from datetime import datetime, timedelta from io import BytesIO from ipaddress import IPv4Address -from typing import Any, Dict, List, NoReturn, Optional, Tuple +from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union import httpx from nonebot.adapters import Bot as BaseBot -from nonebot.adapters import Event as BaseEvent from nonebot.config import Config from nonebot.drivers import Driver, WebSocket -from nonebot.exception import RequestDenied from nonebot.exception import ActionFailed as BaseActionFailed +from nonebot.exception import RequestDenied from nonebot.log import logger from nonebot.message import handle_event from nonebot.typing import overrides +from nonebot.utils import escape_tag from .config import Config as MiraiConfig from .event import Event, FriendMessage, GroupMessage, TempMessage @@ -41,7 +41,8 @@ class SessionManager: @staticmethod def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]: code = data.get('code', 0) - logger.debug(f'Mirai API returned data: {data}') + logger.opt(colors=True).debug('Mirai API returned data: ' + f'{escape_tag(str(data))}') if code != 0: raise ActionFailed(code, message=data['msg']) return data @@ -85,10 +86,10 @@ class SessionManager: @classmethod async def new(cls, self_id: int, *, host: IPv4Address, port: int, auth_key: str): - if self_id in cls.sessions: - manager = cls.get(self_id) - if manager is not None: - return manager + session = cls.get(self_id) + if session is not None: + return session + client = httpx.AsyncClient(base_url=f'http://{host}:{port}') response = await client.post('/auth', json={'authKey': auth_key}) response.raise_for_status() @@ -102,10 +103,13 @@ class SessionManager: }) assert response.json()['code'] == 0 cls.sessions[self_id] = session_key, datetime.now(), client + return cls(session_key, client) @classmethod def get(cls, self_id: int): + if self_id not in cls.sessions: + return None key, time, client = cls.sessions[self_id] if datetime.now() - time > cls.session_expiry: return None @@ -114,6 +118,7 @@ class SessionManager: class MiraiBot(BaseBot): + @overrides(BaseBot) def __init__(self, connection_type: str, self_id: str, @@ -179,17 +184,20 @@ class MiraiBot(BaseBot): @overrides(BaseBot) async def send(self, event: Event, - message: MessageChain, - at_sender: bool = False, - **kwargs): + message: Union[MessageChain, MessageSegment, str], + at_sender: bool = False): + if isinstance(message, MessageSegment): + message = MessageChain(message) + elif isinstance(message, str): + message = MessageChain(MessageSegment.plain(message)) if isinstance(event, FriendMessage): return await self.send_friend_message(target=event.sender.id, message_chain=message) elif isinstance(event, GroupMessage): - return await self.send_group_message( - group=event.sender.group.id, - message_chain=message if not at_sender else - (MessageSegment.at(target=event.sender.id) + message)) + if at_sender: + message = MessageSegment.at(event.sender.id) + message + return await self.send_group_message(group=event.sender.group.id, + message_chain=message) elif isinstance(event, TempMessage): return await self.send_temp_message(qq=event.sender.id, group=event.sender.group.id, diff --git a/nonebot/adapters/mirai/bot_ws.py b/nonebot/adapters/mirai/bot_ws.py index d9803c47..d20d81dd 100644 --- a/nonebot/adapters/mirai/bot_ws.py +++ b/nonebot/adapters/mirai/bot_ws.py @@ -7,50 +7,21 @@ from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set, import httpx import websockets -from nonebot.adapters import Bot as BaseBot -from nonebot.adapters import Event as BaseEvent from nonebot.config import Config from nonebot.drivers import Driver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.exception import RequestDenied from nonebot.log import logger -from nonebot.message import handle_event from nonebot.typing import overrides +from .bot import MiraiBot, SessionManager from .config import Config as MiraiConfig -from .event import Event WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] WebsocketHandler_T = TypeVar('WebsocketHandler_T', bound=WebsocketHandlerFunction) -async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str, - qq: int) -> str: - - async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]: - response = await client.request(method, path, **kwargs) - response.raise_for_status() - return response.json() - - about = await request('GET', path='/about') - logger.opt(colors=True).debug('Mirai API HTTP backend version: ' - f'{about["data"]["version"]}') - - status = await request('POST', path='/auth', json={'authKey': auth_key}) - assert status['code'] == 0 - session_key = status['session'] - - verify = await request('POST', - path='/verify', - json={ - 'sessionKey': session_key, - 'qq': qq - }) - assert verify['code'] == 0, verify['msg'] - return session_key - - class WebSocket(BaseWebSocket): @classmethod @@ -59,6 +30,7 @@ class WebSocket(BaseWebSocket): listen_address = httpx.URL(f'ws://{host}:{port}/all', params={'sessionKey': session_key}) websocket = await websockets.connect(uri=str(listen_address)) + await (await websocket.ping()) return cls(websocket) @overrides(BaseWebSocket) @@ -116,25 +88,24 @@ class WebSocket(BaseWebSocket): return callable -class MiraiWebsocketBot(BaseBot): +class MiraiWebsocketBot(MiraiBot): + @overrides(MiraiBot) def __init__(self, connection_type: str, self_id: str, *, websocket: WebSocket): super().__init__(connection_type, self_id, websocket=websocket) - websocket.handle(self.handle_message) - self.driver._bot_connect(self) @property - @overrides(BaseBot) + @overrides(MiraiBot) def type(self) -> str: - return "mirai" + return "mirai-ws" @property def alive(self) -> bool: return not self.websocket.closed @classmethod - @overrides(BaseBot) + @overrides(MiraiBot) async def check_permission(cls, driver: "Driver", connection_type: str, headers: dict, body: Optional[dict]) -> NoReturn: raise RequestDenied( @@ -142,7 +113,7 @@ class MiraiWebsocketBot(BaseBot): reason=f'Connection {connection_type} not implented') @classmethod - @overrides(BaseBot) + @overrides(MiraiBot) def register(cls, driver: "Driver", config: "Config", qq: int): cls.mirai_config = MiraiConfig(**config.dict()) cls.active = True @@ -152,32 +123,33 @@ class MiraiWebsocketBot(BaseBot): super().register(driver, config) async def _bot_connection(): - async with httpx.AsyncClient( - base_url= - f'http://{cls.mirai_config.host}:{cls.mirai_config.port}' - ) as client: - session_key = await _ws_authorization( - client, - auth_key=cls.mirai_config.auth_key, # type: ignore - qq=qq) # type: ignore - + session: SessionManager = await SessionManager.new( + qq, + host=cls.mirai_config.host, # type: ignore + port=cls.mirai_config.port, # type: ignore + auth_key=cls.mirai_config.auth_key # type: ignore + ) websocket = await WebSocket.new( host=cls.mirai_config.host, # type: ignore port=cls.mirai_config.port, # type: ignore - session_key=session_key) + session_key=session.session_key) bot = cls(connection_type='forward_ws', self_id=str(qq), websocket=websocket) websocket.handle(bot.handle_message) - driver._clients[str(qq)] = bot await websocket.accept() + return bot async def _connection_ensure(): - if str(qq) not in driver._clients: - await _bot_connection() - elif not driver._clients[str(qq)].alive: - driver._clients.pop(str(qq), None) - await _bot_connection() + self_id = str(qq) + if self_id not in driver._clients: + bot = await _bot_connection() + driver._bot_connect(bot) + else: + bot = driver._clients[self_id] + if not bot.alive: + driver._bot_disconnect(bot) + return @driver.on_startup async def _startup(): @@ -202,19 +174,3 @@ class MiraiWebsocketBot(BaseBot): if bot is None: return await bot.websocket.close() #type:ignore - - @overrides(BaseBot) - async def handle_message(self, message: dict): - event = Event.new(message) - await handle_event(self, event) - - @overrides(BaseBot) - async def call_api(self, api: str, **data): - return super().call_api(api, **data) - - @overrides(BaseBot) - async def send(self, event: "BaseEvent", message: str, **kwargs): - return super().send(event, message, **kwargs) - - def __del__(self): - self.driver._bot_disconnect(self) diff --git a/nonebot/adapters/mirai/event/base.py b/nonebot/adapters/mirai/event/base.py index 6fbb30ff..3b6916f5 100644 --- a/nonebot/adapters/mirai/event/base.py +++ b/nonebot/adapters/mirai/event/base.py @@ -86,7 +86,7 @@ class Event(BaseEvent): @overrides(BaseEvent) def get_event_description(self) -> str: - return str(self.dict()) + return str(self.normalize_dict()) @overrides(BaseEvent) def get_message(self) -> BaseMessage: diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py index ef3949a6..a577a807 100644 --- a/nonebot/adapters/mirai/message.py +++ b/nonebot/adapters/mirai/message.py @@ -135,10 +135,11 @@ class MessageSegment(BaseMessageSegment): return cls(type=MessageType.POKE, name=name) -class MessageChain(BaseMessage): +class MessageChain(BaseMessage): #type:List[MessageSegment] @overrides(BaseMessage) - def __init__(self, message: Union[List[Dict[str, Any]], MessageSegment], + def __init__(self, message: Union[List[Dict[str, Any]], + Iterable[MessageSegment], MessageSegment], **kwargs): super().__init__(**kwargs) if isinstance(message, MessageSegment): @@ -152,15 +153,16 @@ class MessageChain(BaseMessage): @overrides(BaseMessage) def _construct( - self, message: Iterable[Union[Dict[str, Any], MessageSegment]] + self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment]] ) -> List[MessageSegment]: if isinstance(message, str): raise ValueError( "String operation is not supported in mirai adapter") return [ *map( - lambda segment: segment if isinstance(segment, MessageSegment) - else MessageSegment(**segment), message) + lambda x: x + if isinstance(x, MessageSegment) else MessageSegment(**x), + message) ] def export(self) -> List[Dict[str, Any]]: