diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index be17a054..656566fb 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -6,15 +6,19 @@ FrontMatter: """ import abc +import asyncio import inspect +from dataclasses import field, dataclass from typing import ( Any, Dict, List, Type, + Tuple, Generic, TypeVar, Callable, + Iterable, Optional, Awaitable, cast, @@ -25,7 +29,6 @@ from pydantic.schema import get_annotation_from_field_info from pydantic.fields import Required, FieldInfo, Undefined, ModelField from nonebot.log import logger -from nonebot.exception import TypeMisMatch from nonebot.typing import _DependentCallable from nonebot.utils import run_sync, is_coroutine_callable @@ -43,25 +46,29 @@ class Param(abc.ABC, FieldInfo): @classmethod def _check_param( - cls, dependent: "Dependent", name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...] ) -> Optional["Param"]: - return None + return @classmethod def _check_parameterless( - cls, dependent: "Dependent", value: Any + cls, value: Any, allow_types: Tuple[Type["Param"], ...] ) -> Optional["Param"]: - return None + return @abc.abstractmethod async def _solve(self, **kwargs: Any) -> Any: raise NotImplementedError + async def _check(self, **kwargs: Any) -> None: + return + class CustomConfig(BaseConfig): arbitrary_types_allowed = True +@dataclass(frozen=True) class Dependent(Generic[R]): """依赖注入容器 @@ -73,76 +80,34 @@ class Dependent(Generic[R]): allow_types: 允许的参数类型 """ - def __init__( - self, - *, - call: _DependentCallable[R], - pre_checkers: Optional[List[Param]] = None, - params: Optional[List[ModelField]] = None, - parameterless: Optional[List[Param]] = None, - allow_types: Optional[List[Type[Param]]] = None, - ) -> None: - self.call = call - self.pre_checkers = pre_checkers or [] - self.params = params or [] - self.parameterless = parameterless or [] - self.allow_types = allow_types or [] + call: _DependentCallable[R] + params: Tuple[ModelField] = field(default_factory=tuple) + parameterless: Tuple[Param] = field(default_factory=tuple) def __repr__(self) -> str: - return ( - f"" - ) - - def __str__(self) -> str: - return self.__repr__() + return f"" async def __call__(self, **kwargs: Any) -> R: + # do pre-check + await self.check(**kwargs) + + # solve param values values = await self.solve(**kwargs) + # call function if is_coroutine_callable(self.call): return await cast(Callable[..., Awaitable[R]], self.call)(**values) else: return await run_sync(cast(Callable[..., R], self.call))(**values) - def parse_param(self, name: str, param: inspect.Parameter) -> Param: - for allow_type in self.allow_types: - if field_info := allow_type._check_param(self, name, param): - return field_info - raise ValueError( - f"Unknown parameter {name} for function {self.call} with type {param.annotation}" - ) + @staticmethod + def parse_params( + call: _DependentCallable[R], allow_types: Tuple[Type[Param], ...] + ) -> Tuple[ModelField]: + fields: List[ModelField] = [] + params = get_typed_signature(call).parameters.values() - def parse_parameterless(self, value: Any) -> Param: - for allow_type in self.allow_types: - if field_info := allow_type._check_parameterless(self, value): - return field_info - raise ValueError( - f"Unknown parameterless {value} for function {self.call} with type {type(value)}" - ) - - def prepend_parameterless(self, value: Any) -> None: - self.parameterless.insert(0, self.parse_parameterless(value)) - - def append_parameterless(self, value: Any) -> None: - self.parameterless.append(self.parse_parameterless(value)) - - @classmethod - def parse( - cls, - *, - call: _DependentCallable[R], - parameterless: Optional[List[Any]] = None, - allow_types: Optional[List[Type[Param]]] = None, - ) -> "Dependent[R]": - signature = get_typed_signature(call) - params = signature.parameters - dependent = cls( - call=call, - allow_types=allow_types, - ) - - for param_name, param in params.items(): + for param in params: default_value = Required if param.default != param.empty: default_value = param.default @@ -150,7 +115,13 @@ class Dependent(Generic[R]): if isinstance(default_value, Param): field_info = default_value else: - field_info = dependent.parse_param(param_name, param) + for allow_type in allow_types: + if field_info := allow_type._check_param(param, allow_types): + break + else: + raise ValueError( + f"Unknown parameter {param.name} for function {call} with type {param.annotation}" + ) default_value = field_info.default @@ -159,11 +130,12 @@ class Dependent(Generic[R]): if param.annotation != param.empty: annotation = param.annotation annotation = get_annotation_from_field_info( - annotation, field_info, param_name + annotation, field_info, param.name ) - dependent.params.append( + + fields.append( ModelField( - name=param_name, + name=param.name, type_=annotation, class_validators=None, model_config=CustomConfig, @@ -173,49 +145,69 @@ class Dependent(Generic[R]): ) ) - parameterless_params = [ - dependent.parse_parameterless(param) for param in (parameterless or []) - ] - dependent.parameterless.extend(parameterless_params) + return tuple(fields) + + @staticmethod + def parse_parameterless( + parameterless: Tuple[Any, ...], allow_types: Tuple[Type[Param], ...] + ) -> Tuple[Param, ...]: + parameterless_params: List[Param] = [] + for value in parameterless: + for allow_type in allow_types: + if param := allow_type._check_parameterless(value, allow_types): + break + else: + raise ValueError(f"Unknown parameterless {value}") + parameterless_params.append(param) + return tuple(parameterless_params) + + @classmethod + def parse( + cls, + *, + call: _DependentCallable[R], + parameterless: Optional[Iterable[Any]] = None, + allow_types: Iterable[Type[Param]], + ) -> "Dependent[R]": + allow_types = tuple(allow_types) + + params = cls.parse_params(call, allow_types) + parameterless_params = ( + tuple() + if parameterless is None + else cls.parse_parameterless(tuple(parameterless), allow_types) + ) logger.trace( f"Parsed dependent with call={call}, " - f"params={[param.field_info for param in dependent.params]}, " - f"parameterless={dependent.parameterless}" + f"params={params}, " + f"parameterless={parameterless_params}" ) - return dependent + return cls(call, params, parameterless_params) - async def solve( - self, - **params: Any, - ) -> Dict[str, Any]: - values: Dict[str, Any] = {} + async def check(self, **params: Any) -> None: + await asyncio.gather(*(param._check(**params) for param in self.parameterless)) + await asyncio.gather( + *(cast(Param, param.field_info)._check(**params) for param in self.params) + ) - for checker in self.pre_checkers: - await checker._solve(**params) + async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any: + value = await cast(Param, field.field_info)._solve(**params) + if value is Undefined: + value = field.get_default() + return check_field_type(field, value) + async def solve(self, **params: Any) -> Dict[str, Any]: + # solve parameterless for param in self.parameterless: await param._solve(**params) - for field in self.params: - field_info = field.field_info - assert isinstance(field_info, Param), "Params must be subclasses of Param" - value = await field_info._solve(**params) - if value is Undefined: - value = field.get_default() - - try: - values[field.name] = check_field_type(field, value) - except TypeMisMatch: - logger.debug( - f"{field_info} " - f"type {type(value)} not match depends {self.call} " - f"annotation {field._type_display()}, ignored" - ) - raise - - return values + # solve param values + values = await asyncio.gather( + *(self._solve_field(field, params) for field in self.params) + ) + return {field.name: value for field, value in zip(self.params, values)} __autodoc__ = {"CustomConfig": False} diff --git a/nonebot/internal/matcher.py b/nonebot/internal/matcher.py index d9b31f1e..525d22e8 100644 --- a/nonebot/internal/matcher.py +++ b/nonebot/internal/matcher.py @@ -12,6 +12,7 @@ from typing import ( Union, TypeVar, Callable, + Iterable, NoReturn, Optional, overload, @@ -133,7 +134,7 @@ class Matcher(metaclass=MatcherMeta): _default_permission_updater: Optional[Dependent[Permission]] = None """事件响应器权限更新函数""" - HANDLER_PARAM_TYPES = [ + HANDLER_PARAM_TYPES = ( DependParam, BotParam, EventParam, @@ -141,7 +142,7 @@ class Matcher(metaclass=MatcherMeta): ArgParam, MatcherParam, DefaultParam, - ] + ) def __init__(self): self.handlers = self.handlers.copy() @@ -153,9 +154,6 @@ class Matcher(metaclass=MatcherMeta): f"priority={self.priority}, temp={self.temp}>" ) - def __str__(self) -> str: - return repr(self) - @classmethod def new( cls, @@ -219,27 +217,35 @@ class Matcher(metaclass=MatcherMeta): "temp": temp, "expire_time": ( expire_time - if isinstance(expire_time, datetime) - else expire_time and datetime.now() + expire_time + and ( + expire_time + if isinstance(expire_time, datetime) + else datetime.now() + expire_time + ) ), "priority": priority, "block": block, "_default_state": default_state or {}, "_default_type_updater": ( default_type_updater - if isinstance(default_type_updater, Dependent) - else default_type_updater - and Dependent[str].parse( - call=default_type_updater, allow_types=cls.HANDLER_PARAM_TYPES + and ( + default_type_updater + if isinstance(default_type_updater, Dependent) + else Dependent[str].parse( + call=default_type_updater, + allow_types=cls.HANDLER_PARAM_TYPES, + ) ) ), "_default_permission_updater": ( default_permission_updater - if isinstance(default_permission_updater, Dependent) - else default_permission_updater - and Dependent[Permission].parse( - call=default_permission_updater, - allow_types=cls.HANDLER_PARAM_TYPES, + and ( + default_permission_updater + if isinstance(default_permission_updater, Dependent) + else Dependent[Permission].parse( + call=default_permission_updater, + allow_types=cls.HANDLER_PARAM_TYPES, + ) ) ), }, @@ -327,7 +333,7 @@ class Matcher(metaclass=MatcherMeta): @classmethod def append_handler( - cls, handler: T_Handler, parameterless: Optional[List[Any]] = None + cls, handler: T_Handler, parameterless: Optional[Iterable[Any]] = None ) -> Dependent[Any]: handler_ = Dependent[Any].parse( call=handler, @@ -339,7 +345,7 @@ class Matcher(metaclass=MatcherMeta): @classmethod def handle( - cls, parameterless: Optional[List[Any]] = None + cls, parameterless: Optional[Iterable[Any]] = None ) -> Callable[[T_Handler], T_Handler]: """装饰一个函数来向事件响应器直接添加一个处理函数 @@ -355,7 +361,7 @@ class Matcher(metaclass=MatcherMeta): @classmethod def receive( - cls, id: str = "", parameterless: Optional[List[Any]] = None + cls, id: str = "", parameterless: Optional[Iterable[Any]] = None ) -> Callable[[T_Handler], T_Handler]: """装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数 @@ -373,14 +379,21 @@ class Matcher(metaclass=MatcherMeta): return await matcher.reject() - _parameterless = [Depends(_receive), *(parameterless or [])] + _parameterless = (Depends(_receive), *(parameterless or tuple())) def _decorator(func: T_Handler) -> T_Handler: if cls.handlers and cls.handlers[-1].call is func: func_handler = cls.handlers[-1] - for depend in reversed(_parameterless): - func_handler.prepend_parameterless(depend) + new_handler = Dependent( + call=func_handler.call, + params=func_handler.params, + parameterless=Dependent.parse_parameterless( + tuple(_parameterless), cls.HANDLER_PARAM_TYPES + ) + + func_handler.parameterless, + ) + cls.handlers[-1] = new_handler else: cls.append_handler(func, parameterless=_parameterless) @@ -393,7 +406,7 @@ class Matcher(metaclass=MatcherMeta): cls, key: str, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, - parameterless: Optional[List[Any]] = None, + parameterless: Optional[Iterable[Any]] = None, ) -> Callable[[T_Handler], T_Handler]: """装饰一个函数来指示 NoneBot 获取一个参数 `key` @@ -414,17 +427,21 @@ class Matcher(metaclass=MatcherMeta): return await matcher.reject(prompt) - _parameterless = [ - Depends(_key_getter), - *(parameterless or []), - ] + _parameterless = (Depends(_key_getter), *(parameterless or tuple())) def _decorator(func: T_Handler) -> T_Handler: if cls.handlers and cls.handlers[-1].call is func: func_handler = cls.handlers[-1] - for depend in reversed(_parameterless): - func_handler.prepend_parameterless(depend) + new_handler = Dependent( + call=func_handler.call, + params=func_handler.params, + parameterless=Dependent.parse_parameterless( + tuple(_parameterless), cls.HANDLER_PARAM_TYPES + ) + + func_handler.parameterless, + ) + cls.handlers[-1] = new_handler else: cls.append_handler(func, parameterless=_parameterless) diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 67800b60..8cbef42b 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -1,8 +1,7 @@ import asyncio import inspect -import warnings -from typing import TYPE_CHECKING, Any, Literal, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager +from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast from pydantic.fields import Required, Undefined, ModelField @@ -76,10 +75,7 @@ class DependParam(Param): @classmethod def _check_param( - cls, - dependent: Dependent, - name: str, - param: inspect.Parameter, + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["DependParam"]: if isinstance(param.default, DependsInner): dependency: T_Handler @@ -90,22 +86,20 @@ class DependParam(Param): dependency = param.default.dependency sub_dependent = Dependent[Any].parse( call=dependency, - allow_types=dependent.allow_types, + allow_types=allow_types, ) - dependent.pre_checkers.extend(sub_dependent.pre_checkers) - sub_dependent.pre_checkers.clear() return cls( Required, use_cache=param.default.use_cache, dependent=sub_dependent ) @classmethod def _check_parameterless( - cls, dependent: "Dependent", value: Any + cls, value: Any, allow_types: Tuple[Type[Param], ...] ) -> Optional["Param"]: if isinstance(value, DependsInner): assert value.dependency, "Dependency cannot be empty" dependent = Dependent[Any].parse( - call=value.dependency, allow_types=dependent.allow_types + call=value.dependency, allow_types=allow_types ) return cls(Required, use_cache=value.use_cache, dependent=dependent) @@ -119,8 +113,7 @@ class DependParam(Param): dependency_cache = {} if dependency_cache is None else dependency_cache sub_dependent: Dependent = self.extra["dependent"] - sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) - call = sub_dependent.call + call = cast(Callable[..., Any], sub_dependent.call) # solve sub dependency with current cache sub_values = await sub_dependent.solve( @@ -132,7 +125,7 @@ class DependParam(Param): # run dependency function task: asyncio.Task[Any] if use_cache and call in dependency_cache: - solved = await dependency_cache[call] + return await dependency_cache[call] elif is_gen_callable(call) or is_async_gen_callable(call): assert isinstance( stack, AsyncExitStack @@ -143,30 +136,20 @@ class DependParam(Param): cm = asynccontextmanager(call)(**sub_values) task = asyncio.create_task(stack.enter_async_context(cm)) dependency_cache[call] = task - solved = await task + return await task elif is_coroutine_callable(call): task = asyncio.create_task(call(**sub_values)) dependency_cache[call] = task - solved = await task + return await task else: task = asyncio.create_task(run_sync(call)(**sub_values)) dependency_cache[call] = task - solved = await task + return await task - return solved - - -class _BotChecker(Param): - async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: - field: ModelField = self.extra["field"] - try: - return check_field_type(field, bot) - except TypeMisMatch: - logger.debug( - f"Bot type {type(bot)} not match " - f"annotation {field._type_display()}, ignored" - ) - raise + async def _check(self, **kwargs: Any) -> None: + # run sub dependent pre-checkers + sub_dependent: Dependent = self.extra["dependent"] + await sub_dependent.check(**kwargs) class BotParam(Param): @@ -174,45 +157,32 @@ class BotParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["BotParam"]: from nonebot.adapters import Bot if param.default == param.empty: if generic_check_issubclass(param.annotation, Bot): + checker: Optional[ModelField] = None if param.annotation is not Bot: - dependent.pre_checkers.append( - _BotChecker( - Required, - field=ModelField( - name=name, - type_=param.annotation, - class_validators=None, - model_config=CustomConfig, - default=None, - required=True, - ), - ) + checker = ModelField( + name=param.name, + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, ) - return cls(Required) - elif param.annotation == param.empty and name == "bot": + return cls(Required, checker=checker) + elif param.annotation == param.empty and param.name == "bot": return cls(Required) async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: return bot - -class _EventChecker(Param): - async def _solve(self, event: "Event", **kwargs: Any) -> Any: - field: ModelField = self.extra["field"] - try: - return check_field_type(field, event) - except TypeMisMatch: - logger.debug( - f"Event type {type(event)} not match " - f"annotation {field._type_display()}, ignored" - ) - raise + async def _check(self, bot: "Bot", **kwargs: Any) -> None: + if checker := self.extra.get("checker", None): + check_field_type(checker, bot) class EventParam(Param): @@ -220,33 +190,33 @@ class EventParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["EventParam"]: from nonebot.adapters import Event if param.default == param.empty: if generic_check_issubclass(param.annotation, Event): + checker: Optional[ModelField] = None if param.annotation is not Event: - dependent.pre_checkers.append( - _EventChecker( - Required, - field=ModelField( - name=name, - type_=param.annotation, - class_validators=None, - model_config=CustomConfig, - default=None, - required=True, - ), - ) + checker = ModelField( + name=param.name, + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, ) - return cls(Required) - elif param.annotation == param.empty and name == "event": + return cls(Required, checker=checker) + elif param.annotation == param.empty and param.name == "event": return cls(Required) async def _solve(self, event: "Event", **kwargs: Any) -> Any: return event + async def _check(self, event: "Event", **kwargs: Any) -> Any: + if checker := self.extra.get("checker", None): + check_field_type(checker, event) + class StateInner(T_State): ... @@ -257,14 +227,14 @@ class StateParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["StateParam"]: if isinstance(param.default, StateInner): return cls(Required) elif param.default == param.empty: if param.annotation is T_State: return cls(Required) - elif param.annotation == param.empty and name == "state": + elif param.annotation == param.empty and param.name == "state": return cls(Required) async def _solve(self, state: T_State, **kwargs: Any) -> Any: @@ -276,12 +246,12 @@ class MatcherParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["MatcherParam"]: from nonebot.matcher import Matcher if generic_check_issubclass(param.annotation, Matcher) or ( - param.annotation == param.empty and name == "matcher" + param.annotation == param.empty and param.name == "matcher" ): return cls(Required) @@ -317,10 +287,12 @@ class ArgParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["ArgParam"]: if isinstance(param.default, ArgInner): - return cls(Required, key=param.default.key or name, type=param.default.type) + return cls( + Required, key=param.default.key or param.name, type=param.default.type + ) async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: message = matcher.get_arg(self.extra["key"]) @@ -339,10 +311,10 @@ class ExceptionParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["ExceptionParam"]: if generic_check_issubclass(param.annotation, Exception) or ( - param.annotation == param.empty and name == "exception" + param.annotation == param.empty and param.name == "exception" ): return cls(Required) @@ -355,7 +327,7 @@ class DefaultParam(Param): @classmethod def _check_param( - cls, dependent: Dependent, name: str, param: inspect.Parameter + cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["DefaultParam"]: if param.default != param.empty: return cls(param.default)