From 66ba25494a0ef4501c890f2747d3dbd5bbbd1023 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Sun, 12 Dec 2021 18:19:08 +0800 Subject: [PATCH] :recycle: rewrite dependency injection system --- nonebot/consts.py | 4 + nonebot/dependencies/__init__.py | 337 +++++++++++++------------------ nonebot/dependencies/models.py | 52 ----- nonebot/dependencies/utils.py | 6 +- nonebot/handler.py | 125 ------------ nonebot/matcher.py | 190 +++++++++-------- nonebot/message.py | 55 +++-- nonebot/params.py | 215 +++++++++++++++++--- nonebot/permission.py | 35 ++-- nonebot/plugin/on.py | 103 +++++----- nonebot/plugin/on.pyi | 101 ++++----- nonebot/rule.py | 39 ++-- nonebot/typing.py | 69 ++++--- nonebot/utils.py | 76 ++----- tests/plugins/depends.py | 23 +++ tests/test_init.py | 17 ++ tests/utils.py | 14 ++ 17 files changed, 728 insertions(+), 733 deletions(-) create mode 100644 nonebot/consts.py delete mode 100644 nonebot/dependencies/models.py delete mode 100644 nonebot/handler.py create mode 100644 tests/plugins/depends.py create mode 100644 tests/utils.py diff --git a/nonebot/consts.py b/nonebot/consts.py new file mode 100644 index 00000000..a8867395 --- /dev/null +++ b/nonebot/consts.py @@ -0,0 +1,4 @@ +RECEIVE_KEY = "_receive_{id}" +ARG_KEY = "_arg_{key}" +ARG_STR_KEY = "{key}" +REJECT_TARGET = "_current_target" diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index 776a25e5..da5f9273 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -5,227 +5,170 @@ 该模块实现了依赖注入的定义与处理。 """ +import abc import inspect -from itertools import chain -from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast -from contextlib import AsyncExitStack, contextmanager, asynccontextmanager +from typing import Any, Dict, List, Type, Generic, TypeVar, Callable, Optional from pydantic import BaseConfig from pydantic.schema import get_annotation_from_field_info -from pydantic.fields import Required, Undefined, ModelField +from pydantic.fields import Required, FieldInfo, Undefined, ModelField from nonebot.log import logger -from .models import Param as Param from .utils import get_typed_signature -from .models import Dependent as Dependent from nonebot.exception import SkippedException -from .models import DependsWrapper as DependsWrapper -from nonebot.typing import T_Handler, T_DependencyCache -from nonebot.utils import ( - CacheLock, - run_sync, - is_gen_callable, - run_sync_ctx_manager, - is_async_gen_callable, - is_coroutine_callable, -) +from nonebot.utils import run_sync, is_coroutine_callable -cache_lock = CacheLock() +T = TypeVar("T", bound="Dependent") +R = TypeVar("R") + + +class Param(abc.ABC, FieldInfo): + @classmethod + def _check_param( + cls, dependent: "Dependent", name: str, param: inspect.Parameter + ) -> Optional["Param"]: + return None + + @classmethod + def _check_parameterless( + cls, dependent: "Dependent", value: Any + ) -> Optional["Param"]: + return None + + @abc.abstractmethod + async def _solve(self, **kwargs: Any) -> Any: + raise NotImplementedError class CustomConfig(BaseConfig): arbitrary_types_allowed = True -def get_param_sub_dependent( - *, param: inspect.Parameter, allow_types: Optional[List[Type[Param]]] = None -) -> Dependent: - depends: DependsWrapper = param.default - if depends.dependency: - dependency = depends.dependency - else: - dependency = param.annotation - return get_sub_dependant( - depends=depends, dependency=dependency, name=param.name, allow_types=allow_types - ) +class Dependent(Generic[R]): + def __init__( + self, + *, + call: Callable[..., Any], + params: Optional[List[ModelField]] = None, + parameterless: Optional[List[Param]] = None, + allow_types: Optional[List[Type[Param]]] = None, + ) -> None: + self.call = call + self.params = params or [] + self.parameterless = parameterless or [] + self.allow_types = allow_types or [] + async def __call__(self, **kwargs: Any) -> R: + values = await self.solve(**kwargs) -def get_parameterless_sub_dependant( - *, depends: DependsWrapper, allow_types: Optional[List[Type[Param]]] = None -) -> Dependent: - assert callable( - depends.dependency - ), "A parameter-less dependency must have a callable dependency" - return get_sub_dependant( - depends=depends, dependency=depends.dependency, allow_types=allow_types - ) - - -def get_sub_dependant( - *, - depends: DependsWrapper, - dependency: T_Handler, - name: Optional[str] = None, - allow_types: Optional[List[Type[Param]]] = None, -) -> Dependent: - sub_dependant = get_dependent( - call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types - ) - return sub_dependant - - -def get_dependent( - *, - call: T_Handler, - name: Optional[str] = None, - use_cache: bool = True, - allow_types: Optional[List[Type[Param]]] = None, -) -> Dependent: - signature = get_typed_signature(call) - params = signature.parameters - dependent = Dependent( - call=call, name=name, allow_types=allow_types, use_cache=use_cache - ) - for param_name, param in params.items(): - if isinstance(param.default, DependsWrapper): - sub_dependent = get_param_sub_dependent( - param=param, allow_types=allow_types - ) - dependent.dependencies.append(sub_dependent) - continue - - default_value = Required - if param.default != param.empty: - default_value = param.default - - if isinstance(default_value, Param): - field_info = default_value - default_value = field_info.default + if is_coroutine_callable(self.call): + return await self.call(**values) else: - for allow_type in dependent.allow_types: - if allow_type._check(param_name, param): - field_info = allow_type(default_value) - break + return await run_sync(self.call)(**values) + + def parse_param(self, name: str, param: inspect.Parameter) -> Param: + for allow_type in self.allow_types: + field_info = allow_type._check_param(self, name, param) + if field_info: + return field_info + else: + raise ValueError( + f"Unknown parameter {name} for function {self.call} with type {param.annotation}" + ) + + def parse_parameterless(self, value: Any) -> Param: + for allow_type in self.allow_types: + field_info = allow_type._check_parameterless(self, value) + if field_info: + return field_info + else: + raise ValueError( + f"Unknown parameterless {value} for function {self.call} with type {type(value)}" + ) + + def prepend_parameterless(self, value: Any) -> None: + self.parameterless.insert(0, self.parse_parameterless(value)) + + def append_parameterless(self, value: Any) -> None: + self.parameterless.append(self.parse_parameterless(value)) + + @classmethod + def parse( + cls: Type[T], + *, + call: Callable[..., Any], + parameterless: Optional[List[Any]] = None, + allow_types: Optional[List[Type[Param]]] = None, + ) -> T: + signature = get_typed_signature(call) + params = signature.parameters + dependent = cls( + call=call, + allow_types=allow_types, + ) + + parameterless_params = [ + dependent.parse_parameterless(param) for param in (parameterless or []) + ] + dependent.parameterless.extend(parameterless_params) + + for param_name, param in params.items(): + default_value = Required + if param.default != param.empty: + default_value = param.default + + if isinstance(default_value, Param): + field_info = default_value + default_value = field_info.default else: - raise ValueError( - f"Unknown parameter {param_name} for function {call} with type {param.annotation}" + field_info = dependent.parse_param(param_name, param) + default_value = field_info.default + + annotation: Any = Any + required = default_value == Required + if param.annotation != param.empty: + annotation = param.annotation + annotation = get_annotation_from_field_info( + annotation, field_info, param_name + ) + dependent.params.append( + ModelField( + name=param_name, + type_=annotation, + class_validators=None, + model_config=CustomConfig, + default=None if required else default_value, + required=required, + field_info=field_info, ) - - annotation: Any = Any - required = default_value == Required - if param.annotation != param.empty: - annotation = param.annotation - annotation = get_annotation_from_field_info(annotation, field_info, param_name) - dependent.params.append( - ModelField( - name=param_name, - type_=annotation, - class_validators=None, - model_config=CustomConfig, - default=None if required else default_value, - required=required, - field_info=field_info, ) - ) - return dependent + return dependent + async def solve( + self, + **params: Any, + ) -> Dict[str, Any]: + values: Dict[str, Any] = {} -async def solve_dependencies( - *, - _dependent: Dependent, - _stack: Optional[AsyncExitStack] = None, - _sub_dependents: Optional[List[Dependent]] = None, - _dependency_cache: Optional[T_DependencyCache] = None, - **params: Any, -) -> Tuple[Dict[str, Any], T_DependencyCache]: - values: Dict[str, Any] = {} - dependency_cache = {} if _dependency_cache is None else _dependency_cache - - # usual dependency - for field in _dependent.params: - field_info = field.field_info - assert isinstance(field_info, Param), "Params must be subclasses of Param" - value = field_info._solve(**params) - if value == Undefined: - value = field.get_default() - _, errs_ = field.validate(value, values, loc=(str(field_info), field.alias)) - if errs_: - logger.debug( - f"{field_info} " - f"type {type(value)} not match depends {_dependent.call} " - f"annotation {field._type_display()}, ignored" - ) - raise SkippedException(field, value) - else: - values[field.name] = value - - # solve sub dependencies - sub_dependent: Dependent - for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies): - sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) - sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key) - call = sub_dependent.call - - # solve sub dependency with current cache - solved_result = await solve_dependencies( - _dependent=sub_dependent, _dependency_cache=dependency_cache, **params - ) - sub_values, sub_dependency_cache = solved_result - # update cache? - # dependency_cache.update(sub_dependency_cache) - - # run dependency function - async with cache_lock: - if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: - solved = dependency_cache[sub_dependent.cache_key] - elif is_gen_callable(call) or is_async_gen_callable(call): - assert isinstance( - _stack, AsyncExitStack - ), "Generator dependency should be called in context" - if is_gen_callable(call): - cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) - else: - cm = asynccontextmanager(call)(**sub_values) - solved = await _stack.enter_async_context(cm) - elif is_coroutine_callable(call): - solved = await call(**sub_values) + for field in self.params: + field_info = field.field_info + assert isinstance(field_info, Param), "Params must be subclasses of Param" + value = await field_info._solve(**params) + if value == Undefined: + value = field.get_default() + _, errs_ = field.validate(value, values, loc=(str(field_info), field.alias)) + if errs_: + logger.debug( + f"{field_info} " + f"type {type(value)} not match depends {self.call} " + f"annotation {field._type_display()}, ignored" + ) + raise SkippedException(field, value) else: - solved = await run_sync(call)(**sub_values) + values[field.name] = value - # parameter dependency - if sub_dependent.name is not None: - values[sub_dependent.name] = solved - # save current dependency to cache - if sub_dependent.cache_key not in dependency_cache: - dependency_cache[sub_dependent.cache_key] = solved + for param in self.parameterless: + await param._solve(**params) - return values, dependency_cache - - -def Depends(dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> Any: - """ - :说明: - - 参数依赖注入装饰器 - - :参数: - - * ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。 - * ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。 - - .. code-block:: python - - def depend_func() -> Any: - return ... - - def depend_gen_func(): - try: - yield ... - finally: - ... - - async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)): - ... - """ - return DependsWrapper(dependency=dependency, use_cache=use_cache) + return values diff --git a/nonebot/dependencies/models.py b/nonebot/dependencies/models.py deleted file mode 100644 index 3875a6d7..00000000 --- a/nonebot/dependencies/models.py +++ /dev/null @@ -1,52 +0,0 @@ -import abc -import inspect -from typing import Any, List, Type, Optional - -from pydantic.fields import FieldInfo, ModelField - -from nonebot.utils import get_name -from nonebot.typing import T_Handler - - -class Param(abc.ABC, FieldInfo): - @classmethod - @abc.abstractmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - raise NotImplementedError - - @abc.abstractmethod - def _solve(self, **kwargs: Any) -> Any: - raise NotImplementedError - - -class DependsWrapper: - def __init__( - self, dependency: Optional[T_Handler] = None, *, use_cache: bool = True - ) -> None: - self.dependency = dependency - self.use_cache = use_cache - - def __repr__(self) -> str: - dep = get_name(self.dependency) - cache = "" if self.use_cache else ", use_cache=False" - return f"{self.__class__.__name__}({dep}{cache})" - - -class Dependent: - def __init__( - self, - *, - call: Optional[T_Handler] = None, - name: Optional[str] = None, - params: Optional[List[ModelField]] = None, - allow_types: Optional[List[Type[Param]]] = None, - dependencies: Optional[List["Dependent"]] = None, - use_cache: bool = True, - ) -> None: - self.call = call - self.name = name - self.params = params or [] - self.allow_types = allow_types or [] - self.dependencies = dependencies or [] - self.use_cache = use_cache - self.cache_key = self.call diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index d82976c5..56a815ff 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -1,13 +1,11 @@ import inspect -from typing import Any, Dict +from typing import Any, Dict, Callable from loguru import logger from pydantic.typing import ForwardRef, evaluate_forwardref -from nonebot.typing import T_Handler - -def get_typed_signature(call: T_Handler) -> inspect.Signature: +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ diff --git a/nonebot/handler.py b/nonebot/handler.py deleted file mode 100644 index b381a228..00000000 --- a/nonebot/handler.py +++ /dev/null @@ -1,125 +0,0 @@ -""" -事件处理函数 -============ - -该模块实现事件处理函数的封装,以实现动态参数等功能。 -""" - -from contextlib import AsyncExitStack -from typing import Any, Dict, List, Type, Callable, Optional - -from nonebot.utils import get_name, run_sync, is_coroutine_callable -from nonebot.dependencies import ( - Param, - Dependent, - DependsWrapper, - get_dependent, - solve_dependencies, - get_parameterless_sub_dependant, -) - - -class Handler: - """事件处理器类。支持依赖注入。""" - - def __init__( - self, - call: Callable[..., Any], - *, - name: Optional[str] = None, - dependencies: Optional[List[DependsWrapper]] = None, - allow_types: Optional[List[Type[Param]]] = None, - ): - """ - :说明: - - 装饰一个函数为事件处理器。 - - :参数: - - * ``call: Callable[..., Any]``: 事件处理函数。 - * ``name: Optional[str]``: 事件处理器名称。默认为函数名。 - * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。 - * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。 - """ - self.call = call - """ - :类型: ``Callable[..., Any]`` - :说明: 事件处理函数 - """ - self.name = get_name(call) if name is None else name - """ - :类型: ``str`` - :说明: 事件处理函数名 - """ - self.allow_types = allow_types or [] - """ - :类型: ``List[Type[Param]]`` - :说明: 事件处理器允许的参数类型 - """ - - self.dependencies = dependencies or [] - """ - :类型: ``List[DependsWrapper]`` - :说明: 事件处理器的额外依赖 - """ - self.sub_dependents: Dict[Callable[..., Any], Dependent] = {} - if dependencies: - for depends in dependencies: - self.cache_dependent(depends) - self.dependent = get_dependent(call=call, allow_types=self.allow_types) - - def __repr__(self) -> str: - return f"" - - def __str__(self) -> str: - return repr(self) - - async def __call__( - self, - *, - _stack: Optional[AsyncExitStack] = None, - _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, - **params, - ) -> Any: - values, _ = await solve_dependencies( - _dependent=self.dependent, - _stack=_stack, - _sub_dependents=[ - self.sub_dependents[dependency.dependency] # type: ignore - for dependency in self.dependencies - ], - _dependency_cache=_dependency_cache, - **params, - ) - - if is_coroutine_callable(self.call): - return await self.call(**values) - else: - return await run_sync(self.call)(**values) - - def cache_dependent(self, dependency: DependsWrapper): - if not dependency.dependency: - raise ValueError(f"{dependency} has no dependency") - if dependency.dependency in self.sub_dependents: - raise ValueError(f"{dependency} is already in dependencies") - sub_dependant = get_parameterless_sub_dependant( - depends=dependency, allow_types=self.allow_types - ) - self.sub_dependents[dependency.dependency] = sub_dependant - - def prepend_dependency(self, dependency: DependsWrapper): - self.cache_dependent(dependency) - self.dependencies.insert(0, dependency) - - def append_dependency(self, dependency: DependsWrapper): - self.cache_dependent(dependency) - self.dependencies.append(dependency) - - def remove_dependency(self, dependency: DependsWrapper): - if not dependency.dependency: - raise ValueError(f"{dependency} has no dependency") - if dependency.dependency in self.sub_dependents: - del self.sub_dependents[dependency.dependency] - if dependency in self.dependencies: - self.dependencies.remove(dependency) diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 4701ff55..5d079718 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -17,6 +17,7 @@ from typing import ( List, Type, Union, + TypeVar, Callable, NoReturn, Optional, @@ -25,9 +26,10 @@ from typing import ( from nonebot import params from nonebot.rule import Rule from nonebot.log import logger -from nonebot.handler import Handler -from nonebot.dependencies import DependsWrapper +from nonebot.utils import CacheDict +from nonebot.dependencies import Dependent from nonebot.permission import USER, Permission +from nonebot.consts import ARG_KEY, ARG_STR_KEY, RECEIVE_KEY, REJECT_TARGET from nonebot.adapters import ( Bot, Event, @@ -35,6 +37,14 @@ from nonebot.adapters import ( MessageSegment, MessageTemplate, ) +from nonebot.typing import ( + Any, + T_State, + T_Handler, + T_ArgsParser, + T_TypeUpdater, + T_PermissionUpdater, +) from nonebot.exception import ( PausedException, StopPropagation, @@ -42,19 +52,12 @@ from nonebot.exception import ( FinishedException, RejectedException, ) -from nonebot.typing import ( - T_State, - T_Handler, - T_ArgsParser, - T_TypeUpdater, - T_StateFactory, - T_DependencyCache, - T_PermissionUpdater, -) if TYPE_CHECKING: from nonebot.plugin import Plugin +T = TypeVar("T") + matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) """ :类型: ``Dict[int, List[Type[Matcher]]]`` @@ -63,7 +66,7 @@ matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) current_bot: ContextVar[Bot] = ContextVar("current_bot") current_event: ContextVar[Event] = ContextVar("current_event") current_state: ContextVar[T_State] = ContextVar("current_state") -current_handler: ContextVar[Handler] = ContextVar("current_handler") +current_handler: ContextVar[Dependent] = ContextVar("current_handler") class MatcherMeta(type): @@ -131,7 +134,7 @@ class Matcher(metaclass=MatcherMeta): :类型: ``Permission`` :说明: 事件响应器触发权限 """ - handlers: List[Handler] = [] + handlers: List[Dependent[Any]] = [] """ :类型: ``List[Handler]`` :说明: 事件响应器拥有的事件处理函数列表 @@ -163,23 +166,24 @@ class Matcher(metaclass=MatcherMeta): :说明: 事件响应器默认状态 """ - _default_parser: Optional[T_ArgsParser] = None + _default_parser: Optional[Dependent[None]] = None """ - :类型: ``Optional[T_ArgsParser]`` + :类型: ``Optional[Dependent]`` :说明: 事件响应器默认参数解析函数 """ - _default_type_updater: Optional[T_TypeUpdater] = None + _default_type_updater: Optional[Dependent[str]] = None """ - :类型: ``Optional[T_TypeUpdater]`` + :类型: ``Optional[Dependent]`` :说明: 事件响应器类型更新函数 """ - _default_permission_updater: Optional[T_PermissionUpdater] = None + _default_permission_updater: Optional[Dependent[Permission]] = None """ - :类型: ``Optional[T_PermissionUpdater]`` + :类型: ``Optional[Dependent]`` :说明: 事件响应器权限更新函数 """ HANDLER_PARAM_TYPES = [ + params.DependParam, params.BotParam, params.EventParam, params.StateParam, @@ -207,9 +211,7 @@ class Matcher(metaclass=MatcherMeta): type_: str = "", rule: Optional[Rule] = None, permission: Optional[Permission] = None, - handlers: Optional[ - Union[List[T_Handler], List[Handler], List[Union[T_Handler, Handler]]] - ] = None, + handlers: Optional[List[Union[T_Handler, Dependent[Any]]]] = None, temp: bool = False, priority: int = 1, block: bool = False, @@ -259,8 +261,10 @@ class Matcher(metaclass=MatcherMeta): "permission": permission or Permission(), "handlers": [ handler - if isinstance(handler, Handler) - else Handler(handler, allow_types=cls.HANDLER_PARAM_TYPES) + if isinstance(handler, Dependent) + else Dependent[Any].parse( + call=handler, allow_types=cls.HANDLER_PARAM_TYPES + ) for handler in handlers ] if handlers @@ -286,7 +290,7 @@ class Matcher(metaclass=MatcherMeta): bot: Bot, event: Event, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ) -> bool: """ :说明: @@ -314,7 +318,7 @@ class Matcher(metaclass=MatcherMeta): event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ) -> bool: """ :说明: @@ -347,7 +351,9 @@ class Matcher(metaclass=MatcherMeta): * ``func: T_ArgsParser``: 参数解析函数 """ - cls._default_parser = func + cls._default_parser = Dependent[None].parse( + call=func, allow_types=cls.HANDLER_PARAM_TYPES + ) return func @classmethod @@ -361,7 +367,9 @@ class Matcher(metaclass=MatcherMeta): * ``func: T_TypeUpdater``: 响应事件类型更新函数 """ - cls._default_type_updater = func + cls._default_type_updater = Dependent[str].parse( + call=func, allow_types=cls.HANDLER_PARAM_TYPES + ) return func @classmethod @@ -375,22 +383,26 @@ class Matcher(metaclass=MatcherMeta): * ``func: T_PermissionUpdater``: 会话权限更新函数 """ - cls._default_permission_updater = func + cls._default_permission_updater = Dependent[Permission].parse( + call=func, allow_types=cls.HANDLER_PARAM_TYPES + ) return func @classmethod def append_handler( - cls, handler: T_Handler, dependencies: Optional[List[DependsWrapper]] = None - ) -> Handler: - handler_ = Handler( - handler, dependencies=dependencies, allow_types=cls.HANDLER_PARAM_TYPES + cls, handler: T_Handler, parameterless: Optional[List[Any]] = None + ) -> Dependent[Any]: + handler_ = Dependent[Any].parse( + call=handler, + parameterless=parameterless, + allow_types=cls.HANDLER_PARAM_TYPES, ) cls.handlers.append(handler_) return handler_ @classmethod def handle( - cls, dependencies: Optional[List[DependsWrapper]] = None + cls, parameterless: Optional[List[Any]] = None ) -> Callable[[T_Handler], T_Handler]: """ :说明: @@ -399,18 +411,18 @@ class Matcher(metaclass=MatcherMeta): :参数: - * ``dependencies: Optional[List[DependsWrapper]]``: 非参数类型依赖列表 + * ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表 """ def _decorator(func: T_Handler) -> T_Handler: - cls.append_handler(func, dependencies=dependencies) + cls.append_handler(func, parameterless=parameterless) return func return _decorator @classmethod def receive( - cls, dependencies: Optional[List[DependsWrapper]] = None + cls, id: str = "", parameterless: Optional[List[Any]] = None ) -> Callable[[T_Handler], T_Handler]: """ :说明: @@ -419,28 +431,30 @@ class Matcher(metaclass=MatcherMeta): :参数: - * ``dependencies: Optional[List[DependsWrapper]]``: 非参数类型依赖列表 + * ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表 """ - async def _receive(state: T_State) -> Union[None, NoReturn]: - if state.get(_receive): + async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]: + if matcher.get_receive(id): return - state[_receive] = True - del state["_current_key"] + if matcher.get_target() == RECEIVE_KEY.format(id=id): + matcher.set_receive(id, event) + return + matcher.set_target(RECEIVE_KEY.format(id=id)) raise RejectedException - _dependencies = [DependsWrapper(_receive), *(dependencies or [])] + parameterless = [params.Depends(_receive), *(parameterless or [])] def _decorator(func: T_Handler) -> T_Handler: if cls.handlers and cls.handlers[-1].call is func: func_handler = cls.handlers[-1] - for depend in reversed(_dependencies): - func_handler.prepend_dependency(depend) + for depend in reversed(parameterless): + func_handler.prepend_parameterless(depend) else: cls.append_handler( func, - dependencies=_dependencies if cls.handlers else dependencies, + parameterless=parameterless if cls.handlers else parameterless, ) return func @@ -453,7 +467,7 @@ class Matcher(metaclass=MatcherMeta): key: str, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, args_parser: Optional[T_ArgsParser] = None, - dependencies: Optional[List[DependsWrapper]] = None, + parameterless: Optional[List[Any]] = None, ) -> Callable[[T_Handler], T_Handler]: """ :说明: @@ -465,51 +479,31 @@ class Matcher(metaclass=MatcherMeta): * ``key: str``: 参数名 * ``prompt: Optional[Union[str, Message, MessageSegment, MessageFormatter]]``: 在参数不存在时向用户发送的消息 * ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数,空则使用默认解析函数 - * ``dependencies: Optional[List[DependsWrapper]]``: 非参数类型依赖列表 + * ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表 """ - async def _key_getter(bot: Bot, event: Event, state: T_State): - if state.get(f"_{key}_prompted"): + async def _key_getter(event: Event, matcher: "Matcher"): + if matcher.get_arg(key): return - - state["_current_key"] = key - state[f"_{key}_prompted"] = True - if key not in state: - if prompt is not None: - if isinstance(prompt, MessageTemplate): - _prompt = prompt.format(**state) - else: - _prompt = prompt - await bot.send(event=event, message=_prompt) - raise RejectedException - else: - state[f"_{key}_parsed"] = True - - async def _key_parser(bot: Bot, event: Event, state: T_State): - if key in state and state.get(f"_{key}_parsed"): + if matcher.get_target() == ARG_KEY.format(key=key): + matcher.set_arg(key, event) return + matcher.set_target(ARG_KEY.format(key=key)) + raise RejectedException - parser = args_parser or cls._default_parser - if parser: - await parser(bot, event, state) - else: - state[key] = str(event.get_message()) - state[f"_{key}_parsed"] = True - - _dependencies = [ - DependsWrapper(_key_getter), - DependsWrapper(_key_parser), - *(dependencies or []), + _parameterless = [ + params.Depends(_key_getter), + *(parameterless or []), ] def _decorator(func: T_Handler) -> T_Handler: if cls.handlers and cls.handlers[-1].call is func: func_handler = cls.handlers[-1] - for depend in reversed(_dependencies): - func_handler.prepend_dependency(depend) + for depend in reversed(_parameterless): + func_handler.prepend_parameterless(depend) else: - cls.append_handler(func, dependencies=_dependencies) + cls.append_handler(func, parameterless=_parameterless) return func @@ -609,8 +603,6 @@ class Matcher(metaclass=MatcherMeta): bot = current_bot.get() event = current_event.get() state = current_state.get() - if "_current_key" in state and f"_{state['_current_key']}_parsed" in state: - del state[f"_{state['_current_key']}_parsed"] if isinstance(prompt, MessageTemplate): _prompt = prompt.format(**state) else: @@ -619,6 +611,28 @@ class Matcher(metaclass=MatcherMeta): await bot.send(event=event, message=_prompt, **kwargs) raise RejectedException + def get_receive(self, id: str, default: T = None) -> Union[Event, T]: + return self.state.get(RECEIVE_KEY.format(id=id), default) + + def set_receive(self, id: str, event: Event) -> None: + self.state[RECEIVE_KEY.format(id=id)] = event + + def get_arg(self, key: str, default: T = None) -> Union[Event, T]: + return self.state.get(ARG_KEY.format(key=key), default) + + def get_arg_str(self, key: str, default: T = None) -> Union[str, T]: + return self.state.get(ARG_STR_KEY.format(key=key), default) + + def set_arg(self, key: str, event: Event) -> None: + self.state[ARG_KEY.format(key=key)] = event + self.state[ARG_STR_KEY.format(key=key)] = str(event.get_message()) + + def set_target(self, target: str) -> None: + self.state[REJECT_TARGET] = target + + def get_target(self, default: T = None) -> Union[str, T]: + return self.state.get(REJECT_TARGET, default) + def stop_propagation(self): """ :说明: @@ -631,13 +645,13 @@ class Matcher(metaclass=MatcherMeta): updater = self.__class__._default_type_updater if not updater: return "message" - return await updater(bot, event, self.state, self.type) + return await updater(bot=bot, event=event, state=self.state, matcher=self) async def update_permission(self, bot: Bot, event: Event) -> Permission: updater = self.__class__._default_permission_updater if not updater: return USER(event.get_session_id(), perm=self.permission) - return await updater(bot, event, self.state, self.permission) + return await updater(bot=bot, event=event, state=self.state, matcher=self) async def simple_run( self, @@ -645,7 +659,7 @@ class Matcher(metaclass=MatcherMeta): event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[T_DependencyCache] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ): b_t = current_bot.set(bot) e_t = current_event.set(event) @@ -664,8 +678,8 @@ class Matcher(metaclass=MatcherMeta): bot=bot, event=event, state=self.state, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) except SkippedException as e: logger.debug( @@ -687,7 +701,7 @@ class Matcher(metaclass=MatcherMeta): event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[T_DependencyCache] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ): try: await self.simple_run(bot, event, state, stack, dependency_cache) diff --git a/nonebot/message.py b/nonebot/message.py index 02863aa5..9b45c62a 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -22,9 +22,9 @@ from typing import ( from nonebot import params from nonebot.log import logger from nonebot.rule import TrieRule -from nonebot.handler import Handler -from nonebot.utils import escape_tag +from nonebot.dependencies import Dependent from nonebot.matcher import Matcher, matchers +from nonebot.utils import CacheDict, escape_tag from nonebot.exception import ( NoLogException, StopPropagation, @@ -33,7 +33,7 @@ from nonebot.exception import ( ) from nonebot.typing import ( T_State, - T_DependencyCache, + T_Handler, T_RunPreProcessor, T_RunPostProcessor, T_EventPreProcessor, @@ -43,18 +43,20 @@ from nonebot.typing import ( if TYPE_CHECKING: from nonebot.adapters import Bot, Event -_event_preprocessors: Set[Handler] = set() -_event_postprocessors: Set[Handler] = set() -_run_preprocessors: Set[Handler] = set() -_run_postprocessors: Set[Handler] = set() +_event_preprocessors: Set[Dependent[None]] = set() +_event_postprocessors: Set[Dependent[None]] = set() +_run_preprocessors: Set[Dependent[None]] = set() +_run_postprocessors: Set[Dependent[None]] = set() EVENT_PCS_PARAMS = [ + params.DependParam, params.BotParam, params.EventParam, params.StateParam, params.DefaultParam, ] RUN_PREPCS_PARAMS = [ + params.DependParam, params.MatcherParam, params.BotParam, params.EventParam, @@ -62,6 +64,7 @@ RUN_PREPCS_PARAMS = [ params.DefaultParam, ] RUN_POSTPCS_PARAMS = [ + params.DependParam, params.MatcherParam, params.ExceptionParam, params.BotParam, @@ -77,7 +80,9 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor: 事件预处理。装饰一个函数,使它在每次接收到事件并分发给各响应器之前执行。 """ - _event_preprocessors.add(Handler(func, allow_types=EVENT_PCS_PARAMS)) + _event_preprocessors.add( + Dependent[None].parse(call=func, allow_types=EVENT_PCS_PARAMS) + ) return func @@ -87,7 +92,9 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor: 事件后处理。装饰一个函数,使它在每次接收到事件并分发给各响应器之后执行。 """ - _event_postprocessors.add(Handler(func, allow_types=EVENT_PCS_PARAMS)) + _event_postprocessors.add( + Dependent[None].parse(call=func, allow_types=EVENT_PCS_PARAMS) + ) return func @@ -97,7 +104,9 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor: 运行预处理。装饰一个函数,使它在每次事件响应器运行前执行。 """ - _run_preprocessors.add(Handler(func, allow_types=RUN_PREPCS_PARAMS)) + _run_preprocessors.add( + Dependent[None].parse(call=func, allow_types=RUN_PREPCS_PARAMS) + ) return func @@ -107,7 +116,9 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor: 运行后处理。装饰一个函数,使它在每次事件响应器运行后执行。 """ - _run_postprocessors.add(Handler(func, allow_types=RUN_POSTPCS_PARAMS)) + _run_postprocessors.add( + Dependent[None].parse(call=func, allow_types=RUN_POSTPCS_PARAMS) + ) return func @@ -125,7 +136,7 @@ async def _check_matcher( event: "Event", state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[T_DependencyCache] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ) -> None: if Matcher.expire_time and datetime.now() > Matcher.expire_time: try: @@ -160,7 +171,7 @@ async def _run_matcher( event: "Event", state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[T_DependencyCache] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ) -> None: logger.info(f"Event will be handled by {Matcher}") @@ -174,8 +185,8 @@ async def _run_matcher( bot=bot, event=event, state=state, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) ), _run_preprocessors, @@ -216,8 +227,8 @@ async def _run_matcher( bot=bot, event=event, state=state, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) ), _run_postprocessors, @@ -264,7 +275,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None: logger.opt(colors=True).success(log_msg) state: Dict[Any, Any] = {} - dependency_cache: T_DependencyCache = {} + dependency_cache: CacheDict[T_Handler, Any] = CacheDict() async with AsyncExitStack() as stack: coros = list( @@ -274,8 +285,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None: bot=bot, event=event, state=state, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) ), _event_preprocessors, @@ -336,8 +347,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None: bot=bot, event=event, state=state, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) ), _event_postprocessors, diff --git a/nonebot/params.py b/nonebot/params.py index 79747ca4..b92e9a15 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -1,20 +1,160 @@ import inspect -from typing import Any, Dict, Optional +from typing import Any, List, Type, Callable, Optional, cast +from contextlib import AsyncExitStack, contextmanager, asynccontextmanager -from pydantic.fields import Undefined +from pydantic.fields import Required, Undefined -from nonebot.typing import T_State -from nonebot.dependencies import Param from nonebot.adapters import Bot, Event -from nonebot.utils import generic_check_issubclass +from nonebot.typing import T_State, T_Handler +from nonebot.dependencies import Param, Dependent +from nonebot.utils import ( + CacheDict, + get_name, + run_sync, + is_gen_callable, + run_sync_ctx_manager, + is_async_gen_callable, + is_coroutine_callable, + generic_check_issubclass, +) + + +class DependsInner: + def __init__( + self, + dependency: Optional[T_Handler] = None, + *, + use_cache: bool = True, + ) -> None: + self.dependency = dependency + self.use_cache = use_cache + + def __repr__(self) -> str: + dep = get_name(self.dependency) + cache = "" if self.use_cache else ", use_cache=False" + return f"{self.__class__.__name__}({dep}{cache})" + + +def Depends( + dependency: Optional[T_Handler] = None, + *, + use_cache: bool = True, +) -> Any: + """ + :说明: + + 参数依赖注入装饰器 + + :参数: + + * ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。 + * ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。 + * ``allow_types: Optional[List[Type[Param]]] = None``: 允许的参数类型。默认为 ``None``。 + + .. code-block:: python + + def depend_func() -> Any: + return ... + + def depend_gen_func(): + try: + yield ... + finally: + ... + + async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)): + ... + """ + return DependsInner(dependency, use_cache=use_cache) + + +class DependParam(Param): + @classmethod + def _check_param( + cls, + dependent: Dependent, + name: str, + param: inspect.Parameter, + ) -> Optional["DependParam"]: + if isinstance(param.default, DependsInner): + dependency: T_Handler + if param.default.dependency is None: + assert param.annotation is not param.empty, "Dependency cannot be empty" + dependency = param.annotation + else: + dependency = param.default.dependency + dependent = Dependent[Any].parse( + call=dependency, + allow_types=dependent.allow_types, + ) + return cls(Required, use_cache=param.default.use_cache, dependent=dependent) + + @classmethod + def _check_parameterless( + cls, dependent: "Dependent", value: Any + ) -> Optional["Param"]: + if isinstance(value, DependsInner): + assert value.dependency, "Dependency cannot be empty" + dependent = Dependent[Any].parse( + call=value.dependency, allow_types=dependent.allow_types + ) + return cls(Required, use_cache=value.use_cache, dependent=dependent) + + async def _solve( + self, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + **kwargs: Any, + ) -> Any: + use_cache: bool = self.extra["use_cache"] + dependency_cache = CacheDict() if dependency_cache is None else dependency_cache + + sub_dependent: Dependent = self.extra["dependent"] + sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) + call = sub_dependent.call + + # solve sub dependency with current cache + sub_values = await sub_dependent.solve( + stack=stack, + dependency_cache=dependency_cache, + **kwargs, + ) + + # run dependency function + async with dependency_cache: + if use_cache and call in dependency_cache: + solved = dependency_cache[call] + elif is_gen_callable(call) or is_async_gen_callable(call): + assert isinstance( + stack, AsyncExitStack + ), "Generator dependency should be called in context" + if is_gen_callable(call): + cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) + else: + cm = asynccontextmanager(call)(**sub_values) + solved = await stack.enter_async_context(cm) + elif is_coroutine_callable(call): + return await call(**sub_values) + else: + return await run_sync(call)(**sub_values) + + # save current dependency to cache + if call not in dependency_cache: + dependency_cache[call] = solved + + return solved class BotParam(Param): @classmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Bot) or ( - param.annotation == param.empty and name == "bot" - ) + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["BotParam"]: + if param.default == param.empty and ( + generic_check_issubclass(param.annotation, Bot) + or (param.annotation == param.empty and name == "bot") + ): + return cls(Required) def _solve(self, bot: Bot, **kwargs: Any) -> Any: return bot @@ -22,22 +162,34 @@ class BotParam(Param): class EventParam(Param): @classmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Event) or ( - param.annotation == param.empty and name == "event" - ) + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["EventParam"]: + if param.default == param.empty and ( + generic_check_issubclass(param.annotation, Event) + or (param.annotation == param.empty and name == "event") + ): + return cls(Required) def _solve(self, event: Event, **kwargs: Any) -> Any: return event -# FIXME: may detect error param +class StateInner: + ... + + +def State() -> Any: + return StateInner() + + class StateParam(Param): @classmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Dict) or ( - param.annotation == param.empty and name == "state" - ) + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["StateParam"]: + if isinstance(param.default, StateInner): + return cls(Required) def _solve(self, state: T_State, **kwargs: Any) -> Any: return state @@ -45,21 +197,27 @@ class StateParam(Param): class MatcherParam(Param): @classmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Matcher) or ( + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["MatcherParam"]: + if generic_check_issubclass(param.annotation, Matcher) or ( param.annotation == param.empty and name == "matcher" - ) + ): + return cls(Required) - def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any: + def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: return matcher class ExceptionParam(Param): @classmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Exception) or ( + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["ExceptionParam"]: + if generic_check_issubclass(param.annotation, Exception) or ( param.annotation == param.empty and name == "exception" - ) + ): + return cls(Required) def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: return exception @@ -67,8 +225,11 @@ class ExceptionParam(Param): class DefaultParam(Param): @classmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - return param.default != param.empty + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["DefaultParam"]: + if param.default != param.empty: + return cls(param.default) def _solve(self, **kwargs: Any) -> Any: return Undefined diff --git a/nonebot/permission.py b/nonebot/permission.py index 3f33dd5d..9ca25f64 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -13,6 +13,7 @@ import asyncio from contextlib import AsyncExitStack from typing import ( Any, + Set, Dict, Tuple, Union, @@ -23,10 +24,11 @@ from typing import ( ) from nonebot import params -from nonebot.handler import Handler +from nonebot.utils import CacheDict from nonebot.adapters import Bot, Event +from nonebot.dependencies import Dependent from nonebot.exception import SkippedException -from nonebot.typing import T_PermissionChecker +from nonebot.typing import T_Handler, T_PermissionChecker async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]): @@ -54,19 +56,26 @@ class Permission: __slots__ = ("checkers",) - HANDLER_PARAM_TYPES = [params.BotParam, params.EventParam, params.DefaultParam] + HANDLER_PARAM_TYPES = [ + params.DependParam, + params.BotParam, + params.EventParam, + params.DefaultParam, + ] - def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None: + def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None: """ :参数: - * ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker + * ``*checkers: Union[T_PermissionChecker, Dependent[bool]``: PermissionChecker """ - self.checkers = set( + self.checkers: Set[Dependent[bool]] = set( checker - if isinstance(checker, Handler) - else Handler(checker, allow_types=self.HANDLER_PARAM_TYPES) + if isinstance(checker, Dependent) + else Dependent[bool].parse( + call=checker, allow_types=self.HANDLER_PARAM_TYPES + ) for checker in checkers ) """ @@ -76,7 +85,7 @@ class Permission: :类型: - * ``Set[Handler]`` + * ``Set[Dependent[bool]]`` """ async def __call__( @@ -84,7 +93,7 @@ class Permission: bot: Bot, event: Event, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ) -> bool: """ :说明: @@ -96,7 +105,7 @@ class Permission: * ``bot: Bot``: Bot 对象 * ``event: Event``: Event 对象 * ``stack: Optional[AsyncExitStack]``: 异步上下文栈 - * ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存 + * ``dependency_cache: Optional[CacheDict[T_Handler, Any]]``: 依赖缓存 :返回: @@ -110,8 +119,8 @@ class Permission: checker( bot=bot, event=event, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) ) for checker in self.checkers diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 2c7446e6..f0a5523c 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -5,11 +5,16 @@ from types import ModuleType from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional from nonebot.adapters import Event -from nonebot.handler import Handler from nonebot.matcher import Matcher from .manager import _current_plugin from nonebot.permission import Permission -from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory +from nonebot.dependencies import Dependent +from nonebot.typing import ( + T_State, + T_Handler, + T_RuleChecker, + T_PermissionChecker, +) from nonebot.rule import ( PREFIX_KEY, RAW_CMD_KEY, @@ -43,9 +48,9 @@ def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]: def on( type: str = "", rule: Optional[Union[Rule, T_RuleChecker]] = None, - permission: Optional[Permission] = None, + permission: Optional[Union[Permission, T_PermissionChecker]] = None, *, - handlers: Optional[List[Union[T_Handler, Handler]]] = None, + handlers: Optional[List[Union[T_Handler, Dependent]]] = None, temp: bool = False, priority: int = 1, block: bool = False, @@ -61,8 +66,8 @@ def on( * ``type: str``: 事件响应器类型 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -75,7 +80,7 @@ def on( matcher = Matcher.new( type, Rule() & rule, - permission or Permission(), + Permission() | permission, temp=temp, priority=priority, block=block, @@ -91,7 +96,7 @@ def on( def on_metaevent( rule: Optional[Union[Rule, T_RuleChecker]] = None, *, - handlers: Optional[List[Union[T_Handler, Handler]]] = None, + handlers: Optional[List[Union[T_Handler, Dependent]]] = None, temp: bool = False, priority: int = 1, block: bool = False, @@ -106,7 +111,7 @@ def on_metaevent( :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -134,9 +139,9 @@ def on_metaevent( def on_message( rule: Optional[Union[Rule, T_RuleChecker]] = None, - permission: Optional[Permission] = None, + permission: Optional[Union[Permission, T_PermissionChecker]] = None, *, - handlers: Optional[List[Union[T_Handler, Handler]]] = None, + handlers: Optional[List[Union[T_Handler, Dependent]]] = None, temp: bool = False, priority: int = 1, block: bool = True, @@ -151,8 +156,8 @@ def on_message( :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -165,7 +170,7 @@ def on_message( matcher = Matcher.new( "message", Rule() & rule, - permission or Permission(), + Permission() | permission, temp=temp, priority=priority, block=block, @@ -181,7 +186,7 @@ def on_message( def on_notice( rule: Optional[Union[Rule, T_RuleChecker]] = None, *, - handlers: Optional[List[Union[T_Handler, Handler]]] = None, + handlers: Optional[List[Union[T_Handler, Dependent]]] = None, temp: bool = False, priority: int = 1, block: bool = False, @@ -196,7 +201,7 @@ def on_notice( :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -225,7 +230,7 @@ def on_notice( def on_request( rule: Optional[Union[Rule, T_RuleChecker]] = None, *, - handlers: Optional[List[Union[T_Handler, Handler]]] = None, + handlers: Optional[List[Union[T_Handler, Dependent]]] = None, temp: bool = False, priority: int = 1, block: bool = False, @@ -240,7 +245,7 @@ def on_request( :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -283,8 +288,8 @@ def on_startswith( * ``msg: Union[str, Tuple[str, ...]]``: 指定消息开头内容 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``ignorecase: bool``: 是否忽略大小写 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -314,8 +319,8 @@ def on_endswith( * ``msg: Union[str, Tuple[str, ...]]``: 指定消息结尾内容 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``ignorecase: bool``: 是否忽略大小写 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -343,8 +348,8 @@ def on_keyword( * ``keywords: Set[str]``: 关键词列表 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -376,8 +381,8 @@ def on_command( * ``cmd: Union[str, Tuple[str, ...]]``: 指定命令内容 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``aliases: Optional[Set[Union[str, Tuple[str, ...]]]]``: 命令别名 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -434,8 +439,8 @@ def on_shell_command( * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 * ``aliases: Optional[Set[Union[str, Tuple[str, ...]]]]``: 命令别名 * ``parser: Optional[ArgumentParser]``: ``nonebot.rule.ArgumentParser`` 对象 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -486,8 +491,8 @@ def on_regex( * ``pattern: str``: 正则表达式 * ``flags: Union[int, re.RegexFlag]``: 正则匹配标志 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -600,8 +605,8 @@ class MatcherGroup: * ``type: str``: 事件响应器类型 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -626,7 +631,7 @@ class MatcherGroup: :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -653,8 +658,8 @@ class MatcherGroup: :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -680,7 +685,7 @@ class MatcherGroup: :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -706,7 +711,7 @@ class MatcherGroup: :参数: * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -736,8 +741,8 @@ class MatcherGroup: * ``msg: Union[str, Tuple[str, ...]]``: 指定消息开头内容 * ``ignorecase: bool``: 是否忽略大小写 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -765,8 +770,8 @@ class MatcherGroup: * ``msg: Union[str, Tuple[str, ...]]``: 指定消息结尾内容 * ``ignorecase: bool``: 是否忽略大小写 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -793,8 +798,8 @@ class MatcherGroup: * ``keywords: Set[str]``: 关键词列表 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -829,8 +834,8 @@ class MatcherGroup: * ``cmd: Union[str, Tuple[str, ...]]``: 指定命令内容 * ``aliases: Optional[Set[Union[str, Tuple[str, ...]]]]``: 命令别名 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -869,8 +874,8 @@ class MatcherGroup: * ``aliases: Optional[Set[Union[str, Tuple[str, ...]]]]``: 命令别名 * ``parser: Optional[ArgumentParser]``: ``nonebot.rule.ArgumentParser`` 对象 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 @@ -904,8 +909,8 @@ class MatcherGroup: * ``pattern: str``: 正则表达式 * ``flags: Union[int, re.RegexFlag]``: 正则匹配标志 * ``rule: Optional[Union[Rule, T_RuleChecker]]``: 事件响应规则 - * ``permission: Optional[Permission]``: 事件响应权限 - * ``handlers: Optional[List[Union[T_Handler, Handler]]]``: 事件处理函数列表 + * ``permission: Optional[Union[Permission, T_PermissionChecker]] =]]``: 事件响应权限 + * ``handlers: Optional[List[Union[T_Handler, Dependent]]]``: 事件处理函数列表 * ``temp: bool``: 是否为临时事件响应器(仅执行一次) * ``priority: int``: 事件响应器优先级 * ``block: bool``: 是否阻止事件向更低优先级传递 diff --git a/nonebot/plugin/on.pyi b/nonebot/plugin/on.pyi index bf1222ef..8cf990fd 100644 --- a/nonebot/plugin/on.pyi +++ b/nonebot/plugin/on.pyi @@ -1,18 +1,23 @@ import re from typing import Set, List, Type, Tuple, Union, Optional -from nonebot.handler import Handler from nonebot.matcher import Matcher from nonebot.permission import Permission +from nonebot.dependencies import Dependent from nonebot.rule import Rule, ArgumentParser -from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory +from nonebot.typing import ( + T_State, + T_Handler, + T_RuleChecker, + T_PermissionChecker, +) def on( type: str = "", rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., *, - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -21,7 +26,7 @@ def on( def on_metaevent( rule: Optional[Union[Rule, T_RuleChecker]] = ..., *, - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -29,9 +34,9 @@ def on_metaevent( ) -> Type[Matcher]: ... def on_message( rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., *, - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -40,7 +45,7 @@ def on_message( def on_notice( rule: Optional[Union[Rule, T_RuleChecker]] = ..., *, - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -49,7 +54,7 @@ def on_notice( def on_request( rule: Optional[Union[Rule, T_RuleChecker]] = ..., *, - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -60,8 +65,8 @@ def on_startswith( rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ..., ignorecase: bool = ..., *, - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -72,8 +77,8 @@ def on_endswith( rule: Optional[Optional[Union[Rule, T_RuleChecker]]] = ..., ignorecase: bool = ..., *, - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -83,8 +88,8 @@ def on_keyword( keywords: Set[str], rule: Optional[Union[Rule, T_RuleChecker]] = ..., *, - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -95,8 +100,8 @@ def on_command( rule: Optional[Union[Rule, T_RuleChecker]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., *, - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -108,8 +113,8 @@ def on_shell_command( aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., parser: Optional[ArgumentParser] = ..., *, - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -120,8 +125,8 @@ def on_regex( flags: Union[int, re.RegexFlag] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ..., *, - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -134,8 +139,8 @@ class CommandGroup: cmd: Union[str, Tuple[str, ...]], *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -147,8 +152,8 @@ class CommandGroup: *, aliases: Optional[Set[Union[str, Tuple[str, ...]]]], rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -161,8 +166,8 @@ class CommandGroup: rule: Optional[Union[Rule, T_RuleChecker]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]], parser: Optional[ArgumentParser] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -175,8 +180,8 @@ class MatcherGroup: *, type: str = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -187,8 +192,8 @@ class MatcherGroup: *, type: str = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -198,7 +203,7 @@ class MatcherGroup: self, *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -208,8 +213,8 @@ class MatcherGroup: self, *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -219,7 +224,7 @@ class MatcherGroup: self, *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -229,7 +234,7 @@ class MatcherGroup: self, *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -241,8 +246,8 @@ class MatcherGroup: *, ignorecase: bool = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -254,8 +259,8 @@ class MatcherGroup: *, ignorecase: bool = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -266,8 +271,8 @@ class MatcherGroup: keywords: Set[str], *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -279,8 +284,8 @@ class MatcherGroup: aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -293,8 +298,8 @@ class MatcherGroup: parser: Optional[ArgumentParser] = ..., *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., @@ -306,8 +311,8 @@ class MatcherGroup: flags: Union[int, re.RegexFlag] = ..., *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - permission: Optional[Permission] = ..., - handlers: Optional[List[Union[T_Handler, Handler]]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., priority: int = ..., block: bool = ..., diff --git a/nonebot/rule.py b/nonebot/rule.py index 2408e4eb..a58d3125 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -17,27 +17,17 @@ from argparse import Namespace from contextlib import AsyncExitStack from typing_extensions import TypedDict from argparse import ArgumentParser as ArgParser -from typing import ( - Any, - Dict, - List, - Type, - Tuple, - Union, - Callable, - NoReturn, - Optional, - Sequence, -) +from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence from pygtrie import CharTrie from nonebot.log import logger -from nonebot.handler import Handler +from nonebot.utils import CacheDict from nonebot import params, get_driver -from nonebot.typing import T_State, T_RuleChecker +from nonebot.dependencies import Dependent from nonebot.adapters import Bot, Event, MessageSegment from nonebot.exception import ParserExit, SkippedException +from nonebot.typing import T_State, T_Handler, T_RuleChecker PREFIX_KEY = "_prefix" SUFFIX_KEY = "_suffix" @@ -74,23 +64,24 @@ class Rule: __slots__ = ("checkers",) HANDLER_PARAM_TYPES = [ + params.DependParam, params.BotParam, params.EventParam, params.StateParam, params.DefaultParam, ] - def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None: + def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None: """ :参数: - * ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker + * ``*checkers: Union[T_RuleChecker, Dependent[bool]]``: RuleChecker """ - self.checkers = set( + self.checkers: Set[Dependent[bool]] = set( checker - if isinstance(checker, Handler) - else Handler(checker, allow_types=self.HANDLER_PARAM_TYPES) + if isinstance(checker, Dependent) + else Dependent[bool](call=checker, allow_types=self.HANDLER_PARAM_TYPES) for checker in checkers ) """ @@ -100,7 +91,7 @@ class Rule: :类型: - * ``Set[Handler]`` + * ``Set[Dependent[bool]]`` """ async def __call__( @@ -109,7 +100,7 @@ class Rule: event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, + dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, ) -> bool: """ :说明: @@ -122,7 +113,7 @@ class Rule: * ``event: Event``: Event 对象 * ``state: T_State``: 当前 State * ``stack: Optional[AsyncExitStack]``: 异步上下文栈 - * ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存 + * ``dependency_cache: Optional[CacheDict[T_Handler, Any]]``: 依赖缓存 :返回: @@ -137,8 +128,8 @@ class Rule: bot=bot, event=event, state=state, - _stack=stack, - _dependency_cache=dependency_cache, + stack=stack, + dependency_cache=dependency_cache, ) for checker in self.checkers ) diff --git a/nonebot/typing.py b/nonebot/typing.py index 04632c72..bf9d1042 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -28,9 +28,11 @@ from typing import ( NoReturn, Optional, Awaitable, + ForwardRef, ) if TYPE_CHECKING: + from nonebot.utils import CacheDict from nonebot.adapters import Bot, Event from nonebot.permission import Permission @@ -53,14 +55,6 @@ T_State = Dict[Any, Any] 事件处理状态 State 类型 """ -T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]] -""" -:类型: ``Callable[[Bot, Event], Awaitable[T_State]]`` - -:说明: - - 事件处理状态 State 类工厂函数 -""" T_BotConnectionHook = Callable[["Bot"], Awaitable[None]] """ @@ -103,9 +97,11 @@ T_EventPreProcessor = Callable[..., Union[None, Awaitable[None]]] :依赖参数: + * ``DependParam``: 子依赖参数 * ``BotParam``: Bot 对象 * ``EventParam``: Event 对象 * ``StateParam``: State 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: @@ -117,9 +113,11 @@ T_EventPostProcessor = Callable[..., Union[None, Awaitable[None]]] :依赖参数: + * ``DependParam``: 子依赖参数 * ``BotParam``: Bot 对象 * ``EventParam``: Event 对象 * ``StateParam``: State 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: @@ -131,10 +129,12 @@ T_RunPreProcessor = Callable[..., Union[None, Awaitable[None]]] :依赖参数: + * ``DependParam``: 子依赖参数 * ``BotParam``: Bot 对象 * ``EventParam``: Event 对象 * ``StateParam``: State 对象 * ``MatcherParam``: Matcher 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: @@ -146,11 +146,13 @@ T_RunPostProcessor = Callable[..., Union[None, Awaitable[None]]] :依赖参数: + * ``DependParam``: 子依赖参数 * ``BotParam``: Bot 对象 * ``EventParam``: Event 对象 * ``StateParam``: State 对象 * ``MatcherParam``: Matcher 对象 * ``ExceptionParam``: 异常对象(可能为 None) + * ``DefaultParam``: 带有默认值的参数 :说明: @@ -163,9 +165,11 @@ T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]] :依赖参数: + * ``DependParam``: 子依赖参数 * ``BotParam``: Bot 对象 * ``EventParam``: Event 对象 * ``StateParam``: State 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: @@ -177,8 +181,10 @@ T_PermissionChecker = Callable[..., Union[bool, Awaitable[bool]]] :依赖参数: + * ``DependParam``: 子依赖参数 * ``BotParam``: Bot 对象 * ``EventParam``: Event 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: @@ -193,37 +199,52 @@ T_Handler = Callable[..., Any] Handler 处理函数。 """ -T_DependencyCache = Dict[T_Handler, Any] +T_ArgsParser = Callable[..., Union[None, Awaitable[None]]] """ -:类型: ``Dict[T_Handler, Any]`` +:类型: ``Callable[..., Union[None, Awaitable[None]]]`` -:说明: +:依赖参数: - 依赖缓存, 用于存储依赖函数的返回值 -""" -T_ArgsParser = Callable[ - ["Bot", "Event", T_State], Union[Awaitable[None], Awaitable[NoReturn]] -] -""" -:类型: ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` + * ``DependParam``: 子依赖参数 + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 + * ``MatcherParam``: Matcher 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: ArgsParser 即消息参数解析函数,在 Matcher.got 获取参数时被运行。 """ -T_TypeUpdater = Callable[["Bot", "Event", T_State, str], Awaitable[str]] +T_TypeUpdater = Callable[..., Union[str, Awaitable[str]]] """ -:类型: ``Callable[[Bot, Event, T_State, str], Awaitable[str]]`` +:类型: ``Callable[..., Union[None, Awaitable[None]]]`` + +:依赖参数: + + * ``DependParam``: 子依赖参数 + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 + * ``MatcherParam``: Matcher 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: TypeUpdater 在 Matcher.pause, Matcher.reject 时被运行,用于更新响应的事件类型。默认会更新为 ``message``。 """ -T_PermissionUpdater = Callable[ - ["Bot", "Event", T_State, "Permission"], Awaitable["Permission"] -] +T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]]] """ -:类型: ``Callable[[Bot, Event, T_State, Permission], Awaitable[Permission]]`` +:类型: ``Callable[..., Union[Permission, Awaitable[Permission]]]`` + +:依赖参数: + + * ``DependParam``: 子依赖参数 + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 + * ``MatcherParam``: Matcher 对象 + * ``DefaultParam``: 带有默认值的参数 :说明: diff --git a/nonebot/utils.py b/nonebot/utils.py index 1c1f2e67..075ba667 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -2,17 +2,17 @@ import re import json import asyncio import inspect -import collections import dataclasses from functools import wraps, partial from contextlib import asynccontextmanager from typing_extensions import ParamSpec, get_args, get_origin from typing import ( Any, + Dict, Type, - Deque, Tuple, Union, + Generic, TypeVar, Callable, Optional, @@ -27,6 +27,8 @@ from nonebot.typing import overrides P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") def escape_tag(s: str) -> str: @@ -133,77 +135,31 @@ def get_name(obj: Any) -> str: return obj.__class__.__name__ -class CacheLock: - def __init__(self): - self._waiters: Optional[Deque[asyncio.Future]] = None - self._locked = False +class CacheDict(Dict[K, V], Generic[K, V]): + def __init__(self, *args, **kwargs): + super(CacheDict, self).__init__(*args, **kwargs) + self._lock = asyncio.Lock() + + @property + def locked(self): + return self._lock.locked() def __repr__(self): - extra = "locked" if self._locked else "unlocked" - if self._waiters: - extra = f"{extra}, waiters: {len(self._waiters)}" + extra = "locked" if self.locked else "unlocked" return f"<{self.__class__.__name__} [{extra}]>" - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() return None async def __aexit__(self, exc_type, exc, tb): self.release() - def locked(self): - return self._locked - async def acquire(self): - if not self._locked and ( - self._waiters is None or all(w.cancelled() for w in self._waiters) - ): - self._locked = True - return True - - if self._waiters is None: - self._waiters = collections.deque() - - loop = asyncio.get_running_loop() - future = loop.create_future() - self._waiters.append(future) - - # Finally block should be called before the CancelledError - # handling as we don't want CancelledError to call - # _wake_up_first() and attempt to wake up itself. - try: - try: - await future - finally: - self._waiters.remove(future) - except asyncio.CancelledError: - if not self._locked: - self._wake_up_first() - raise - - self._locked = True - return True + return await self._lock.acquire() def release(self): - if self._locked: - self._locked = False - self._wake_up_first() - else: - raise RuntimeError("Lock is not acquired.") - - def _wake_up_first(self): - if not self._waiters: - return - try: - future = next(iter(self._waiters)) - except StopIteration: - return - - # .done() necessarily means that a waiter will wake up later on and - # either take the lock, or, if it was cancelled and lock wasn't - # taken already, will hit this again and wake up a new waiter. - if not future.done(): - future.set_result(True) + self._lock.release() class DataclassEncoder(json.JSONEncoder): diff --git a/tests/plugins/depends.py b/tests/plugins/depends.py new file mode 100644 index 00000000..36f3c3fc --- /dev/null +++ b/tests/plugins/depends.py @@ -0,0 +1,23 @@ +from nonebot import on_message +from nonebot.adapters import Event +from nonebot.params import Depends + +test = on_message() +test2 = on_message() + +runned = False + + +def dependency(event: Event): + # test cache + global runned + assert not runned + runned = True + return event + + +@test.handle() +@test2.handle() +async def handle(x: Event = Depends(dependency)): + # test dependency + return x diff --git a/tests/test_init.py b/tests/test_init.py index b5f6ac96..61dbdae7 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,6 +1,12 @@ import os +import sys +from typing import TYPE_CHECKING, Set import pytest +from utils import load_plugin + +if TYPE_CHECKING: + from nonebot.plugin import Plugin os.environ["CONFIG_FROM_ENV"] = "env" @@ -17,3 +23,14 @@ async def test_init(nonebug_init): assert config.config_from_env == "env" assert config.config_from_init == "init" assert config.common_config == "common" + + +@pytest.mark.asyncio +async def test_load_plugin(load_plugin: Set["Plugin"]): + import nonebot + + assert nonebot.get_loaded_plugins() == load_plugin + plugin = nonebot.get_plugin("depends") + assert plugin + assert plugin.module_name == "plugins.depends" + assert "plugins.depends" in sys.modules diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..0223b365 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,14 @@ +from pathlib import Path +from typing import TYPE_CHECKING, Set + +import pytest + +if TYPE_CHECKING: + from nonebot.plugin import Plugin + + +@pytest.fixture +def load_plugin(nonebug_init: None) -> Set["Plugin"]: + import nonebot + + return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))