From 471d306e13772670946fb91e229467c6e3838438 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Fri, 19 Nov 2021 18:18:53 +0800 Subject: [PATCH] :alembic: change rule to use handler --- nonebot/dependencies/__init__.py | 48 +++++----- nonebot/drivers/aiohttp.py | 3 +- nonebot/drivers/fastapi.py | 3 +- nonebot/exception.py | 157 ++++++++++++++++++------------- nonebot/handler.py | 32 +++---- nonebot/matcher.py | 16 ++-- nonebot/rule.py | 66 ++++++++----- nonebot/typing.py | 5 +- 8 files changed, 182 insertions(+), 148 deletions(-) diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index bbf835fb..f9a6ab63 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -18,6 +18,7 @@ from nonebot.log import logger from .models import Param as Param from .utils import get_typed_signature from .models import Dependent as Dependent +from nonebot.exception import SkippedException from .models import DependsWrapper as DependsWrapper from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, is_async_gen_callable, is_coroutine_callable) @@ -112,21 +113,20 @@ def get_dependent(*, async def solve_dependencies( - *, - dependent: Dependent, - stack: Optional[AsyncExitStack] = None, - sub_dependents: Optional[List[Dependent]] = None, - dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, - **params: Any -) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]: + *, + _dependent: Dependent, + _stack: Optional[AsyncExitStack] = None, + _sub_dependents: Optional[List[Dependent]] = None, + _dependency_overrides_provider: Optional[Any] = None, + _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, + **params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]: values: Dict[str, Any] = {} - dependency_cache = dependency_cache or {} + dependency_cache = _dependency_cache or {} # solve sub dependencies sub_dependent: Dependent - for sub_dependent in chain(sub_dependents or tuple(), - dependent.dependencies): + for sub_dependent in chain(_sub_dependents or tuple(), + _dependent.dependencies): sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key) @@ -134,10 +134,10 @@ async def solve_dependencies( # dependency overrides use_sub_dependant = sub_dependent - if (dependency_overrides_provider and - hasattr(dependency_overrides_provider, "dependency_overrides")): + if (_dependency_overrides_provider and hasattr( + _dependency_overrides_provider, "dependency_overrides")): original_call = sub_dependent.func - func = getattr(dependency_overrides_provider, + func = getattr(_dependency_overrides_provider, "dependency_overrides", {}).get(original_call, original_call) use_sub_dependant = get_dependent( @@ -148,13 +148,11 @@ async def solve_dependencies( # solve sub dependency with current cache solved_result = await solve_dependencies( - dependent=use_sub_dependant, - dependency_overrides_provider=dependency_overrides_provider, + _dependent=use_sub_dependant, + _dependency_overrides_provider=_dependency_overrides_provider, dependency_cache=dependency_cache, **params) - sub_values, sub_dependency_cache, ignored = solved_result - if ignored: - return values, dependency_cache, True + sub_values, sub_dependency_cache = solved_result # update cache? dependency_cache.update(sub_dependency_cache) @@ -163,13 +161,13 @@ async def solve_dependencies( solved = dependency_cache[sub_dependent.cache_key] elif is_gen_callable(func) or is_async_gen_callable(func): assert isinstance( - stack, AsyncExitStack + _stack, AsyncExitStack ), "Generator dependency should be called in context" if is_gen_callable(func): cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) else: cm = asynccontextmanager(func)(**sub_values) - solved = await stack.enter_async_context(cm) + solved = await _stack.enter_async_context(cm) elif is_coroutine_callable(func): solved = await func(**sub_values) else: @@ -183,7 +181,7 @@ async def solve_dependencies( dependency_cache[sub_dependent.cache_key] = solved # usual dependency - for field in dependent.params: + for field in _dependent.params: field_info = field.field_info assert isinstance(field_info, Param), "Params must be subclasses of Param" @@ -194,13 +192,13 @@ async def solve_dependencies( if errs_: logger.debug( f"{field_info} " - f"type {type(value)} not match depends {dependent.func} " + f"type {type(value)} not match depends {_dependent.func} " f"annotation {field._type_display()}, ignored") - return values, dependency_cache, True + raise SkippedException else: values[field.name] = value - return values, dependency_cache, False + return values, dependency_cache def Depends(dependency: Optional[Callable[..., Any]] = None, diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 20ca668f..3670c414 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -248,6 +248,8 @@ class Driver(ForwardDriver): await asyncio.sleep(3) continue + setup_ = cast(HTTPPollingSetup, setup_) + if not bot: request = await _build_request(setup_) if not request: @@ -264,7 +266,6 @@ class Driver(ForwardDriver): bot.request = request request = cast(HTTPRequest, request) - setup_ = cast(HTTPPollingSetup, setup_) headers = request.headers timeout = aiohttp.ClientTimeout(30) diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 0ace2e31..e0d47d79 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -409,6 +409,8 @@ class Driver(ReverseDriver, ForwardDriver): await asyncio.sleep(3) continue + setup_ = cast(HTTPPollingSetup, setup_) + if not bot: request = await _build_request(setup_) if not request: @@ -423,7 +425,6 @@ class Driver(ReverseDriver, ForwardDriver): continue bot.request = request - setup_ = cast(HTTPPollingSetup, setup_) request = cast(HTTPRequest, request) headers = request.headers diff --git a/nonebot/exception.py b/nonebot/exception.py index 3cad317a..227d8e10 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -6,6 +6,8 @@ 这些异常并非所有需要用户处理,在 NoneBot 内部运行时被捕获,并进行对应操作。 """ +from typing import Optional + class NoneBotException(Exception): """ @@ -13,9 +15,33 @@ class NoneBotException(Exception): 所有 NoneBot 发生的异常基类。 """ - pass +# Rule Exception +class ParserExit(NoneBotException): + """ + :说明: + + ``shell command`` 处理消息失败时返回的异常 + + :参数: + + * ``status`` + * ``message`` + """ + + def __init__(self, status: int = 0, message: Optional[str] = None): + self.status = status + self.message = message + + def __repr__(self): + return f"" + + def __str__(self): + return self.__repr__() + + +# Processor Exception class IgnoredException(NoneBotException): """ :说明: @@ -37,71 +63,6 @@ class IgnoredException(NoneBotException): return self.__repr__() -class ParserExit(NoneBotException): - """ - :说明: - - ``shell command`` 处理消息失败时返回的异常 - - :参数: - - * ``status`` - * ``message`` - """ - - def __init__(self, status=0, message=None): - self.status = status - self.message = message - - def __repr__(self): - return f"" - - def __str__(self): - return self.__repr__() - - -class PausedException(NoneBotException): - """ - :说明: - - 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。 - 可用于用户输入新信息。 - - :用法: - - 可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。 - """ - pass - - -class RejectedException(NoneBotException): - """ - :说明: - - 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。 - 可用于用户重新输入。 - - :用法: - - 可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。 - """ - pass - - -class FinishedException(NoneBotException): - """ - :说明: - - 指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。 - 可用于结束用户会话。 - - :用法: - - 可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。 - """ - pass - - class StopPropagation(NoneBotException): """ :说明: @@ -112,9 +73,69 @@ class StopPropagation(NoneBotException): 在 ``Matcher.block == True`` 时抛出。 """ - pass +# Matcher Exceptions +class MatcherException(NoneBotException): + """ + :说明: + + 所有 Matcher 发生的异常基类。 + """ + + +class SkippedException(MatcherException): + """ + :说明: + + 指示 NoneBot 立即结束当前 ``Handler`` 的处理,继续处理下一个 ``Handler``。 + + :用法: + + 可以在 ``Handler`` 中通过 ``Matcher.skip()`` 抛出。 + """ + + +class PausedException(MatcherException): + """ + :说明: + + 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``。 + 可用于用户输入新信息。 + + :用法: + + 可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出。 + """ + + +class RejectedException(MatcherException): + """ + :说明: + + 指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``。 + 可用于用户重新输入。 + + :用法: + + 可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出。 + """ + + +class FinishedException(MatcherException): + """ + :说明: + + 指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行。 + 可用于结束用户会话。 + + :用法: + + 可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出。 + """ + + +# Adapter Exceptions class AdapterException(NoneBotException): """ :说明: @@ -130,7 +151,7 @@ class AdapterException(NoneBotException): self.adapter_name = adapter_name -class NoLogException(Exception): +class NoLogException(AdapterException): """ :说明: diff --git a/nonebot/handler.py b/nonebot/handler.py index 232f2200..d6df70dd 100644 --- a/nonebot/handler.py +++ b/nonebot/handler.py @@ -7,24 +7,19 @@ import asyncio from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Any, Dict, List, Type, Callable, Optional +from typing import Any, Dict, List, Type, Callable, Optional -from nonebot.typing import T_Handler from nonebot.utils import get_name, run_sync from nonebot.dependencies import (Param, Dependent, DependsWrapper, get_dependent, solve_dependencies, get_parameterless_sub_dependant) -if TYPE_CHECKING: - from nonebot.matcher import Matcher - from nonebot.adapters import Bot, Event - class Handler: """事件处理器类。支持依赖注入。""" def __init__(self, - func: T_Handler, + func: Callable[..., Any], *, name: Optional[str] = None, dependencies: Optional[List[DependsWrapper]] = None, @@ -37,7 +32,7 @@ class Handler: :参数: - * ``func: T_Handler``: 事件处理函数。 + * ``func: Callable[..., Any]``: 事件处理函数。 * ``name: Optional[str]``: 事件处理器名称。默认为函数名。 * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。 * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。 @@ -45,7 +40,7 @@ class Handler: """ self.func = func """ - :类型: ``T_Handler`` + :类型: ``Callable[..., Any]`` :说明: 事件处理函数 """ self.name = get_name(func) if name is None else name @@ -85,24 +80,21 @@ class Handler: _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, **params) -> Any: - values, _, ignored = await solve_dependencies( - dependent=self.dependent, - stack=_stack, - sub_dependents=[ + values, cache = await solve_dependencies( + _dependent=self.dependent, + _stack=_stack, + _sub_dependents=[ self.sub_dependents[dependency.dependency] # type: ignore for dependency in self.dependencies ], - dependency_overrides_provider=self.dependency_overrides_provider, - dependency_cache=_dependency_cache, + _dependency_overrides_provider=self.dependency_overrides_provider, + _dependency_cache=_dependency_cache, **params) - if ignored: - return - if asyncio.iscoroutinefunction(self.func): - await self.func(**values) + return await self.func(**values) else: - await run_sync(self.func)(**values) + return await run_sync(self.func)(**values) def cache_dependent(self, dependency: DependsWrapper): if not dependency.dependency: diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 8856a3cf..0967d5c3 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -20,10 +20,11 @@ from nonebot.dependencies import DependsWrapper from nonebot.permission import USER, Permission from nonebot.adapters import (Bot, Event, Message, MessageSegment, MessageTemplate) -from nonebot.exception import (PausedException, StopPropagation, - FinishedException, RejectedException) from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, T_StateFactory, T_PermissionUpdater) +from nonebot.exception import (PausedException, StopPropagation, + SkippedException, FinishedException, + RejectedException) if TYPE_CHECKING: from nonebot.plugin import Plugin @@ -601,10 +602,13 @@ class Matcher(metaclass=MatcherMeta): while self.handlers: handler = self.handlers.pop(0) logger.debug(f"Running handler {handler}") - await handler(matcher=self, - bot=bot, - event=event, - state=self.state) + try: + await handler(matcher=self, + bot=bot, + event=event, + state=self.state) + except SkippedException: + pass except RejectedException: self.handlers.insert(0, handler) # type: ignore diff --git a/nonebot/rule.py b/nonebot/rule.py index 7ac0fe8d..3a1c2c33 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -14,16 +14,18 @@ import shlex import asyncio from itertools import product from argparse import Namespace +from contextlib import AsyncExitStack from typing_extensions import TypedDict from argparse import ArgumentParser as ArgParser -from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence, - Awaitable) +from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional, + Sequence, Awaitable) from pygtrie import CharTrie -from nonebot import get_driver from nonebot.log import logger from nonebot.utils import run_sync +from nonebot.handler import Handler +from nonebot import params, get_driver from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker from nonebot.adapters import Bot, Event, MessageSegment @@ -62,16 +64,22 @@ class Rule: """ __slots__ = ("checkers",) - def __init__( - self, *checkers: Callable[[Bot, Event, T_State], - Awaitable[bool]]) -> None: + HANDLER_PARAM_TYPES = [ + params.BotParam, params.EventParam, params.StateParam + ] + + def __init__(self, *checkers: T_RuleChecker) -> None: """ :参数: - * ``*checkers: Callable[[Bot, Event, T_State], Awaitable[bool]]``: **异步** RuleChecker + * ``*checkers: T_RuleChecker``: RuleChecker """ - self.checkers = set(checkers) + self.checkers = set( + Handler(checker, + allow_types=self.HANDLER_PARAM_TYPES, + dependency_overrides_provider=get_driver()) + for checker in checkers) """ :说明: @@ -79,10 +87,17 @@ class Rule: :类型: - * ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]`` + * ``Set[Handler]`` """ - async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: + async def __call__( + self, + bot: Bot, + event: Event, + state: T_State, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[Dict[Callable[..., Any], + Any]] = None) -> bool: """ :说明: @@ -99,19 +114,21 @@ class Rule: - ``bool`` """ results = await asyncio.gather( - *map(lambda c: c(bot, event, state), self.checkers)) + checker(bot=bot, + event=event, + state=state, + _stack=stack, + _dependency_cache=dependency_cache) + for checker in self.checkers) return all(results) def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule": - checkers = self.checkers.copy() if other is None: return self elif isinstance(other, Rule): - checkers |= other.checkers - elif asyncio.iscoroutinefunction(other): - checkers.add(other) # type: ignore + checkers = [*self.checkers, *other.checkers] else: - checkers.add(run_sync(other)) + checkers = [*self.checkers, other] return Rule(*checkers) def __or__(self, other) -> NoReturn: @@ -226,7 +243,7 @@ def keyword(*keywords: str) -> Rule: * ``*keywords: str``: 关键词 """ - async def _keyword(bot: Bot, event: Event, state: T_State) -> bool: + async def _keyword(event: Event) -> bool: if event.get_type() != "message": return False text = event.get_plaintext() @@ -274,7 +291,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - async def _command(bot: Bot, event: Event, state: T_State) -> bool: + async def _command(state: T_State) -> bool: return state[PREFIX_KEY][CMD_KEY] in commands return Rule(_command) @@ -294,7 +311,7 @@ class ArgumentParser(ArgParser): old_message += message setattr(self, "message", old_message) - def exit(self, status=0, message=None): + def exit(self, status: int = 0, message: Optional[str] = None): raise ParserExit(status=status, message=message or getattr(self, "message", None)) @@ -360,7 +377,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]], for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool: + async def _shell_command(event: Event, state: T_State) -> bool: if state[PREFIX_KEY][CMD_KEY] in commands: message = str(event.get_message()) strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY] @@ -400,7 +417,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: pattern = re.compile(regex, flags) - async def _regex(bot: Bot, event: Event, state: T_State) -> bool: + async def _regex(event: Event, state: T_State) -> bool: if event.get_type() != "message": return False matched = pattern.search(str(event.get_message())) @@ -415,6 +432,10 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: return Rule(_regex) +async def _to_me(event: Event) -> bool: + return event.is_tome() + + def to_me() -> Rule: """ :说明: @@ -426,7 +447,4 @@ def to_me() -> Rule: * 无 """ - async def _to_me(bot: Bot, event: Event, state: T_State) -> bool: - return event.is_tome() - return Rule(_to_me) diff --git a/nonebot/typing.py b/nonebot/typing.py index 3725b301..337ff426 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -123,10 +123,9 @@ T_RunPostProcessor = Callable[..., Awaitable[None]] 事件响应器运行前预处理函数 RunPostProcessor 类型,第二个参数为运行时产生的错误(如果存在) """ -T_RuleChecker = Callable[["Bot", "Event", T_State], Union[bool, - Awaitable[bool]]] +T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]] """ -:类型: ``Callable[[Bot, Event, T_State], Union[bool, Awaitable[bool]]]`` +:类型: ``Callable[..., Union[bool, Awaitable[bool]]]`` :说明: