⚗️ trying to change mirai adapter message processing behavior

This commit is contained in:
Mix 2021-02-07 11:52:50 +08:00
parent b59ff03abf
commit 49010bf5b7
5 changed files with 77 additions and 57 deletions

View File

@ -1,8 +1,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps
from io import BytesIO from io import BytesIO
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union) from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx import httpx
@ -10,15 +9,12 @@ from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ApiNotAvailable, RequestDenied from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment from .message import MessageChain, MessageSegment
from .utils import catch_network_error, argument_validation, check_tome, Log from .utils import Log, argument_validation, catch_network_error, process_event
class SessionManager: class SessionManager:
@ -212,20 +208,15 @@ class Bot(BaseBot):
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
Log.debug(f'received message {message}') Log.debug(f'received message {message}')
try: try:
await handle_event( await process_event(
bot=self, bot=self,
event=await check_tome( event=Event.new({
bot=self, **message,
event=Event.new({ 'self_id': self.self_id,
**message, }),
'self_id': self.self_id,
}),
),
) )
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).exception( Log.error(f'Failed to handle message: {message}', e)
'Failed to handle message '
f'<r>{escape_tag(str(message))}</r>: ')
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn: async def call_api(self, api: str, **data) -> NoReturn:

View File

@ -13,7 +13,7 @@ from .request import *
__all__ = [ __all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission', 'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
'MessageChain', 'MessageEvent', 'GroupMessage', 'FriendMessage', 'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent', 'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent', 'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent', 'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',

View File

@ -1,6 +1,7 @@
from typing import Any from datetime import datetime
from typing import Any, Optional
from pydantic import Field from pydantic import BaseModel, Field
from nonebot.typing import overrides from nonebot.typing import overrides
@ -8,9 +9,15 @@ from ..message import MessageChain
from .base import Event, GroupChatInfo, PrivateChatInfo from .base import Event, GroupChatInfo, PrivateChatInfo
class MessageSource(BaseModel):
id: int
time: datetime
class MessageEvent(Event): class MessageEvent(Event):
"""消息事件基类""" """消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain') message_chain: MessageChain = Field(alias='messageChain')
source: Optional[MessageSource] = None
sender: Any sender: Any
@overrides(Event) @overrides(Event)

View File

@ -306,5 +306,13 @@ class MessageChain(BaseMessage):
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
] ]
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
if not len(self):
return None
first: MessageSegment = self[0]
if (not type) or (first.type in type):
return self.pop(0)
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>' return f'<{self.__class__.__name__} {[*self.copy()]}>'

View File

@ -7,10 +7,11 @@ from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception import nonebot.exception as exception
from nonebot.log import logger from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper from nonebot.utils import escape_tag, logger_wrapper
from .event import Event, GroupMessage from .event import Event, GroupMessage, MessageEvent, MessageSource
from .message import MessageSegment, MessageType from .message import MessageType
if TYPE_CHECKING: if TYPE_CHECKING:
from .bot import Bot from .bot import Bot
@ -124,39 +125,52 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
return wrapper # type: ignore return wrapper # type: ignore
async def check_tome(bot: "Bot", event: "Event") -> "Event": def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
if not isinstance(event, GroupMessage): source = event.message_chain.extract_first(MessageType.SOURCE)
return event if source is not None:
event.source = MessageSource.parse_obj(source.data)
def _is_at(event: GroupMessage) -> bool:
for segment in event.message_chain:
segment: MessageSegment
if segment.type != MessageType.AT:
continue
if segment.data['target'] == event.self_id:
return True
return False
def _is_nick(event: GroupMessage) -> bool:
text = event.get_plaintext()
if not text:
return False
nick_regex = '|'.join(
{i.strip() for i in bot.config.nickname if i.strip()})
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is None:
return False
Log.info(f'User is calling me {matched.group(1)}')
return True
def _is_reply(event: GroupMessage) -> bool:
for segment in event.message_chain:
segment: MessageSegment
if segment.type != MessageType.QUOTE:
continue
if segment.data['senderId'] == event.self_id:
return True
return False
event.to_me = any([_is_at(event), _is_reply(event), _is_nick(event)])
return event return event
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
at = event.message_chain.extract_first(MessageType.AT)
if at is not None:
if at.data['target'] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, at)
return event
def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
plain = event.message_chain.extract_first(MessageType.PLAIN)
if plain is not None:
text = str(plain)
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is not None:
nickname = matched.group(1)
Log.info(f'User is calling me {nickname}')
plain.data['text'] = text[matched.end():]
event.message_chain.insert(0, plain)
return event
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
reply = event.message_chain.extract_first(MessageType.QUOTE)
if reply is not None:
if reply.data['sender_id'] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, reply)
return event
async def process_event(bot: "Bot", event: Event) -> None:
if isinstance(event, MessageEvent):
event = process_source(bot, event)
if isinstance(event, GroupMessage):
event = process_nick(bot, event)
event = process_reply(bot, event)
event = process_at(bot, event)
await handle_event(bot, event)