diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 75c59541..b12652d9 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -51,7 +51,16 @@ class BaseEvent(abc.ABC): def __repr__(self) -> str: # TODO: pretty print - return f"" + return f"" + + @property + def raw_event(self) -> dict: + return self._raw_event + + @property + @abc.abstractmethod + def self_id(self) -> str: + raise NotImplementedError @property @abc.abstractmethod @@ -93,6 +102,16 @@ class BaseEvent(abc.ABC): def user_id(self, value) -> None: raise NotImplementedError + @property + @abc.abstractmethod + def to_me(self) -> Optional[bool]: + raise NotImplementedError + + @to_me.setter + @abc.abstractmethod + def to_me(self, value) -> None: + raise NotImplementedError + @property @abc.abstractmethod def message(self) -> Optional[Message]: diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 2519710f..a62d8546 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -7,11 +7,12 @@ import asyncio import httpx +from nonebot.log import logger from nonebot.config import Config from nonebot.message import handle_event -from nonebot.typing import overrides, Driver, WebSocket, NoReturn from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable +from nonebot.typing import overrides, Driver, WebSocket, NoReturn from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment @@ -41,6 +42,67 @@ def _b2s(b: bool) -> str: return str(b).lower() +def _check_at_me(bot: "Bot", event: "Event"): + if event.type != "message": + return + + if event.detail_type == "private": + event.to_me = True + else: + event.to_me = False + at_me_seg = MessageSegment.at(event.self_id) + + # check the first segment + first_msg_seg = event.message[0] + if first_msg_seg == at_me_seg: + event.to_me = True + del event.message[0] + + if not event.to_me: + # check the last segment + i = -1 + last_msg_seg = event.message[i] + if last_msg_seg.type == "text" and \ + not last_msg_seg.data["text"].strip() and \ + len(event.message) >= 2: + i -= 1 + last_msg_seg = event.message[i] + + if last_msg_seg == at_me_seg: + event.to_me = True + del event.message[i:] + + if not event.message: + event.message.append(MessageSegment.text("")) + + +def _check_nickname(bot: "Bot", event: "Event"): + if event.type != "message": + return + + first_msg_seg = event.message[0] + if first_msg_seg.type != "text": + return + + first_text = first_msg_seg.data["text"] + + if bot.config.NICKNAME: + # check if the user is calling me with my nickname + if isinstance(bot.config.NICKNAME, str) or \ + not isinstance(bot.config.NICKNAME, Iterable): + nicknames = (bot.config.NICKNAME,) + else: + nicknames = filter(lambda n: n, bot.config.NICKNAME) + nickname_regex = "|".join(nicknames) + m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, + re.IGNORECASE) + if m: + nickname = m.group(1) + logger.debug(f"User is calling me {nickname}") + event.to_me = True + first_msg_seg.data["text"] = first_text[m.end():] + + def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: if isinstance(result, dict): if result.get("status") == "failed": @@ -108,6 +170,10 @@ class Bot(BaseBot): event = Event(message) + # Check whether user is calling me + _check_at_me(self, event) + _check_nickname(self, event) + await handle_event(self, event) @overrides(BaseBot) @@ -166,6 +232,11 @@ class Event(BaseEvent): super().__init__(raw_event) + @property + @overrides(BaseEvent) + def self_id(self) -> str: + return str(self._raw_event["self_id"]) + @property @overrides(BaseEvent) def type(self) -> str: @@ -206,6 +277,16 @@ class Event(BaseEvent): def user_id(self, value) -> None: self._raw_event["user_id"] = value + @property + @overrides(BaseEvent) + def to_me(self) -> Optional[bool]: + return self._raw_event.get("to_me") + + @to_me.setter + @overrides(BaseEvent) + def to_me(self, value) -> None: + self._raw_event["to_me"] = value + @property @overrides(BaseEvent) def message(self) -> Optional["Message"]: @@ -244,6 +325,18 @@ class Event(BaseEvent): class MessageSegment(BaseMessageSegment): + @overrides(BaseMessageSegment) + def __init__(self, type: str, data: Dict[str, str]) -> None: + if type == "at" and data.get("qq") == "all": + type = "at_all" + data.clear() + elif type == "shake": + type = "poke" + data = {"type": "Poke"} + elif type == "text": + data["text"] = unescape(data["text"]) + super().__init__(type=type, data=data) + @overrides(BaseMessageSegment) def __str__(self): type_ = self.type @@ -271,7 +364,7 @@ class MessageSegment(BaseMessageSegment): return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) @staticmethod - def at(user_id: int) -> "MessageSegment": + def at(user_id: Union[int, str]) -> "MessageSegment": return MessageSegment("at", {"qq": str(user_id)}) @staticmethod diff --git a/nonebot/exception.py b/nonebot/exception.py index 1a0eeb10..e129bd01 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -8,7 +8,13 @@ 这些异常并非所有需要用户处理,在 NoneBot 内部运行时被捕获,并进行对应操作。 """ -from nonebot.typing import Optional +from nonebot.typing import List, Type, Optional + + +class _ExceptionContainer(Exception): + + def __init__(self, exceptions: List[Type[Exception]]) -> None: + self.exceptions = exceptions class IgnoredException(Exception): @@ -37,12 +43,12 @@ class PausedException(Exception): """ :说明: - 指示 NoneBot 结束当前 Handler 并等待下一条消息后继续下一个 Handler。 + 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。 可用于用户输入新信息。 :用法: - 可以在 Handler 中通过 Matcher.pause() 抛出。 + 可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。 """ pass @@ -51,12 +57,12 @@ class RejectedException(Exception): """ :说明: - 指示 NoneBot 结束当前 Handler 并等待下一条消息后重新运行当前 Handler。 + 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。 可用于用户重新输入。 :用法: - 可以在 Handler 中通过 Matcher.reject() 抛出。 + 可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。 """ pass @@ -65,12 +71,38 @@ class FinishedException(Exception): """ :说明: - 指示 NoneBot 结束当前 Handler 且后续 Handler 不再被运行。 + 指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。 可用于结束用户会话。 :用法: - 可以在 Handler 中通过 Matcher.finish() 抛出。 + 可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。 + """ + pass + + +class ExpiredException(Exception): + """ + :说明: + + 指示 NoneBot 当前 ``Matcher`` 已失效。 + + :用法: + + 当 ``Matcher`` 运行前检查时抛出。 + """ + pass + + +class StopPropagation(Exception): + """ + :说明: + + 指示 NoneBot 终止事件向下层传播。 + + :用法: + + 在 ``Matcher.block == True`` 时抛出。 """ pass diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 53158cff..acfabce0 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -26,6 +26,7 @@ class Matcher: temp: bool = False expire_time: Optional[datetime] = None priority: int = 1 + block: bool = False _default_state: dict = {} @@ -45,6 +46,7 @@ class Matcher: handlers: list = [], temp: bool = False, priority: int = 1, + block: bool = False, *, default_state: dict = {}, expire_time: Optional[datetime] = None) -> Type["Matcher"]: @@ -63,6 +65,7 @@ class Matcher: "temp": temp, "expire_time": expire_time, "priority": priority, + "block": block, "_default_state": default_state }) diff --git a/nonebot/message.py b/nonebot/message.py index 5a6306bf..0d41155a 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -7,8 +7,10 @@ from datetime import datetime from nonebot.log import logger from nonebot.rule import TrieRule from nonebot.matcher import matchers -from nonebot.exception import IgnoredException -from nonebot.typing import Bot, Set, Event, PreProcessor +from nonebot.typing import Set, Type, Union, NoReturn +from nonebot.typing import Bot, Event, Matcher, PreProcessor +from nonebot.exception import IgnoredException, ExpiredException +from nonebot.exception import StopPropagation, _ExceptionContainer _event_preprocessors: Set[PreProcessor] = set() @@ -18,6 +20,38 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor: return func +async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event, + state: dict) -> Union[None, NoReturn]: + if datetime.now() > Matcher.expire_time: + raise _ExceptionContainer([ExpiredException]) + + try: + if not await Matcher.check_perm( + bot, event) or not await Matcher.check_rule(bot, event, state): + return + except Exception as e: + logger.error(f"Rule check failed for matcher {Matcher}. Ignored.") + logger.exception(e) + return + + matcher = Matcher() + # TODO: BeforeMatcherRun + try: + logger.debug(f"Running matcher {matcher}") + await matcher.run(bot, event, state) + except Exception as e: + logger.error(f"Running matcher {matcher} failed.") + logger.exception(e) + + exceptions = [] + if Matcher.temp: + exceptions.append(ExpiredException) + if Matcher.block: + exceptions.append(StopPropagation) + if exceptions: + raise _ExceptionContainer(exceptions) + + async def handle_event(bot: Bot, event: Event): coros = [] state = {} @@ -33,37 +67,24 @@ async def handle_event(bot: Bot, event: Event): # Trie Match _, _ = TrieRule.get_value(bot, event, state) + break_flag = False for priority in sorted(matchers.keys()): - index = 0 - while index <= len(matchers[priority]): - Matcher = matchers[priority][index] + if break_flag: + break - # Delete expired Matcher - if datetime.now() > Matcher.expire_time: - del matchers[priority][index] - continue + pending_tasks = [ + _run_matcher(matcher, bot, event, state.copy()) + for matcher in matchers[priority] + ] - # Check rule - try: - if not await Matcher.check_perm( - bot, event) or not await Matcher.check_rule( - bot, event, state): - index += 1 - continue - except Exception as e: - logger.error( - f"Rule check failed for matcher {Matcher}. Ignored.") - logger.exception(e) - continue + results = await asyncio.gather(*pending_tasks, return_exceptions=True) - matcher = Matcher() - # TODO: BeforeMatcherRun - if Matcher.temp: - del matchers[priority][index] - - try: - await matcher.run(bot, event, state) - except Exception as e: - logger.error(f"Running matcher {matcher} failed.") - logger.exception(e) - return + i = 0 + for index, result in enumerate(results): + if isinstance(result, _ExceptionContainer): + e_list = result.exceptions + if StopPropagation in e_list: + break_flag = True + if ExpiredException in e_list: + del matchers[priority][index - i] + i += 1 diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 4c01130b..292c954b 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -33,12 +33,14 @@ def on(rule: Union[Rule, RuleChecker] = Rule(), handlers=[], temp=False, priority: int = 1, + block: bool = False, state={}) -> Type[Matcher]: matcher = Matcher.new("", Rule() & rule, permission, temp=temp, priority=priority, + block=block, handlers=handlers, default_state=state) _tmp_matchers.add(matcher) @@ -50,12 +52,14 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), handlers=[], temp=False, priority: int = 1, + block: bool = False, state={}) -> Type[Matcher]: matcher = Matcher.new("meta_event", Rule() & rule, Permission(), temp=temp, priority=priority, + block=block, handlers=handlers, default_state=state) _tmp_matchers.add(matcher) @@ -68,12 +72,14 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(), handlers=[], temp=False, priority: int = 1, + block: bool = True, state={}) -> Type[Matcher]: matcher = Matcher.new("message", Rule() & rule, permission, temp=temp, priority=priority, + block=block, handlers=handlers, default_state=state) _tmp_matchers.add(matcher) @@ -85,12 +91,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(), handlers=[], temp=False, priority: int = 1, + block: bool = False, state={}) -> Type[Matcher]: matcher = Matcher.new("notice", Rule() & rule, Permission(), temp=temp, priority=priority, + block=block, handlers=handlers, default_state=state) _tmp_matchers.add(matcher) @@ -102,12 +110,14 @@ def on_request(rule: Union[Rule, RuleChecker] = Rule(), handlers=[], temp=False, priority: int = 1, + block: bool = False, state={}) -> Type[Matcher]: matcher = Matcher.new("request", Rule() & rule, Permission(), temp=temp, priority=priority, + block=block, handlers=handlers, default_state=state) _tmp_matchers.add(matcher)