make matcher running concurrently and add to me checking

This commit is contained in:
yanyongyu 2020-08-21 14:24:32 +08:00
parent c01d3c7ca1
commit c1d0eae34b
6 changed files with 220 additions and 42 deletions

View File

@ -51,7 +51,16 @@ class BaseEvent(abc.ABC):
def __repr__(self) -> str: def __repr__(self) -> str:
# TODO: pretty print # TODO: pretty print
return f"<Event: >" return f"<Event: {self.type}/{self.detail_type} {self.raw_message}>"
@property
def raw_event(self) -> dict:
return self._raw_event
@property
@abc.abstractmethod
def self_id(self) -> str:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -93,6 +102,16 @@ class BaseEvent(abc.ABC):
def user_id(self, value) -> None: def user_id(self, value) -> None:
raise NotImplementedError raise NotImplementedError
@property
@abc.abstractmethod
def to_me(self) -> Optional[bool]:
raise NotImplementedError
@to_me.setter
@abc.abstractmethod
def to_me(self, value) -> None:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def message(self) -> Optional[Message]: def message(self) -> Optional[Message]:

View File

@ -7,11 +7,12 @@ import asyncio
import httpx import httpx
from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
@ -41,6 +42,67 @@ def _b2s(b: bool) -> str:
return str(b).lower() return str(b).lower()
def _check_at_me(bot: "Bot", event: "Event"):
if event.type != "message":
return
if event.detail_type == "private":
event.to_me = True
else:
event.to_me = False
at_me_seg = MessageSegment.at(event.self_id)
# check the first segment
first_msg_seg = event.message[0]
if first_msg_seg == at_me_seg:
event.to_me = True
del event.message[0]
if not event.to_me:
# check the last segment
i = -1
last_msg_seg = event.message[i]
if last_msg_seg.type == "text" and \
not last_msg_seg.data["text"].strip() and \
len(event.message) >= 2:
i -= 1
last_msg_seg = event.message[i]
if last_msg_seg == at_me_seg:
event.to_me = True
del event.message[i:]
if not event.message:
event.message.append(MessageSegment.text(""))
def _check_nickname(bot: "Bot", event: "Event"):
if event.type != "message":
return
first_msg_seg = event.message[0]
if first_msg_seg.type != "text":
return
first_text = first_msg_seg.data["text"]
if bot.config.NICKNAME:
# check if the user is calling me with my nickname
if isinstance(bot.config.NICKNAME, str) or \
not isinstance(bot.config.NICKNAME, Iterable):
nicknames = (bot.config.NICKNAME,)
else:
nicknames = filter(lambda n: n, bot.config.NICKNAME)
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
if m:
nickname = m.group(1)
logger.debug(f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
if isinstance(result, dict): if isinstance(result, dict):
if result.get("status") == "failed": if result.get("status") == "failed":
@ -108,6 +170,10 @@ class Bot(BaseBot):
event = Event(message) event = Event(message)
# Check whether user is calling me
_check_at_me(self, event)
_check_nickname(self, event)
await handle_event(self, event) await handle_event(self, event)
@overrides(BaseBot) @overrides(BaseBot)
@ -166,6 +232,11 @@ class Event(BaseEvent):
super().__init__(raw_event) super().__init__(raw_event)
@property
@overrides(BaseEvent)
def self_id(self) -> str:
return str(self._raw_event["self_id"])
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def type(self) -> str: def type(self) -> str:
@ -206,6 +277,16 @@ class Event(BaseEvent):
def user_id(self, value) -> None: def user_id(self, value) -> None:
self._raw_event["user_id"] = value self._raw_event["user_id"] = value
@property
@overrides(BaseEvent)
def to_me(self) -> Optional[bool]:
return self._raw_event.get("to_me")
@to_me.setter
@overrides(BaseEvent)
def to_me(self, value) -> None:
self._raw_event["to_me"] = value
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def message(self) -> Optional["Message"]: def message(self) -> Optional["Message"]:
@ -244,6 +325,18 @@ class Event(BaseEvent):
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __init__(self, type: str, data: Dict[str, str]) -> None:
if type == "at" and data.get("qq") == "all":
type = "at_all"
data.clear()
elif type == "shake":
type = "poke"
data = {"type": "Poke"}
elif type == "text":
data["text"] = unescape(data["text"])
super().__init__(type=type, data=data)
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __str__(self): def __str__(self):
type_ = self.type type_ = self.type
@ -271,7 +364,7 @@ class MessageSegment(BaseMessageSegment):
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
@staticmethod @staticmethod
def at(user_id: int) -> "MessageSegment": def at(user_id: Union[int, str]) -> "MessageSegment":
return MessageSegment("at", {"qq": str(user_id)}) return MessageSegment("at", {"qq": str(user_id)})
@staticmethod @staticmethod

View File

@ -8,7 +8,13 @@
这些异常并非所有需要用户处理 NoneBot 内部运行时被捕获并进行对应操作 这些异常并非所有需要用户处理 NoneBot 内部运行时被捕获并进行对应操作
""" """
from nonebot.typing import Optional from nonebot.typing import List, Type, Optional
class _ExceptionContainer(Exception):
def __init__(self, exceptions: List[Type[Exception]]) -> None:
self.exceptions = exceptions
class IgnoredException(Exception): class IgnoredException(Exception):
@ -37,12 +43,12 @@ class PausedException(Exception):
""" """
:说明: :说明:
指示 NoneBot 结束当前 Handler 并等待下一条消息后继续下一个 Handler 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``
可用于用户输入新信息 可用于用户输入新信息
:用法: :用法:
可以在 Handler 中通过 Matcher.pause() 抛出 可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出
""" """
pass pass
@ -51,12 +57,12 @@ class RejectedException(Exception):
""" """
:说明: :说明:
指示 NoneBot 结束当前 Handler 并等待下一条消息后重新运行当前 Handler 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``
可用于用户重新输入 可用于用户重新输入
:用法: :用法:
可以在 Handler 中通过 Matcher.reject() 可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出
""" """
pass pass
@ -65,12 +71,38 @@ class FinishedException(Exception):
""" """
:说明: :说明:
指示 NoneBot 结束当前 Handler 且后续 Handler 不再被运行 指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行
可用于结束用户会话 可用于结束用户会话
:用法: :用法:
可以在 Handler 中通过 Matcher.finish() 抛出 可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出
"""
pass
class ExpiredException(Exception):
"""
:说明:
指示 NoneBot 当前 ``Matcher`` 已失效
:用法:
``Matcher`` 运行前检查时抛出
"""
pass
class StopPropagation(Exception):
"""
:说明:
指示 NoneBot 终止事件向下层传播
:用法:
``Matcher.block == True`` 时抛出
""" """
pass pass

View File

@ -26,6 +26,7 @@ class Matcher:
temp: bool = False temp: bool = False
expire_time: Optional[datetime] = None expire_time: Optional[datetime] = None
priority: int = 1 priority: int = 1
block: bool = False
_default_state: dict = {} _default_state: dict = {}
@ -45,6 +46,7 @@ class Matcher:
handlers: list = [], handlers: list = [],
temp: bool = False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False,
*, *,
default_state: dict = {}, default_state: dict = {},
expire_time: Optional[datetime] = None) -> Type["Matcher"]: expire_time: Optional[datetime] = None) -> Type["Matcher"]:
@ -63,6 +65,7 @@ class Matcher:
"temp": temp, "temp": temp,
"expire_time": expire_time, "expire_time": expire_time,
"priority": priority, "priority": priority,
"block": block,
"_default_state": default_state "_default_state": default_state
}) })

