mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-01 01:25:07 +08:00
🚸 add factory classmethods in MessageSegment at mirai adapter
This commit is contained in:
parent
95f27824ee
commit
73be9151b0
@ -10,6 +10,7 @@ from nonebot.adapters import Event as BaseEvent
|
|||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.drivers import Driver, WebSocket
|
from nonebot.drivers import Driver, WebSocket
|
||||||
from nonebot.exception import RequestDenied
|
from nonebot.exception import RequestDenied
|
||||||
|
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
@ -19,6 +20,17 @@ from .event import Event, FriendMessage, GroupMessage, TempMessage
|
|||||||
from .message import MessageChain, MessageSegment
|
from .message import MessageChain, MessageSegment
|
||||||
|
|
||||||
|
|
||||||
|
class ActionFailed(BaseActionFailed):
|
||||||
|
|
||||||
|
def __init__(self, code: int, message: str = ''):
|
||||||
|
super().__init__('mirai')
|
||||||
|
self.code = code
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(code={self.code}, message={self.message!r})"
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
|
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
|
||||||
session_expiry: timedelta = timedelta(minutes=15)
|
session_expiry: timedelta = timedelta(minutes=15)
|
||||||
@ -26,14 +38,22 @@ class SessionManager:
|
|||||||
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
||||||
self.session_key, self.client = session_key, client
|
self.session_key, self.client = session_key, client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
code = data.get('code', 0)
|
||||||
|
logger.debug(f'Mirai API returned data: {data}')
|
||||||
|
if code != 0:
|
||||||
|
raise ActionFailed(code, message=data['msg'])
|
||||||
|
return data
|
||||||
|
|
||||||
async def post(self,
|
async def post(self,
|
||||||
path: str,
|
path: str,
|
||||||
*,
|
*,
|
||||||
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
params = {**(params or {}), 'sessionKey': self.session_key}
|
params = {**(params or {}), 'sessionKey': self.session_key}
|
||||||
response = await self.client.post(path, json=params)
|
response = await self.client.post(path, json=params, timeout=3)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return self._raise_code(response.json())
|
||||||
|
|
||||||
async def request(self,
|
async def request(self,
|
||||||
path: str,
|
path: str,
|
||||||
@ -44,9 +64,10 @@ class SessionManager:
|
|||||||
params={
|
params={
|
||||||
**(params or {}), 'sessionKey':
|
**(params or {}), 'sessionKey':
|
||||||
self.session_key
|
self.session_key
|
||||||
})
|
},
|
||||||
|
timeout=3)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return self._raise_code(response.json())
|
||||||
|
|
||||||
async def upload(self, path: str, *, type: str,
|
async def upload(self, path: str, *, type: str,
|
||||||
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
|
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
|
||||||
@ -59,7 +80,7 @@ class SessionManager:
|
|||||||
files={file_type: file_io},
|
files={file_type: file_io},
|
||||||
timeout=6)
|
timeout=6)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return self._raise_code(response.json())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||||
@ -152,7 +173,7 @@ class MiraiBot(BaseBot):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def __getattr__(self, key: str) -> NoReturn:
|
def __getattr__(self, key: str) -> NoReturn:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
@ -165,8 +186,10 @@ class MiraiBot(BaseBot):
|
|||||||
return await self.send_friend_message(target=event.sender.id,
|
return await self.send_friend_message(target=event.sender.id,
|
||||||
message_chain=message)
|
message_chain=message)
|
||||||
elif isinstance(event, GroupMessage):
|
elif isinstance(event, GroupMessage):
|
||||||
return await self.send_group_message(target=event.sender.group.id,
|
return await self.send_group_message(
|
||||||
message_chain=message)
|
group=event.sender.group.id,
|
||||||
|
message_chain=message if not at_sender else
|
||||||
|
(MessageSegment.at(target=event.sender.id) + message))
|
||||||
elif isinstance(event, TempMessage):
|
elif isinstance(event, TempMessage):
|
||||||
return await self.send_temp_message(qq=event.sender.id,
|
return await self.send_temp_message(qq=event.sender.id,
|
||||||
group=event.sender.group.id,
|
group=event.sender.group.id,
|
||||||
@ -191,12 +214,15 @@ class MiraiBot(BaseBot):
|
|||||||
'messageChain': message_chain.export()
|
'messageChain': message_chain.export()
|
||||||
})
|
})
|
||||||
|
|
||||||
async def send_group_message(self, target: int,
|
async def send_group_message(self,
|
||||||
message_chain: MessageChain):
|
group: int,
|
||||||
|
message_chain: MessageChain,
|
||||||
|
quote: Optional[int] = None):
|
||||||
return await self.api.post('sendGroupMessage',
|
return await self.api.post('sendGroupMessage',
|
||||||
params={
|
params={
|
||||||
'target': target,
|
'group': group,
|
||||||
'messageChain': message_chain.export()
|
'messageChain': message_chain.export(),
|
||||||
|
'quote': quote
|
||||||
})
|
})
|
||||||
|
|
||||||
async def recall(self, target: int):
|
async def recall(self, target: int):
|
||||||
|
@ -18,7 +18,7 @@ class MessageEvent(Event):
|
|||||||
|
|
||||||
@overrides(Event)
|
@overrides(Event)
|
||||||
def get_plaintext(self) -> str:
|
def get_plaintext(self) -> str:
|
||||||
return self.message_chain.__str__()
|
return self.message_chain.extract_plain_text()
|
||||||
|
|
||||||
@overrides(Event)
|
@overrides(Event)
|
||||||
def get_user_id(self) -> str:
|
def get_user_id(self) -> str:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Iterable, List, Union
|
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import validate_arguments
|
from pydantic import validate_arguments
|
||||||
|
|
||||||
@ -31,7 +31,8 @@ class MessageSegment(BaseMessageSegment):
|
|||||||
@overrides(BaseMessageSegment)
|
@overrides(BaseMessageSegment)
|
||||||
@validate_arguments
|
@validate_arguments
|
||||||
def __init__(self, type: MessageType, **data):
|
def __init__(self, type: MessageType, **data):
|
||||||
super().__init__(type=type, data=data)
|
super().__init__(type=type,
|
||||||
|
data={k: v for k, v in data.items() if v is not None})
|
||||||
|
|
||||||
@overrides(BaseMessageSegment)
|
@overrides(BaseMessageSegment)
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -60,6 +61,79 @@ class MessageSegment(BaseMessageSegment):
|
|||||||
def as_dict(self) -> Dict[str, Any]:
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
return {'type': self.type.value, **self.data}
|
return {'type': self.type.value, **self.data}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def source(cls, id: int, time: int):
|
||||||
|
return cls(type=MessageType.SOURCE, id=id, time=time)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
|
||||||
|
origin: "MessageChain"):
|
||||||
|
return cls(type=MessageType.QUOTE,
|
||||||
|
id=id,
|
||||||
|
groupId=group_id,
|
||||||
|
senderId=sender_id,
|
||||||
|
targetId=target_id,
|
||||||
|
origin=origin.export())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def at(cls, target: int):
|
||||||
|
return cls(type=MessageType.AT, target=target)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def at_all(cls):
|
||||||
|
return cls(type=MessageType.AT_ALL)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def face(cls, face_id: Optional[int] = None, name: Optional[str] = None):
|
||||||
|
return cls(type=MessageType.FACE, faceId=face_id, name=name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def plain(cls, text: str):
|
||||||
|
return cls(type=MessageType.PLAIN, text=text)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def image(cls,
|
||||||
|
image_id: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
path: Optional[str] = None):
|
||||||
|
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def flash_image(cls,
|
||||||
|
image_id: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
path: Optional[str] = None):
|
||||||
|
return cls(type=MessageType.FLASH_IMAGE,
|
||||||
|
imageId=image_id,
|
||||||
|
url=url,
|
||||||
|
path=path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def voice(cls,
|
||||||
|
voice_id: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
path: Optional[str] = None):
|
||||||
|
return cls(type=MessageType.FLASH_IMAGE,
|
||||||
|
imageId=voice_id,
|
||||||
|
url=url,
|
||||||
|
path=path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def xml(cls, xml: str):
|
||||||
|
return cls(type=MessageType.XML, xml=xml)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def json(cls, json: str):
|
||||||
|
return cls(type=MessageType.JSON, json=json)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def app(cls, content: str):
|
||||||
|
return cls(type=MessageType.APP, content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def poke(cls, name: str):
|
||||||
|
return cls(type=MessageType.POKE, name=name)
|
||||||
|
|
||||||
|
|
||||||
class MessageChain(BaseMessage):
|
class MessageChain(BaseMessage):
|
||||||
|
|
||||||
@ -90,11 +164,9 @@ class MessageChain(BaseMessage):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def export(self) -> List[Dict[str, Any]]:
|
def export(self) -> List[Dict[str, Any]]:
|
||||||
chain: List[Dict[str, Any]] = []
|
return [
|
||||||
for segment in self.copy():
|
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
|
||||||
segment: MessageSegment
|
]
|
||||||
chain.append({'type': segment.type.value, **segment.data})
|
|
||||||
return chain
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'<{self.__class__.__name__} {[*self.copy()]}>'
|
return f'<{self.__class__.__name__} {[*self.copy()]}>'
|
||||||
|
Loading…
Reference in New Issue
Block a user