🔀 Merge pull request #199

Fix mirai adapter command process
This commit is contained in:
Ju4tCode 2021-02-07 12:55:49 +08:00 committed by GitHub
commit a3fe3a1ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 124 additions and 91 deletions

View File

@ -1,8 +1,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps
from io import BytesIO from io import BytesIO
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union) from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx import httpx
@ -10,15 +9,12 @@ from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ApiNotAvailable, RequestDenied from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment from .message import MessageChain, MessageSegment
from .utils import catch_network_error, argument_validation, check_tome, Log from .utils import Log, argument_validation, catch_network_error, process_event
class SessionManager: class SessionManager:
@ -212,20 +208,15 @@ class Bot(BaseBot):
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
Log.debug(f'received message {message}') Log.debug(f'received message {message}')
try: try:
await handle_event( await process_event(
bot=self, bot=self,
event=await check_tome( event=Event.new({
bot=self, **message,
event=Event.new({ 'self_id': self.self_id,
**message, }),
'self_id': self.self_id,
}),
),
) )
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).exception( Log.error(f'Failed to handle message: {message}', e)
'Failed to handle message '
f'<r>{escape_tag(str(message))}</r>: ')
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn: async def call_api(self, api: str, **data) -> NoReturn:
@ -262,10 +253,8 @@ class Bot(BaseBot):
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息 * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主体 * ``at_sender: bool``: 是否 @ 事件主体
""" """
if isinstance(message, MessageSegment): if not isinstance(message, MessageChain):
message = MessageChain(message) message = MessageChain(message)
elif isinstance(message, str):
message = MessageChain(MessageSegment.plain(message))
if isinstance(event, FriendMessage): if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id, return await self.send_friend_message(target=event.sender.id,
message_chain=message) message_chain=message)

View File

