finish matcher process

This commit is contained in:
yanyongyu 2021-11-14 18:51:23 +08:00
parent 7495fee2a2
commit 0a1ae75b70
8 changed files with 139 additions and 90 deletions

View File

@ -10,14 +10,12 @@ r"""
""" """
import asyncio import asyncio
from typing import TYPE_CHECKING, Union, Callable, NoReturn, Optional, Awaitable from typing import Union, Callable, NoReturn, Optional, Awaitable
from nonebot.utils import run_sync from nonebot.utils import run_sync
from nonebot.adapters import Bot, Event
from nonebot.typing import T_PermissionChecker from nonebot.typing import T_PermissionChecker
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
class Permission: class Permission:
""" """
@ -36,9 +34,8 @@ class Permission:
""" """
__slots__ = ("checkers",) __slots__ = ("checkers",)
def __init__( def __init__(self, *checkers: Callable[[Bot, Event],
self, *checkers: Callable[["Bot", "Event"], Awaitable[bool]]) -> None:
Awaitable[bool]]) -> None:
""" """
:参数: :参数:
@ -55,7 +52,7 @@ class Permission:
* ``Set[Callable[[Bot, Event], Awaitable[bool]]]`` * ``Set[Callable[[Bot, Event], Awaitable[bool]]]``
""" """
async def __call__(self, bot: "Bot", event: "Event") -> bool: async def __call__(self, bot: Bot, event: Event) -> bool:
""" """
:说明: :说明:
@ -94,19 +91,19 @@ class Permission:
return Permission(*checkers) return Permission(*checkers)
async def _message(bot: "Bot", event: "Event") -> bool: async def _message(bot: Bot, event: Event) -> bool:
return event.get_type() == "message" return event.get_type() == "message"
async def _notice(bot: "Bot", event: "Event") -> bool: async def _notice(bot: Bot, event: Event) -> bool:
return event.get_type() == "notice" return event.get_type() == "notice"
async def _request(bot: "Bot", event: "Event") -> bool: async def _request(bot: Bot, event: Event) -> bool:
return event.get_type() == "request" return event.get_type() == "request"
async def _metaevent(bot: "Bot", event: "Event") -> bool: async def _metaevent(bot: Bot, event: Event) -> bool:
return event.get_type() == "meta_event" return event.get_type() == "meta_event"
@ -140,14 +137,14 @@ def USER(*user: str, perm: Optional[Permission] = None):
* ``perm: Optional[Permission]``: 需要同时满足的权限 * ``perm: Optional[Permission]``: 需要同时满足的权限
""" """
async def _user(bot: "Bot", event: "Event") -> bool: async def _user(bot: Bot, event: Event) -> bool:
return bool(event.get_session_id() in user and return bool(event.get_session_id() in user and
(perm is None or await perm(bot, event))) (perm is None or await perm(bot, event)))
return Permission(_user) return Permission(_user)
async def _superuser(bot: "Bot", event: "Event") -> bool: async def _superuser(bot: Bot, event: Event) -> bool:
return (event.get_type() == "message" and return (event.get_type() == "message" and
event.get_user_id() in bot.config.superusers) event.get_user_id() in bot.config.superusers)

View File

@ -2,19 +2,16 @@ import re
import sys import sys
import inspect import inspect
from types import ModuleType from types import ModuleType
from typing import (TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union, from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
Optional)
from .manager import _current_plugin from .manager import _current_plugin
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.processor import Handler, Matcher from nonebot.processor import Handler, Matcher
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword,
endswith, startswith, shell_command) endswith, startswith, shell_command)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
def _store_matcher(matcher: Type[Matcher]) -> None: def _store_matcher(matcher: Type[Matcher]) -> None:
plugin = _current_plugin.get() plugin = _current_plugin.get()
@ -375,7 +372,7 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): async def _strip_cmd(bot: Bot, event: Event, state: T_State):
message = event.get_message() message = event.get_message()
if len(message) < 1: if len(message) < 1:
return return
@ -432,7 +429,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): async def _strip_cmd(bot: Bot, event: Event, state: T_State):
message = event.get_message() message = event.get_message()
segment = message.pop(0) segment = message.pop(0)
new_message = message.__class__( new_message = message.__class__(

View File

@ -3,6 +3,7 @@ from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from .models import Dependent from .models import Dependent
from nonebot.log import logger
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass from .models import Depends as DependsClass
@ -70,7 +71,7 @@ def get_dependent(*,
f"{dependent.event_param_name} / {param_name}") f"{dependent.event_param_name} / {param_name}")
dependent.event_param_name = param_name dependent.event_param_name = param_name
dependent.event_param_type = generic_get_types(param.annotation) dependent.event_param_type = generic_get_types(param.annotation)
elif generic_check_issubclass(param.annotation, dict): elif generic_check_issubclass(param.annotation, Dict):
if dependent.state_param_name is not None: if dependent.state_param_name is not None:
raise ValueError(f"{func} has more than one State parameter: " raise ValueError(f"{func} has more than one State parameter: "
f"{dependent.state_param_name} / {param_name}") f"{dependent.state_param_name} / {param_name}")
@ -114,9 +115,15 @@ async def solve_dependencies(
# check bot and event type # check bot and event type
if sub_dependent.bot_param_type and not isinstance( if sub_dependent.bot_param_type and not isinstance(
bot, sub_dependent.bot_param_type): bot, sub_dependent.bot_param_type):
logger.debug(
f"Matcher {matcher} bot type {type(bot)} not match depends {func} "
f"annotation {sub_dependent.bot_param_type}, ignored")
return values, dependency_cache, True return values, dependency_cache, True
elif sub_dependent.event_param_type and not isinstance( elif sub_dependent.event_param_type and not isinstance(
event, sub_dependent.event_param_type): event, sub_dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match depends {func} "
f"annotation {sub_dependent.event_param_type}, ignored")
return values, dependency_cache, True return values, dependency_cache, True
# dependency overrides # dependency overrides

View File

@ -8,6 +8,7 @@
import asyncio import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger
from .models import Depends, Dependent from .models import Depends, Dependent
from nonebot.utils import get_name, run_sync from nonebot.utils import get_name, run_sync
from nonebot.typing import T_State, T_Handler from nonebot.typing import T_State, T_Handler
@ -48,7 +49,18 @@ class Handler:
self.dependency_overrides_provider = dependency_overrides_provider self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func) self.dependent = get_dependent(func=func)
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event, def __repr__(self) -> str:
return (
f"<Handler {self.func}("
f"[bot {self.dependent.bot_param_name}]: {self.dependent.bot_param_type}, "
f"[event {self.dependent.event_param_name}]: {self.dependent.event_param_type}, "
f"[state {self.dependent.state_param_name}], "
f"[matcher {self.dependent.matcher_param_name}])>")
def __str__(self) -> str:
return repr(self)
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
state: T_State): state: T_State):
values, _, ignored = await solve_dependencies( values, _, ignored = await solve_dependencies(
dependent=self.dependent, dependent=self.dependent,
@ -68,9 +80,14 @@ class Handler:
# check bot and event type # check bot and event type
if self.dependent.bot_param_type and not isinstance( if self.dependent.bot_param_type and not isinstance(
bot, self.dependent.bot_param_type): bot, self.dependent.bot_param_type):
logger.debug(f"Matcher {matcher} bot type {type(bot)} not match "
f"annotation {self.dependent.bot_param_type}, ignored")
return return
elif self.dependent.event_param_type and not isinstance( elif self.dependent.event_param_type and not isinstance(
event, self.dependent.event_param_type): event, self.dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match "
f"annotation {self.dependent.event_param_type}, ignored")
return return
if asyncio.iscoroutinefunction(self.func): if asyncio.iscoroutinefunction(self.func):

View File

@ -17,8 +17,9 @@ from .handler import Handler
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot import get_driver from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.adapters import MessageTemplate
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate)
from nonebot.exception import (PausedException, StopPropagation, from nonebot.exception import (PausedException, StopPropagation,
FinishedException, RejectedException) FinishedException, RejectedException)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
@ -26,15 +27,14 @@ from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.plugin import Plugin from nonebot.plugin import Plugin
from nonebot.adapters import Bot, Event, Message, MessageSegment
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
""" """
:类型: ``Dict[int, List[Type[Matcher]]]`` :类型: ``Dict[int, List[Type[Matcher]]]``
:说明: 用于存储当前所有的事件响应器 :说明: 用于存储当前所有的事件响应器
""" """
current_bot: ContextVar["Bot"] = ContextVar("current_bot") current_bot: ContextVar[Bot] = ContextVar("current_bot")
current_event: ContextVar["Event"] = ContextVar("current_event") current_event: ContextVar[Event] = ContextVar("current_event")
current_state: ContextVar[T_State] = ContextVar("current_state") current_state: ContextVar[T_State] = ContextVar("current_state")
@ -259,7 +259,7 @@ class Matcher(metaclass=MatcherMeta):
return NewMatcher return NewMatcher
@classmethod @classmethod
async def check_perm(cls, bot: "Bot", event: "Event") -> bool: async def check_perm(cls, bot: Bot, event: Event) -> bool:
""" """
:说明: :说明:
@ -279,8 +279,7 @@ class Matcher(metaclass=MatcherMeta):
await cls.permission(bot, event)) await cls.permission(bot, event))
@classmethod @classmethod
async def check_rule(cls, bot: "Bot", event: "Event", async def check_rule(cls, bot: Bot, event: Event, state: T_State) -> bool:
state: T_State) -> bool:
""" """
:说明: :说明:
@ -383,18 +382,21 @@ class Matcher(metaclass=MatcherMeta):
* *
""" """
async def _receive(state: T_State) -> Union[None, NoReturn]:
if state.get(_receive):
return
state[_receive] = True
raise RejectedException
def _decorator(func: T_Handler) -> T_Handler: def _decorator(func: T_Handler) -> T_Handler:
async def _receive() -> NoReturn:
func_handler.remove_dependency(depend)
raise PausedException
depend = Depends(_receive) depend = Depends(_receive)
if cls.handlers and cls.handlers[-1].func is func: if cls.handlers and cls.handlers[-1].func is func:
func_handler = cls.handlers[-1] func_handler = cls.handlers[-1]
func_handler.prepend_dependency(depend) func_handler.prepend_dependency(depend)
else: else:
func_handler = cls.append_handler( cls.append_handler(
func, dependencies=[depend] if cls.handlers else []) func, dependencies=[depend] if cls.handlers else [])
return func return func
@ -405,7 +407,7 @@ class Matcher(metaclass=MatcherMeta):
def got( def got(
cls, cls,
key: str, key: str,
prompt: Optional[Union[str, "Message", "MessageSegment", prompt: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None, MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None args_parser: Optional[T_ArgsParser] = None
) -> Callable[[T_Handler], T_Handler]: ) -> Callable[[T_Handler], T_Handler]:
@ -421,32 +423,36 @@ class Matcher(metaclass=MatcherMeta):
* ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数空则使用默认解析函数 * ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数空则使用默认解析函数
""" """
async def _key_getter(bot: Bot, event: Event, state: T_State):
if state.get(f"_{key}_prompted"):
return
state["_current_key"] = key
state[f"_{key}_prompted"] = True
if key not in state:
if prompt is not None:
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
_prompt = prompt
await bot.send(event=event, message=_prompt)
raise RejectedException
else:
state[f"_{key}_parsed"] = True
async def _key_parser(bot: Bot, event: Event, state: T_State):
if key in state and state.get(f"_{key}_parsed"):
return
parser = args_parser or cls._default_parser
if parser:
await parser(bot, event, state)
else:
state[key] = str(event.get_message())
state[f"_{key}_parsed"] = True
def _decorator(func: T_Handler) -> T_Handler: def _decorator(func: T_Handler) -> T_Handler:
async def _key_getter(bot: "Bot", event: "Event", state: T_State):
func_handler.remove_dependency(get_depend)
state["_current_key"] = key
if key not in state:
if prompt is not None:
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
_prompt = prompt
await bot.send(event=event, message=_prompt)
raise PausedException
else:
state["_skip_key"] = True
async def _key_parser(bot: "Bot", event: "Event", state: T_State):
if key in state and state.get("_skip_key"):
del state["_skip_key"]
return
parser = args_parser or cls._default_parser
if parser:
await parser(bot, event, state)
else:
state[state["_current_key"]] = str(event.get_message())
get_depend = Depends(_key_getter) get_depend = Depends(_key_getter)
parser_depend = Depends(_key_parser) parser_depend = Depends(_key_parser)
@ -455,15 +461,15 @@ class Matcher(metaclass=MatcherMeta):
func_handler.prepend_dependency(parser_depend) func_handler.prepend_dependency(parser_depend)
func_handler.prepend_dependency(get_depend) func_handler.prepend_dependency(get_depend)
else: else:
func_handler = cls.append_handler( cls.append_handler(func,
func, dependencies=[get_depend, parser_depend]) dependencies=[get_depend, parser_depend])
return func return func
return _decorator return _decorator
@classmethod @classmethod
async def send(cls, message: Union[str, "Message", "MessageSegment", async def send(cls, message: Union[str, Message, MessageSegment,
MessageTemplate], **kwargs) -> Any: MessageTemplate], **kwargs) -> Any:
""" """
:说明: :说明:
@ -486,7 +492,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
async def finish(cls, async def finish(cls,
message: Optional[Union[str, "Message", "MessageSegment", message: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None, MessageTemplate]] = None,
**kwargs) -> NoReturn: **kwargs) -> NoReturn:
""" """
@ -512,7 +518,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
async def pause(cls, async def pause(cls,
prompt: Optional[Union[str, "Message", "MessageSegment", prompt: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None, MessageTemplate]] = None,
**kwargs) -> NoReturn: **kwargs) -> NoReturn:
""" """
@ -538,8 +544,8 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
async def reject(cls, async def reject(cls,
prompt: Optional[Union[str, "Message", prompt: Optional[Union[str, Message,
"MessageSegment"]] = None, MessageSegment]] = None,
**kwargs) -> NoReturn: **kwargs) -> NoReturn:
""" """
:说明: :说明:
@ -554,6 +560,8 @@ class Matcher(metaclass=MatcherMeta):
bot = current_bot.get() bot = current_bot.get()
event = current_event.get() event = current_event.get()
state = current_state.get() state = current_state.get()
if "_current_key" in state and f"_{state['_current_key']}_parsed" in state:
del state[f"_{state['_current_key']}_parsed"]
if isinstance(prompt, MessageTemplate): if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state) _prompt = prompt.format(**state)
else: else:
@ -571,7 +579,7 @@ class Matcher(metaclass=MatcherMeta):
self.block = True self.block = True
# 运行handlers # 运行handlers
async def run(self, bot: "Bot", event: "Event", state: T_State): async def run(self, bot: Bot, event: Event, state: T_State):
b_t = current_bot.set(bot) b_t = current_bot.set(bot)
e_t = current_event.set(event) e_t = current_event.set(event)
s_t = current_state.set(self.state) s_t = current_state.set(self.state)

View File

@ -1,8 +1,9 @@
import inspect import inspect
from typing import Any, Dict, Type, Tuple, Union, Callable from typing import Any, Dict, Type, Tuple, Union, Callable
from typing_extensions import GenericAlias, get_args, get_origin # type: ignore
from pydantic.typing import (ForwardRef, get_args, get_origin, from loguru import logger
evaluate_forwardref) from pydantic.typing import ForwardRef, evaluate_forwardref
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
@ -25,7 +26,13 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
annotation = param.annotation annotation = param.annotation
if isinstance(annotation, str): if isinstance(annotation, str):
annotation = ForwardRef(annotation) annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns) try:
annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception as e:
logger.opt(colors=True, exception=e).warning(
f"Unknown ForwardRef[\"{param.annotation}\"] for parameter {param.name}"
)
return inspect.Parameter.empty
return annotation return annotation
@ -33,13 +40,16 @@ def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any],
...]]) -> bool: ...]]) -> bool:
try: try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) return issubclass(cls, class_or_tuple)
except TypeError: except TypeError:
if get_origin(cls) is Union: if get_origin(cls) is Union:
for type_ in get_args(cls): for type_ in get_args(cls):
if not generic_check_issubclass(type_, class_or_tuple): if not generic_check_issubclass(type_, class_or_tuple):
return False return False
return True return True
elif isinstance(cls, GenericAlias):
origin = get_origin(cls)
return bool(origin and issubclass(origin, class_or_tuple))
raise raise

View File

@ -15,20 +15,18 @@ import asyncio
from itertools import product from itertools import product
from argparse import Namespace from argparse import Namespace
from argparse import ArgumentParser as ArgParser from argparse import ArgumentParser as ArgParser
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Callable, NoReturn, from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Optional, Sequence, Awaitable) Sequence, Awaitable)
from pygtrie import CharTrie from pygtrie import CharTrie
from nonebot import get_driver from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import run_sync from nonebot.utils import run_sync
from nonebot.adapters import Bot, Event
from nonebot.exception import ParserExit from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker from nonebot.typing import T_State, T_RuleChecker
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
class Rule: class Rule:
""" """
@ -48,8 +46,8 @@ class Rule:
__slots__ = ("checkers",) __slots__ = ("checkers",)
def __init__( def __init__(
self, *checkers: Callable[["Bot", "Event", T_State], self, *checkers: Callable[[Bot, Event, T_State],
Awaitable[bool]]) -> None: Awaitable[bool]]) -> None:
""" """
:参数: :参数:
@ -67,8 +65,7 @@ class Rule:
* ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]`` * ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]``
""" """
async def __call__(self, bot: "Bot", event: "Event", async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
state: T_State) -> bool:
""" """
:说明: :说明:
@ -123,7 +120,7 @@ class TrieRule:
cls.suffix[suffix[::-1]] = value cls.suffix[suffix[::-1]] = value
@classmethod @classmethod
def get_value(cls, bot: "Bot", event: "Event", def get_value(cls, bot: Bot, event: Event,
state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]: state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if event.get_type() != "message": if event.get_type() != "message":
state["_prefix"] = {"raw_command": None, "command": None} state["_prefix"] = {"raw_command": None, "command": None}
@ -195,7 +192,7 @@ def startswith(msg: Union[str, Tuple[str, ...]],
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})", f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0) re.IGNORECASE if ignorecase else 0)
async def _startswith(bot: "Bot", event: "Event", state: T_State) -> bool: async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
return False return False
text = event.get_plaintext() text = event.get_plaintext()
@ -222,7 +219,7 @@ def endswith(msg: Union[str, Tuple[str, ...]],
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$", f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0) re.IGNORECASE if ignorecase else 0)
async def _endswith(bot: "Bot", event: "Event", state: T_State) -> bool: async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
return False return False
text = event.get_plaintext() text = event.get_plaintext()
@ -242,7 +239,7 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词 * ``*keywords: str``: 关键词
""" """
async def _keyword(bot: "Bot", event: "Event", state: T_State) -> bool: async def _keyword(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
return False return False
text = event.get_plaintext() text = event.get_plaintext()
@ -290,7 +287,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep): for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: "Bot", event: "Event", state: T_State) -> bool: async def _command(bot: Bot, event: Event, state: T_State) -> bool:
return state["_prefix"]["command"] in commands return state["_prefix"]["command"] in commands
return Rule(_command) return Rule(_command)
@ -376,8 +373,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
for start, sep in product(command_start, command_sep): for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: "Bot", event: "Event", async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool:
state: T_State) -> bool:
if state["_prefix"]["command"] in commands: if state["_prefix"]["command"] in commands:
message = str(event.get_message()) message = str(event.get_message())
strip_message = message[len(state["_prefix"]["raw_command"] strip_message = message[len(state["_prefix"]["raw_command"]
@ -417,7 +413,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags) pattern = re.compile(regex, flags)
async def _regex(bot: "Bot", event: "Event", state: T_State) -> bool: async def _regex(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
return False return False
matched = pattern.search(str(event.get_message())) matched = pattern.search(str(event.get_message()))
@ -443,7 +439,7 @@ def to_me() -> Rule:
* *
""" """
async def _to_me(bot: "Bot", event: "Event", state: T_State) -> bool: async def _to_me(bot: Bot, event: Event, state: T_State) -> bool:
return event.is_tome() return event.is_tome()
return Rule(_to_me) return Rule(_to_me)

View File

@ -0,0 +1,17 @@
from nonebot import on_command
from nonebot.log import logger
from nonebot.processor import Depends
test = on_command("123")
def depend(state: dict):
return state
@test.got("a", prompt="a")
@test.got("b", prompt="b")
@test.receive()
@test.got("c", prompt="c")
async def _(state: dict = Depends(depend)):
logger.info(f"=======, {state}")