mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
make matcher running concurrently and add to me checking
This commit is contained in:
parent
c01d3c7ca1
commit
c1d0eae34b
@ -51,7 +51,16 @@ class BaseEvent(abc.ABC):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO: pretty print
|
||||
return f"<Event: >"
|
||||
return f"<Event: {self.type}/{self.detail_type} {self.raw_message}>"
|
||||
|
||||
@property
|
||||
def raw_event(self) -> dict:
|
||||
return self._raw_event
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def self_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@ -93,6 +102,16 @@ class BaseEvent(abc.ABC):
|
||||
def user_id(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def to_me(self) -> Optional[bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@to_me.setter
|
||||
@abc.abstractmethod
|
||||
def to_me(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def message(self) -> Optional[Message]:
|
||||
|
@ -7,11 +7,12 @@ import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
|
||||
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||
|
||||
|
||||
@ -41,6 +42,67 @@ def _b2s(b: bool) -> str:
|
||||
return str(b).lower()
|
||||
|
||||
|
||||
def _check_at_me(bot: "Bot", event: "Event"):
|
||||
if event.type != "message":
|
||||
return
|
||||
|
||||
if event.detail_type == "private":
|
||||
event.to_me = True
|
||||
else:
|
||||
event.to_me = False
|
||||
at_me_seg = MessageSegment.at(event.self_id)
|
||||
|
||||
# check the first segment
|
||||
first_msg_seg = event.message[0]
|
||||
if first_msg_seg == at_me_seg:
|
||||
event.to_me = True
|
||||
del event.message[0]
|
||||
|
||||
if not event.to_me:
|
||||
# check the last segment
|
||||
i = -1
|
||||
last_msg_seg = event.message[i]
|
||||
if last_msg_seg.type == "text" and \
|
||||
not last_msg_seg.data["text"].strip() and \
|
||||
len(event.message) >= 2:
|
||||
i -= 1
|
||||
last_msg_seg = event.message[i]
|
||||
|
||||
if last_msg_seg == at_me_seg:
|
||||
event.to_me = True
|
||||
del event.message[i:]
|
||||
|
||||
if not event.message:
|
||||
event.message.append(MessageSegment.text(""))
|
||||
|
||||
|
||||
def _check_nickname(bot: "Bot", event: "Event"):
|
||||
if event.type != "message":
|
||||
return
|
||||
|
||||
first_msg_seg = event.message[0]
|
||||
if first_msg_seg.type != "text":
|
||||
return
|
||||
|
||||
first_text = first_msg_seg.data["text"]
|
||||
|
||||
if bot.config.NICKNAME:
|
||||
# check if the user is calling me with my nickname
|
||||
if isinstance(bot.config.NICKNAME, str) or \
|
||||
not isinstance(bot.config.NICKNAME, Iterable):
|
||||
nicknames = (bot.config.NICKNAME,)
|
||||
else:
|
||||
nicknames = filter(lambda n: n, bot.config.NICKNAME)
|
||||
nickname_regex = "|".join(nicknames)
|
||||
m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text,
|
||||
re.IGNORECASE)
|
||||
if m:
|
||||
nickname = m.group(1)
|
||||
logger.debug(f"User is calling me {nickname}")
|
||||
event.to_me = True
|
||||
first_msg_seg.data["text"] = first_text[m.end():]
|
||||
|
||||
|
||||
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
|
||||
if isinstance(result, dict):
|
||||
if result.get("status") == "failed":
|
||||
@ -108,6 +170,10 @@ class Bot(BaseBot):
|
||||
|
||||
event = Event(message)
|
||||
|
||||
# Check whether user is calling me
|
||||
_check_at_me(self, event)
|
||||
_check_nickname(self, event)
|
||||
|
||||
await handle_event(self, event)
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -166,6 +232,11 @@ class Event(BaseEvent):
|
||||
|
||||
super().__init__(raw_event)
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def self_id(self) -> str:
|
||||
return str(self._raw_event["self_id"])
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def type(self) -> str:
|
||||
@ -206,6 +277,16 @@ class Event(BaseEvent):
|
||||
def user_id(self, value) -> None:
|
||||
self._raw_event["user_id"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def to_me(self) -> Optional[bool]:
|
||||
return self._raw_event.get("to_me")
|
||||
|
||||
@to_me.setter
|
||||
@overrides(BaseEvent)
|
||||
def to_me(self, value) -> None:
|
||||
self._raw_event["to_me"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def message(self) -> Optional["Message"]:
|
||||
@ -244,6 +325,18 @@ class Event(BaseEvent):
|
||||
|
||||
class MessageSegment(BaseMessageSegment):
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __init__(self, type: str, data: Dict[str, str]) -> None:
|
||||
if type == "at" and data.get("qq") == "all":
|
||||
type = "at_all"
|
||||
data.clear()
|
||||
elif type == "shake":
|
||||
type = "poke"
|
||||
data = {"type": "Poke"}
|
||||
elif type == "text":
|
||||
data["text"] = unescape(data["text"])
|
||||
super().__init__(type=type, data=data)
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __str__(self):
|
||||
type_ = self.type
|
||||
@ -271,7 +364,7 @@ class MessageSegment(BaseMessageSegment):
|
||||
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
|
||||
|
||||
@staticmethod
|
||||
def at(user_id: int) -> "MessageSegment":
|
||||
def at(user_id: Union[int, str]) -> "MessageSegment":
|
||||
return MessageSegment("at", {"qq": str(user_id)})
|
||||
|
||||
@staticmethod
|
||||
|
@ -8,7 +8,13 @@
|
||||
这些异常并非所有需要用户处理,在 NoneBot 内部运行时被捕获,并进行对应操作。
|
||||
"""
|
||||
|
||||
from nonebot.typing import Optional
|
||||
from nonebot.typing import List, Type, Optional
|
||||
|
||||
|
||||
class _ExceptionContainer(Exception):
|
||||
|
||||
def __init__(self, exceptions: List[Type[Exception]]) -> None:
|
||||
self.exceptions = exceptions
|
||||
|
||||
|
||||
class IgnoredException(Exception):
|
||||
@ -37,12 +43,12 @@ class PausedException(Exception):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 Handler 并等待下一条消息后继续下一个 Handler。
|
||||
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。
|
||||
可用于用户输入新信息。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 Handler 中通过 Matcher.pause() 抛出。
|
||||
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -51,12 +57,12 @@ class RejectedException(Exception):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 Handler 并等待下一条消息后重新运行当前 Handler。
|
||||
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。
|
||||
可用于用户重新输入。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 Handler 中通过 Matcher.reject() 抛出。
|
||||
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -65,12 +71,38 @@ class FinishedException(Exception):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 结束当前 Handler 且后续 Handler 不再被运行。
|
||||
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。
|
||||
可用于结束用户会话。
|
||||
|
||||
:用法:
|
||||
|
||||
可以在 Handler 中通过 Matcher.finish() 抛出。
|
||||
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ExpiredException(Exception):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 当前 ``Matcher`` 已失效。
|
||||
|
||||
:用法:
|
||||
|
||||
当 ``Matcher`` 运行前检查时抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StopPropagation(Exception):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
指示 NoneBot 终止事件向下层传播。
|
||||
|
||||
:用法:
|
||||
|
||||
在 ``Matcher.block == True`` 时抛出。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -26,6 +26,7 @@ class Matcher:
|
||||
temp: bool = False
|
||||
expire_time: Optional[datetime] = None
|
||||
priority: int = 1
|
||||
block: bool = False
|
||||
|
||||
_default_state: dict = {}
|
||||
|
||||
@ -45,6 +46,7 @@ class Matcher:
|
||||
handlers: list = [],
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
*,
|
||||
default_state: dict = {},
|
||||
expire_time: Optional[datetime] = None) -> Type["Matcher"]:
|
||||
@ -63,6 +65,7 @@ class Matcher:
|
||||
"temp": temp,
|
||||
"expire_time": expire_time,
|
||||
"priority": priority,
|
||||
"block": block,
|
||||
"_default_state": default_state
|
||||
})
|
||||
|
||||
|
@ -7,8 +7,10 @@ from datetime import datetime
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import TrieRule
|
||||
from nonebot.matcher import matchers
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.typing import Bot, Set, Event, PreProcessor
|
||||
from nonebot.typing import Set, Type, Union, NoReturn
|
||||
from nonebot.typing import Bot, Event, Matcher, PreProcessor
|
||||
from nonebot.exception import IgnoredException, ExpiredException
|
||||
from nonebot.exception import StopPropagation, _ExceptionContainer
|
||||
|
||||
_event_preprocessors: Set[PreProcessor] = set()
|
||||
|
||||
@ -18,6 +20,38 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
|
||||
return func
|
||||
|
||||
|
||||
async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event,
|
||||
state: dict) -> Union[None, NoReturn]:
|
||||
if datetime.now() > Matcher.expire_time:
|
||||
raise _ExceptionContainer([ExpiredException])
|
||||
|
||||
try:
|
||||
if not await Matcher.check_perm(
|
||||
bot, event) or not await Matcher.check_rule(bot, event, state):
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Rule check failed for matcher {Matcher}. Ignored.")
|
||||
logger.exception(e)
|
||||
return
|
||||
|
||||
matcher = Matcher()
|
||||
# TODO: BeforeMatcherRun
|
||||
try:
|
||||
logger.debug(f"Running matcher {matcher}")
|
||||
await matcher.run(bot, event, state)
|
||||
except Exception as e:
|
||||
logger.error(f"Running matcher {matcher} failed.")
|
||||
logger.exception(e)
|
||||
|
||||
exceptions = []
|
||||
if Matcher.temp:
|
||||
exceptions.append(ExpiredException)
|
||||
if Matcher.block:
|
||||
exceptions.append(StopPropagation)
|
||||
if exceptions:
|
||||
raise _ExceptionContainer(exceptions)
|
||||
|
||||
|
||||
async def handle_event(bot: Bot, event: Event):
|
||||
coros = []
|
||||
state = {}
|
||||
@ -33,37 +67,24 @@ async def handle_event(bot: Bot, event: Event):
|
||||
# Trie Match
|
||||
_, _ = TrieRule.get_value(bot, event, state)
|
||||
|
||||
break_flag = False
|
||||
for priority in sorted(matchers.keys()):
|
||||
index = 0
|
||||
while index <= len(matchers[priority]):
|
||||
Matcher = matchers[priority][index]
|
||||
if break_flag:
|
||||
break
|
||||
|
||||
# Delete expired Matcher
|
||||
if datetime.now() > Matcher.expire_time:
|
||||
del matchers[priority][index]
|
||||
continue
|
||||
pending_tasks = [
|
||||
_run_matcher(matcher, bot, event, state.copy())
|
||||
for matcher in matchers[priority]
|
||||
]
|
||||
|
||||
# Check rule
|
||||
try:
|
||||
if not await Matcher.check_perm(
|
||||
bot, event) or not await Matcher.check_rule(
|
||||
bot, event, state):
|
||||
index += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Rule check failed for matcher {Matcher}. Ignored.")
|
||||
logger.exception(e)
|
||||
continue
|
||||
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
|
||||
|
||||
matcher = Matcher()
|
||||
# TODO: BeforeMatcherRun
|
||||
if Matcher.temp:
|
||||
del matchers[priority][index]
|
||||
|
||||
try:
|
||||
await matcher.run(bot, event, state)
|
||||
except Exception as e:
|
||||
logger.error(f"Running matcher {matcher} failed.")
|
||||
logger.exception(e)
|
||||
return
|
||||
i = 0
|
||||
for index, result in enumerate(results):
|
||||
if isinstance(result, _ExceptionContainer):
|
||||
e_list = result.exceptions
|
||||
if StopPropagation in e_list:
|
||||
break_flag = True
|
||||
if ExpiredException in e_list:
|
||||
del matchers[priority][index - i]
|
||||
i += 1
|
||||
|
@ -33,12 +33,14 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new("",
|
||||
Rule() & rule,
|
||||
permission,
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
@ -50,12 +52,14 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new("meta_event",
|
||||
Rule() & rule,
|
||||
Permission(),
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
@ -68,12 +72,14 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
block: bool = True,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new("message",
|
||||
Rule() & rule,
|
||||
permission,
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
@ -85,12 +91,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new("notice",
|
||||
Rule() & rule,
|
||||
Permission(),
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
@ -102,12 +110,14 @@ def on_request(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new("request",
|
||||
Rule() & rule,
|
||||
Permission(),
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
|
Loading…
Reference in New Issue
Block a user