finish matcher process

This commit is contained in:
yanyongyu 2021-11-14 18:51:23 +08:00
parent 7495fee2a2
commit 0a1ae75b70
8 changed files with 139 additions and 90 deletions

View File

@ -10,14 +10,12 @@ r"""
"""
import asyncio
from typing import TYPE_CHECKING, Union, Callable, NoReturn, Optional, Awaitable
from typing import Union, Callable, NoReturn, Optional, Awaitable
from nonebot.utils import run_sync
from nonebot.adapters import Bot, Event
from nonebot.typing import T_PermissionChecker
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
class Permission:
"""
@ -36,9 +34,8 @@ class Permission:
"""
__slots__ = ("checkers",)
def __init__(
self, *checkers: Callable[["Bot", "Event"],
Awaitable[bool]]) -> None:
def __init__(self, *checkers: Callable[[Bot, Event],
Awaitable[bool]]) -> None:
"""
:参数:
@ -55,7 +52,7 @@ class Permission:
* ``Set[Callable[[Bot, Event], Awaitable[bool]]]``
"""
async def __call__(self, bot: "Bot", event: "Event") -> bool:
async def __call__(self, bot: Bot, event: Event) -> bool:
"""
:说明:
@ -94,19 +91,19 @@ class Permission:
return Permission(*checkers)
async def _message(bot: "Bot", event: "Event") -> bool:
async def _message(bot: Bot, event: Event) -> bool:
return event.get_type() == "message"
async def _notice(bot: "Bot", event: "Event") -> bool:
async def _notice(bot: Bot, event: Event) -> bool:
return event.get_type() == "notice"
async def _request(bot: "Bot", event: "Event") -> bool:
async def _request(bot: Bot, event: Event) -> bool:
return event.get_type() == "request"
async def _metaevent(bot: "Bot", event: "Event") -> bool:
async def _metaevent(bot: Bot, event: Event) -> bool:
return event.get_type() == "meta_event"
@ -140,14 +137,14 @@ def USER(*user: str, perm: Optional[Permission] = None):
* ``perm: Optional[Permission]``: 需要同时满足的权限
"""
async def _user(bot: "Bot", event: "Event") -> bool:
async def _user(bot: Bot, event: Event) -> bool:
return bool(event.get_session_id() in user and
(perm is None or await perm(bot, event)))
return Permission(_user)
async def _superuser(bot: "Bot", event: "Event") -> bool:
async def _superuser(bot: Bot, event: Event) -> bool:
return (event.get_type() == "message" and
event.get_user_id() in bot.config.superusers)

View File

