♻️ improve dependent structure (#1227)

This commit is contained in:
Ju4tCode 2022-09-07 09:59:05 +08:00 committed by GitHub
parent 595c64e760
commit a0b186aff3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 191 additions and 210 deletions

View File

@ -6,15 +6,19 @@ FrontMatter:
""" """
import abc import abc
import asyncio
import inspect import inspect
from dataclasses import field, dataclass
from typing import ( from typing import (
Any, Any,
Dict, Dict,
List, List,
Type, Type,
Tuple,
Generic, Generic,
TypeVar, TypeVar,
Callable, Callable,
Iterable,
Optional, Optional,
Awaitable, Awaitable,
cast, cast,
@ -25,7 +29,6 @@ from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, FieldInfo, Undefined, ModelField from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from nonebot.log import logger from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.typing import _DependentCallable from nonebot.typing import _DependentCallable
from nonebot.utils import run_sync, is_coroutine_callable from nonebot.utils import run_sync, is_coroutine_callable
@ -43,25 +46,29 @@ class Param(abc.ABC, FieldInfo):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: "Dependent", name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
) -> Optional["Param"]: ) -> Optional["Param"]:
return None return
@classmethod @classmethod
def _check_parameterless( def _check_parameterless(
cls, dependent: "Dependent", value: Any cls, value: Any, allow_types: Tuple[Type["Param"], ...]
) -> Optional["Param"]: ) -> Optional["Param"]:
return None return
@abc.abstractmethod @abc.abstractmethod
async def _solve(self, **kwargs: Any) -> Any: async def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError raise NotImplementedError
async def _check(self, **kwargs: Any) -> None:
return
class CustomConfig(BaseConfig): class CustomConfig(BaseConfig):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@dataclass(frozen=True)
class Dependent(Generic[R]): class Dependent(Generic[R]):
"""依赖注入容器 """依赖注入容器
@ -73,76 +80,34 @@ class Dependent(Generic[R]):
allow_types: 允许的参数类型 allow_types: 允许的参数类型
""" """
def __init__( call: _DependentCallable[R]
self, params: Tuple[ModelField] = field(default_factory=tuple)
*, parameterless: Tuple[Param] = field(default_factory=tuple)
call: _DependentCallable[R],
pre_checkers: Optional[List[Param]] = None,
params: Optional[List[ModelField]] = None,
parameterless: Optional[List[Param]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> None:
self.call = call
self.pre_checkers = pre_checkers or []
self.params = params or []
self.parameterless = parameterless or []
self.allow_types = allow_types or []
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"<Dependent call={self.call}>"
f"<Dependent call={self.call}, params={self.params},"
f" parameterless={self.parameterless}>"
)
def __str__(self) -> str:
return self.__repr__()
async def __call__(self, **kwargs: Any) -> R: async def __call__(self, **kwargs: Any) -> R:
# do pre-check
await self.check(**kwargs)
# solve param values
values = await self.solve(**kwargs) values = await self.solve(**kwargs)
# call function
if is_coroutine_callable(self.call): if is_coroutine_callable(self.call):
return await cast(Callable[..., Awaitable[R]], self.call)(**values) return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else: else:
return await run_sync(cast(Callable[..., R], self.call))(**values) return await run_sync(cast(Callable[..., R], self.call))(**values)
def parse_param(self, name: str, param: inspect.Parameter) -> Param: @staticmethod
for allow_type in self.allow_types: def parse_params(
if field_info := allow_type._check_param(self, name, param): call: _DependentCallable[R], allow_types: Tuple[Type[Param], ...]
return field_info ) -> Tuple[ModelField]:
raise ValueError( fields: List[ModelField] = []
f"Unknown parameter {name} for function {self.call} with type {param.annotation}" params = get_typed_signature(call).parameters.values()
)
def parse_parameterless(self, value: Any) -> Param: for param in params:
for allow_type in self.allow_types:
if field_info := allow_type._check_parameterless(self, value):
return field_info
raise ValueError(
f"Unknown parameterless {value} for function {self.call} with type {type(value)}"
)
def prepend_parameterless(self, value: Any) -> None:
self.parameterless.insert(0, self.parse_parameterless(value))
def append_parameterless(self, value: Any) -> None:
self.parameterless.append(self.parse_parameterless(value))
@classmethod
def parse(
cls,
*,
call: _DependentCallable[R],
parameterless: Optional[List[Any]] = None,
allow_types: Optional[List[Type[Param]]] = None,
) -> "Dependent[R]":
signature = get_typed_signature(call)
params = signature.parameters
dependent = cls(
call=call,
allow_types=allow_types,
)
for param_name, param in params.items():
default_value = Required default_value = Required
if param.default != param.empty: if param.default != param.empty:
default_value = param.default default_value = param.default
@ -150,7 +115,13 @@ class Dependent(Generic[R]):
if isinstance(default_value, Param): if isinstance(default_value, Param):
field_info = default_value field_info = default_value
else: else:
field_info = dependent.parse_param(param_name, param) for allow_type in allow_types:
if field_info := allow_type._check_param(param, allow_types):
break
else:
raise ValueError(
f"Unknown parameter {param.name} for function {call} with type {param.annotation}"
)
default_value = field_info.default default_value = field_info.default
@ -159,11 +130,12 @@ class Dependent(Generic[R]):
if param.annotation != param.empty: if param.annotation != param.empty:
annotation = param.annotation annotation = param.annotation
annotation = get_annotation_from_field_info( annotation = get_annotation_from_field_info(
annotation, field_info, param_name annotation, field_info, param.name
) )
dependent.params.append(
fields.append(
ModelField( ModelField(
name=param_name, name=param.name,
type_=annotation, type_=annotation,
class_validators=None, class_validators=None,
model_config=CustomConfig, model_config=CustomConfig,
@ -173,49 +145,69 @@ class Dependent(Generic[R]):
) )
) )
parameterless_params = [ return tuple(fields)
dependent.parse_parameterless(param) for param in (parameterless or [])
] @staticmethod
dependent.parameterless.extend(parameterless_params) def parse_parameterless(
parameterless: Tuple[Any, ...], allow_types: Tuple[Type[Param], ...]
) -> Tuple[Param, ...]:
parameterless_params: List[Param] = []
for value in parameterless:
for allow_type in allow_types:
if param := allow_type._check_parameterless(value, allow_types):
break
else:
raise ValueError(f"Unknown parameterless {value}")
parameterless_params.append(param)
return tuple(parameterless_params)
@classmethod
def parse(
cls,
*,
call: _DependentCallable[R],
parameterless: Optional[Iterable[Any]] = None,
allow_types: Iterable[Type[Param]],
) -> "Dependent[R]":
allow_types = tuple(allow_types)
params = cls.parse_params(call, allow_types)
parameterless_params = (
tuple()
if parameterless is None
else cls.parse_parameterless(tuple(parameterless), allow_types)
)
logger.trace( logger.trace(
f"Parsed dependent with call={call}, " f"Parsed dependent with call={call}, "
f"params={[param.field_info for param in dependent.params]}, " f"params={params}, "
f"parameterless={dependent.parameterless}" f"parameterless={parameterless_params}"
) )
return dependent return cls(call, params, parameterless_params)
async def solve( async def check(self, **params: Any) -> None:
self, await asyncio.gather(*(param._check(**params) for param in self.parameterless))
**params: Any, await asyncio.gather(
) -> Dict[str, Any]: *(cast(Param, param.field_info)._check(**params) for param in self.params)
values: Dict[str, Any] = {} )
for checker in self.pre_checkers: async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
await checker._solve(**params) value = await cast(Param, field.field_info)._solve(**params)
if value is Undefined:
value = field.get_default()
return check_field_type(field, value)
async def solve(self, **params: Any) -> Dict[str, Any]:
# solve parameterless
for param in self.parameterless: for param in self.parameterless:
await param._solve(**params) await param._solve(**params)
for field in self.params: # solve param values
field_info = field.field_info values = await asyncio.gather(
assert isinstance(field_info, Param), "Params must be subclasses of Param" *(self._solve_field(field, params) for field in self.params)
value = await field_info._solve(**params) )
if value is Undefined: return {field.name: value for field, value in zip(self.params, values)}
value = field.get_default()
try:
values[field.name] = check_field_type(field, value)
except TypeMisMatch:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {self.call} "
f"annotation {field._type_display()}, ignored"
)
raise
return values
__autodoc__ = {"CustomConfig": False} __autodoc__ = {"CustomConfig": False}

View File

@ -12,6 +12,7 @@ from typing import (
Union, Union,
TypeVar, TypeVar,
Callable, Callable,
Iterable,
NoReturn, NoReturn,
Optional, Optional,
overload, overload,
@ -133,7 +134,7 @@ class Matcher(metaclass=MatcherMeta):
_default_permission_updater: Optional[Dependent[Permission]] = None _default_permission_updater: Optional[Dependent[Permission]] = None
"""事件响应器权限更新函数""" """事件响应器权限更新函数"""
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = (
DependParam, DependParam,
BotParam, BotParam,
EventParam, EventParam,
@ -141,7 +142,7 @@ class Matcher(metaclass=MatcherMeta):
ArgParam, ArgParam,
MatcherParam, MatcherParam,
DefaultParam, DefaultParam,
] )
def __init__(self): def __init__(self):
self.handlers = self.handlers.copy() self.handlers = self.handlers.copy()
@ -153,9 +154,6 @@ class Matcher(metaclass=MatcherMeta):
f"priority={self.priority}, temp={self.temp}>" f"priority={self.priority}, temp={self.temp}>"
) )
def __str__(self) -> str:
return repr(self)
@classmethod @classmethod
def new( def new(
cls, cls,
@ -219,27 +217,35 @@ class Matcher(metaclass=MatcherMeta):
"temp": temp, "temp": temp,
"expire_time": ( "expire_time": (
expire_time expire_time
if isinstance(expire_time, datetime) and (
else expire_time and datetime.now() + expire_time expire_time
if isinstance(expire_time, datetime)
else datetime.now() + expire_time
)
), ),
"priority": priority, "priority": priority,
"block": block, "block": block,
"_default_state": default_state or {}, "_default_state": default_state or {},
"_default_type_updater": ( "_default_type_updater": (
default_type_updater default_type_updater
if isinstance(default_type_updater, Dependent) and (
else default_type_updater default_type_updater
and Dependent[str].parse( if isinstance(default_type_updater, Dependent)
call=default_type_updater, allow_types=cls.HANDLER_PARAM_TYPES else Dependent[str].parse(
call=default_type_updater,
allow_types=cls.HANDLER_PARAM_TYPES,
)
) )
), ),
"_default_permission_updater": ( "_default_permission_updater": (
default_permission_updater default_permission_updater
if isinstance(default_permission_updater, Dependent) and (
else default_permission_updater default_permission_updater
and Dependent[Permission].parse( if isinstance(default_permission_updater, Dependent)
call=default_permission_updater, else Dependent[Permission].parse(
allow_types=cls.HANDLER_PARAM_TYPES, call=default_permission_updater,
allow_types=cls.HANDLER_PARAM_TYPES,
)
) )
), ),
}, },
@ -327,7 +333,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
def append_handler( def append_handler(
cls, handler: T_Handler, parameterless: Optional[List[Any]] = None cls, handler: T_Handler, parameterless: Optional[Iterable[Any]] = None
) -> Dependent[Any]: ) -> Dependent[Any]:
handler_ = Dependent[Any].parse( handler_ = Dependent[Any].parse(
call=handler, call=handler,
@ -339,7 +345,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
def handle( def handle(
cls, parameterless: Optional[List[Any]] = None cls, parameterless: Optional[Iterable[Any]] = None
) -> Callable[[T_Handler], T_Handler]: ) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来向事件响应器直接添加一个处理函数 """装饰一个函数来向事件响应器直接添加一个处理函数
@ -355,7 +361,7 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
def receive( def receive(
cls, id: str = "", parameterless: Optional[List[Any]] = None cls, id: str = "", parameterless: Optional[Iterable[Any]] = None
) -> Callable[[T_Handler], T_Handler]: ) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数 """装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数
@ -373,14 +379,21 @@ class Matcher(metaclass=MatcherMeta):
return return
await matcher.reject() await matcher.reject()
_parameterless = [Depends(_receive), *(parameterless or [])] _parameterless = (Depends(_receive), *(parameterless or tuple()))
def _decorator(func: T_Handler) -> T_Handler: def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].call is func: if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1] func_handler = cls.handlers[-1]
for depend in reversed(_parameterless): new_handler = Dependent(
func_handler.prepend_parameterless(depend) call=func_handler.call,
params=func_handler.params,
parameterless=Dependent.parse_parameterless(
tuple(_parameterless), cls.HANDLER_PARAM_TYPES
)
+ func_handler.parameterless,
)
cls.handlers[-1] = new_handler
else: else:
cls.append_handler(func, parameterless=_parameterless) cls.append_handler(func, parameterless=_parameterless)
@ -393,7 +406,7 @@ class Matcher(metaclass=MatcherMeta):
cls, cls,
key: str, key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
parameterless: Optional[List[Any]] = None, parameterless: Optional[Iterable[Any]] = None,
) -> Callable[[T_Handler], T_Handler]: ) -> Callable[[T_Handler], T_Handler]:
"""装饰一个函数来指示 NoneBot 获取一个参数 `key` """装饰一个函数来指示 NoneBot 获取一个参数 `key`
@ -414,17 +427,21 @@ class Matcher(metaclass=MatcherMeta):
return return
await matcher.reject(prompt) await matcher.reject(prompt)
_parameterless = [ _parameterless = (Depends(_key_getter), *(parameterless or tuple()))
Depends(_key_getter),
*(parameterless or []),
]
def _decorator(func: T_Handler) -> T_Handler: def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].call is func: if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1] func_handler = cls.handlers[-1]
for depend in reversed(_parameterless): new_handler = Dependent(
func_handler.prepend_parameterless(depend) call=func_handler.call,
params=func_handler.params,
parameterless=Dependent.parse_parameterless(
tuple(_parameterless), cls.HANDLER_PARAM_TYPES
)
+ func_handler.parameterless,
)
cls.handlers[-1] = new_handler
else: else:
cls.append_handler(func, parameterless=_parameterless) cls.append_handler(func, parameterless=_parameterless)