View File

@ -7,8 +7,10 @@ from datetime import datetime
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import TrieRule from nonebot.rule import TrieRule
from nonebot.matcher import matchers from nonebot.matcher import matchers
from nonebot.exception import IgnoredException from nonebot.typing import Set, Type, Union, NoReturn
from nonebot.typing import Bot, Set, Event, PreProcessor from nonebot.typing import Bot, Event, Matcher, PreProcessor
from nonebot.exception import IgnoredException, ExpiredException
from nonebot.exception import StopPropagation, _ExceptionContainer
_event_preprocessors: Set[PreProcessor] = set() _event_preprocessors: Set[PreProcessor] = set()
@ -18,6 +20,38 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
return func return func
async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event,
state: dict) -> Union[None, NoReturn]:
if datetime.now() > Matcher.expire_time:
raise _ExceptionContainer([ExpiredException])
try:
if not await Matcher.check_perm(
bot, event) or not await Matcher.check_rule(bot, event, state):
return
except Exception as e:
logger.error(f"Rule check failed for matcher {Matcher}. Ignored.")
logger.exception(e)
return
matcher = Matcher()
# TODO: BeforeMatcherRun
try:
logger.debug(f"Running matcher {matcher}")
await matcher.run(bot, event, state)
except Exception as e:
logger.error(f"Running matcher {matcher} failed.")
logger.exception(e)
exceptions = []
if Matcher.temp:
exceptions.append(ExpiredException)
if Matcher.block:
exceptions.append(StopPropagation)
if exceptions:
raise _ExceptionContainer(exceptions)
async def handle_event(bot: Bot, event: Event): async def handle_event(bot: Bot, event: Event):
coros = [] coros = []
state = {} state = {}
@ -33,37 +67,24 @@ async def handle_event(bot: Bot, event: Event):
# Trie Match # Trie Match
_, _ = TrieRule.get_value(bot, event, state) _, _ = TrieRule.get_value(bot, event, state)
break_flag = False
for priority in sorted(matchers.keys()): for priority in sorted(matchers.keys()):
index = 0 if break_flag:
while index <= len(matchers[priority]): break
Matcher = matchers[priority][index]
# Delete expired Matcher pending_tasks = [
if datetime.now() > Matcher.expire_time: _run_matcher(matcher, bot, event, state.copy())
del matchers[priority][index] for matcher in matchers[priority]
continue ]
# Check rule results = await asyncio.gather(*pending_tasks, return_exceptions=True)
try:
if not await Matcher.check_perm(
bot, event) or not await Matcher.check_rule(
bot, event, state):
index += 1
continue
except Exception as e:
logger.error(
f"Rule check failed for matcher {Matcher}. Ignored.")
logger.exception(e)
continue
matcher = Matcher() i = 0
# TODO: BeforeMatcherRun for index, result in enumerate(results):
if Matcher.temp: if isinstance(result, _ExceptionContainer):
del matchers[priority][index] e_list = result.exceptions
if StopPropagation in e_list:
try: break_flag = True
await matcher.run(bot, event, state) if ExpiredException in e_list:
except Exception as e: del matchers[priority][index - i]
logger.error(f"Running matcher {matcher} failed.") i += 1
logger.exception(e)
return

