mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-05 03:24:53 +08:00
♻️ improve dependent structure (#1227)
This commit is contained in:
parent
595c64e760
commit
a0b186aff3
@ -6,15 +6,19 @@ FrontMatter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
from dataclasses import field, dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Type,
|
Type,
|
||||||
|
Tuple,
|
||||||
Generic,
|
Generic,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Callable,
|
Callable,
|
||||||
|
Iterable,
|
||||||
Optional,
|
Optional,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
cast,
|
cast,
|
||||||
@ -25,7 +29,6 @@ from pydantic.schema import get_annotation_from_field_info
|
|||||||
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.exception import TypeMisMatch
|
|
||||||
from nonebot.typing import _DependentCallable
|
from nonebot.typing import _DependentCallable
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable
|
from nonebot.utils import run_sync, is_coroutine_callable
|
||||||
|
|
||||||
@ -43,25 +46,29 @@ class Param(abc.ABC, FieldInfo):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: "Dependent", name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
|
||||||
) -> Optional["Param"]:
|
) -> Optional["Param"]:
|
||||||
return None
|
return
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_parameterless(
|
def _check_parameterless(
|
||||||
cls, dependent: "Dependent", value: Any
|
cls, value: Any, allow_types: Tuple[Type["Param"], ...]
|
||||||
) -> Optional["Param"]:
|
) -> Optional["Param"]:
|
||||||
return None
|
return
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _solve(self, **kwargs: Any) -> Any:
|
async def _solve(self, **kwargs: Any) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _check(self, **kwargs: Any) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class CustomConfig(BaseConfig):
|
class CustomConfig(BaseConfig):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
class Dependent(Generic[R]):
|
class Dependent(Generic[R]):
|
||||||
"""依赖注入容器
|
"""依赖注入容器
|
||||||
|
|
||||||
@ -73,76 +80,34 @@ class Dependent(Generic[R]):
|
|||||||
allow_types: 允许的参数类型
|
allow_types: 允许的参数类型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
call: _DependentCallable[R]
|
||||||
self,
|
params: Tuple[ModelField] = field(default_factory=tuple)
|
||||||
*,
|
parameterless: Tuple[Param] = field(default_factory=tuple)
|
||||||
call: _DependentCallable[R],
|
|
||||||
pre_checkers: Optional[List[Param]] = None,
|
|
||||||
params: Optional[List[ModelField]] = None,
|
|
||||||
parameterless: Optional[List[Param]] = None,
|
|
||||||
allow_types: Optional[List[Type[Param]]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.call = call
|
|
||||||
self.pre_checkers = pre_checkers or []
|
|
||||||
self.params = params or []
|
|
||||||
self.parameterless = parameterless or []
|
|
||||||
self.allow_types = allow_types or []
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return f"<Dependent call={self.call}>"
|
||||||
f"<Dependent call={self.call}, params={self.params},"
|
|
||||||
f" parameterless={self.parameterless}>"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.__repr__()
|
|
||||||
|
|
||||||
async def __call__(self, **kwargs: Any) -> R:
|
async def __call__(self, **kwargs: Any) -> R:
|
||||||
|
# do pre-check
|
||||||
|
await self.check(**kwargs)
|
||||||
|
|
||||||
|
# solve param values
|
||||||
values = await self.solve(**kwargs)
|
values = await self.solve(**kwargs)
|
||||||
|
|
||||||
|
# call function
|
||||||
if is_coroutine_callable(self.call):
|
if is_coroutine_callable(self.call):
|
||||||
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
||||||
else:
|
else:
|
||||||
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
||||||
|
|
||||||
def parse_param(self, name: str, param: inspect.Parameter) -> Param:
|
@staticmethod
|
||||||
for allow_type in self.allow_types:
|
def parse_params(
|
||||||
if field_info := allow_type._check_param(self, name, param):
|
call: _DependentCallable[R], allow_types: Tuple[Type[Param], ...]
|
||||||
return field_info
|
) -> Tuple[ModelField]:
|
||||||
raise ValueError(
|
fields: List[ModelField] = []
|
||||||
f"Unknown parameter {name} for function {self.call} with type {param.annotation}"
|
params = get_typed_signature(call).parameters.values()
|
||||||
)
|
|
||||||
|
|
||||||
def parse_parameterless(self, value: Any) -> Param:
|
for param in params:
|
||||||
for allow_type in self.allow_types:
|
|
||||||
if field_info := allow_type._check_parameterless(self, value):
|
|
||||||
return field_info
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown parameterless {value} for function {self.call} with type {type(value)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepend_parameterless(self, value: Any) -> None:
|
|
||||||
self.parameterless.insert(0, self.parse_parameterless(value))
|
|
||||||
|
|
||||||
def append_parameterless(self, value: Any) -> None:
|
|
||||||
self.parameterless.append(self.parse_parameterless(value))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
call: _DependentCallable[R],
|
|
||||||
parameterless: Optional[List[Any]] = None,
|
|
||||||
allow_types: Optional[List[Type[Param]]] = None,
|
|
||||||
) -> "Dependent[R]":
|
|
||||||
signature = get_typed_signature(call)
|
|
||||||
params = signature.parameters
|
|
||||||
dependent = cls(
|
|
||||||
call=call,
|
|
||||||
allow_types=allow_types,
|
|
||||||
)
|
|
||||||
|
|
||||||
for param_name, param in params.items():
|
|
||||||
default_value = Required
|
default_value = Required
|
||||||
if param.default != param.empty:
|
if param.default != param.empty:
|
||||||
default_value = param.default
|
default_value = param.default
|
||||||
@ -150,7 +115,13 @@ class Dependent(Generic[R]):
|
|||||||
if isinstance(default_value, Param):
|
if isinstance(default_value, Param):
|
||||||
field_info = default_value
|
field_info = default_value
|
||||||
else:
|
else:
|
||||||
field_info = dependent.parse_param(param_name, param)
|
for allow_type in allow_types:
|
||||||
|
if field_info := allow_type._check_param(param, allow_types):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown parameter {param.name} for function {call} with type {param.annotation}"
|
||||||
|
)
|
||||||
|
|
||||||
default_value = field_info.default
|
default_value = field_info.default
|
||||||
|
|
||||||
@ -159,11 +130,12 @@ class Dependent(Generic[R]):
|
|||||||
if param.annotation != param.empty:
|
if param.annotation != param.empty:
|
||||||
annotation = param.annotation
|
annotation = param.annotation
|
||||||
annotation = get_annotation_from_field_info(
|
annotation = get_annotation_from_field_info(
|
||||||
annotation, field_info, param_name
|
annotation, field_info, param.name
|
||||||
)
|
)
|
||||||
dependent.params.append(
|
|
||||||
|
fields.append(
|
||||||
ModelField(
|
ModelField(
|
||||||
name=param_name,
|
name=param.name,
|
||||||
type_=annotation,
|
type_=annotation,
|
||||||
class_validators=None,
|
class_validators=None,
|
||||||
model_config=CustomConfig,
|
model_config=CustomConfig,
|
||||||
@ -173,49 +145,69 @@ class Dependent(Generic[R]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
parameterless_params = [
|
return tuple(fields)
|
||||||
dependent.parse_parameterless(param) for param in (parameterless or [])
|
|
||||||
]
|
@staticmethod
|
||||||
dependent.parameterless.extend(parameterless_params)
|
def parse_parameterless(
|
||||||
|
parameterless: Tuple[Any, ...], allow_types: Tuple[Type[Param], ...]
|
||||||
|
) -> Tuple[Param, ...]:
|
||||||
|
parameterless_params: List[Param] = []
|
||||||
|
for value in parameterless:
|
||||||
|
for allow_type in allow_types:
|
||||||
|
if param := allow_type._check_parameterless(value, allow_types):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown parameterless {value}")
|
||||||
|
parameterless_params.append(param)
|
||||||
|
return tuple(parameterless_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
call: _DependentCallable[R],
|
||||||
|
parameterless: Optional[Iterable[Any]] = None,
|
||||||
|
allow_types: Iterable[Type[Param]],
|
||||||
|
) -> "Dependent[R]":
|
||||||
|
allow_types = tuple(allow_types)
|
||||||
|
|
||||||
|
params = cls.parse_params(call, allow_types)
|
||||||
|
parameterless_params = (
|
||||||
|
tuple()
|
||||||
|
if parameterless is None
|
||||||
|
else cls.parse_parameterless(tuple(parameterless), allow_types)
|
||||||
|
)
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
f"Parsed dependent with call={call}, "
|
f"Parsed dependent with call={call}, "
|
||||||
f"params={[param.field_info for param in dependent.params]}, "
|
f"params={params}, "
|
||||||
f"parameterless={dependent.parameterless}"
|
f"parameterless={parameterless_params}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return dependent
|
return cls(call, params, parameterless_params)
|
||||||
|
|
||||||
async def solve(
|
async def check(self, **params: Any) -> None:
|
||||||
self,
|
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
|
||||||
**params: Any,
|
await asyncio.gather(
|
||||||
) -> Dict[str, Any]:
|
*(cast(Param, param.field_info)._check(**params) for param in self.params)
|
||||||
values: Dict[str, Any] = {}
|
)
|
||||||
|
|
||||||
for checker in self.pre_checkers:
|
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
|
||||||
await checker._solve(**params)
|
value = await cast(Param, field.field_info)._solve(**params)
|
||||||
|
if value is Undefined:
|
||||||
|
value = field.get_default()
|
||||||
|
return check_field_type(field, value)
|
||||||
|
|
||||||
|
async def solve(self, **params: Any) -> Dict[str, Any]:
|
||||||
|
# solve parameterless
|
||||||
for param in self.parameterless:
|
for param in self.parameterless:
|
||||||
await param._solve(**params)
|
await param._solve(**params)
|
||||||
|
|
||||||
for field in self.params:
|
# solve param values
|
||||||
field_info = field.field_info
|
values = await asyncio.gather(
|
||||||
assert isinstance(field_info, Param), "Params must be subclasses of Param"
|
*(self._solve_field(field, params) for field in self.params)
|
||||||
value = await field_info._solve(**params)
|
)
|
||||||
if value is Undefined:
|
return {field.name: value for field, value in zip(self.params, values)}
|
||||||
value = field.get_default()
|
|
||||||
|
|
||||||
try:
|
|
||||||
values[field.name] = check_field_type(field, value)
|
|
||||||
except TypeMisMatch:
|
|
||||||
logger.debug(
|
|
||||||
f"{field_info} "
|
|
||||||
f"type {type(value)} not match depends {self.call} "
|
|
||||||
f"annotation {field._type_display()}, ignored"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
__autodoc__ = {"CustomConfig": False}
|
__autodoc__ = {"CustomConfig": False}
|
||||||
|
@ -12,6 +12,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Callable,
|
Callable,
|
||||||
|
Iterable,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
overload,
|
overload,
|
||||||
@ -133,7 +134,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
_default_permission_updater: Optional[Dependent[Permission]] = None
|
_default_permission_updater: Optional[Dependent[Permission]] = None
|
||||||
"""事件响应器权限更新函数"""
|
"""事件响应器权限更新函数"""
|
||||||
|
|
||||||
HANDLER_PARAM_TYPES = [
|
HANDLER_PARAM_TYPES = (
|
||||||
DependParam,
|
DependParam,
|
||||||
BotParam,
|
BotParam,
|
||||||
EventParam,
|
EventParam,
|
||||||
@ -141,7 +142,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
ArgParam,
|
ArgParam,
|
||||||
MatcherParam,
|
MatcherParam,
|
||||||
DefaultParam,
|
DefaultParam,
|
||||||
]
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.handlers = self.handlers.copy()
|
self.handlers = self.handlers.copy()
|
||||||
@ -153,9 +154,6 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
f"priority={self.priority}, temp={self.temp}>"
|
f"priority={self.priority}, temp={self.temp}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return repr(self)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(
|
def new(
|
||||||
cls,
|
cls,
|
||||||
@ -219,27 +217,35 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
"temp": temp,
|
"temp": temp,
|
||||||
"expire_time": (
|
"expire_time": (
|
||||||
expire_time
|
expire_time
|
||||||
if isinstance(expire_time, datetime)
|
and (
|
||||||
else expire_time and datetime.now() + expire_time
|
expire_time
|
||||||
|
if isinstance(expire_time, datetime)
|
||||||
|
else datetime.now() + expire_time
|
||||||
|
)
|
||||||
),
|
),
|
||||||
"priority": priority,
|
"priority": priority,
|
||||||
"block": block,
|
"block": block,
|
||||||
"_default_state": default_state or {},
|
"_default_state": default_state or {},
|
||||||
"_default_type_updater": (
|
"_default_type_updater": (
|
||||||
default_type_updater
|
default_type_updater
|
||||||
if isinstance(default_type_updater, Dependent)
|
and (
|
||||||
else default_type_updater
|
default_type_updater
|
||||||
and Dependent[str].parse(
|
if isinstance(default_type_updater, Dependent)
|
||||||
call=default_type_updater, allow_types=cls.HANDLER_PARAM_TYPES
|
else Dependent[str].parse(
|
||||||
|
call=default_type_updater,
|
||||||
|
allow_types=cls.HANDLER_PARAM_TYPES,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
"_default_permission_updater": (
|
"_default_permission_updater": (
|
||||||
default_permission_updater
|
default_permission_updater
|
||||||
if isinstance(default_permission_updater, Dependent)
|
and (
|
||||||
else default_permission_updater
|
default_permission_updater
|
||||||
and Dependent[Permission].parse(
|
if isinstance(default_permission_updater, Dependent)
|
||||||
call=default_permission_updater,
|
else Dependent[Permission].parse(
|
||||||
allow_types=cls.HANDLER_PARAM_TYPES,
|
call=default_permission_updater,
|
||||||
|
allow_types=cls.HANDLER_PARAM_TYPES,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
@ -327,7 +333,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def append_handler(
|
def append_handler(
|
||||||
cls, handler: T_Handler, parameterless: Optional[List[Any]] = None
|
cls, handler: T_Handler, parameterless: Optional[Iterable[Any]] = None
|
||||||
) -> Dependent[Any]:
|
) -> Dependent[Any]:
|
||||||
handler_ = Dependent[Any].parse(
|
handler_ = Dependent[Any].parse(
|
||||||
call=handler,
|
call=handler,
|
||||||
@ -339,7 +345,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def handle(
|
def handle(
|
||||||
cls, parameterless: Optional[List[Any]] = None
|
cls, parameterless: Optional[Iterable[Any]] = None
|
||||||
) -> Callable[[T_Handler], T_Handler]:
|
) -> Callable[[T_Handler], T_Handler]:
|
||||||
"""装饰一个函数来向事件响应器直接添加一个处理函数
|
"""装饰一个函数来向事件响应器直接添加一个处理函数
|
||||||
|
|
||||||
@ -355,7 +361,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def receive(
|
def receive(
|
||||||
cls, id: str = "", parameterless: Optional[List[Any]] = None
|
cls, id: str = "", parameterless: Optional[Iterable[Any]] = None
|
||||||
) -> Callable[[T_Handler], T_Handler]:
|
) -> Callable[[T_Handler], T_Handler]:
|
||||||
"""装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数
|
"""装饰一个函数来指示 NoneBot 在接收用户新的一条消息后继续运行该函数
|
||||||
|
|
||||||
@ -373,14 +379,21 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
return
|
return
|
||||||
await matcher.reject()
|
await matcher.reject()
|
||||||
|
|
||||||
_parameterless = [Depends(_receive), *(parameterless or [])]
|
_parameterless = (Depends(_receive), *(parameterless or tuple()))
|
||||||
|
|
||||||
def _decorator(func: T_Handler) -> T_Handler:
|
def _decorator(func: T_Handler) -> T_Handler:
|
||||||
|
|
||||||
if cls.handlers and cls.handlers[-1].call is func:
|
if cls.handlers and cls.handlers[-1].call is func:
|
||||||
func_handler = cls.handlers[-1]
|
func_handler = cls.handlers[-1]
|
||||||
for depend in reversed(_parameterless):
|
new_handler = Dependent(
|
||||||
func_handler.prepend_parameterless(depend)
|
call=func_handler.call,
|
||||||
|
params=func_handler.params,
|
||||||
|
parameterless=Dependent.parse_parameterless(
|
||||||
|
tuple(_parameterless), cls.HANDLER_PARAM_TYPES
|
||||||
|
)
|
||||||
|
+ func_handler.parameterless,
|
||||||
|
)
|
||||||
|
cls.handlers[-1] = new_handler
|
||||||
else:
|
else:
|
||||||
cls.append_handler(func, parameterless=_parameterless)
|
cls.append_handler(func, parameterless=_parameterless)
|
||||||
|
|
||||||
@ -393,7 +406,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
cls,
|
cls,
|
||||||
key: str,
|
key: str,
|
||||||
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
|
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
|
||||||
parameterless: Optional[List[Any]] = None,
|
parameterless: Optional[Iterable[Any]] = None,
|
||||||
) -> Callable[[T_Handler], T_Handler]:
|
) -> Callable[[T_Handler], T_Handler]:
|
||||||
"""装饰一个函数来指示 NoneBot 获取一个参数 `key`
|
"""装饰一个函数来指示 NoneBot 获取一个参数 `key`
|
||||||
|
|
||||||
@ -414,17 +427,21 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
return
|
return
|
||||||
await matcher.reject(prompt)
|
await matcher.reject(prompt)
|
||||||
|
|
||||||
_parameterless = [
|
_parameterless = (Depends(_key_getter), *(parameterless or tuple()))
|
||||||
Depends(_key_getter),
|
|
||||||
*(parameterless or []),
|
|
||||||
]
|
|
||||||
|
|
||||||
def _decorator(func: T_Handler) -> T_Handler:
|
def _decorator(func: T_Handler) -> T_Handler:
|
||||||
|
|
||||||
if cls.handlers and cls.handlers[-1].call is func:
|
if cls.handlers and cls.handlers[-1].call is func:
|
||||||
func_handler = cls.handlers[-1]
|
func_handler = cls.handlers[-1]
|
||||||
for depend in reversed(_parameterless):
|
new_handler = Dependent(
|
||||||
func_handler.prepend_parameterless(depend)
|
call=func_handler.call,
|
||||||
|
params=func_handler.params,
|
||||||
|
parameterless=Dependent.parse_parameterless(
|
||||||
|
tuple(_parameterless), cls.HANDLER_PARAM_TYPES
|
||||||
|
)
|
||||||
|
+ func_handler.parameterless,
|
||||||
|
)
|
||||||
|
cls.handlers[-1] = new_handler
|
||||||
else:
|
else:
|
||||||
cls.append_handler(func, parameterless=_parameterless)
|
cls.append_handler(func, parameterless=_parameterless)
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import warnings
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Callable, Optional, cast
|
|
||||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||||
|
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
|
||||||
|
|
||||||
from pydantic.fields import Required, Undefined, ModelField
|
from pydantic.fields import Required, Undefined, ModelField
|
||||||
|
|
||||||
@ -76,10 +75,7 @@ class DependParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls,
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
dependent: Dependent,
|
|
||||||
name: str,
|
|
||||||
param: inspect.Parameter,
|
|
||||||
) -> Optional["DependParam"]:
|
) -> Optional["DependParam"]:
|
||||||
if isinstance(param.default, DependsInner):
|
if isinstance(param.default, DependsInner):
|
||||||
dependency: T_Handler
|
dependency: T_Handler
|
||||||
@ -90,22 +86,20 @@ class DependParam(Param):
|
|||||||
dependency = param.default.dependency
|
dependency = param.default.dependency
|
||||||
sub_dependent = Dependent[Any].parse(
|
sub_dependent = Dependent[Any].parse(
|
||||||
call=dependency,
|
call=dependency,
|
||||||
allow_types=dependent.allow_types,
|
allow_types=allow_types,
|
||||||
)
|
)
|
||||||
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
|
|
||||||
sub_dependent.pre_checkers.clear()
|
|
||||||
return cls(
|
return cls(
|
||||||
Required, use_cache=param.default.use_cache, dependent=sub_dependent
|
Required, use_cache=param.default.use_cache, dependent=sub_dependent
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_parameterless(
|
def _check_parameterless(
|
||||||
cls, dependent: "Dependent", value: Any
|
cls, value: Any, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["Param"]:
|
) -> Optional["Param"]:
|
||||||
if isinstance(value, DependsInner):
|
if isinstance(value, DependsInner):
|
||||||
assert value.dependency, "Dependency cannot be empty"
|
assert value.dependency, "Dependency cannot be empty"
|
||||||
dependent = Dependent[Any].parse(
|
dependent = Dependent[Any].parse(
|
||||||
call=value.dependency, allow_types=dependent.allow_types
|
call=value.dependency, allow_types=allow_types
|
||||||
)
|
)
|
||||||
return cls(Required, use_cache=value.use_cache, dependent=dependent)
|
return cls(Required, use_cache=value.use_cache, dependent=dependent)
|
||||||
|
|
||||||
@ -119,8 +113,7 @@ class DependParam(Param):
|
|||||||
dependency_cache = {} if dependency_cache is None else dependency_cache
|
dependency_cache = {} if dependency_cache is None else dependency_cache
|
||||||
|
|
||||||
sub_dependent: Dependent = self.extra["dependent"]
|
sub_dependent: Dependent = self.extra["dependent"]
|
||||||
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
call = cast(Callable[..., Any], sub_dependent.call)
|
||||||
call = sub_dependent.call
|
|
||||||
|
|
||||||
# solve sub dependency with current cache
|
# solve sub dependency with current cache
|
||||||
sub_values = await sub_dependent.solve(
|
sub_values = await sub_dependent.solve(
|
||||||
@ -132,7 +125,7 @@ class DependParam(Param):
|
|||||||
# run dependency function
|
# run dependency function
|
||||||
task: asyncio.Task[Any]
|
task: asyncio.Task[Any]
|
||||||
if use_cache and call in dependency_cache:
|
if use_cache and call in dependency_cache:
|
||||||
solved = await dependency_cache[call]
|
return await dependency_cache[call]
|
||||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
stack, AsyncExitStack
|
stack, AsyncExitStack
|
||||||
@ -143,30 +136,20 @@ class DependParam(Param):
|
|||||||
cm = asynccontextmanager(call)(**sub_values)
|
cm = asynccontextmanager(call)(**sub_values)
|
||||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
task = asyncio.create_task(stack.enter_async_context(cm))
|
||||||
dependency_cache[call] = task
|
dependency_cache[call] = task
|
||||||
solved = await task
|
return await task
|
||||||
elif is_coroutine_callable(call):
|
elif is_coroutine_callable(call):
|
||||||
task = asyncio.create_task(call(**sub_values))
|
task = asyncio.create_task(call(**sub_values))
|
||||||
dependency_cache[call] = task
|
dependency_cache[call] = task
|
||||||
solved = await task
|
return await task
|
||||||
else:
|
else:
|
||||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
task = asyncio.create_task(run_sync(call)(**sub_values))
|
||||||
dependency_cache[call] = task
|
dependency_cache[call] = task
|
||||||
solved = await task
|
return await task
|
||||||
|
|
||||||
return solved
|
async def _check(self, **kwargs: Any) -> None:
|
||||||
|
# run sub dependent pre-checkers
|
||||||
|
sub_dependent: Dependent = self.extra["dependent"]
|
||||||
class _BotChecker(Param):
|
await sub_dependent.check(**kwargs)
|
||||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
|
||||||
field: ModelField = self.extra["field"]
|
|
||||||
try:
|
|
||||||
return check_field_type(field, bot)
|
|
||||||
except TypeMisMatch:
|
|
||||||
logger.debug(
|
|
||||||
f"Bot type {type(bot)} not match "
|
|
||||||
f"annotation {field._type_display()}, ignored"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class BotParam(Param):
|
class BotParam(Param):
|
||||||
@ -174,45 +157,32 @@ class BotParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["BotParam"]:
|
) -> Optional["BotParam"]:
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
|
|
||||||
if param.default == param.empty:
|
if param.default == param.empty:
|
||||||
if generic_check_issubclass(param.annotation, Bot):
|
if generic_check_issubclass(param.annotation, Bot):
|
||||||
|
checker: Optional[ModelField] = None
|
||||||
if param.annotation is not Bot:
|
if param.annotation is not Bot:
|
||||||
dependent.pre_checkers.append(
|
checker = ModelField(
|
||||||
_BotChecker(
|
name=param.name,
|
||||||
Required,
|
type_=param.annotation,
|
||||||
field=ModelField(
|
class_validators=None,
|
||||||
name=name,
|
model_config=CustomConfig,
|
||||||
type_=param.annotation,
|
default=None,
|
||||||
class_validators=None,
|
required=True,
|
||||||
model_config=CustomConfig,
|
|
||||||
default=None,
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return cls(Required)
|
return cls(Required, checker=checker)
|
||||||
elif param.annotation == param.empty and name == "bot":
|
elif param.annotation == param.empty and param.name == "bot":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
|
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
|
||||||
class _EventChecker(Param):
|
if checker := self.extra.get("checker", None):
|
||||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
check_field_type(checker, bot)
|
||||||
field: ModelField = self.extra["field"]
|
|
||||||
try:
|
|
||||||
return check_field_type(field, event)
|
|
||||||
except TypeMisMatch:
|
|
||||||
logger.debug(
|
|
||||||
f"Event type {type(event)} not match "
|
|
||||||
f"annotation {field._type_display()}, ignored"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class EventParam(Param):
|
class EventParam(Param):
|
||||||
@ -220,33 +190,33 @@ class EventParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["EventParam"]:
|
) -> Optional["EventParam"]:
|
||||||
from nonebot.adapters import Event
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
if param.default == param.empty:
|
if param.default == param.empty:
|
||||||
if generic_check_issubclass(param.annotation, Event):
|
if generic_check_issubclass(param.annotation, Event):
|
||||||
|
checker: Optional[ModelField] = None
|
||||||
if param.annotation is not Event:
|
if param.annotation is not Event:
|
||||||
dependent.pre_checkers.append(
|
checker = ModelField(
|
||||||
_EventChecker(
|
name=param.name,
|
||||||
Required,
|
type_=param.annotation,
|
||||||
field=ModelField(
|
class_validators=None,
|
||||||
name=name,
|
model_config=CustomConfig,
|
||||||
type_=param.annotation,
|
default=None,
|
||||||
class_validators=None,
|
required=True,
|
||||||
model_config=CustomConfig,
|
|
||||||
default=None,
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return cls(Required)
|
return cls(Required, checker=checker)
|
||||||
elif param.annotation == param.empty and name == "event":
|
elif param.annotation == param.empty and param.name == "event":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
async def _check(self, event: "Event", **kwargs: Any) -> Any:
|
||||||
|
if checker := self.extra.get("checker", None):
|
||||||
|
check_field_type(checker, event)
|
||||||
|
|
||||||
|
|
||||||
class StateInner(T_State):
|
class StateInner(T_State):
|
||||||
...
|
...
|
||||||
@ -257,14 +227,14 @@ class StateParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["StateParam"]:
|
) -> Optional["StateParam"]:
|
||||||
if isinstance(param.default, StateInner):
|
if isinstance(param.default, StateInner):
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
elif param.default == param.empty:
|
elif param.default == param.empty:
|
||||||
if param.annotation is T_State:
|
if param.annotation is T_State:
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
elif param.annotation == param.empty and name == "state":
|
elif param.annotation == param.empty and param.name == "state":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||||
@ -276,12 +246,12 @@ class MatcherParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["MatcherParam"]:
|
) -> Optional["MatcherParam"]:
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
|
|
||||||
if generic_check_issubclass(param.annotation, Matcher) or (
|
if generic_check_issubclass(param.annotation, Matcher) or (
|
||||||
param.annotation == param.empty and name == "matcher"
|
param.annotation == param.empty and param.name == "matcher"
|
||||||
):
|
):
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
@ -317,10 +287,12 @@ class ArgParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["ArgParam"]:
|
) -> Optional["ArgParam"]:
|
||||||
if isinstance(param.default, ArgInner):
|
if isinstance(param.default, ArgInner):
|
||||||
return cls(Required, key=param.default.key or name, type=param.default.type)
|
return cls(
|
||||||
|
Required, key=param.default.key or param.name, type=param.default.type
|
||||||
|
)
|
||||||
|
|
||||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||||
message = matcher.get_arg(self.extra["key"])
|
message = matcher.get_arg(self.extra["key"])
|
||||||
@ -339,10 +311,10 @@ class ExceptionParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["ExceptionParam"]:
|
) -> Optional["ExceptionParam"]:
|
||||||
if generic_check_issubclass(param.annotation, Exception) or (
|
if generic_check_issubclass(param.annotation, Exception) or (
|
||||||
param.annotation == param.empty and name == "exception"
|
param.annotation == param.empty and param.name == "exception"
|
||||||
):
|
):
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
@ -355,7 +327,7 @@ class DefaultParam(Param):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, dependent: Dependent, name: str, param: inspect.Parameter
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["DefaultParam"]:
|
) -> Optional["DefaultParam"]:
|
||||||
if param.default != param.empty:
|
if param.default != param.empty:
|
||||||
return cls(param.default)
|
return cls(param.default)
|
||||||
|
Loading…
Reference in New Issue
Block a user