From 4cbdd726e51dfa476c2d75f388e9f01f7c9e51dd Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Tue, 16 Nov 2021 18:30:16 +0800 Subject: [PATCH] :building_construction: change code structure --- docs_build/dependencies.rst | 12 ++ nonebot/__init__.py | 1 + nonebot/adapters/_bot.py | 10 -- .../{processor => dependencies}/__init__.py | 117 +++++++++--------- nonebot/{processor => dependencies}/models.py | 28 ++++- nonebot/{processor => dependencies}/utils.py | 26 +--- nonebot/{processor => }/handler.py | 93 ++++++++------ nonebot/{processor => }/matcher.py | 28 +++-- nonebot/message.py | 60 ++++++--- nonebot/{processor => }/params.py | 55 +++----- nonebot/plugin/on.py | 3 +- nonebot/plugin/on.pyi | 3 +- nonebot/plugin/plugin.py | 2 +- nonebot/plugins/echo.py | 10 +- nonebot/plugins/single_session.py | 7 +- nonebot/typing.py | 12 +- nonebot/utils.py | 25 +++- tests/test_plugins/test_depends.py | 2 +- tests/test_plugins/test_processor.py | 8 +- 19 files changed, 276 insertions(+), 226 deletions(-) create mode 100644 docs_build/dependencies.rst rename nonebot/{processor => dependencies}/__init__.py (71%) rename nonebot/{processor => dependencies}/models.py (59%) rename nonebot/{processor => dependencies}/utils.py (57%) rename nonebot/{processor => }/handler.py (52%) rename nonebot/{processor => }/matcher.py (96%) rename nonebot/{processor => }/params.py (53%) diff --git a/docs_build/dependencies.rst b/docs_build/dependencies.rst new file mode 100644 index 00000000..4db2e19a --- /dev/null +++ b/docs_build/dependencies.rst @@ -0,0 +1,12 @@ +\-\-\- +contentSidebar: true +sidebarDepth: 0 +\-\-\- + +NoneBot.handler 模块 +==================== + +.. automodule:: nonebot.dependencies + :members: + :private-members: + :show-inheritance: diff --git a/nonebot/__init__.py b/nonebot/__init__.py index a9c6d350..809b696c 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -278,6 +278,7 @@ def run(host: Optional[str] = None, get_driver().run(host, port, *args, **kwargs) +import nonebot.params as params from nonebot.plugin import export as export from nonebot.plugin import require as require from nonebot.plugin import on_regex as on_regex diff --git a/nonebot/adapters/_bot.py b/nonebot/adapters/_bot.py index fba3a134..b96e31ae 100644 --- a/nonebot/adapters/_bot.py +++ b/nonebot/adapters/_bot.py @@ -55,16 +55,6 @@ class Bot(abc.ABC): def __getattr__(self, name: str) -> _ApiCall: return partial(self.call_api, name) - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - if not isinstance(v, cls): - raise TypeError(f"{v} is not an instance of {cls}") - return v - @property @abc.abstractmethod def type(self) -> str: diff --git a/nonebot/processor/__init__.py b/nonebot/dependencies/__init__.py similarity index 71% rename from nonebot/processor/__init__.py rename to nonebot/dependencies/__init__.py index 9860fc8d..bbf835fb 100644 --- a/nonebot/processor/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -1,3 +1,10 @@ +""" +依赖注入处理模块 +=============== + +该模块实现了依赖注入的定义与处理。 +""" + import inspect from itertools import chain from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast @@ -7,33 +14,38 @@ from pydantic import BaseConfig from pydantic.fields import Required, ModelField from pydantic.schema import get_annotation_from_field_info -from .models import Dependent from nonebot.log import logger -from nonebot.typing import T_State +from .models import Param as Param from .utils import get_typed_signature -from nonebot.adapters import Bot, Event -from .models import Depends as DependsClass +from .models import Dependent as Dependent +from .models import DependsWrapper as DependsWrapper from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, is_async_gen_callable, is_coroutine_callable) -def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: - depends: DependsClass = param.default +class CustomConfig(BaseConfig): + arbitrary_types_allowed = True + + +def get_param_sub_dependent( + *, + param: inspect.Parameter, + allow_types: Optional[List[Type[Param]]] = None) -> Dependent: + depends: DependsWrapper = param.default if depends.dependency: dependency = depends.dependency else: dependency = param.annotation - return get_sub_dependant( - depends=depends, - dependency=dependency, - name=param.name, - ) + return get_sub_dependant(depends=depends, + dependency=dependency, + name=param.name, + allow_types=allow_types) def get_parameterless_sub_dependant( *, - depends: DependsClass, - allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: + depends: DependsWrapper, + allow_types: Optional[List[Type[Param]]] = None) -> Dependent: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" @@ -44,10 +56,10 @@ def get_parameterless_sub_dependant( def get_sub_dependant( *, - depends: DependsClass, + depends: DependsWrapper, dependency: Callable[..., Any], name: Optional[str] = None, - allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: + allow_types: Optional[List[Type[Param]]] = None) -> Dependent: sub_dependant = get_dependent(func=dependency, name=name, use_cache=depends.use_cache, @@ -55,32 +67,32 @@ def get_sub_dependant( return sub_dependant -def get_dependent( - *, - func: Callable[..., Any], - name: Optional[str] = None, - use_cache: bool = True, - allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: +def get_dependent(*, + func: Callable[..., Any], + name: Optional[str] = None, + use_cache: bool = True, + allow_types: Optional[List[Type[Param]]] = None) -> Dependent: signature = get_typed_signature(func) params = signature.parameters - allow_types = allow_types or [ - ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE - ] - dependent = Dependent(func=func, name=name, use_cache=use_cache) + dependent = Dependent(func=func, + name=name, + allow_types=allow_types, + use_cache=use_cache) for param_name, param in params.items(): - if isinstance(param.default, DependsClass): - sub_dependent = get_param_sub_dependent(param=param) + if isinstance(param.default, DependsWrapper): + sub_dependent = get_param_sub_dependent(param=param, + allow_types=allow_types) dependent.dependencies.append(sub_dependent) continue - for allow_type in allow_types: - field_info_class: Type[Param] = allow_type.value - if field_info_class._check(param_name, param): - field_info = field_info_class(param.default) + for allow_type in dependent.allow_types: + if allow_type._check(param_name, param): + field_info = allow_type(param.default) break else: raise ValueError( - f"Unknown parameter {param_name} with type {param.annotation}") + f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}" + ) annotation: Any = Any if param.annotation != param.empty: @@ -91,7 +103,7 @@ def get_dependent( ModelField(name=param_name, type_=annotation, class_validators=None, - model_config=BaseConfig, + model_config=CustomConfig, default=Required, required=True, field_info=field_info)) @@ -102,15 +114,11 @@ def get_dependent( async def solve_dependencies( *, dependent: Dependent, - bot: Bot, - event: Event, - state: T_State, - matcher: Optional["Matcher"] = None, - exception: Optional[Exception] = None, stack: Optional[AsyncExitStack] = None, sub_dependents: Optional[List[Dependent]] = None, dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, + **params: Any ) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]: values: Dict[str, Any] = {} dependency_cache = dependency_cache or {} @@ -135,18 +143,15 @@ async def solve_dependencies( use_sub_dependant = get_dependent( func=func, name=sub_dependent.name, + allow_types=sub_dependent.allow_types, ) # solve sub dependency with current cache solved_result = await solve_dependencies( dependent=use_sub_dependant, - bot=bot, - event=event, - state=state, - matcher=matcher, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, - ) + **params) sub_values, sub_dependency_cache, ignored = solved_result if ignored: return values, dependency_cache, True @@ -182,18 +187,13 @@ async def solve_dependencies( field_info = field.field_info assert isinstance(field_info, Param), "Params must be subclasses of Param" - value = field_info._solve(bot=bot, - event=event, - state=state, - matcher=matcher, - exception=exception) + value = field_info._solve(**params) _, errs_ = field.validate(value, values, - loc=(ParamTypes(type(field_info)).name, - field.alias)) + loc=(str(field_info), field.alias)) if errs_: logger.debug( - f"Matcher {matcher} {ParamTypes(type(field_info)).name} " + f"{field_info} " f"type {type(value)} not match depends {dependent.func} " f"annotation {field._type_display()}, ignored") return values, dependency_cache, True @@ -206,11 +206,14 @@ async def solve_dependencies( def Depends(dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True) -> Any: - return DependsClass(dependency=dependency, use_cache=use_cache) + """ + :说明: + 参数依赖注入装饰器 -from .params import Param -from .handler import Handler as Handler -from .matcher import Matcher as Matcher -from .matcher import matchers as matchers -from .params import ParamTypes as ParamTypes + :参数: + + * ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数。默认为参数的类型注释。 + * ``use_cache: bool = True``: 是否使用缓存。默认为 ``True``。 + """ + return DependsWrapper(dependency=dependency, use_cache=use_cache) diff --git a/nonebot/processor/models.py b/nonebot/dependencies/models.py similarity index 59% rename from nonebot/processor/models.py rename to nonebot/dependencies/models.py index 06d11890..ca764f9b 100644 --- a/nonebot/processor/models.py +++ b/nonebot/dependencies/models.py @@ -1,11 +1,31 @@ -from typing import Any, List, Callable, Optional +import abc +import inspect +from typing import Any, List, Type, Callable, Optional -from pydantic.fields import ModelField +from pydantic.fields import FieldInfo, ModelField from nonebot.utils import get_name -class Depends: +class Param(FieldInfo, abc.ABC): + + def __repr__(self) -> str: + return f"{self.__class__.__name__}" + + def __str__(self) -> str: + return repr(self) + + @classmethod + @abc.abstractmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + raise NotImplementedError + + @abc.abstractmethod + def _solve(self, **kwargs: Any) -> Any: + raise NotImplementedError + + +class DependsWrapper: def __init__(self, dependency: Optional[Callable[..., Any]] = None, @@ -27,11 +47,13 @@ class Dependent: func: Optional[Callable[..., Any]] = None, name: Optional[str] = None, params: Optional[List[ModelField]] = None, + allow_types: Optional[List[Type[Param]]] = None, dependencies: Optional[List["Dependent"]] = None, use_cache: bool = True) -> None: self.func = func self.name = name self.params = params or [] + self.allow_types = allow_types or [] self.dependencies = dependencies or [] self.use_cache = use_cache self.cache_key = self.func diff --git a/nonebot/processor/utils.py b/nonebot/dependencies/utils.py similarity index 57% rename from nonebot/processor/utils.py rename to nonebot/dependencies/utils.py index 6f13e96c..ed08b228 100644 --- a/nonebot/processor/utils.py +++ b/nonebot/dependencies/utils.py @@ -1,6 +1,5 @@ import inspect -from typing import Any, Dict, Type, Tuple, Union, Callable -from typing_extensions import GenericAlias, get_args, get_origin # type: ignore +from typing import Any, Dict, Callable from loguru import logger from pydantic.typing import ForwardRef, evaluate_forwardref @@ -34,26 +33,3 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, ) return inspect.Parameter.empty return annotation - - -def generic_check_issubclass( - cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], - ...]]) -> bool: - try: - return issubclass(cls, class_or_tuple) - except TypeError: - if get_origin(cls) is Union: - for type_ in get_args(cls): - if not generic_check_issubclass(type_, class_or_tuple): - return False - return True - elif isinstance(cls, GenericAlias): - origin = get_origin(cls) - return bool(origin and issubclass(origin, class_or_tuple)) - raise - - -def generic_get_types(cls: Any) -> Tuple[Type[Any], ...]: - if get_origin(cls) is Union: - return get_args(cls) - return (cls,) diff --git a/nonebot/processor/handler.py b/nonebot/handler.py similarity index 52% rename from nonebot/processor/handler.py rename to nonebot/handler.py index fed8d92c..232f2200 100644 --- a/nonebot/processor/handler.py +++ b/nonebot/handler.py @@ -7,83 +7,94 @@ import asyncio from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Type, Callable, Optional -from .models import Depends, Dependent +from nonebot.typing import T_Handler from nonebot.utils import get_name, run_sync -from nonebot.typing import T_State, T_Handler -from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant +from nonebot.dependencies import (Param, Dependent, DependsWrapper, + get_dependent, solve_dependencies, + get_parameterless_sub_dependant) if TYPE_CHECKING: - from .matcher import Matcher - from .params import ParamTypes + from nonebot.matcher import Matcher from nonebot.adapters import Bot, Event class Handler: - """事件处理函数类""" + """事件处理器类。支持依赖注入。""" def __init__(self, func: T_Handler, *, name: Optional[str] = None, - dependencies: Optional[List[Depends]] = None, - allow_types: Optional[List["ParamTypes"]] = None, + dependencies: Optional[List[DependsWrapper]] = None, + allow_types: Optional[List[Type[Param]]] = None, dependency_overrides_provider: Optional[Any] = None): - """装饰事件处理函数以便根据动态参数运行""" - self.func: T_Handler = func + """ + :说明: + + 装饰一个函数为事件处理器。 + + :参数: + + * ``func: T_Handler``: 事件处理函数。 + * ``name: Optional[str]``: 事件处理器名称。默认为函数名。 + * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。 + * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。 + * ``dependency_overrides_provider: Optional[Any]``: 依赖注入覆盖提供者。 + """ + self.func = func """ :类型: ``T_Handler`` :说明: 事件处理函数 """ self.name = get_name(func) if name is None else name - self.allow_types = allow_types + """ + :类型: ``str`` + :说明: 事件处理函数名 + """ + self.allow_types = allow_types or [] + """ + :类型: ``List[Type[Param]]`` + :说明: 事件处理器允许的参数类型 + """ self.dependencies = dependencies or [] + """ + :类型: ``List[DependsWrapper]`` + :说明: 事件处理器的额外依赖 + """ self.sub_dependents: Dict[Callable[..., Any], Dependent] = {} if dependencies: for depends in dependencies: - if not depends.dependency: - raise ValueError(f"{depends} has no dependency") - if depends.dependency in self.sub_dependents: - raise ValueError(f"{depends} is already in dependencies") - sub_dependant = get_parameterless_sub_dependant( - depends=depends, allow_types=self.allow_types) - self.sub_dependents[depends.dependency] = sub_dependant + self.cache_dependent(depends) self.dependency_overrides_provider = dependency_overrides_provider self.dependent = get_dependent(func=func, allow_types=self.allow_types) def __repr__(self) -> str: return ( - f"" + f"" ) def __str__(self) -> str: return repr(self) - async def __call__( - self, - matcher: "Matcher", - bot: "Bot", - event: "Event", - state: T_State, - *, - stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[Dict[Callable[..., Any], - Any]] = None) -> Any: + async def __call__(self, + *, + _stack: Optional[AsyncExitStack] = None, + _dependency_cache: Optional[Dict[Callable[..., Any], + Any]] = None, + **params) -> Any: values, _, ignored = await solve_dependencies( dependent=self.dependent, - bot=bot, - event=event, - state=state, - matcher=matcher, - stack=stack, + stack=_stack, sub_dependents=[ self.sub_dependents[dependency.dependency] # type: ignore for dependency in self.dependencies ], dependency_overrides_provider=self.dependency_overrides_provider, - dependency_cache=dependency_cache) + dependency_cache=_dependency_cache, + **params) if ignored: return @@ -93,24 +104,24 @@ class Handler: else: await run_sync(self.func)(**values) - def cache_dependent(self, dependency: Depends): + def cache_dependent(self, dependency: DependsWrapper): if not dependency.dependency: raise ValueError(f"{dependency} has no dependency") - if (dependency.dependency,) in self.sub_dependents: + if dependency.dependency in self.sub_dependents: raise ValueError(f"{dependency} is already in dependencies") sub_dependant = get_parameterless_sub_dependant( depends=dependency, allow_types=self.allow_types) self.sub_dependents[dependency.dependency] = sub_dependant - def prepend_dependency(self, dependency: Depends): + def prepend_dependency(self, dependency: DependsWrapper): self.cache_dependent(dependency) self.dependencies.insert(0, dependency) - def append_dependency(self, dependency: Depends): + def append_dependency(self, dependency: DependsWrapper): self.cache_dependent(dependency) self.dependencies.append(dependency) - def remove_dependency(self, dependency: Depends): + def remove_dependency(self, dependency: DependsWrapper): if not dependency.dependency: raise ValueError(f"{dependency} has no dependency") if dependency.dependency in self.sub_dependents: diff --git a/nonebot/processor/matcher.py b/nonebot/matcher.py similarity index 96% rename from nonebot/processor/matcher.py rename to nonebot/matcher.py index 4ae3ca21..8856a3cf 100644 --- a/nonebot/processor/matcher.py +++ b/nonebot/matcher.py @@ -12,12 +12,11 @@ from collections import defaultdict from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, NoReturn, Optional) -from .models import Depends -from .handler import Handler from nonebot.rule import Rule -from .params import ParamTypes -from nonebot import get_driver from nonebot.log import logger +from nonebot.handler import Handler +from nonebot import params, get_driver +from nonebot.dependencies import DependsWrapper from nonebot.permission import USER, Permission from nonebot.adapters import (Bot, Event, Message, MessageSegment, MessageTemplate) @@ -155,7 +154,8 @@ class Matcher(metaclass=MatcherMeta): """ HANDLER_PARAM_TYPES = [ - ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE, ParamTypes.MATCHER + params.BotParam, params.EventParam, params.StateParam, + params.MatcherParam ] def __init__(self): @@ -350,9 +350,10 @@ class Matcher(metaclass=MatcherMeta): return func @classmethod - def append_handler(cls, - handler: T_Handler, - dependencies: Optional[List[Depends]] = None) -> Handler: + def append_handler( + cls, + handler: T_Handler, + dependencies: Optional[List[DependsWrapper]] = None) -> Handler: handler_ = Handler(handler, dependencies=dependencies, dependency_overrides_provider=get_driver(), @@ -398,7 +399,7 @@ class Matcher(metaclass=MatcherMeta): def _decorator(func: T_Handler) -> T_Handler: - depend = Depends(_receive) + depend = DependsWrapper(_receive) if cls.handlers and cls.handlers[-1].func is func: func_handler = cls.handlers[-1] @@ -461,8 +462,8 @@ class Matcher(metaclass=MatcherMeta): def _decorator(func: T_Handler) -> T_Handler: - get_depend = Depends(_key_getter) - parser_depend = Depends(_key_parser) + get_depend = DependsWrapper(_key_getter) + parser_depend = DependsWrapper(_key_parser) if cls.handlers and cls.handlers[-1].func is func: func_handler = cls.handlers[-1] @@ -600,7 +601,10 @@ class Matcher(metaclass=MatcherMeta): while self.handlers: handler = self.handlers.pop(0) logger.debug(f"Running handler {handler}") - await handler(self, bot, event, self.state) + await handler(matcher=self, + bot=bot, + event=event, + state=self.state) except RejectedException: self.handlers.insert(0, handler) # type: ignore diff --git a/nonebot/message.py b/nonebot/message.py index a74003fe..fec00687 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -12,8 +12,10 @@ from typing import TYPE_CHECKING, Set, Type from nonebot.log import logger from nonebot.rule import TrieRule +from nonebot.handler import Handler from nonebot.utils import escape_tag -from nonebot.processor import Matcher, matchers +from nonebot import params, get_driver +from nonebot.matcher import Matcher, matchers from nonebot.exception import NoLogException, StopPropagation, IgnoredException from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor, T_EventPreProcessor, T_EventPostProcessor) @@ -21,10 +23,19 @@ from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor, if TYPE_CHECKING: from nonebot.adapters import Bot, Event -_event_preprocessors: Set[T_EventPreProcessor] = set() -_event_postprocessors: Set[T_EventPostProcessor] = set() -_run_preprocessors: Set[T_RunPreProcessor] = set() -_run_postprocessors: Set[T_RunPostProcessor] = set() +_event_preprocessors: Set[Handler] = set() +_event_postprocessors: Set[Handler] = set() +_run_preprocessors: Set[Handler] = set() +_run_postprocessors: Set[Handler] = set() + +EVENT_PCS_PARAMS = [params.BotParam, params.EventParam, params.StateParam] +RUN_PREPCS_PARAMS = [ + params.MatcherParam, params.BotParam, params.EventParam, params.StateParam +] +RUN_POSTPCS_PARAMS = [ + params.MatcherParam, params.ExceptionParam, params.BotParam, + params.EventParam, params.StateParam +] def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor: @@ -41,7 +52,10 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor: * ``event: Event``: Event 对象 * ``state: T_State``: 当前 State """ - _event_preprocessors.add(func) + _event_preprocessors.add( + Handler(func, + allow_types=EVENT_PCS_PARAMS, + dependency_overrides_provider=get_driver())) return func @@ -59,7 +73,10 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor: * ``event: Event``: Event 对象 * ``state: T_State``: 当前事件运行前 State """ - _event_postprocessors.add(func) + _event_postprocessors.add( + Handler(func, + allow_types=EVENT_PCS_PARAMS, + dependency_overrides_provider=get_driver())) return func @@ -78,7 +95,10 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor: * ``event: Event``: Event 对象 * ``state: T_State``: 当前 State """ - _run_preprocessors.add(func) + _run_preprocessors.add( + Handler(func, + allow_types=RUN_PREPCS_PARAMS, + dependency_overrides_provider=get_driver())) return func @@ -98,7 +118,10 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor: * ``event: Event``: Event 对象 * ``state: T_State``: 当前 State """ - _run_postprocessors.add(func) + _run_postprocessors.add( + Handler(func, + allow_types=RUN_POSTPCS_PARAMS, + dependency_overrides_provider=get_driver())) return func @@ -136,7 +159,8 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event", matcher = Matcher() coros = list( - map(lambda x: x(matcher, bot, event, state), _run_preprocessors)) + map(lambda x: x(matcher=matcher, bot=bot, event=event, state=state), + _run_preprocessors)) if coros: try: await asyncio.gather(*coros) @@ -162,8 +186,12 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event", exception = e coros = list( - map(lambda x: x(matcher, exception, bot, event, state), - _run_postprocessors)) + map( + lambda x: x(matcher=matcher, + exception=exception, + bot=bot, + event=event, + state=state), _run_postprocessors)) if coros: try: await asyncio.gather(*coros) @@ -208,7 +236,9 @@ async def handle_event(bot: "Bot", event: "Event") -> None: # TODO async with AsyncExitStack() as stack: - coros = list(map(lambda x: x(bot, event, state), _event_preprocessors)) + coros = list( + map(lambda x: x(bot=bot, event=event, state=state), + _event_preprocessors)) if coros: try: if show_log: @@ -255,7 +285,9 @@ async def handle_event(bot: "Bot", event: "Event") -> None: "Error when checking Matcher." ) - coros = list(map(lambda x: x(bot, event, state), _event_postprocessors)) + coros = list( + map(lambda x: x(bot=bot, event=event, state=state), + _event_postprocessors)) if coros: try: if show_log: diff --git a/nonebot/processor/params.py b/nonebot/params.py similarity index 53% rename from nonebot/processor/params.py rename to nonebot/params.py index f777889c..8b644f91 100644 --- a/nonebot/processor/params.py +++ b/nonebot/params.py @@ -1,38 +1,19 @@ -import abc import inspect -from enum import Enum from typing import Any, Dict, Optional -from pydantic.fields import FieldInfo - from nonebot.typing import T_State +from nonebot.dependencies import Param from nonebot.adapters import Bot, Event -from .utils import generic_check_issubclass - - -class Param(FieldInfo, abc.ABC): - - def __repr__(self) -> str: - return f"{self.__class__.__name__}" - - def __str__(self) -> str: - return repr(self) - - @classmethod - @abc.abstractmethod - def _check(cls, name: str, param: inspect.Parameter) -> bool: - raise NotImplementedError - - @abc.abstractmethod - def _solve(self, **kwargs: Any) -> Any: - raise NotImplementedError +from nonebot.utils import generic_check_issubclass class BotParam(Param): @classmethod def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Bot) + return generic_check_issubclass( + param.annotation, Bot) or (param.annotation == param.empty and + name == "bot") def _solve(self, bot: Bot, **kwargs: Any) -> Any: return bot @@ -42,7 +23,9 @@ class EventParam(Param): @classmethod def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Event) + return generic_check_issubclass( + param.annotation, Event) or (param.annotation == param.empty and + name == "event") def _solve(self, event: Event, **kwargs: Any) -> Any: return event @@ -52,7 +35,9 @@ class StateParam(Param): @classmethod def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Dict) + return generic_check_issubclass( + param.annotation, Dict) or (param.annotation == param.empty and + name == "state") def _solve(self, state: T_State, **kwargs: Any) -> Any: return state @@ -62,7 +47,9 @@ class MatcherParam(Param): @classmethod def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Matcher) + return generic_check_issubclass( + param.annotation, Matcher) or (param.annotation == param.empty and + name == "matcher") def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any: return matcher @@ -72,7 +59,9 @@ class ExceptionParam(Param): @classmethod def _check(cls, name: str, param: inspect.Parameter) -> bool: - return generic_check_issubclass(param.annotation, Exception) + return generic_check_issubclass( + param.annotation, Exception) or (param.annotation == param.empty and + name == "exception") def _solve(self, exception: Optional[Exception] = None, @@ -80,12 +69,4 @@ class ExceptionParam(Param): return exception -class ParamTypes(Enum): - BOT = BotParam - EVENT = EventParam - STATE = StateParam - MATCHER = MatcherParam - EXCEPTION = ExceptionParam - - -from .matcher import Matcher +from nonebot.matcher import Matcher diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 40038984..626e4ed1 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -4,10 +4,11 @@ import inspect from types import ModuleType from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional +from nonebot.handler import Handler +from nonebot.matcher import Matcher from .manager import _current_plugin from nonebot.adapters import Bot, Event from nonebot.permission import Permission -from nonebot.processor import Handler, Matcher from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, endswith, startswith, shell_command) diff --git a/nonebot/plugin/on.pyi b/nonebot/plugin/on.pyi index fbbf9c90..68e8ad62 100644 --- a/nonebot/plugin/on.pyi +++ b/nonebot/plugin/on.pyi @@ -1,9 +1,10 @@ import re from typing import Set, List, Type, Tuple, Union, Optional +from nonebot.handler import Handler +from nonebot.matcher import Matcher from nonebot.permission import Permission from nonebot.rule import Rule, ArgumentParser -from nonebot.processor import Handler, Matcher from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory diff --git a/nonebot/plugin/plugin.py b/nonebot/plugin/plugin.py index bf7773ac..ee5a7c51 100644 --- a/nonebot/plugin/plugin.py +++ b/nonebot/plugin/plugin.py @@ -3,7 +3,7 @@ from dataclasses import field, dataclass from typing import Set, Dict, Type, Optional from .export import Export -from nonebot.processor import Matcher +from nonebot.matcher import Matcher plugins: Dict[str, "Plugin"] = {} """ diff --git a/nonebot/plugins/echo.py b/nonebot/plugins/echo.py index 5e61fd95..440aa4bf 100644 --- a/nonebot/plugins/echo.py +++ b/nonebot/plugins/echo.py @@ -3,14 +3,14 @@ from functools import reduce from nonebot.rule import to_me from nonebot.plugin import on_command from nonebot.permission import SUPERUSER -from nonebot.adapters.cqhttp import (Bot, Message, MessageEvent, MessageSegment, +from nonebot.adapters.cqhttp import (Message, MessageEvent, MessageSegment, unescape) say = on_command("say", to_me(), permission=SUPERUSER) @say.handle() -async def say_unescape(bot: Bot, event: MessageEvent): +async def say_unescape(event: MessageEvent): def _unescape(message: Message, segment: MessageSegment): if segment.is_text(): @@ -18,12 +18,12 @@ async def say_unescape(bot: Bot, event: MessageEvent): return message.append(segment) message = reduce(_unescape, event.get_message(), Message()) # type: ignore - await bot.send(message=message, event=event) + await say.send(message=message) echo = on_command("echo", to_me()) @echo.handle() -async def echo_escape(bot: Bot, event: MessageEvent): - await bot.send(message=event.get_message(), event=event) +async def echo_escape(event: MessageEvent): + await say.send(message=event.get_message()) diff --git a/nonebot/plugins/single_session.py b/nonebot/plugins/single_session.py index 1b6d4116..81a9a9d1 100644 --- a/nonebot/plugins/single_session.py +++ b/nonebot/plugins/single_session.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from nonebot.typing import T_State -from nonebot.processor import Matcher +from nonebot.matcher import Matcher from nonebot.adapters import Bot, Event from nonebot.message import (IgnoredException, run_preprocessor, run_postprocessor) @@ -10,7 +10,7 @@ _running_matcher: Dict[str, int] = {} @run_preprocessor -async def preprocess(matcher: Matcher, bot: Bot, event: Event, state: T_State): +async def preprocess(event: Event): try: session_id = event.get_session_id() except Exception: @@ -24,8 +24,7 @@ async def preprocess(matcher: Matcher, bot: Bot, event: Event, state: T_State): @run_postprocessor -async def postprocess(matcher: Matcher, exception: Optional[Exception], - bot: Bot, event: Event, state: T_State): +async def postprocess(event: Event): try: session_id = event.get_session_id() except Exception: diff --git a/nonebot/typing.py b/nonebot/typing.py index 053fcff7..3725b301 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -22,7 +22,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable, NoReturn, Optional, Awaitable) if TYPE_CHECKING: - from nonebot.processor import Matcher + from nonebot.matcher import Matcher from nonebot.adapters import Bot, Event from nonebot.permission import Permission @@ -90,7 +90,7 @@ T_CalledAPIHook = Callable[ ``bot.call_api`` 后执行的函数,参数分别为 bot, exception, api, data, result """ -T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] +T_EventPreProcessor = Callable[..., Awaitable[None]] """ :类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` @@ -98,7 +98,7 @@ T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] 事件预处理函数 EventPreProcessor 类型 """ -T_EventPostProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] +T_EventPostProcessor = Callable[..., Awaitable[None]] """ :类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` @@ -106,8 +106,7 @@ T_EventPostProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] 事件预处理函数 EventPostProcessor 类型 """ -T_RunPreProcessor = Callable[["Matcher", "Bot", "Event", T_State], - Awaitable[None]] +T_RunPreProcessor = Callable[..., Awaitable[None]] """ :类型: ``Callable[[Matcher, Bot, Event, T_State], Awaitable[None]]`` @@ -115,8 +114,7 @@ T_RunPreProcessor = Callable[["Matcher", "Bot", "Event", T_State], 事件响应器运行前预处理函数 RunPreProcessor 类型 """ -T_RunPostProcessor = Callable[ - ["Matcher", Optional[Exception], "Bot", "Event", T_State], Awaitable[None]] +T_RunPostProcessor = Callable[..., Awaitable[None]] """ :类型: ``Callable[[Matcher, Optional[Exception], Bot, Event, T_State], Awaitable[None]]`` diff --git a/nonebot/utils.py b/nonebot/utils.py index 2786ef1d..66e905e6 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -4,10 +4,11 @@ import asyncio import inspect import dataclasses from functools import wraps, partial -from typing_extensions import ParamSpec from contextlib import asynccontextmanager -from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator, - ContextManager) +from typing_extensions import GenericAlias # type: ignore +from typing_extensions import ParamSpec, get_args, get_origin +from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional, + Awaitable, AsyncGenerator, ContextManager) from nonebot.log import logger from nonebot.typing import overrides @@ -34,6 +35,24 @@ def escape_tag(s: str) -> str: return re.sub(r"\s]*)>", r"\\\g<0>", s) +def generic_check_issubclass( + cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], + ...]]) -> bool: + try: + return issubclass(cls, class_or_tuple) + except TypeError: + if get_origin(cls) is Union: + for type_ in get_args(cls): + if type_ is not type(None) and not generic_check_issubclass( + type_, class_or_tuple): + return False + return True + elif isinstance(cls, GenericAlias): + origin = get_origin(cls) + return bool(origin and issubclass(origin, class_or_tuple)) + raise + + def is_coroutine_callable(func: Callable[..., Any]) -> bool: if inspect.isroutine(func): return inspect.iscoroutinefunction(func) diff --git a/tests/test_plugins/test_depends.py b/tests/test_plugins/test_depends.py index 77580374..c604169b 100644 --- a/tests/test_plugins/test_depends.py +++ b/tests/test_plugins/test_depends.py @@ -1,6 +1,6 @@ from nonebot import on_command from nonebot.log import logger -from nonebot.processor import Depends +from nonebot.dependencies import Depends test = on_command("123") diff --git a/tests/test_plugins/test_processor.py b/tests/test_plugins/test_processor.py index 5eb6b70e..438dda96 100644 --- a/tests/test_plugins/test_processor.py +++ b/tests/test_plugins/test_processor.py @@ -1,15 +1,15 @@ +from nonebot.adapters import Event from nonebot.typing import T_State -from nonebot.processor import Matcher -from nonebot.adapters import Bot, Event +from nonebot.matcher import Matcher from nonebot.message import run_preprocessor, event_preprocessor @event_preprocessor -async def handle(bot: Bot, event: Event, state: T_State): +async def handle(event: Event, state: T_State): state["preprocessed"] = True print(type(event), event) @run_preprocessor -async def run(matcher: Matcher, bot: Bot, event: Event, state: T_State): +async def run(matcher: Matcher): print(matcher)