🚸 add factory classmethods in MessageSegment at mirai adapter

This commit is contained in:
Mix 2021-01-30 21:51:51 +08:00
parent 95f27824ee
commit 73be9151b0
3 changed files with 118 additions and 20 deletions

View File

@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver, WebSocket
from nonebot.exception import RequestDenied from nonebot.exception import RequestDenied
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.log import logger from nonebot.log import logger
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment from .message import MessageChain, MessageSegment
class ActionFailed(BaseActionFailed):
def __init__(self, code: int, message: str = ''):
super().__init__('mirai')
self.code = code
self.message = message
def __repr__(self):
return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})"
class SessionManager: class SessionManager:
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {} sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
session_expiry: timedelta = timedelta(minutes=15) session_expiry: timedelta = timedelta(minutes=15)
@ -26,14 +38,22 @@ class SessionManager:
def __init__(self, session_key: str, client: httpx.AsyncClient): def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client self.session_key, self.client = session_key, client
@staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0)
logger.debug(f'Mirai API returned data: {data}')
if code != 0:
raise ActionFailed(code, message=data['msg'])
return data
async def post(self, async def post(self,
path: str, path: str,
*, *,
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
params = {**(params or {}), 'sessionKey': self.session_key} params = {**(params or {}), 'sessionKey': self.session_key}
response = await self.client.post(path, json=params) response = await self.client.post(path, json=params, timeout=3)
response.raise_for_status() response.raise_for_status()
return response.json() return self._raise_code(response.json())
async def request(self, async def request(self,
path: str, path: str,
@ -44,9 +64,10 @@ class SessionManager:
params={ params={
**(params or {}), 'sessionKey': **(params or {}), 'sessionKey':
self.session_key self.session_key
}) },
timeout=3)
response.raise_for_status() response.raise_for_status()
return response.json() return self._raise_code(response.json())
async def upload(self, path: str, *, type: str, async def upload(self, path: str, *, type: str,
file: Tuple[str, BytesIO]) -> Dict[str, Any]: file: Tuple[str, BytesIO]) -> Dict[str, Any]:
@ -59,7 +80,7 @@ class SessionManager:
files={file_type: file_io}, files={file_type: file_io},
timeout=6) timeout=6)
response.raise_for_status() response.raise_for_status()
return response.json() return self._raise_code(response.json())
@classmethod @classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int, async def new(cls, self_id: int, *, host: IPv4Address, port: int,
@ -152,7 +173,7 @@ class MiraiBot(BaseBot):
raise NotImplementedError raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)
async def __getattr__(self, key: str) -> NoReturn: def __getattr__(self, key: str) -> NoReturn:
raise NotImplementedError raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)
@ -165,8 +186,10 @@ class MiraiBot(BaseBot):
return await self.send_friend_message(target=event.sender.id, return await self.send_friend_message(target=event.sender.id,
message_chain=message) message_chain=message)
elif isinstance(event, GroupMessage): elif isinstance(event, GroupMessage):
return await self.send_group_message(target=event.sender.group.id, return await self.send_group_message(
message_chain=message) group=event.sender.group.id,
message_chain=message if not at_sender else
(MessageSegment.at(target=event.sender.id) + message))
elif isinstance(event, TempMessage): elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id, return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id, group=event.sender.group.id,
@ -191,12 +214,15 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export() 'messageChain': message_chain.export()
}) })
async def send_group_message(self, target: int, async def send_group_message(self,
message_chain: MessageChain): group: int,
message_chain: MessageChain,
quote: Optional[int] = None):
return await self.api.post('sendGroupMessage', return await self.api.post('sendGroupMessage',
params={ params={
'target': target, 'group': group,
'messageChain': message_chain.export() 'messageChain': message_chain.export(),
'quote': quote
}) })
async def recall(self, target: int): async def recall(self, target: int):

View File

@ -18,7 +18,7 @@ class MessageEvent(Event):
@overrides(Event) @overrides(Event)
def get_plaintext(self) -> str: def get_plaintext(self) -> str:
return self.message_chain.__str__() return self.message_chain.extract_plain_text()
@overrides(Event) @overrides(Event)
def get_user_id(self) -> str: def get_user_id(self) -> str:

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable, List, Union from typing import Any, Dict, Iterable, List, Optional, Union
from pydantic import validate_arguments from pydantic import validate_arguments
@ -31,7 +31,8 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
@validate_arguments @validate_arguments
def __init__(self, type: MessageType, **data): def __init__(self, type: MessageType, **data):
super().__init__(type=type, data=data) super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None})
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __str__(self) -> str: def __str__(self) -> str:
@ -60,6 +61,79 @@ class MessageSegment(BaseMessageSegment):
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
return {'type': self.type.value, **self.data} return {'type': self.type.value, **self.data}
@classmethod
def source(cls, id: int, time: int):
return cls(type=MessageType.SOURCE, id=id, time=time)
@classmethod
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
origin: "MessageChain"):
return cls(type=MessageType.QUOTE,
id=id,
groupId=group_id,
senderId=sender_id,
targetId=target_id,
origin=origin.export())
@classmethod
def at(cls, target: int):
return cls(type=MessageType.AT, target=target)
@classmethod
def at_all(cls):
return cls(type=MessageType.AT_ALL)
@classmethod
def face(cls, face_id: Optional[int] = None, name: Optional[str] = None):
return cls(type=MessageType.FACE, faceId=face_id, name=name)
@classmethod
def plain(cls, text: str):
return cls(type=MessageType.PLAIN, text=text)
@classmethod
def image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
@classmethod
def flash_image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
return cls(type=MessageType.FLASH_IMAGE,
imageId=image_id,
url=url,
path=path)
@classmethod
def voice(cls,
voice_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
return cls(type=MessageType.FLASH_IMAGE,
imageId=voice_id,
url=url,
path=path)
@classmethod
def xml(cls, xml: str):
return cls(type=MessageType.XML, xml=xml)
@classmethod
def json(cls, json: str):
return cls(type=MessageType.JSON, json=json)
@classmethod
def app(cls, content: str):
return cls(type=MessageType.APP, content=content)
@classmethod
def poke(cls, name: str):
return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage): class MessageChain(BaseMessage):
@ -90,11 +164,9 @@ class MessageChain(BaseMessage):
] ]
def export(self) -> List[Dict[str, Any]]: def export(self) -> List[Dict[str, Any]]:
chain: List[Dict[str, Any]] = [] return [
for segment in self.copy(): *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
segment: MessageSegment ]
chain.append({'type': segment.type.value, **segment.data})
return chain
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>' return f'<{self.__class__.__name__} {[*self.copy()]}>'