nonebot2/nonebot/internal/params.py

379 lines
12 KiB
Python
Raw Normal View History

import asyncio
import inspect
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
from pydantic.fields import Required, Undefined, ModelField
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.utils import (
get_name,
run_sync,
is_gen_callable,
run_sync_ctx_manager,
is_async_gen_callable,
is_coroutine_callable,
generic_check_issubclass,
)
if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
class DependsInner:
def __init__(
self,
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False"
return f"DependsInner({dep}{cache})"
def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
) -> Any:
"""子依赖装饰器
参数:
dependency: 依赖函数默认为参数的类型注释
use_cache: 是否使用缓存默认为 `True`
用法:
```python
def depend_func() -> Any:
return ...
def depend_gen_func():
try:
yield ...
finally:
...
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
...
```
"""
return DependsInner(dependency, use_cache=use_cache)
class DependParam(Param):
"""子依赖参数"""
def __repr__(self) -> str:
return f"Depends({self.extra['dependent']})"
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DependParam"]:
if isinstance(param.default, DependsInner):
dependency: T_Handler
if param.default.dependency is None:
assert param.annotation is not param.empty, "Dependency cannot be empty"
dependency = param.annotation
else:
dependency = param.default.dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=allow_types,
)
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod
def _check_parameterless(
cls, value: Any, allow_types: Tuple[Type[Param], ...]
) -> Optional["Param"]:
if isinstance(value, DependsInner):
assert value.dependency, "Dependency cannot be empty"
dependent = Dependent[Any].parse(
call=value.dependency, allow_types=allow_types
)
return cls(Required, use_cache=value.use_cache, dependent=dependent)
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.extra["use_cache"]
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
# run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
return await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
return await task
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
return await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
return await task
async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"]
await sub_dependent.check(**kwargs)
class BotParam(Param):
"""{ref}`nonebot.adapters.Bot` 参数"""
def __repr__(self) -> str:
return (
"BotParam("
+ (
repr(cast(ModelField, checker).type_)
if (checker := self.extra.get("checker"))
else ""
)
+ ")"
)
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["BotParam"]:
from nonebot.adapters import Bot
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None
if param.annotation is not Bot:
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
)
return cls(Required, checker=checker)
elif param.annotation == param.empty and param.name == "bot":
return cls(Required)
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
if checker := self.extra.get("checker"):
check_field_type(checker, bot)
class EventParam(Param):
"""{ref}`nonebot.adapters.Event` 参数"""
def __repr__(self) -> str:
return (
"EventParam("
+ (
repr(cast(ModelField, checker).type_)
if (checker := self.extra.get("checker"))
else ""
)
+ ")"
)
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["EventParam"]:
from nonebot.adapters import Event
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None
if param.annotation is not Event:
checker = ModelField(
name=param.name,
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
)
return cls(Required, checker=checker)
elif param.annotation == param.empty and param.name == "event":
return cls(Required)
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event
async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, event)
class StateParam(Param):
"""事件处理状态参数"""
def __repr__(self) -> str:
return "StateParam()"
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]:
if param.default == param.empty:
if param.annotation is T_State:
return cls(Required)
elif param.annotation == param.empty and param.name == "state":
return cls(Required)
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state
class MatcherParam(Param):
"""事件响应器实例参数"""
def __repr__(self) -> str:
return "MatcherParam()"
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["MatcherParam"]:
from nonebot.matcher import Matcher
if generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and param.name == "matcher"
):
return cls(Required)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
return matcher
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
) -> None:
self.key = key
self.type = type
def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})"
def Arg(key: Optional[str] = None) -> Any:
"""`got` 的 Arg 参数消息"""
return ArgInner(key, "message")
def ArgStr(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息文本"""
return ArgInner(key, "str") # type: ignore
def ArgPlainText(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息纯文本"""
return ArgInner(key, "plaintext") # type: ignore
class ArgParam(Param):
"""`got` 的 Arg 参数"""
def __repr__(self) -> str:
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(
Required, key=param.default.key or param.name, type=param.default.type
)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
message = matcher.get_arg(self.extra["key"])
if message is None:
return message
if self.extra["type"] == "message":
return message
elif self.extra["type"] == "str":
return str(message)
else:
return message.extract_plain_text()
class ExceptionParam(Param):
"""`run_postprocessor` 的异常参数"""
def __repr__(self) -> str:
return "ExceptionParam()"
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and param.name == "exception"
):
return cls(Required)
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
return exception
class DefaultParam(Param):
"""默认值参数"""
def __repr__(self) -> str:
return f"DefaultParam(default={self.default!r})"
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DefaultParam"]:
if param.default != param.empty:
return cls(param.default)
async def _solve(self, **kwargs: Any) -> Any:
return Undefined
__autodoc__ = {
"DependsInner": False,
"StateInner": False,
"ArgInner": False,
}