From c454cf08743dead351ee8edb94391cdc60afcabd Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Fri, 12 Nov 2021 18:10:40 +0800 Subject: [PATCH] :construction: process handler dependency --- nonebot/dependencies/__init__.py | 2 + nonebot/dependencies/models.py | 89 +++++++++++++++++ nonebot/dependencies/utils.py | 150 ++++++++++++++++++++++++++++ nonebot/handler.py | 161 ++----------------------------- nonebot/typing.py | 13 +-- nonebot/utils.py | 7 ++ 6 files changed, 256 insertions(+), 166 deletions(-) create mode 100644 nonebot/dependencies/__init__.py create mode 100644 nonebot/dependencies/models.py create mode 100644 nonebot/dependencies/utils.py diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py new file mode 100644 index 00000000..95e6a040 --- /dev/null +++ b/nonebot/dependencies/__init__.py @@ -0,0 +1,2 @@ +from .models import Depends as Depends +from .utils import get_dependent as get_dependent diff --git a/nonebot/dependencies/models.py b/nonebot/dependencies/models.py new file mode 100644 index 00000000..4232ccec --- /dev/null +++ b/nonebot/dependencies/models.py @@ -0,0 +1,89 @@ +from enum import Enum +from typing import Any, List, Callable, Optional + +from pydantic.fields import Required, FieldInfo, ModelField + +from nonebot.utils import get_name + + +class Depends: + + def __init__(self, + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True) -> None: + self.dependency = dependency + self.use_cache = use_cache + + def __repr__(self) -> str: + dep = get_name(self.dependency) + cache = "" if self.use_cache else ", use_cache=False" + return f"{self.__class__.__name__}({dep}{cache})" + + +class Dependent: + + def __init__(self, + *, + func: Optional[Callable[..., Any]] = None, + name: Optional[str] = None, + bot_param: Optional[ModelField] = None, + event_param: Optional[ModelField] = None, + state_param: Optional[ModelField] = None, + matcher_param: Optional[ModelField] = None, + simple_params: Optional[List[ModelField]] = None, + dependencies: Optional[List["Dependent"]] = None, + use_cache: bool = True) -> None: + self.func = func + self.name = name + self.bot_param = bot_param + self.event_param = event_param + self.state_param = state_param + self.matcher_param = matcher_param + self.simple_params = simple_params or [] + self.dependencies = dependencies or [] + self.use_cache = use_cache + self.cache_key = (self.func,) + + +class ParamTypes(Enum): + BOT = "bot" + EVENT = "event" + STATE = "state" + MATCHER = "matcher" + SIMPLE = "simple" + + +class Param(FieldInfo): + in_: ParamTypes + + def __init__(self, default: Any): + super().__init__(default=default) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}" + + +class BotParam(Param): + in_ = ParamTypes.BOT + + +class EventParam(Param): + in_ = ParamTypes.EVENT + + +class StateParam(Param): + in_ = ParamTypes.STATE + + +class MatcherParam(Param): + in_ = ParamTypes.MATCHER + + +class SimpleParam(Param): + in_ = ParamTypes.SIMPLE + + def __init__(self, default: Any): + if default is Required: + raise ValueError("SimpleParam should be given a default value") + super().__init__(default) diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py new file mode 100644 index 00000000..9503fbf1 --- /dev/null +++ b/nonebot/dependencies/utils.py @@ -0,0 +1,150 @@ +import inspect +from typing import Any, Dict, Type, Union, Callable, Optional, ForwardRef + +from pydantic import BaseConfig +from pydantic.class_validators import Validator +from pydantic.typing import evaluate_forwardref +from pydantic.schema import get_annotation_from_field_info +from pydantic.fields import Required, FieldInfo, ModelField, UndefinedType + +from .models import Param, Depends, Dependent, ParamTypes, SimpleParam + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param, globalns), + ) for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature + + +def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, + Any]) -> Any: + annotation = param.annotation + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + +def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: + depends: Depends = param.default + if depends.dependency: + dependency = depends.dependency + else: + dependency = param.annotation + return get_sub_dependant( + depends=depends, + dependency=dependency, + name=param.name, + ) + + +def get_parameterless_sub_dependant(*, depends: Depends) -> Dependent: + assert callable( + depends.dependency + ), "A parameter-less dependency must have a callable dependency" + return get_sub_dependant(depends=depends, dependency=depends.dependency) + + +def get_sub_dependant( + *, + depends: Depends, + dependency: Callable[..., Any], + name: Optional[str] = None, +) -> Dependent: + sub_dependant = get_dependent( + func=dependency, + name=name, + use_cache=depends.use_cache, + ) + return sub_dependant + + +def get_dependent(*, + func: Callable[..., Any], + name: Optional[str] = None, + use_cache: bool = True) -> Dependent: + signature = get_typed_signature(func) + params = signature.parameters + dependent = Dependent(func=func, name=name, use_cache=use_cache) + for param_name, param in params.items(): + if isinstance(param.default, Depends): + sub_dependent = get_param_sub_dependent(param=param) + dependent.dependencies.append(sub_dependent) + continue + param_field = get_param_field(param=param, + param_name=param_name, + default_field_info=SimpleParam) + + return dependent + + +def get_param_field(*, + param: inspect.Parameter, + param_name: str, + default_field_info: Type[Param] = Param, + force_type: Optional[ParamTypes] = None, + ignore_default: bool = False) -> ModelField: + default_value = Required + if param.default != param.empty and not ignore_default: + default_value = param.default + if isinstance(default_value, FieldInfo): + field_info = default_value + default_value = field_info.default + if (isinstance(field_info, Param) and + getattr(field_info, "in_", None) is None): + field_info.in_ = default_field_info.in_ + if force_type: + field_info.in_ = force_type # type: ignore + else: + field_info = default_field_info(default_value) + required: bool = default_value == Required + annotation: Any = Any + if param.annotation != param.empty: + annotation = param.annotation + annotation = get_annotation_from_field_info(annotation, field_info, + param_name) + if not field_info.alias and getattr(field_info, "convert_underscores", + None): + alias = param.name.replace("_", "-") + else: + alias = field_info.alias or param.name + field = create_field( + name=param.name, + type_=annotation, + default=None if required else default_value, + alias=alias, + required=required, + field_info=field_info, + ) + # field.required = required + + return field + + +def create_field(name: str, + type_: Type[Any], + class_validators: Optional[Dict[str, Validator]] = None, + default: Optional[Any] = None, + required: Union[bool, UndefinedType] = False, + model_config: Type[BaseConfig] = BaseConfig, + field_info: Optional[FieldInfo] = None, + alias: Optional[str] = None) -> ModelField: + class_validators = class_validators or {} + field_info = field_info or FieldInfo(None) + return ModelField(name=name, + type_=type_, + class_validators=class_validators, + model_config=model_config, + default=default, + required=required, + alias=alias, + field_info=field_info) diff --git a/nonebot/handler.py b/nonebot/handler.py index 055e6f88..ee31d4cf 100644 --- a/nonebot/handler.py +++ b/nonebot/handler.py @@ -6,171 +6,22 @@ """ import inspect -from typing import _eval_type # type: ignore -from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Optional, - ForwardRef) +from typing import Optional -from nonebot.log import logger -from nonebot.typing import T_State, T_Handler +from pydantic.typing import evaluate_forwardref -if TYPE_CHECKING: - from nonebot.matcher import Matcher - from nonebot.adapters import Bot, Event +from nonebot.utils import get_name +from nonebot.typing import T_Handler class Handler: """事件处理函数类""" - def __init__(self, func: T_Handler): + def __init__(self, func: T_Handler, *, name: Optional[str] = None): """装饰事件处理函数以便根据动态参数运行""" self.func: T_Handler = func """ :类型: ``T_Handler`` :说明: 事件处理函数 """ - self.signature: inspect.Signature = self.get_signature() - """ - :类型: ``inspect.Signature`` - :说明: 事件处理函数签名 - """ - - def __repr__(self) -> str: - return (f"") - - def __str__(self) -> str: - return repr(self) - - async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event", - state: T_State): - BotType = ((self.bot_type is not inspect.Parameter.empty) and - inspect.isclass(self.bot_type) and self.bot_type) - if BotType and not isinstance(bot, BotType): - logger.debug( - f"Matcher {matcher} bot type {type(bot)} not match annotation {BotType}, ignored" - ) - return - - EventType = ((self.event_type is not inspect.Parameter.empty) and - inspect.isclass(self.event_type) and self.event_type) - if EventType and not isinstance(event, EventType): - logger.debug( - f"Matcher {matcher} event type {type(event)} not match annotation {EventType}, ignored" - ) - return - - args = {"bot": bot, "event": event, "state": state, "matcher": matcher} - await self.func( - **{ - k: v - for k, v in args.items() - if self.signature.parameters.get(k, None) is not None - }) - - @property - def bot_type(self) -> Union[Type["Bot"], inspect.Parameter.empty]: - """ - :类型: ``Union[Type["Bot"], inspect.Parameter.empty]`` - :说明: 事件处理函数接受的 Bot 对象类型""" - return self.signature.parameters["bot"].annotation - - @property - def event_type( - self) -> Optional[Union[Type["Event"], inspect.Parameter.empty]]: - """ - :类型: ``Optional[Union[Type[Event], inspect.Parameter.empty]]`` - :说明: 事件处理函数接受的 event 类型 / 不需要 event 参数 - """ - if "event" not in self.signature.parameters: - return None - return self.signature.parameters["event"].annotation - - @property - def state_type(self) -> Optional[Union[T_State, inspect.Parameter.empty]]: - """ - :类型: ``Optional[Union[T_State, inspect.Parameter.empty]]`` - :说明: 事件处理函数是否接受 state 参数 - """ - if "state" not in self.signature.parameters: - return None - return self.signature.parameters["state"].annotation - - @property - def matcher_type( - self) -> Optional[Union[Type["Matcher"], inspect.Parameter.empty]]: - """ - :类型: ``Optional[Union[Type["Matcher"], inspect.Parameter.empty]]`` - :说明: 事件处理函数是否接受 matcher 参数 - """ - if "matcher" not in self.signature.parameters: - return None - return self.signature.parameters["matcher"].annotation - - def get_signature(self) -> inspect.Signature: - wrapped_signature = self._get_typed_signature() - signature = self._get_typed_signature(False) - self._check_params(signature) - self._check_bot_param(signature) - self._check_bot_param(wrapped_signature) - signature.parameters["bot"].replace( - annotation=wrapped_signature.parameters["bot"].annotation) - if "event" in wrapped_signature.parameters and "event" in signature.parameters: - signature.parameters["event"].replace( - annotation=wrapped_signature.parameters["event"].annotation) - return signature - - def update_signature( - self, **kwargs: Union[None, Type["Bot"], Type["Event"], Type["Matcher"], - T_State, inspect.Parameter.empty] - ) -> None: - params: List[inspect.Parameter] = [] - for param in ["bot", "event", "state", "matcher"]: - sig = self.signature.parameters.get(param, None) - if param in kwargs: - sig = inspect.Parameter(param, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=kwargs[param]) - if sig: - params.append(sig) - - self.signature = inspect.Signature(params) - - def _get_typed_signature(self, - follow_wrapped: bool = True) -> inspect.Signature: - signature = inspect.signature(self.func, follow_wrapped=follow_wrapped) - globalns = getattr(self.func, "__globals__", {}) - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=param.default, - annotation=param.annotation if follow_wrapped else - self._get_typed_annotation(param, globalns), - ) for param in signature.parameters.values() - ] - typed_signature = inspect.Signature(typed_params) - return typed_signature - - def _get_typed_annotation(self, param: inspect.Parameter, - globalns: Dict[str, Any]) -> Any: - try: - if isinstance(param.annotation, str): - return _eval_type(ForwardRef(param.annotation), globalns, - globalns) - else: - return param.annotation - except Exception: - return param.annotation - - def _check_params(self, signature: inspect.Signature): - if not set(signature.parameters.keys()) <= { - "bot", "event", "state", "matcher" - }: - raise ValueError( - "Handler param names must in `bot`/`event`/`state`/`matcher`") - - def _check_bot_param(self, signature: inspect.Signature): - if not any( - param.name == "bot" for param in signature.parameters.values()): - raise ValueError("Handler missing parameter 'bot'") + self.name = get_name(func) if name is None else name diff --git a/nonebot/typing.py b/nonebot/typing.py index 273661eb..970fbcc7 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -143,20 +143,11 @@ T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]] RuleChecker 即判断是否响应消息的处理函数。 """ -T_Handler = Union[Callable[[Any, Any, Any, Any], Union[Awaitable[None], - Awaitable[NoReturn]]], - Callable[[Any, Any, Any], Union[Awaitable[None], - Awaitable[NoReturn]]], - Callable[[Any, Any], Union[Awaitable[None], - Awaitable[NoReturn]]], - Callable[[Any], Union[Awaitable[None], Awaitable[NoReturn]]]] +T_Handler = Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]] """ :类型: - * ``Callable[[Bot, Event, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` - * ``Callable[[Bot, Event], Union[Awaitable[None], Awaitable[NoReturn]]]`` - * ``Callable[[Bot, T_State], Union[Awaitable[None], Awaitable[NoReturn]]]`` - * ``Callable[[Bot], Union[Awaitable[None], Awaitable[NoReturn]]]`` + * ``Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]`` :说明: diff --git a/nonebot/utils.py b/nonebot/utils.py index 8183986d..71df0d4b 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -1,6 +1,7 @@ import re import json import asyncio +import inspect import dataclasses from functools import wraps, partial from typing import Any, Callable, Optional, Awaitable @@ -51,6 +52,12 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: return _wrapper +def get_name(obj: Any) -> str: + if inspect.isfunction(obj) or inspect.isclass(obj): + return obj.__name__ + return obj.__class__.__name__ + + class DataclassEncoder(json.JSONEncoder): """ :说明: