🔀 Merge pull request #199

Fix mirai adapter command process
This commit is contained in:
Ju4tCode 2021-02-07 12:55:49 +08:00 committed by GitHub
commit a3fe3a1ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 124 additions and 91 deletions

View File

@ -1,8 +1,7 @@
from datetime import datetime, timedelta
from functools import wraps
from io import BytesIO
from ipaddress import IPv4Address
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
@ -10,15 +9,12 @@ from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment
from .utils import catch_network_error, argument_validation, check_tome, Log
from .utils import Log, argument_validation, catch_network_error, process_event
class SessionManager:
@ -212,20 +208,15 @@ class Bot(BaseBot):
async def handle_message(self, message: dict):
Log.debug(f'received message {message}')
try:
await handle_event(
await process_event(
bot=self,
event=await check_tome(
bot=self,
event=Event.new({
**message,
'self_id': self.self_id,
}),
),
event=Event.new({
**message,
'self_id': self.self_id,
}),
)
except Exception as e:
logger.opt(colors=True, exception=e).exception(
'Failed to handle message '
f'<r>{escape_tag(str(message))}</r>: ')
Log.error(f'Failed to handle message: {message}', e)
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
@ -262,10 +253,8 @@ class Bot(BaseBot):
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主体
"""
if isinstance(message, MessageSegment):
if not isinstance(message, MessageChain):
message = MessageChain(message)
elif isinstance(message, str):
message = MessageChain(MessageSegment.plain(message))
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)

View File

@ -13,7 +13,7 @@ from .request import *
__all__ = [
'Event', 'GroupChatInfo', 'GroupInfo', 'PrivateChatInfo', 'UserPermission',
'MessageChain', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'MessageSource', 'MessageEvent', 'GroupMessage', 'FriendMessage',
'TempMessage', 'NoticeEvent', 'MuteEvent', 'BotMuteEvent', 'BotUnmuteEvent',
'MemberMuteEvent', 'MemberUnmuteEvent', 'BotJoinGroupEvent',
'BotLeaveEventActive', 'BotLeaveEventKick', 'MemberJoinEvent',

View File

@ -1,6 +1,7 @@
from typing import Any
from datetime import datetime
from typing import Any, Optional
from pydantic import Field
from pydantic import BaseModel, Field
from nonebot.typing import overrides
@ -8,9 +9,15 @@ from ..message import MessageChain
from .base import Event, GroupChatInfo, PrivateChatInfo
class MessageSource(BaseModel):
id: int
time: datetime
class MessageEvent(Event):
"""消息事件基类"""
message_chain: MessageChain = Field(alias='messageChain')
source: Optional[MessageSource] = None
sender: Any
@overrides(Event)

View File

@ -44,8 +44,9 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __str__(self) -> str:
if self.is_text():
return self.data.get('text', '')
return self.data['text'] if self.is_text() else repr(self)
def __repr__(self) -> str:
return '[mirai:%s]' % ','.join([
self.type.value,
*map(
@ -273,12 +274,14 @@ class MessageChain(BaseMessage):
"""
@overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str, Any]],
Iterable[MessageSegment], MessageSegment],
**kwargs):
def __init__(self, message: Union[List[Dict[str,
Any]], Iterable[MessageSegment],
MessageSegment, str], **kwargs):
super().__init__(**kwargs)
if isinstance(message, MessageSegment):
self.append(message)
elif isinstance(message, str):
self.append(MessageSegment.plain(text=message))
elif isinstance(message, Iterable):
self.extend(self._construct(message))
else:
@ -306,5 +309,13 @@ class MessageChain(BaseMessage):
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
]
def extract_first(self, *type: MessageType) -> Optional[MessageSegment]:
if not len(self):
return None
first: MessageSegment = self[0]
if (not type) or (first.type in type):
return self.pop(0)
return None
def __repr__(self) -> str:
return f'<{self.__class__.__name__} {[*self.copy()]}>'

View File

