🚸 add factory classmethods in MessageSegment at mirai adapter

This commit is contained in:
Mix 2021-01-30 21:51:51 +08:00
parent 95f27824ee
commit 73be9151b0
3 changed files with 118 additions and 20 deletions

View File

@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import RequestDenied
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment
class ActionFailed(BaseActionFailed):
def __init__(self, code: int, message: str = ''):
super().__init__('mirai')
self.code = code
self.message = message
def __repr__(self):
return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})"
class SessionManager:
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
session_expiry: timedelta = timedelta(minutes=15)
@ -26,14 +38,22 @@ class SessionManager:
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
@staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0)
logger.debug(f'Mirai API returned data: {data}')
if code != 0:
raise ActionFailed(code, message=data['msg'])
return data
async def post(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
params = {**(params or {}), 'sessionKey': self.session_key}
response = await self.client.post(path, json=params)
response = await self.client.post(path, json=params, timeout=3)
response.raise_for_status()
return response.json()
return self._raise_code(response.json())
async def request(self,
path: str,
@ -44,9 +64,10 @@ class SessionManager:
params={
**(params or {}), 'sessionKey':
self.session_key
})
},
timeout=3)
response.raise_for_status()
return response.json()
return self._raise_code(response.json())
async def upload(self, path: str, *, type: str,
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
@ -59,7 +80,7 @@ class SessionManager:
files={file_type: file_io},
timeout=6)
response.raise_for_status()
return response.json()
return self._raise_code(response.json())
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
@ -152,7 +173,7 @@ class MiraiBot(BaseBot):
raise NotImplementedError
@overrides(BaseBot)
async def __getattr__(self, key: str) -> NoReturn:
def __getattr__(self, key: str) -> NoReturn:
raise NotImplementedError
@overrides(BaseBot)
@ -165,8 +186,10 @@ class MiraiBot(BaseBot):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
elif isinstance(event, GroupMessage):
return await self.send_group_message(target=event.sender.group.id,
message_chain=message)
return await self.send_group_message(
group=event.sender.group.id,
message_chain=message if not at_sender else
(MessageSegment.at(target=event.sender.id) + message))
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,
@ -191,12 +214,15 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export()
})
async def send_group_message(self, target: int,
message_chain: MessageChain):
async def send_group_message(self,
group: int,
message_chain: MessageChain,
quote: Optional[int] = None):
return await self.api.post('sendGroupMessage',
params={
'target': target,
'messageChain': message_chain.export()
'group': group,
'messageChain': message_chain.export(),
'quote': quote
})
async def recall(self, target: int):

View File

@ -18,7 +18,7 @@ class MessageEvent(Event):
@overrides(Event)
def get_plaintext(self) -> str:
return self.message_chain.__str__()
return self.message_chain.extract_plain_text()
@overrides(Event)
def get_user_id(self) -> str:

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, Iterable, List, Union
from typing import Any, Dict, Iterable, List, Optional, Union
from pydantic import validate_arguments
@ -31,7 +31,8 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
@validate_arguments
def __init__(self, type: MessageType, **data):
super().__init__(type=type, data=data)
super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None})
@overrides(BaseMessageSegment)
def __str__(self) -> str:
@ -60,6 +61,79 @@ class MessageSegment(BaseMessageSegment):
def as_dict(self) -> Dict[str, Any]:
return {'type': self.type.value, **self.data}
@classmethod
def source(cls, id: int, time: int):
return cls(type=MessageType.SOURCE, id=id, time=time)
@classmethod
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
origin: "MessageChain"):
return cls(type=MessageType.QUOTE,
id=id,
groupId=group_id,
senderId=sender_id,
targetId=target_id,
origin=origin.export())
@classmethod
def at(cls, target: int):
return cls(type=MessageType.AT, target=target)
@classmethod
def at_all(cls):
return cls(type=MessageType.AT_ALL)
@classmethod
def face(cls, face_id: Optional[int] = None, name: Optional[str] = None):
return cls(type=MessageType.FACE, faceId=face_id, name=name)
@classmethod
def plain(cls, text: str):
return cls(type=MessageType.PLAIN, text=text)
@classmethod
def image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
@classmethod
def flash_image(cls,
image_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
return cls(type=MessageType.FLASH_IMAGE,
imageId=image_id,
url=url,
path=path)
@classmethod
def voice(cls,
voice_id: Optional[str] = None,
url: Optional[str] = None,
path: Optional[str] = None):
return cls(type=MessageType.FLASH_IMAGE,
imageId=voice_id,
url=url,
path=path)
@classmethod
def xml(cls, xml: str):
return cls(type=MessageType.XML, xml=xml)
@classmethod
def json(cls, json: str):
return cls(type=MessageType.JSON, json=json)
@classmethod
def app(cls, content: str):
return cls(type=MessageType.APP, content=content)
@classmethod
def poke(cls, name: str):
return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage):
@ -90,11 +164,9 @@ class MessageChain(BaseMessage):
]
def export(self) -> List[Dict[str, Any]]:
chain: List[Dict[str, Any]] = []
for segment in self.copy():
segment: MessageSegment
chain.append({'type': segment.type.value, **segment.data})
return chain
return [
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
]
def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>'