View File

@ -33,12 +33,14 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[], handlers=[],
temp=False, temp=False,
priority: int = 1, priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]: state={}) -> Type[Matcher]:
matcher = Matcher.new("", matcher = Matcher.new("",
Rule() & rule, Rule() & rule,
permission, permission,
temp=temp, temp=temp,
priority=priority, priority=priority,
block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.add(matcher)
@ -50,12 +52,14 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[], handlers=[],
temp=False, temp=False,
priority: int = 1, priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]: state={}) -> Type[Matcher]:
matcher = Matcher.new("meta_event", matcher = Matcher.new("meta_event",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
temp=temp, temp=temp,
priority=priority, priority=priority,
block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.add(matcher)
@ -68,12 +72,14 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[], handlers=[],
temp=False, temp=False,
priority: int = 1, priority: int = 1,
block: bool = True,
state={}) -> Type[Matcher]: state={}) -> Type[Matcher]:
matcher = Matcher.new("message", matcher = Matcher.new("message",
Rule() & rule, Rule() & rule,
permission, permission,
temp=temp, temp=temp,
priority=priority, priority=priority,
block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.add(matcher)
@ -85,12 +91,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[], handlers=[],
temp=False, temp=False,
priority: int = 1, priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]: state={}) -> Type[Matcher]:
matcher = Matcher.new("notice", matcher = Matcher.new("notice",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
temp=temp, temp=temp,
priority=priority, priority=priority,
block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.add(matcher)
@ -102,12 +110,14 @@ def on_request(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[], handlers=[],
temp=False, temp=False,
priority: int = 1, priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]: state={}) -> Type[Matcher]:
matcher = Matcher.new("request", matcher = Matcher.new("request",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
temp=temp, temp=temp,
priority=priority, priority=priority,
block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.add(matcher)