From d1c6eeb6c2b731c7fd315c59d58592231a30b322 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 15 Nov 2021 21:44:24 +0800 Subject: [PATCH] :recycle: allow dynamic param types --- nonebot/adapters/_bot.py | 10 +++ nonebot/message.py | 110 +++++++++++++------------- nonebot/processor/__init__.py | 141 ++++++++++++++++++---------------- nonebot/processor/handler.py | 32 +++----- nonebot/processor/matcher.py | 12 ++- nonebot/processor/models.py | 21 ++--- nonebot/processor/params.py | 91 ++++++++++++++++++++++ 7 files changed, 260 insertions(+), 157 deletions(-) create mode 100644 nonebot/processor/params.py diff --git a/nonebot/adapters/_bot.py b/nonebot/adapters/_bot.py index b96e31ae..fba3a134 100644 --- a/nonebot/adapters/_bot.py +++ b/nonebot/adapters/_bot.py @@ -55,6 +55,16 @@ class Bot(abc.ABC): def __getattr__(self, name: str) -> _ApiCall: return partial(self.call_api, name) + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + if not isinstance(v, cls): + raise TypeError(f"{v} is not an instance of {cls}") + return v + @property @abc.abstractmethod def type(self) -> str: diff --git a/nonebot/message.py b/nonebot/message.py index f03a7c6d..a74003fe 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -7,7 +7,8 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供 import asyncio from datetime import datetime -from typing import TYPE_CHECKING, Set, Type, Optional +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Set, Type from nonebot.log import logger from nonebot.rule import TrieRule @@ -204,58 +205,63 @@ async def handle_event(bot: "Bot", event: "Event") -> None: logger.opt(colors=True).success(log_msg) state = {} - coros = list(map(lambda x: x(bot, event, state), _event_preprocessors)) - if coros: - try: - if show_log: - logger.debug("Running PreProcessors...") - await asyncio.gather(*coros) - except IgnoredException as e: - logger.opt(colors=True).info( - f"Event {escape_tag(event.get_event_name())} is ignored") - return - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running EventPreProcessors. " - "Event ignored!") - return - # Trie Match - _, _ = TrieRule.get_value(bot, event, state) - - break_flag = False - for priority in sorted(matchers.keys()): - if break_flag: - break - - if show_log: - logger.debug(f"Checking for matchers in priority {priority}...") - - pending_tasks = [ - _check_matcher(priority, matcher, bot, event, state.copy()) - for matcher in matchers[priority] - ] - - results = await asyncio.gather(*pending_tasks, return_exceptions=True) - - for result in results: - if not isinstance(result, Exception): - continue - if isinstance(result, StopPropagation): - break_flag = True - logger.debug("Stop event propagation") - else: - logger.opt(colors=True, exception=result).error( - "Error when checking Matcher." + # TODO + async with AsyncExitStack() as stack: + coros = list(map(lambda x: x(bot, event, state), _event_preprocessors)) + if coros: + try: + if show_log: + logger.debug("Running PreProcessors...") + await asyncio.gather(*coros) + except IgnoredException as e: + logger.opt(colors=True).info( + f"Event {escape_tag(event.get_event_name())} is ignored" ) + return + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running EventPreProcessors. " + "Event ignored!") + return + + # Trie Match + _, _ = TrieRule.get_value(bot, event, state) + + break_flag = False + for priority in sorted(matchers.keys()): + if break_flag: + break - coros = list(map(lambda x: x(bot, event, state), _event_postprocessors)) - if coros: - try: if show_log: - logger.debug("Running PostProcessors...") - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running EventPostProcessors" - ) + logger.debug(f"Checking for matchers in priority {priority}...") + + pending_tasks = [ + _check_matcher(priority, matcher, bot, event, state.copy()) + for matcher in matchers[priority] + ] + + results = await asyncio.gather(*pending_tasks, + return_exceptions=True) + + for result in results: + if not isinstance(result, Exception): + continue + if isinstance(result, StopPropagation): + break_flag = True + logger.debug("Stop event propagation") + else: + logger.opt(colors=True, exception=result).error( + "Error when checking Matcher." + ) + + coros = list(map(lambda x: x(bot, event, state), _event_postprocessors)) + if coros: + try: + if show_log: + logger.debug("Running PostProcessors...") + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running EventPostProcessors" + ) diff --git a/nonebot/processor/__init__.py b/nonebot/processor/__init__.py index 8b344b01..9860fc8d 100644 --- a/nonebot/processor/__init__.py +++ b/nonebot/processor/__init__.py @@ -1,15 +1,18 @@ import inspect from itertools import chain -from typing import Any, Dict, List, Tuple, Callable, Optional, cast +from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager +from pydantic import BaseConfig +from pydantic.fields import Required, ModelField +from pydantic.schema import get_annotation_from_field_info + from .models import Dependent from nonebot.log import logger from nonebot.typing import T_State +from .utils import get_typed_signature from nonebot.adapters import Bot, Event from .models import Depends as DependsClass -from .utils import (generic_get_types, get_typed_signature, - generic_check_issubclass) from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, is_async_gen_callable, is_coroutine_callable) @@ -27,33 +30,42 @@ def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: ) -def get_parameterless_sub_dependant(*, depends: DependsClass) -> Dependent: +def get_parameterless_sub_dependant( + *, + depends: DependsClass, + allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" - return get_sub_dependant(depends=depends, dependency=depends.dependency) + return get_sub_dependant(depends=depends, + dependency=depends.dependency, + allow_types=allow_types) def get_sub_dependant( - *, - depends: DependsClass, - dependency: Callable[..., Any], - name: Optional[str] = None, -) -> Dependent: - sub_dependant = get_dependent( - func=dependency, - name=name, - use_cache=depends.use_cache, - ) + *, + depends: DependsClass, + dependency: Callable[..., Any], + name: Optional[str] = None, + allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: + sub_dependant = get_dependent(func=dependency, + name=name, + use_cache=depends.use_cache, + allow_types=allow_types) return sub_dependant -def get_dependent(*, - func: Callable[..., Any], - name: Optional[str] = None, - use_cache: bool = True) -> Dependent: +def get_dependent( + *, + func: Callable[..., Any], + name: Optional[str] = None, + use_cache: bool = True, + allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: signature = get_typed_signature(func) params = signature.parameters + allow_types = allow_types or [ + ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE + ] dependent = Dependent(func=func, name=name, use_cache=use_cache) for param_name, param in params.items(): if isinstance(param.default, DependsClass): @@ -61,33 +73,29 @@ def get_dependent(*, dependent.dependencies.append(sub_dependent) continue - if generic_check_issubclass(param.annotation, Bot): - if dependent.bot_param_name is not None: - raise ValueError(f"{func} has more than one Bot parameter: " - f"{dependent.bot_param_name} / {param_name}") - dependent.bot_param_name = param_name - dependent.bot_param_type = generic_get_types(param.annotation) - elif generic_check_issubclass(param.annotation, Event): - if dependent.event_param_name is not None: - raise ValueError(f"{func} has more than one Event parameter: " - f"{dependent.event_param_name} / {param_name}") - dependent.event_param_name = param_name - dependent.event_param_type = generic_get_types(param.annotation) - elif generic_check_issubclass(param.annotation, Dict): - if dependent.state_param_name is not None: - raise ValueError(f"{func} has more than one State parameter: " - f"{dependent.state_param_name} / {param_name}") - dependent.state_param_name = param_name - elif generic_check_issubclass(param.annotation, Matcher): - if dependent.matcher_param_name is not None: - raise ValueError( - f"{func} has more than one Matcher parameter: " - f"{dependent.matcher_param_name} / {param_name}") - dependent.matcher_param_name = param_name + for allow_type in allow_types: + field_info_class: Type[Param] = allow_type.value + if field_info_class._check(param_name, param): + field_info = field_info_class(param.default) + break else: raise ValueError( f"Unknown parameter {param_name} with type {param.annotation}") + annotation: Any = Any + if param.annotation != param.empty: + annotation = param.annotation + annotation = get_annotation_from_field_info(annotation, field_info, + param_name) + dependent.params.append( + ModelField(name=param_name, + type_=annotation, + class_validators=None, + model_config=BaseConfig, + default=Required, + required=True, + field_info=field_info)) + return dependent @@ -97,7 +105,8 @@ async def solve_dependencies( bot: Bot, event: Event, state: T_State, - matcher: Optional["Matcher"], + matcher: Optional["Matcher"] = None, + exception: Optional[Exception] = None, stack: Optional[AsyncExitStack] = None, sub_dependents: Optional[List[Dependent]] = None, dependency_overrides_provider: Optional[Any] = None, @@ -115,20 +124,6 @@ async def solve_dependencies( sub_dependent.cache_key) func = sub_dependent.func - # check bot and event type - if sub_dependent.bot_param_type and not isinstance( - bot, sub_dependent.bot_param_type): - logger.debug( - f"Matcher {matcher} bot type {type(bot)} not match depends {func} " - f"annotation {sub_dependent.bot_param_type}, ignored") - return values, dependency_cache, True - elif sub_dependent.event_param_type and not isinstance( - event, sub_dependent.event_param_type): - logger.debug( - f"Matcher {matcher} event type {type(event)} not match depends {func} " - f"annotation {sub_dependent.event_param_type}, ignored") - return values, dependency_cache, True - # dependency overrides use_sub_dependant = sub_dependent if (dependency_overrides_provider and @@ -183,14 +178,28 @@ async def solve_dependencies( dependency_cache[sub_dependent.cache_key] = solved # usual dependency - if dependent.bot_param_name is not None: - values[dependent.bot_param_name] = bot - if dependent.event_param_name is not None: - values[dependent.event_param_name] = event - if dependent.state_param_name is not None: - values[dependent.state_param_name] = state - if dependent.matcher_param_name is not None: - values[dependent.matcher_param_name] = matcher + for field in dependent.params: + field_info = field.field_info + assert isinstance(field_info, + Param), "Params must be subclasses of Param" + value = field_info._solve(bot=bot, + event=event, + state=state, + matcher=matcher, + exception=exception) + _, errs_ = field.validate(value, + values, + loc=(ParamTypes(type(field_info)).name, + field.alias)) + if errs_: + logger.debug( + f"Matcher {matcher} {ParamTypes(type(field_info)).name} " + f"type {type(value)} not match depends {dependent.func} " + f"annotation {field._type_display()}, ignored") + return values, dependency_cache, True + else: + values[field.name] = value + return values, dependency_cache, False @@ -200,6 +209,8 @@ def Depends(dependency: Optional[Callable[..., Any]] = None, return DependsClass(dependency=dependency, use_cache=use_cache) +from .params import Param from .handler import Handler as Handler from .matcher import Matcher as Matcher from .matcher import matchers as matchers +from .params import ParamTypes as ParamTypes diff --git a/nonebot/processor/handler.py b/nonebot/processor/handler.py index 57e38af5..fed8d92c 100644 --- a/nonebot/processor/handler.py +++ b/nonebot/processor/handler.py @@ -9,7 +9,6 @@ import asyncio from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional -from nonebot.log import logger from .models import Depends, Dependent from nonebot.utils import get_name, run_sync from nonebot.typing import T_State, T_Handler @@ -17,6 +16,7 @@ from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant if TYPE_CHECKING: from .matcher import Matcher + from .params import ParamTypes from nonebot.adapters import Bot, Event @@ -28,6 +28,7 @@ class Handler: *, name: Optional[str] = None, dependencies: Optional[List[Depends]] = None, + allow_types: Optional[List["ParamTypes"]] = None, dependency_overrides_provider: Optional[Any] = None): """装饰事件处理函数以便根据动态参数运行""" self.func: T_Handler = func @@ -36,6 +37,7 @@ class Handler: :说明: 事件处理函数 """ self.name = get_name(func) if name is None else name + self.allow_types = allow_types self.dependencies = dependencies or [] self.sub_dependents: Dict[Callable[..., Any], Dependent] = {} @@ -45,18 +47,16 @@ class Handler: raise ValueError(f"{depends} has no dependency") if depends.dependency in self.sub_dependents: raise ValueError(f"{depends} is already in dependencies") - sub_dependant = get_parameterless_sub_dependant(depends=depends) + sub_dependant = get_parameterless_sub_dependant( + depends=depends, allow_types=self.allow_types) self.sub_dependents[depends.dependency] = sub_dependant self.dependency_overrides_provider = dependency_overrides_provider - self.dependent = get_dependent(func=func) + self.dependent = get_dependent(func=func, allow_types=self.allow_types) def __repr__(self) -> str: return ( - f"") + f"" + ) def __str__(self) -> str: return repr(self) @@ -88,19 +88,6 @@ class Handler: if ignored: return - # check bot and event type - if self.dependent.bot_param_type and not isinstance( - bot, self.dependent.bot_param_type): - logger.debug(f"Matcher {matcher} bot type {type(bot)} not match " - f"annotation {self.dependent.bot_param_type}, ignored") - return - elif self.dependent.event_param_type and not isinstance( - event, self.dependent.event_param_type): - logger.debug( - f"Matcher {matcher} event type {type(event)} not match " - f"annotation {self.dependent.event_param_type}, ignored") - return - if asyncio.iscoroutinefunction(self.func): await self.func(**values) else: @@ -111,7 +98,8 @@ class Handler: raise ValueError(f"{dependency} has no dependency") if (dependency.dependency,) in self.sub_dependents: raise ValueError(f"{dependency} is already in dependencies") - sub_dependant = get_parameterless_sub_dependant(depends=dependency) + sub_dependant = get_parameterless_sub_dependant( + depends=dependency, allow_types=self.allow_types) self.sub_dependents[dependency.dependency] = sub_dependant def prepend_dependency(self, dependency: Depends): diff --git a/nonebot/processor/matcher.py b/nonebot/processor/matcher.py index 9fbceaa9..4ae3ca21 100644 --- a/nonebot/processor/matcher.py +++ b/nonebot/processor/matcher.py @@ -15,6 +15,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, from .models import Depends from .handler import Handler from nonebot.rule import Rule +from .params import ParamTypes from nonebot import get_driver from nonebot.log import logger from nonebot.permission import USER, Permission @@ -153,6 +154,10 @@ class Matcher(metaclass=MatcherMeta): :说明: 事件响应器权限更新函数 """ + HANDLER_PARAM_TYPES = [ + ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE, ParamTypes.MATCHER + ] + def __init__(self): """实例化 Matcher 以便运行""" self.handlers = self.handlers.copy() @@ -230,7 +235,9 @@ class Matcher(metaclass=MatcherMeta): permission or Permission(), "handlers": [ handler if isinstance(handler, Handler) else Handler( - handler, dependency_overrides_provider=get_driver()) + handler, + dependency_overrides_provider=get_driver(), + allow_types=cls.HANDLER_PARAM_TYPES) for handler in handlers ] if handlers else [], "temp": @@ -348,7 +355,8 @@ class Matcher(metaclass=MatcherMeta): dependencies: Optional[List[Depends]] = None) -> Handler: handler_ = Handler(handler, dependencies=dependencies, - dependency_overrides_provider=get_driver()) + dependency_overrides_provider=get_driver(), + allow_types=cls.HANDLER_PARAM_TYPES) cls.handlers.append(handler_) return handler_ diff --git a/nonebot/processor/models.py b/nonebot/processor/models.py index 9413fb8a..06d11890 100644 --- a/nonebot/processor/models.py +++ b/nonebot/processor/models.py @@ -1,10 +1,9 @@ -from typing import TYPE_CHECKING, Any, List, Type, Tuple, Callable, Optional +from typing import Any, List, Callable, Optional + +from pydantic.fields import ModelField from nonebot.utils import get_name -if TYPE_CHECKING: - from nonebot.adapters import Bot, Event - class Depends: @@ -27,22 +26,12 @@ class Dependent: *, func: Optional[Callable[..., Any]] = None, name: Optional[str] = None, - bot_param_name: Optional[str] = None, - bot_param_type: Optional[Tuple[Type["Bot"], ...]] = None, - event_param_name: Optional[str] = None, - event_param_type: Optional[Tuple[Type["Event"], ...]] = None, - state_param_name: Optional[str] = None, - matcher_param_name: Optional[str] = None, + params: Optional[List[ModelField]] = None, dependencies: Optional[List["Dependent"]] = None, use_cache: bool = True) -> None: self.func = func self.name = name - self.bot_param_name = bot_param_name - self.bot_param_type = bot_param_type - self.event_param_name = event_param_name - self.event_param_type = event_param_type - self.state_param_name = state_param_name - self.matcher_param_name = matcher_param_name + self.params = params or [] self.dependencies = dependencies or [] self.use_cache = use_cache self.cache_key = self.func diff --git a/nonebot/processor/params.py b/nonebot/processor/params.py new file mode 100644 index 00000000..f777889c --- /dev/null +++ b/nonebot/processor/params.py @@ -0,0 +1,91 @@ +import abc +import inspect +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic.fields import FieldInfo + +from nonebot.typing import T_State +from nonebot.adapters import Bot, Event +from .utils import generic_check_issubclass + + +class Param(FieldInfo, abc.ABC): + + def __repr__(self) -> str: + return f"{self.__class__.__name__}" + + def __str__(self) -> str: + return repr(self) + + @classmethod + @abc.abstractmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + raise NotImplementedError + + @abc.abstractmethod + def _solve(self, **kwargs: Any) -> Any: + raise NotImplementedError + + +class BotParam(Param): + + @classmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + return generic_check_issubclass(param.annotation, Bot) + + def _solve(self, bot: Bot, **kwargs: Any) -> Any: + return bot + + +class EventParam(Param): + + @classmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + return generic_check_issubclass(param.annotation, Event) + + def _solve(self, event: Event, **kwargs: Any) -> Any: + return event + + +class StateParam(Param): + + @classmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + return generic_check_issubclass(param.annotation, Dict) + + def _solve(self, state: T_State, **kwargs: Any) -> Any: + return state + + +class MatcherParam(Param): + + @classmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + return generic_check_issubclass(param.annotation, Matcher) + + def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any: + return matcher + + +class ExceptionParam(Param): + + @classmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + return generic_check_issubclass(param.annotation, Exception) + + def _solve(self, + exception: Optional[Exception] = None, + **kwargs: Any) -> Any: + return exception + + +class ParamTypes(Enum): + BOT = BotParam + EVENT = EventParam + STATE = StateParam + MATCHER = MatcherParam + EXCEPTION = ExceptionParam + + +from .matcher import Matcher