make matcher running concurrently and add to me checking

This commit is contained in:
yanyongyu 2020-08-21 14:24:32 +08:00
parent c01d3c7ca1
commit c1d0eae34b
6 changed files with 220 additions and 42 deletions

View File

@ -51,7 +51,16 @@ class BaseEvent(abc.ABC):
def __repr__(self) -> str:
# TODO: pretty print
return f"<Event: >"
return f"<Event: {self.type}/{self.detail_type} {self.raw_message}>"
@property
def raw_event(self) -> dict:
return self._raw_event
@property
@abc.abstractmethod
def self_id(self) -> str:
raise NotImplementedError
@property
@abc.abstractmethod
@ -93,6 +102,16 @@ class BaseEvent(abc.ABC):
def user_id(self, value) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def to_me(self) -> Optional[bool]:
raise NotImplementedError
@to_me.setter
@abc.abstractmethod
def to_me(self, value) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def message(self) -> Optional[Message]:

View File

@ -7,11 +7,12 @@ import asyncio
import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
@ -41,6 +42,67 @@ def _b2s(b: bool) -> str:
return str(b).lower()
def _check_at_me(bot: "Bot", event: "Event"):
if event.type != "message":
return
if event.detail_type == "private":
event.to_me = True
else:
event.to_me = False
at_me_seg = MessageSegment.at(event.self_id)
# check the first segment
first_msg_seg = event.message[0]
if first_msg_seg == at_me_seg:
event.to_me = True
del event.message[0]
if not event.to_me:
# check the last segment
i = -1
last_msg_seg = event.message[i]
if last_msg_seg.type == "text" and \
not last_msg_seg.data["text"].strip() and \
len(event.message) >= 2:
i -= 1
last_msg_seg = event.message[i]
if last_msg_seg == at_me_seg:
event.to_me = True
del event.message[i:]
if not event.message:
event.message.append(MessageSegment.text(""))
def _check_nickname(bot: "Bot", event: "Event"):
if event.type != "message":
return
first_msg_seg = event.message[0]
if first_msg_seg.type != "text":
return
first_text = first_msg_seg.data["text"]
if bot.config.NICKNAME:
# check if the user is calling me with my nickname
if isinstance(bot.config.NICKNAME, str) or \
not isinstance(bot.config.NICKNAME, Iterable):
nicknames = (bot.config.NICKNAME,)
else:
nicknames = filter(lambda n: n, bot.config.NICKNAME)
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
if m:
nickname = m.group(1)
logger.debug(f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
if isinstance(result, dict):
if result.get("status") == "failed":
@ -108,6 +170,10 @@ class Bot(BaseBot):
event = Event(message)
# Check whether user is calling me
_check_at_me(self, event)
_check_nickname(self, event)
await handle_event(self, event)
@overrides(BaseBot)
@ -166,6 +232,11 @@ class Event(BaseEvent):
super().__init__(raw_event)
@property
@overrides(BaseEvent)
def self_id(self) -> str:
return str(self._raw_event["self_id"])
@property
@overrides(BaseEvent)
def type(self) -> str:
@ -206,6 +277,16 @@ class Event(BaseEvent):
def user_id(self, value) -> None:
self._raw_event["user_id"] = value
@property
@overrides(BaseEvent)
def to_me(self) -> Optional[bool]:
return self._raw_event.get("to_me")
@to_me.setter
@overrides(BaseEvent)
def to_me(self, value) -> None:
self._raw_event["to_me"] = value
@property
@overrides(BaseEvent)
def message(self) -> Optional["Message"]:
@ -244,6 +325,18 @@ class Event(BaseEvent):
class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __init__(self, type: str, data: Dict[str, str]) -> None:
if type == "at" and data.get("qq") == "all":
type = "at_all"
data.clear()
elif type == "shake":
type = "poke"
data = {"type": "Poke"}
elif type == "text":
data["text"] = unescape(data["text"])
super().__init__(type=type, data=data)
@overrides(BaseMessageSegment)
def __str__(self):
type_ = self.type
@ -271,7 +364,7 @@ class MessageSegment(BaseMessageSegment):
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
@staticmethod
def at(user_id: int) -> "MessageSegment":
def at(user_id: Union[int, str]) -> "MessageSegment":
return MessageSegment("at", {"qq": str(user_id)})
@staticmethod

View File

@ -8,7 +8,13 @@
这些异常并非所有需要用户处理 NoneBot 内部运行时被捕获并进行对应操作
"""
from nonebot.typing import Optional
from nonebot.typing import List, Type, Optional
class _ExceptionContainer(Exception):
def __init__(self, exceptions: List[Type[Exception]]) -> None:
self.exceptions = exceptions
class IgnoredException(Exception):
@ -37,12 +43,12 @@ class PausedException(Exception):
"""
:说明:
指示 NoneBot 结束当前 Handler 并等待下一条消息后继续下一个 Handler
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``
可用于用户输入新信息
:用法:
可以在 Handler 中通过 Matcher.pause() 抛出
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出
"""
pass
@ -51,12 +57,12 @@ class RejectedException(Exception):
"""
:说明:
指示 NoneBot 结束当前 Handler 并等待下一条消息后重新运行当前 Handler
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``
可用于用户重新输入
:用法:
可以在 Handler 中通过 Matcher.reject() 抛出
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出
"""
pass
@ -65,12 +71,38 @@ class FinishedException(Exception):
"""
:说明:
指示 NoneBot 结束当前 Handler 且后续 Handler 不再被运行
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行
可用于结束用户会话
:用法:
可以在 Handler 中通过 Matcher.finish() 抛出
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出
"""
pass
class ExpiredException(Exception):
"""
:说明:
指示 NoneBot 当前 ``Matcher`` 已失效
:用法:
``Matcher`` 运行前检查时抛出
"""
pass
class StopPropagation(Exception):
"""
:说明:
指示 NoneBot 终止事件向下层传播
:用法:
``Matcher.block == True`` 时抛出
"""
pass

View File

@ -26,6 +26,7 @@ class Matcher:
temp: bool = False
expire_time: Optional[datetime] = None
priority: int = 1
block: bool = False
_default_state: dict = {}
@ -45,6 +46,7 @@ class Matcher:
handlers: list = [],
temp: bool = False,
priority: int = 1,
block: bool = False,
*,
default_state: dict = {},
expire_time: Optional[datetime] = None) -> Type["Matcher"]:
@ -63,6 +65,7 @@ class Matcher:
"temp": temp,
"expire_time": expire_time,
"priority": priority,
"block": block,
"_default_state": default_state
})

View File

@ -7,8 +7,10 @@ from datetime import datetime
from nonebot.log import logger
from nonebot.rule import TrieRule
from nonebot.matcher import matchers
from nonebot.exception import IgnoredException
from nonebot.typing import Bot, Set, Event, PreProcessor
from nonebot.typing import Set, Type, Union, NoReturn
from nonebot.typing import Bot, Event, Matcher, PreProcessor
from nonebot.exception import IgnoredException, ExpiredException
from nonebot.exception import StopPropagation, _ExceptionContainer
_event_preprocessors: Set[PreProcessor] = set()
@ -18,6 +20,38 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
return func
async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event,
state: dict) -> Union[None, NoReturn]:
if datetime.now() > Matcher.expire_time:
raise _ExceptionContainer([ExpiredException])
try:
if not await Matcher.check_perm(
bot, event) or not await Matcher.check_rule(bot, event, state):
return
except Exception as e:
logger.error(f"Rule check failed for matcher {Matcher}. Ignored.")
logger.exception(e)
return
matcher = Matcher()
# TODO: BeforeMatcherRun
try:
logger.debug(f"Running matcher {matcher}")
await matcher.run(bot, event, state)
except Exception as e:
logger.error(f"Running matcher {matcher} failed.")
logger.exception(e)
exceptions = []
if Matcher.temp:
exceptions.append(ExpiredException)
if Matcher.block:
exceptions.append(StopPropagation)
if exceptions:
raise _ExceptionContainer(exceptions)
async def handle_event(bot: Bot, event: Event):
coros = []
state = {}
@ -33,37 +67,24 @@ async def handle_event(bot: Bot, event: Event):
# Trie Match
_, _ = TrieRule.get_value(bot, event, state)
break_flag = False
for priority in sorted(matchers.keys()):
index = 0
while index <= len(matchers[priority]):
Matcher = matchers[priority][index]
if break_flag:
break
# Delete expired Matcher
if datetime.now() > Matcher.expire_time:
del matchers[priority][index]
continue
pending_tasks = [
_run_matcher(matcher, bot, event, state.copy())
for matcher in matchers[priority]
]
# Check rule
try:
if not await Matcher.check_perm(
bot, event) or not await Matcher.check_rule(
bot, event, state):
index += 1
continue
except Exception as e:
logger.error(
f"Rule check failed for matcher {Matcher}. Ignored.")
logger.exception(e)
continue
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
matcher = Matcher()
# TODO: BeforeMatcherRun
if Matcher.temp:
del matchers[priority][index]
try:
await matcher.run(bot, event, state)
except Exception as e:
logger.error(f"Running matcher {matcher} failed.")
logger.exception(e)
return
i = 0
for index, result in enumerate(results):
if isinstance(result, _ExceptionContainer):
e_list = result.exceptions
if StopPropagation in e_list:
break_flag = True
if ExpiredException in e_list:
del matchers[priority][index - i]
i += 1

View File

@ -33,12 +33,14 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[],
temp=False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
matcher = Matcher.new("",
Rule() & rule,
permission,
temp=temp,
priority=priority,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
@ -50,12 +52,14 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[],
temp=False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
matcher = Matcher.new("meta_event",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
@ -68,12 +72,14 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[],
temp=False,
priority: int = 1,
block: bool = True,
state={}) -> Type[Matcher]:
matcher = Matcher.new("message",
Rule() & rule,
permission,
temp=temp,
priority=priority,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
@ -85,12 +91,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[],
temp=False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
matcher = Matcher.new("notice",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
@ -102,12 +110,14 @@ def on_request(rule: Union[Rule, RuleChecker] = Rule(),
handlers=[],
temp=False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
matcher = Matcher.new("request",
Rule() & rule,
Permission(),
temp=temp,
priority=priority,
block=block,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)