View File

@ -1,8 +1,7 @@
import asyncio import asyncio
import inspect import inspect
import warnings
from typing import TYPE_CHECKING, Any, Literal, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
from pydantic.fields import Required, Undefined, ModelField from pydantic.fields import Required, Undefined, ModelField
@ -76,10 +75,7 @@ class DependParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
dependent: Dependent,
name: str,
param: inspect.Parameter,
) -> Optional["DependParam"]: ) -> Optional["DependParam"]:
if isinstance(param.default, DependsInner): if isinstance(param.default, DependsInner):
dependency: T_Handler dependency: T_Handler
@ -90,22 +86,20 @@ class DependParam(Param):
dependency = param.default.dependency dependency = param.default.dependency
sub_dependent = Dependent[Any].parse( sub_dependent = Dependent[Any].parse(
call=dependency, call=dependency,
allow_types=dependent.allow_types, allow_types=allow_types,
) )
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls( return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent Required, use_cache=param.default.use_cache, dependent=sub_dependent
) )
@classmethod @classmethod
def _check_parameterless( def _check_parameterless(
cls, dependent: "Dependent", value: Any cls, value: Any, allow_types: Tuple[Type[Param], ...]
) -> Optional["Param"]: ) -> Optional["Param"]:
if isinstance(value, DependsInner): if isinstance(value, DependsInner):
assert value.dependency, "Dependency cannot be empty" assert value.dependency, "Dependency cannot be empty"
dependent = Dependent[Any].parse( dependent = Dependent[Any].parse(
call=value.dependency, allow_types=dependent.allow_types call=value.dependency, allow_types=allow_types
) )
return cls(Required, use_cache=value.use_cache, dependent=dependent) return cls(Required, use_cache=value.use_cache, dependent=dependent)
@ -119,8 +113,7 @@ class DependParam(Param):
dependency_cache = {} if dependency_cache is None else dependency_cache dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"] sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) call = cast(Callable[..., Any], sub_dependent.call)
call = sub_dependent.call
# solve sub dependency with current cache # solve sub dependency with current cache
sub_values = await sub_dependent.solve( sub_values = await sub_dependent.solve(
@ -132,7 +125,7 @@ class DependParam(Param):
# run dependency function # run dependency function
task: asyncio.Task[Any] task: asyncio.Task[Any]
if use_cache and call in dependency_cache: if use_cache and call in dependency_cache:
solved = await dependency_cache[call] return await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance( assert isinstance(
stack, AsyncExitStack stack, AsyncExitStack
@ -143,30 +136,20 @@ class DependParam(Param):
cm = asynccontextmanager(call)(**sub_values) cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm)) task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task dependency_cache[call] = task
solved = await task return await task
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values)) task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task dependency_cache[call] = task
solved = await task return await task
else: else:
task = asyncio.create_task(run_sync(call)(**sub_values)) task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task dependency_cache[call] = task
solved = await task return await task
return solved async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"]
class _BotChecker(Param): await sub_dependent.check(**kwargs)
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
try:
return check_field_type(field, bot)
except TypeMisMatch:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise
class BotParam(Param): class BotParam(Param):
@ -174,45 +157,32 @@ class BotParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["BotParam"]: ) -> Optional["BotParam"]:
from nonebot.adapters import Bot from nonebot.adapters import Bot
if param.default == param.empty: if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot): if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None
if param.annotation is not Bot: if param.annotation is not Bot:
dependent.pre_checkers.append( checker = ModelField(
_BotChecker( name=param.name,
Required, type_=param.annotation,
field=ModelField( class_validators=None,
name=name, model_config=CustomConfig,
type_=param.annotation, default=None,
class_validators=None, required=True,
model_config=CustomConfig,
default=None,
required=True,
),
)
) )
return cls(Required) return cls(Required, checker=checker)
elif param.annotation == param.empty and name == "bot": elif param.annotation == param.empty and param.name == "bot":
return cls(Required) return cls(Required)
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot return bot
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
class _EventChecker(Param): if checker := self.extra.get("checker", None):
async def _solve(self, event: "Event", **kwargs: Any) -> Any: check_field_type(checker, bot)
field: ModelField = self.extra["field"]
try:
return check_field_type(field, event)
except TypeMisMatch:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise
class EventParam(Param): class EventParam(Param):
@ -220,33 +190,33 @@ class EventParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["EventParam"]: ) -> Optional["EventParam"]:
from nonebot.adapters import Event from nonebot.adapters import Event
if param.default == param.empty: if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event): if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None
if param.annotation is not Event: if param.annotation is not Event:
dependent.pre_checkers.append( checker = ModelField(
_EventChecker( name=param.name,
Required, type_=param.annotation,
field=ModelField( class_validators=None,
name=name, model_config=CustomConfig,
type_=param.annotation, default=None,
class_validators=None, required=True,
model_config=CustomConfig,
default=None,
required=True,
),
)
) )
return cls(Required) return cls(Required, checker=checker)
elif param.annotation == param.empty and name == "event": elif param.annotation == param.empty and param.name == "event":
return cls(Required) return cls(Required)
async def _solve(self, event: "Event", **kwargs: Any) -> Any: async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event return event
async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, event)
class StateInner(T_State): class StateInner(T_State):
... ...
@ -257,14 +227,14 @@ class StateParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]: ) -> Optional["StateParam"]:
if isinstance(param.default, StateInner): if isinstance(param.default, StateInner):
return cls(Required) return cls(Required)
elif param.default == param.empty: elif param.default == param.empty:
if param.annotation is T_State: if param.annotation is T_State:
return cls(Required) return cls(Required)
elif param.annotation == param.empty and name == "state": elif param.annotation == param.empty and param.name == "state":
return cls(Required) return cls(Required)
async def _solve(self, state: T_State, **kwargs: Any) -> Any: async def _solve(self, state: T_State, **kwargs: Any) -> Any:
@ -276,12 +246,12 @@ class MatcherParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["MatcherParam"]: ) -> Optional["MatcherParam"]:
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
if generic_check_issubclass(param.annotation, Matcher) or ( if generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and name == "matcher" param.annotation == param.empty and param.name == "matcher"
): ):
return cls(Required) return cls(Required)
@ -317,10 +287,12 @@ class ArgParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ArgParam"]: ) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner): if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type) return cls(
Required, key=param.default.key or param.name, type=param.default.type
)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
message = matcher.get_arg(self.extra["key"]) message = matcher.get_arg(self.extra["key"])
@ -339,10 +311,10 @@ class ExceptionParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]: ) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or ( if generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and name == "exception" param.annotation == param.empty and param.name == "exception"
): ):
return cls(Required) return cls(Required)
@ -355,7 +327,7 @@ class DefaultParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DefaultParam"]: ) -> Optional["DefaultParam"]:
if param.default != param.empty: if param.default != param.empty:
return cls(param.default) return cls(param.default)