From 73be9151b0c265ca0afc15333f4baa5bc1e2932f Mon Sep 17 00:00:00 2001 From: Mix Date: Sat, 30 Jan 2021 21:51:51 +0800 Subject: [PATCH] :children_crossing: add factory classmethods in MessageSegment at mirai adapter --- nonebot/adapters/mirai/bot.py | 50 ++++++++++---- nonebot/adapters/mirai/event/message.py | 2 +- nonebot/adapters/mirai/message.py | 86 +++++++++++++++++++++++-- 3 files changed, 118 insertions(+), 20 deletions(-) diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index e89eb245..2414dca8 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent from nonebot.config import Config from nonebot.drivers import Driver, WebSocket from nonebot.exception import RequestDenied +from nonebot.exception import ActionFailed as BaseActionFailed from nonebot.log import logger from nonebot.message import handle_event from nonebot.typing import overrides @@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage from .message import MessageChain, MessageSegment +class ActionFailed(BaseActionFailed): + + def __init__(self, code: int, message: str = ''): + super().__init__('mirai') + self.code = code + self.message = message + + def __repr__(self): + return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})" + + class SessionManager: sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {} session_expiry: timedelta = timedelta(minutes=15) @@ -26,14 +38,22 @@ class SessionManager: def __init__(self, session_key: str, client: httpx.AsyncClient): self.session_key, self.client = session_key, client + @staticmethod + def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]: + code = data.get('code', 0) + logger.debug(f'Mirai API returned data: {data}') + if code != 0: + raise ActionFailed(code, message=data['msg']) + return data + async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: params = {**(params or {}), 'sessionKey': self.session_key} - response = await self.client.post(path, json=params) + response = await self.client.post(path, json=params, timeout=3) response.raise_for_status() - return response.json() + return self._raise_code(response.json()) async def request(self, path: str, @@ -44,9 +64,10 @@ class SessionManager: params={ **(params or {}), 'sessionKey': self.session_key - }) + }, + timeout=3) response.raise_for_status() - return response.json() + return self._raise_code(response.json()) async def upload(self, path: str, *, type: str, file: Tuple[str, BytesIO]) -> Dict[str, Any]: @@ -59,7 +80,7 @@ class SessionManager: files={file_type: file_io}, timeout=6) response.raise_for_status() - return response.json() + return self._raise_code(response.json()) @classmethod async def new(cls, self_id: int, *, host: IPv4Address, port: int, @@ -152,7 +173,7 @@ class MiraiBot(BaseBot): raise NotImplementedError @overrides(BaseBot) - async def __getattr__(self, key: str) -> NoReturn: + def __getattr__(self, key: str) -> NoReturn: raise NotImplementedError @overrides(BaseBot) @@ -165,8 +186,10 @@ class MiraiBot(BaseBot): return await self.send_friend_message(target=event.sender.id, message_chain=message) elif isinstance(event, GroupMessage): - return await self.send_group_message(target=event.sender.group.id, - message_chain=message) + return await self.send_group_message( + group=event.sender.group.id, + message_chain=message if not at_sender else + (MessageSegment.at(target=event.sender.id) + message)) elif isinstance(event, TempMessage): return await self.send_temp_message(qq=event.sender.id, group=event.sender.group.id, @@ -191,12 +214,15 @@ class MiraiBot(BaseBot): 'messageChain': message_chain.export() }) - async def send_group_message(self, target: int, - message_chain: MessageChain): + async def send_group_message(self, + group: int, + message_chain: MessageChain, + quote: Optional[int] = None): return await self.api.post('sendGroupMessage', params={ - 'target': target, - 'messageChain': message_chain.export() + 'group': group, + 'messageChain': message_chain.export(), + 'quote': quote }) async def recall(self, target: int): diff --git a/nonebot/adapters/mirai/event/message.py b/nonebot/adapters/mirai/event/message.py index 1cfca586..10574d5e 100644 --- a/nonebot/adapters/mirai/event/message.py +++ b/nonebot/adapters/mirai/event/message.py @@ -18,7 +18,7 @@ class MessageEvent(Event): @overrides(Event) def get_plaintext(self) -> str: - return self.message_chain.__str__() + return self.message_chain.extract_plain_text() @overrides(Event) def get_user_id(self) -> str: diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py index 7562b6be..ef3949a6 100644 --- a/nonebot/adapters/mirai/message.py +++ b/nonebot/adapters/mirai/message.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, Iterable, List, Union +from typing import Any, Dict, Iterable, List, Optional, Union from pydantic import validate_arguments @@ -31,7 +31,8 @@ class MessageSegment(BaseMessageSegment): @overrides(BaseMessageSegment) @validate_arguments def __init__(self, type: MessageType, **data): - super().__init__(type=type, data=data) + super().__init__(type=type, + data={k: v for k, v in data.items() if v is not None}) @overrides(BaseMessageSegment) def __str__(self) -> str: @@ -60,6 +61,79 @@ class MessageSegment(BaseMessageSegment): def as_dict(self) -> Dict[str, Any]: return {'type': self.type.value, **self.data} + @classmethod + def source(cls, id: int, time: int): + return cls(type=MessageType.SOURCE, id=id, time=time) + + @classmethod + def quote(cls, id: int, group_id: int, sender_id: int, target_id: int, + origin: "MessageChain"): + return cls(type=MessageType.QUOTE, + id=id, + groupId=group_id, + senderId=sender_id, + targetId=target_id, + origin=origin.export()) + + @classmethod + def at(cls, target: int): + return cls(type=MessageType.AT, target=target) + + @classmethod + def at_all(cls): + return cls(type=MessageType.AT_ALL) + + @classmethod + def face(cls, face_id: Optional[int] = None, name: Optional[str] = None): + return cls(type=MessageType.FACE, faceId=face_id, name=name) + + @classmethod + def plain(cls, text: str): + return cls(type=MessageType.PLAIN, text=text) + + @classmethod + def image(cls, + image_id: Optional[str] = None, + url: Optional[str] = None, + path: Optional[str] = None): + return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path) + + @classmethod + def flash_image(cls, + image_id: Optional[str] = None, + url: Optional[str] = None, + path: Optional[str] = None): + return cls(type=MessageType.FLASH_IMAGE, + imageId=image_id, + url=url, + path=path) + + @classmethod + def voice(cls, + voice_id: Optional[str] = None, + url: Optional[str] = None, + path: Optional[str] = None): + return cls(type=MessageType.FLASH_IMAGE, + imageId=voice_id, + url=url, + path=path) + + @classmethod + def xml(cls, xml: str): + return cls(type=MessageType.XML, xml=xml) + + @classmethod + def json(cls, json: str): + return cls(type=MessageType.JSON, json=json) + + @classmethod + def app(cls, content: str): + return cls(type=MessageType.APP, content=content) + + @classmethod + def poke(cls, name: str): + return cls(type=MessageType.POKE, name=name) + class MessageChain(BaseMessage): @@ -90,11 +164,9 @@ class MessageChain(BaseMessage): ] def export(self) -> List[Dict[str, Any]]: - chain: List[Dict[str, Any]] = [] - for segment in self.copy(): - segment: MessageSegment - chain.append({'type': segment.type.value, **segment.data}) - return chain + return [ + *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore + ] def __repr__(self) -> str: return f'<{self.__class__.__name__} {[*self.copy()]}>'