mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🚸 add factory classmethods in MessageSegment at mirai adapter
This commit is contained in:
parent
95f27824ee
commit
73be9151b0
@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides
|
||||
@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage
|
||||
from .message import MessageChain, MessageSegment
|
||||
|
||||
|
||||
class ActionFailed(BaseActionFailed):
|
||||
|
||||
def __init__(self, code: int, message: str = ''):
|
||||
super().__init__('mirai')
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})"
|
||||
|
||||
|
||||
class SessionManager:
|
||||
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
|
||||
session_expiry: timedelta = timedelta(minutes=15)
|
||||
@ -26,14 +38,22 @@ class SessionManager:
|
||||
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
||||
self.session_key, self.client = session_key, client
|
||||
|
||||
@staticmethod
|
||||
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
code = data.get('code', 0)
|
||||
logger.debug(f'Mirai API returned data: {data}')
|
||||
if code != 0:
|
||||
raise ActionFailed(code, message=data['msg'])
|
||||
return data
|
||||
|
||||
async def post(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
params = {**(params or {}), 'sessionKey': self.session_key}
|
||||
response = await self.client.post(path, json=params)
|
||||
response = await self.client.post(path, json=params, timeout=3)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return self._raise_code(response.json())
|
||||
|
||||
async def request(self,
|
||||
path: str,
|
||||
@ -44,9 +64,10 @@ class SessionManager:
|
||||
params={
|
||||
**(params or {}), 'sessionKey':
|
||||
self.session_key
|
||||
})
|
||||
},
|
||||
timeout=3)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return self._raise_code(response.json())
|
||||
|
||||
async def upload(self, path: str, *, type: str,
|
||||
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
|
||||
@ -59,7 +80,7 @@ class SessionManager:
|
||||
files={file_type: file_io},
|
||||
timeout=6)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return self._raise_code(response.json())
|
||||
|
||||
@classmethod
|
||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||
@ -152,7 +173,7 @@ class MiraiBot(BaseBot):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def __getattr__(self, key: str) -> NoReturn:
|
||||
def __getattr__(self, key: str) -> NoReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -165,8 +186,10 @@ class MiraiBot(BaseBot):
|
||||
return await self.send_friend_message(target=event.sender.id,
|
||||
message_chain=message)
|
||||
elif isinstance(event, GroupMessage):
|
||||
return await self.send_group_message(target=event.sender.group.id,
|
||||
message_chain=message)
|
||||
return await self.send_group_message(
|
||||
group=event.sender.group.id,
|
||||
message_chain=message if not at_sender else
|
||||
(MessageSegment.at(target=event.sender.id) + message))
|
||||
elif isinstance(event, TempMessage):
|
||||
return await self.send_temp_message(qq=event.sender.id,
|
||||
group=event.sender.group.id,
|
||||
@ -191,12 +214,15 @@ class MiraiBot(BaseBot):
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
|
||||
async def send_group_message(self, target: int,
|
||||
message_chain: MessageChain):
|
||||
async def send_group_message(self,
|
||||
group: int,
|
||||
message_chain: MessageChain,
|
||||
quote: Optional[int] = None):
|
||||
return await self.api.post('sendGroupMessage',
|
||||
params={
|
||||
'target': target,
|
||||
'messageChain': message_chain.export()
|
||||
'group': group,
|
||||
'messageChain': message_chain.export(),
|
||||
'quote': quote
|
||||
})
|
||||
|
||||
async def recall(self, target: int):
|
||||
|
@ -18,7 +18,7 @@ class MessageEvent(Event):
|
||||
|
||||
@overrides(Event)
|
||||
def get_plaintext(self) -> str:
|
||||
return self.message_chain.__str__()
|
||||
return self.message_chain.extract_plain_text()
|
||||
|
||||
@overrides(Event)
|
||||
def get_user_id(self) -> str:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from pydantic import validate_arguments
|
||||
|
||||
@ -31,7 +31,8 @@ class MessageSegment(BaseMessageSegment):
|
||||
@overrides(BaseMessageSegment)
|
||||
@validate_arguments
|
||||
def __init__(self, type: MessageType, **data):
|
||||
super().__init__(type=type, data=data)
|
||||
super().__init__(type=type,
|
||||
data={k: v for k, v in data.items() if v is not None})
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __str__(self) -> str:
|
||||
@ -60,6 +61,79 @@ class MessageSegment(BaseMessageSegment):
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
return {'type': self.type.value, **self.data}
|
||||
|
||||
@classmethod
|
||||
def source(cls, id: int, time: int):
|
||||
return cls(type=MessageType.SOURCE, id=id, time=time)
|
||||
|
||||
@classmethod
|
||||
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
|
||||
origin: "MessageChain"):
|
||||
return cls(type=MessageType.QUOTE,
|
||||
id=id,
|
||||
groupId=group_id,
|
||||
senderId=sender_id,
|
||||
targetId=target_id,
|
||||
origin=origin.export())
|
||||
|
||||
@classmethod
|
||||
def at(cls, target: int):
|
||||
return cls(type=MessageType.AT, target=target)
|
||||
|
||||
@classmethod
|
||||
def at_all(cls):
|
||||
return cls(type=MessageType.AT_ALL)
|
||||
|
||||
@classmethod
|
||||
def face(cls, face_id: Optional[int] = None, name: Optional[str] = None):
|
||||
return cls(type=MessageType.FACE, faceId=face_id, name=name)
|
||||
|
||||
@classmethod
|
||||
def plain(cls, text: str):
|
||||
return cls(type=MessageType.PLAIN, text=text)
|
||||
|
||||
@classmethod
|
||||
def image(cls,
|
||||
image_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None):
|
||||
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
|
||||
|
||||
@classmethod
|
||||
def flash_image(cls,
|
||||
image_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None):
|
||||
return cls(type=MessageType.FLASH_IMAGE,
|
||||
imageId=image_id,
|
||||
url=url,
|
||||
path=path)
|
||||
|
||||
@classmethod
|
||||
def voice(cls,
|
||||
voice_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
path: Optional[str] = None):
|
||||
return cls(type=MessageType.FLASH_IMAGE,
|
||||
imageId=voice_id,
|
||||
url=url,
|
||||
path=path)
|
||||
|
||||
@classmethod
|
||||
def xml(cls, xml: str):
|
||||
return cls(type=MessageType.XML, xml=xml)
|
||||
|
||||
@classmethod
|
||||
def json(cls, json: str):
|
||||
return cls(type=MessageType.JSON, json=json)
|
||||
|
||||
@classmethod
|
||||
def app(cls, content: str):
|
||||
return cls(type=MessageType.APP, content=content)
|
||||
|
||||
@classmethod
|
||||
def poke(cls, name: str):
|
||||
return cls(type=MessageType.POKE, name=name)
|
||||
|
||||
|
||||
class MessageChain(BaseMessage):
|
||||
|
||||
@ -90,11 +164,9 @@ class MessageChain(BaseMessage):
|
||||
]
|
||||
|
||||
def export(self) -> List[Dict[str, Any]]:
|
||||
chain: List[Dict[str, Any]] = []
|
||||
for segment in self.copy():
|
||||
segment: MessageSegment
|
||||
chain.append({'type': segment.type.value, **segment.data})
|
||||
return chain
|
||||
return [
|
||||
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} {[*self.copy()]}>'
|
||||
|
Loading…
Reference in New Issue
Block a user