@ -7,10 +7,11 @@ from pydantic import Extra, ValidationError, validate_arguments
import nonebot.exception as exception
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.utils import escape_tag, logger_wrapper
from .event import Event, GroupMessage
from .message import MessageSegment, MessageType
from .event import Event, GroupMessage, MessageEvent, MessageSource
from .message import MessageType
if TYPE_CHECKING:
from .bot import Bot
@ -22,27 +23,26 @@ _AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class Log:
@staticmethod
def _log(level: str, message: Any, exception: Optional[Exception] = None):
def log(level: str, message: str, exception: Optional[Exception] = None):
logger = logger_wrapper('MIRAI')
logger(level=level,
message=escape_tag(str(message)),
exception=exception)
message = '<e>' + escape_tag(message) + '</e>'
logger(level=level.upper(), message=message, exception=exception)
@classmethod
def info(cls, message: Any):
cls._log('INFO', escape_tag(str(message)))
cls.log('INFO', str(message))
@classmethod
def debug(cls, message: Any):
cls._log('DEBUG', escape_tag(str(message)))
cls.log('DEBUG', str(message))
@classmethod
def warn(cls, message: Any):
cls._log('WARNING', escape_tag(str(message)))
cls.log('WARNING', str(message))
@classmethod
def error(cls, message: Any, exception: Optional[Exception] = None):
cls._log('ERROR', escape_tag(str(message)), exception=exception)
cls.log('ERROR', str(message), exception=exception)
class ActionFailed(exception.ActionFailed):
@ -124,39 +124,54 @@ def argument_validation(function: _AnyCallable) -> _AnyCallable:
return wrapper # type: ignore
async def check_tome(bot: "Bot", event: "Event") -> "Event":
if not isinstance(event, GroupMessage):
return event
def _is_at(event: GroupMessage) -> bool:
for segment in event.message_chain:
segment: MessageSegment
if segment.type != MessageType.AT:
continue
if segment.data['target'] == event.self_id:
return True
return False
def _is_nick(event: GroupMessage) -> bool:
text = event.get_plaintext()
if not text:
return False
nick_regex = '|'.join(
{i.strip() for i in bot.config.nickname if i.strip()})
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is None:
return False
Log.info(f'User is calling me {matched.group(1)}')
return True
def _is_reply(event: GroupMessage) -> bool:
for segment in event.message_chain:
segment: MessageSegment
if segment.type != MessageType.QUOTE:
continue
if segment.data['senderId'] == event.self_id:
return True
return False
event.to_me = any([_is_at(event), _is_reply(event), _is_nick(event)])
def process_source(bot: "Bot", event: MessageEvent) -> MessageEvent:
source = event.message_chain.extract_first(MessageType.SOURCE)
if source is not None:
event.source = MessageSource.parse_obj(source.data)
return event
def process_at(bot: "Bot", event: GroupMessage) -> GroupMessage:
at = event.message_chain.extract_first(MessageType.AT)
if at is not None:
if at.data['target'] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, at)
return event
def process_nick(bot: "Bot", event: GroupMessage) -> GroupMessage:
plain = event.message_chain.extract_first(MessageType.PLAIN)
if plain is not None:
text = str(plain)
nick_regex = '|'.join(filter(lambda x: x, bot.config.nickname))
matched = re.search(rf"^({nick_regex})([\s,]*|$)", text, re.IGNORECASE)
if matched is not None:
event.to_me = True
nickname = matched.group(1)
Log.info(f'User is calling me {nickname}')
plain.data['text'] = text[matched.end():]
event.message_chain.insert(0, plain)
return event
def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
reply = event.message_chain.extract_first(MessageType.QUOTE)
if reply is not None:
if reply.data['senderId'] == event.self_id:
event.to_me = True
else:
event.message_chain.insert(0, reply)
return event
async def process_event(bot: "Bot", event: Event) -> None:
if isinstance(event, MessageEvent):
Log.debug(event.message_chain)
event = process_source(bot, event)
if isinstance(event, GroupMessage):
event = process_nick(bot, event)
event = process_reply(bot, event)
event = process_at(bot, event)
await handle_event(bot, event)

View File

@ -418,7 +418,7 @@ class Matcher(metaclass=MatcherMeta):
"""
bot = current_bot.get()
event = current_event.get()
return await bot.send(event=event, message=message, **kwargs)
await bot.send(event=event, message=message, **kwargs)
@classmethod
async def finish(cls,

View File

@ -13,7 +13,7 @@ from types import ModuleType
from dataclasses import dataclass
from importlib._bootstrap import _load
from contextvars import Context, ContextVar, copy_context
from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING
from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING, Iterable
from nonebot.log import logger
from nonebot.matcher import Matcher
@ -22,7 +22,7 @@ from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from nonebot.adapters import Bot, Event, MessageSegment
plugins: Dict[str, "Plugin"] = {}
"""
@ -422,12 +422,15 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(
str(segment)
[len(state["_prefix"]["raw_command"]):].lstrip()) # type: ignore
for new_segment in reversed(new_message):
message.insert(0, new_segment)
text_processed = False
for index, segment in enumerate(message):
segment: "MessageSegment" = message.pop(index)
if segment.is_text() and not text_processed:
segment, *_ = message.__class__(
str(segment)[len(state["_prefix"]["raw_command"]):].lstrip(
)) # type: ignore
text_processed = True
message.insert(index, segment)
handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd)

View File

@ -25,7 +25,7 @@ from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
from nonebot.adapters import Bot, Event, MessageSegment
class Rule:
@ -137,8 +137,9 @@ class TrieRule:
prefix = None
suffix = None
message = event.get_message()
message_seg = message[0]
if message_seg.is_text():
message_seg: Optional["MessageSegment"] = next(
filter(lambda x: x.is_text(), message), None)
if message_seg is not None:
prefix = cls.prefix.longest_prefix(str(message_seg).lstrip())
message_seg_r = message[-1]
if message_seg_r.is_text():

View File

@ -1,13 +1,20 @@
from nonebot.plugin import on_message
from nonebot.plugin import on_keyword, on_command
from nonebot.rule import to_me
from nonebot.adapters.mirai import Bot, MessageEvent
message_test = on_message()
message_test = on_keyword({'reply'}, rule=to_me())
@message_test.handle()
async def _message(bot: Bot, event: MessageEvent):
text = event.get_plaintext()
if not text:
return
reversed_text = ''.join(reversed(text))
await bot.send(event, reversed_text, at_sender=True)
await bot.send(event, text, at_sender=True)
command_test = on_command('miecho')
@command_test.handle()
async def _echo(bot: Bot, event: MessageEvent):
text = event.get_plaintext()
await bot.send(event, text, at_sender=True)