@ -13,7 +13,7 @@ from .request import *
__all__ = [ __all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission', 'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
'MessageChain', 'MessageEvent', 'GroupMessage', 'FriendMessage', 'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent', 'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent', 'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent', 'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',

View File

@ -1,6 +1,7 @@
from typing import Any from datetime import datetime
from typing import Any, Optional
from pydantic import Field from pydantic import BaseModel, Field
from nonebot.typing import overrides from nonebot.typing import overrides
@ -8,9 +9,15 @@ from ..message import MessageChain
from .base import Event, GroupChatInfo, PrivateChatInfo from .base import Event, GroupChatInfo, PrivateChatInfo
class MessageSource(BaseModel):
id: int
time: datetime
class MessageEvent(Event): class MessageEvent(Event):
"""消息事件基类""" """消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain') message_chain: MessageChain = Field(alias='messageChain')
source: Optional[MessageSource] = None
sender: Any sender: Any
@overrides(Event) @overrides(Event)

View File

@ -44,8 +44,9 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __str__(self) -> str: def __str__(self) -> str:
if self.is_text(): return self.data['text'] if self.is_text() else repr(self)
return self.data.get('text', '')
def __repr__(self) -> str:
return '[mirai:%s]' % ','.join([ return '[mirai:%s]' % ','.join([
self.type.value, self.type.value,
*map( *map(
@ -273,12 +274,14 @@ class MessageChain(BaseMessage):
""" """
@overrides(BaseMessage) @overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str, Any]], def __init__(self, message: Union[List[Dict[str,
Iterable[MessageSegment], MessageSegment], Any]], Iterable[MessageSegment],
**kwargs): MessageSegment, str], **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if isinstance(message, MessageSegment): if isinstance(message, MessageSegment):
self.append(message) self.append(message)
elif isinstance(message, str):
self.append(MessageSegment.plain(text=message))
elif isinstance(message, Iterable): elif isinstance(message, Iterable):
self.extend(self._construct(message)) self.extend(self._construct(message))
else: else:
@ -306,5 +309,13 @@ class MessageChain(BaseMessage):
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
] ]
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
if not len(self):
return None
first: MessageSegment = self[0]
if (not type) or (first.type in type):
return self.pop(0)
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>' return f'<{self.__class__.__name__} {[*self.copy()]}>'

View File

@ -7,10 +7,11 @@ from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception import nonebot.exception as exception
from nonebot.log import logger from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper from nonebot.utils import escape_tag, logger_wrapper
from .event import Event, GroupMessage from .event import Event, GroupMessage, MessageEvent, MessageSource
from .message import MessageSegment, MessageType from .message import MessageType
if TYPE_CHECKING: if TYPE_CHECKING:
from .bot import Bot from .bot import Bot
@ -22,27 +23,26 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class Log: class Log:
@staticmethod @staticmethod
def _log(level: str, message: Any, exception: Optional[Exception] = None): def log(level: str, message: str, exception: Optional[Exception] = None):
logger = logger_wrapper('MIRAI') logger = logger_wrapper('MIRAI')
logger(level=level, message = '<e>' + escape_tag(message) + '</e>'
message=escape_tag(str(message)), logger(level=level.upper(), message=message, exception=exception)
exception=exception)
@classmethod @classmethod
def info(cls, message: Any): def info(cls, message: Any):
cls._log('INFO', escape_tag(str(message))) cls.log('INFO', str(message))
@classmethod @classmethod
def debug(cls, message: Any): def debug(cls, message: Any):
cls._log('DEBUG', escape_tag(str(message))) cls.log('DEBUG', str(message))
@classmethod @classmethod
def warn(cls, message: Any): def warn(cls, message: Any):
cls._log('WARNING', escape_tag(str(message))) cls.log('WARNING', str(message))
@classmethod @classmethod
def error(cls, message: Any, exception: Optional[Exception] = None): def error(cls, message: Any, exception: Optional[Exception] = None):
cls._log('ERROR', escape_tag(str(message)), exception=exception) cls.log('ERROR', str(message), exception=exception)
class ActionFailed(exception.ActionFailed): class ActionFailed(exception.ActionFailed):
@ -124,39 +124,54 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
return wrapper # type: ignore return wrapper # type: ignore
async def check_tome(bot: "Bot", event: "Event") -> "Event": def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
if not isinstance(event, GroupMessage): source = event.message_chain.extract_first(MessageType.SOURCE)
return event if source is not None:
event.source = MessageSource.parse_obj(source.data)
def _is_at(event: GroupMessage) -> bool:
for segment in event.message_chain:
segment: MessageSegment
if segment.type != MessageType.AT:
continue
if segment.data['target'] == event.self_id:
return True
return False
def _is_nick(event: GroupMessage) -> bool:
text = event.get_plaintext()
if not text:
return False
nick_regex = '|'.join(
{i.strip() for i in bot.config.nickname if i.strip()})
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is None:
return False
Log.info(f'User is calling me {matched.group(1)}')
return True
def _is_reply(event: GroupMessage) -> bool:
for segment in event.message_chain:
segment: MessageSegment
if segment.type != MessageType.QUOTE:
continue
if segment.data['senderId'] == event.self_id:
return True
return False
event.to_me = any([_is_at(event), _is_reply(event), _is_nick(event)])
return event return event
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
at = event.message_chain.extract_first(MessageType.AT)
if at is not None:
if at.data['target'] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, at)
return event
def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
plain = event.message_chain.extract_first(MessageType.PLAIN)
if plain is not None:
text = str(plain)
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is not None:
event.to_me = True
nickname = matched.group(1)
Log.info(f'User is calling me {nickname}')
plain.data['text'] = text[matched.end():]
event.message_chain.insert(0, plain)
return event
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
reply = event.message_chain.extract_first(MessageType.QUOTE)
if reply is not None:
if reply.data['senderId'] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, reply)
return event
async def process_event(bot: "Bot", event: Event) -> None:
if isinstance(event, MessageEvent):
Log.debug(event.message_chain)
event = process_source(bot, event)
if isinstance(event, GroupMessage):
event = process_nick(bot, event)
event = process_reply(bot, event)
event = process_at(bot, event)
await handle_event(bot, event)

View File

@ -418,7 +418,7 @@ class Matcher(metaclass=MatcherMeta):
""" """
bot = current_bot.get() bot = current_bot.get()
event = current_event.get() event = current_event.get()
return await bot.send(event=event, message=message, **kwargs) await bot.send(event=event, message=message, **kwargs)
@classmethod @classmethod
async def finish(cls, async def finish(cls,

View File

@ -13,7 +13,7 @@ from types import ModuleType
from dataclasses import dataclass from dataclasses import dataclass
from importlib._bootstrap import _load from importlib._bootstrap import _load
from contextvars import Context, ContextVar, copy_context from contextvars import Context, ContextVar, copy_context
from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING, Iterable
from nonebot.log import logger from nonebot.log import logger
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
@ -22,7 +22,7 @@ from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event, MessageSegment
plugins: Dict[str, "Plugin"] = {} plugins: Dict[str, "Plugin"] = {}
""" """
@ -422,12 +422,15 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
message = event.get_message() message = event.get_message()
segment = message.pop(0) text_processed = False
new_message = message.__class__( for index, segment in enumerate(message):
str(segment) segment: "MessageSegment" = message.pop(index)
[len(state["_prefix"]["raw_command"]):].lstrip()) # type: ignore if segment.is_text() and not text_processed:
for new_segment in reversed(new_message): segment, *_ = message.__class__(
message.insert(0, new_segment) str(segment)[len(state["_prefix"]["raw_command"]):].lstrip(
)) # type: ignore
text_processed = True
message.insert(index, segment)
handlers = kwargs.pop("handlers", []) handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd) handlers.insert(0, _strip_cmd)

View File

@ -25,7 +25,7 @@ from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker from nonebot.typing import T_State, T_RuleChecker
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event, MessageSegment
class Rule: class Rule:
@ -137,8 +137,9 @@ class TrieRule:
prefix = None prefix = None
suffix = None suffix = None
message = event.get_message() message = event.get_message()
message_seg = message[0] message_seg: Optional["MessageSegment"] = next(
if message_seg.is_text(): filter(lambda x: x.is_text(), message), None)
if message_seg is not None:
prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) prefix = cls.prefix.longest_prefix(str(message_seg).lstrip())
message_seg_r = message[-1] message_seg_r = message[-1]
if message_seg_r.is_text(): if message_seg_r.is_text():

View File

@ -1,13 +1,20 @@
from nonebot.plugin import on_message from nonebot.plugin import on_keyword, on_command
from nonebot.rule import to_me
from nonebot.adapters.mirai import Bot, MessageEvent from nonebot.adapters.mirai import Bot, MessageEvent
message_test = on_message() message_test = on_keyword({'reply'}, rule=to_me())
@message_test.handle() @message_test.handle()
async def _message(bot: Bot, event: MessageEvent): async def _message(bot: Bot, event: MessageEvent):
text = event.get_plaintext() text = event.get_plaintext()
if not text: await bot.send(event, text, at_sender=True)
return
reversed_text = ''.join(reversed(text))
await bot.send(event, reversed_text, at_sender=True) command_test = on_command('miecho')
@command_test.handle()
async def _echo(bot: Bot, event: MessageEvent):
text = event.get_plaintext()
await bot.send(event, text, at_sender=True)