@ -2,19 +2,16 @@ import re
import sys
import inspect
from types import ModuleType
from typing import (TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union,
Optional)
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
from .manager import _current_plugin
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission
from nonebot.processor import Handler, Matcher
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword,
endswith, startswith, shell_command)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
def _store_matcher(matcher: Type[Matcher]) -> None:
plugin = _current_plugin.get()
@ -375,7 +372,7 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]``
"""
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
async def _strip_cmd(bot: Bot, event: Event, state: T_State):
message = event.get_message()
if len(message) < 1:
return
@ -432,7 +429,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]``
"""
async def _strip_cmd(bot: "Bot", event: "Event", state: T_State):
async def _strip_cmd(bot: Bot, event: Event, state: T_State):
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(

View File

@ -3,6 +3,7 @@ from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from .models import Dependent
from nonebot.log import logger
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass
@ -70,7 +71,7 @@ def get_dependent(*,
f"{dependent.event_param_name} / {param_name}")
dependent.event_param_name = param_name
dependent.event_param_type = generic_get_types(param.annotation)
elif generic_check_issubclass(param.annotation, dict):
elif generic_check_issubclass(param.annotation, Dict):
if dependent.state_param_name is not None:
raise ValueError(f"{func} has more than one State parameter: "
f"{dependent.state_param_name} / {param_name}")
@ -114,9 +115,15 @@ async def solve_dependencies(
# check bot and event type
if sub_dependent.bot_param_type and not isinstance(
bot, sub_dependent.bot_param_type):
logger.debug(
f"Matcher {matcher} bot type {type(bot)} not match depends {func} "
f"annotation {sub_dependent.bot_param_type}, ignored")
return values, dependency_cache, True
elif sub_dependent.event_param_type and not isinstance(
event, sub_dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match depends {func} "
f"annotation {sub_dependent.event_param_type}, ignored")
return values, dependency_cache, True
# dependency overrides

View File

@ -8,6 +8,7 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger
from .models import Depends, Dependent
from nonebot.utils import get_name, run_sync
from nonebot.typing import T_State, T_Handler
@ -48,7 +49,18 @@ class Handler:
self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func)
async def __call__(self, matcher: "Matcher", bot: Bot, event: Event,
def __repr__(self) -> str:
return (
f"<Handler {self.func}("
f"[bot {self.dependent.bot_param_name}]: {self.dependent.bot_param_type}, "
f"[event {self.dependent.event_param_name}]: {self.dependent.event_param_type}, "
f"[state {self.dependent.state_param_name}], "
f"[matcher {self.dependent.matcher_param_name}])>")
def __str__(self) -> str:
return repr(self)
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
state: T_State):
values, _, ignored = await solve_dependencies(
dependent=self.dependent,
@ -68,9 +80,14 @@ class Handler:
# check bot and event type
if self.dependent.bot_param_type and not isinstance(
bot, self.dependent.bot_param_type):
logger.debug(f"Matcher {matcher} bot type {type(bot)} not match "
f"annotation {self.dependent.bot_param_type}, ignored")
return
elif self.dependent.event_param_type and not isinstance(
event, self.dependent.event_param_type):
logger.debug(
f"Matcher {matcher} event type {type(event)} not match "
f"annotation {self.dependent.event_param_type}, ignored")
return
if asyncio.iscoroutinefunction(self.func):

View File

@ -17,8 +17,9 @@ from .handler import Handler
from nonebot.rule import Rule
from nonebot import get_driver
from nonebot.log import logger
from nonebot.adapters import MessageTemplate
from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate)
from nonebot.exception import (PausedException, StopPropagation,
FinishedException, RejectedException)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
@ -26,15 +27,14 @@ from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
if TYPE_CHECKING:
from nonebot.plugin import Plugin
from nonebot.adapters import Bot, Event, Message, MessageSegment
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
"""
:类型: ``Dict[int, List[Type[Matcher]]]``
:说明: 用于存储当前所有的事件响应器
"""
current_bot: ContextVar["Bot"] = ContextVar("current_bot")
current_event: ContextVar["Event"] = ContextVar("current_event")
current_bot: ContextVar[Bot] = ContextVar("current_bot")
current_event: ContextVar[Event] = ContextVar("current_event")
current_state: ContextVar[T_State] = ContextVar("current_state")
@ -259,7 +259,7 @@ class Matcher(metaclass=MatcherMeta):
return NewMatcher
@classmethod
async def check_perm(cls, bot: "Bot", event: "Event") -> bool:
async def check_perm(cls, bot: Bot, event: Event) -> bool:
"""
:说明:
@ -279,8 +279,7 @@ class Matcher(metaclass=MatcherMeta):
await cls.permission(bot, event))
@classmethod
async def check_rule(cls, bot: "Bot", event: "Event",
state: T_State) -> bool:
async def check_rule(cls, bot: Bot, event: Event, state: T_State) -> bool:
"""
:说明:
@ -383,18 +382,21 @@ class Matcher(metaclass=MatcherMeta):
*
"""
async def _receive(state: T_State) -> Union[None, NoReturn]:
if state.get(_receive):
return
state[_receive] = True
raise RejectedException
def _decorator(func: T_Handler) -> T_Handler:
async def _receive() -> NoReturn:
func_handler.remove_dependency(depend)
raise PausedException
depend = Depends(_receive)
if cls.handlers and cls.handlers[-1].func is func:
func_handler = cls.handlers[-1]
func_handler.prepend_dependency(depend)
else:
func_handler = cls.append_handler(
cls.append_handler(
func, dependencies=[depend] if cls.handlers else [])
return func
@ -405,7 +407,7 @@ class Matcher(metaclass=MatcherMeta):
def got(
cls,
key: str,
prompt: Optional[Union[str, "Message", "MessageSegment",
prompt: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None
) -> Callable[[T_Handler], T_Handler]:
@ -421,32 +423,36 @@ class Matcher(metaclass=MatcherMeta):
* ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数空则使用默认解析函数
"""
async def _key_getter(bot: Bot, event: Event, state: T_State):
if state.get(f"_{key}_prompted"):
return
state["_current_key"] = key
state[f"_{key}_prompted"] = True
if key not in state:
if prompt is not None:
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
_prompt = prompt
await bot.send(event=event, message=_prompt)
raise RejectedException
else:
state[f"_{key}_parsed"] = True
async def _key_parser(bot: Bot, event: Event, state: T_State):
if key in state and state.get(f"_{key}_parsed"):
return
parser = args_parser or cls._default_parser
if parser:
await parser(bot, event, state)
else:
state[key] = str(event.get_message())
state[f"_{key}_parsed"] = True
def _decorator(func: T_Handler) -> T_Handler:
async def _key_getter(bot: "Bot", event: "Event", state: T_State):
func_handler.remove_dependency(get_depend)
state["_current_key"] = key
if key not in state:
if prompt is not None:
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
_prompt = prompt
await bot.send(event=event, message=_prompt)
raise PausedException
else:
state["_skip_key"] = True
async def _key_parser(bot: "Bot", event: "Event", state: T_State):
if key in state and state.get("_skip_key"):
del state["_skip_key"]
return
parser = args_parser or cls._default_parser
if parser:
await parser(bot, event, state)
else:
state[state["_current_key"]] = str(event.get_message())
get_depend = Depends(_key_getter)
parser_depend = Depends(_key_parser)
@ -455,15 +461,15 @@ class Matcher(metaclass=MatcherMeta):
func_handler.prepend_dependency(parser_depend)
func_handler.prepend_dependency(get_depend)
else:
func_handler = cls.append_handler(
func, dependencies=[get_depend, parser_depend])
cls.append_handler(func,
dependencies=[get_depend, parser_depend])
return func
return _decorator
@classmethod
async def send(cls, message: Union[str, "Message", "MessageSegment",
async def send(cls, message: Union[str, Message, MessageSegment,
MessageTemplate], **kwargs) -> Any:
"""
:说明:
@ -486,7 +492,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def finish(cls,
message: Optional[Union[str, "Message", "MessageSegment",
message: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None,
**kwargs) -> NoReturn:
"""
@ -512,7 +518,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def pause(cls,
prompt: Optional[Union[str, "Message", "MessageSegment",
prompt: Optional[Union[str, Message, MessageSegment,
MessageTemplate]] = None,
**kwargs) -> NoReturn:
"""
@ -538,8 +544,8 @@ class Matcher(metaclass=MatcherMeta):
@classmethod
async def reject(cls,
prompt: Optional[Union[str, "Message",
"MessageSegment"]] = None,
prompt: Optional[Union[str, Message,
MessageSegment]] = None,
**kwargs) -> NoReturn:
"""
:说明:
@ -554,6 +560,8 @@ class Matcher(metaclass=MatcherMeta):
bot = current_bot.get()
event = current_event.get()
state = current_state.get()
if "_current_key" in state and f"_{state['_current_key']}_parsed" in state:
del state[f"_{state['_current_key']}_parsed"]
if isinstance(prompt, MessageTemplate):
_prompt = prompt.format(**state)
else:
@ -571,7 +579,7 @@ class Matcher(metaclass=MatcherMeta):
self.block = True
# 运行handlers
async def run(self, bot: "Bot", event: "Event", state: T_State):
async def run(self, bot: Bot, event: Event, state: T_State):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
s_t = current_state.set(self.state)

View File

@ -1,8 +1,9 @@
import inspect
from typing import Any, Dict, Type, Tuple, Union, Callable
from typing_extensions import GenericAlias, get_args, get_origin # type: ignore
from pydantic.typing import (ForwardRef, get_args, get_origin,
evaluate_forwardref)
from loguru import logger
from pydantic.typing import ForwardRef, evaluate_forwardref
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
@ -25,7 +26,13 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
try:
annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception as e:
logger.opt(colors=True, exception=e).warning(
f"Unknown ForwardRef[\"{param.annotation}\"] for parameter {param.name}"
)
return inspect.Parameter.empty
return annotation
@ -33,13 +40,16 @@ def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any],
...]]) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
return issubclass(cls, class_or_tuple)
except TypeError:
if get_origin(cls) is Union:
for type_ in get_args(cls):
if not generic_check_issubclass(type_, class_or_tuple):
return False
return True
elif isinstance(cls, GenericAlias):
origin = get_origin(cls)
return bool(origin and issubclass(origin, class_or_tuple))
raise

View File

@ -15,20 +15,18 @@ import asyncio
from itertools import product
from argparse import Namespace
from argparse import ArgumentParser as ArgParser
from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Callable, NoReturn,
Optional, Sequence, Awaitable)
from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Sequence, Awaitable)
from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger
from nonebot.utils import run_sync
from nonebot.adapters import Bot, Event
from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
class Rule:
"""
@ -48,8 +46,8 @@ class Rule:
__slots__ = ("checkers",)
def __init__(
self, *checkers: Callable[["Bot", "Event", T_State],
Awaitable[bool]]) -> None:
self, *checkers: Callable[[Bot, Event, T_State],
Awaitable[bool]]) -> None:
"""
:参数:
@ -67,8 +65,7 @@ class Rule:
* ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]``
"""
async def __call__(self, bot: "Bot", event: "Event",
state: T_State) -> bool:
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
"""
:说明:
@ -123,7 +120,7 @@ class TrieRule:
cls.suffix[suffix[::-1]] = value
@classmethod
def get_value(cls, bot: "Bot", event: "Event",
def get_value(cls, bot: Bot, event: Event,
state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if event.get_type() != "message":
state["_prefix"] = {"raw_command": None, "command": None}
@ -195,7 +192,7 @@ def startswith(msg: Union[str, Tuple[str, ...]],
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0)
async def _startswith(bot: "Bot", event: "Event", state: T_State) -> bool:
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
@ -222,7 +219,7 @@ def endswith(msg: Union[str, Tuple[str, ...]],
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0)
async def _endswith(bot: "Bot", event: "Event", state: T_State) -> bool:
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
@ -242,7 +239,7 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
async def _keyword(bot: "Bot", event: "Event", state: T_State) -> bool:
async def _keyword(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
@ -290,7 +287,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: "Bot", event: "Event", state: T_State) -> bool:
async def _command(bot: Bot, event: Event, state: T_State) -> bool:
return state["_prefix"]["command"] in commands
return Rule(_command)
@ -376,8 +373,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: "Bot", event: "Event",
state: T_State) -> bool:
async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool:
if state["_prefix"]["command"] in commands:
message = str(event.get_message())
strip_message = message[len(state["_prefix"]["raw_command"]
@ -417,7 +413,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags)
async def _regex(bot: "Bot", event: "Event", state: T_State) -> bool:
async def _regex(bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = pattern.search(str(event.get_message()))
@ -443,7 +439,7 @@ def to_me() -> Rule:
*
"""
async def _to_me(bot: "Bot", event: "Event", state: T_State) -> bool:
async def _to_me(bot: Bot, event: Event, state: T_State) -> bool:
return event.is_tome()
return Rule(_to_me)

View File

@ -0,0 +1,17 @@
from nonebot import on_command
from nonebot.log import logger
from nonebot.processor import Depends
test = on_command("123")
def depend(state: dict):
return state
@test.got("a", prompt="a")
@test.got("b", prompt="b")
@test.receive()
@test.got("c", prompt="c")
async def _(state: dict = Depends(depend)):
logger.info(f"=======, {state}")