nonebot2/nonebot/params.py

465 lines
13 KiB
Python
Raw Normal View History

2021-12-16 23:22:25 +08:00
import asyncio
2021-11-15 21:44:24 +08:00
import inspect
2022-01-10 22:24:45 +08:00
import warnings
2021-12-20 00:28:02 +08:00
from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
2021-11-15 21:44:24 +08:00
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
2021-12-14 01:08:48 +08:00
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent, CustomConfig
2021-12-16 23:22:25 +08:00
from nonebot.typing import T_State, T_Handler, T_DependencyCache
2021-12-14 01:08:48 +08:00
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
REGEX_DICT,
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
)
from nonebot.utils import (
get_name,
run_sync,
is_gen_callable,
run_sync_ctx_manager,
is_async_gen_callable,
is_coroutine_callable,
generic_check_issubclass,
)
class DependsInner:
def __init__(
self,
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
def __repr__(self) -> str:
dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False"
return f"{self.__class__.__name__}({dep}{cache})"
def Depends(
dependency: Optional[T_Handler] = None,
*,
2021-12-16 23:22:25 +08:00
use_cache: bool = True,
) -> Any:
"""
2022-01-12 18:16:05 +08:00
参数依赖注入装饰器
2022-01-12 18:31:12 +08:00
参数:
2022-01-12 19:10:29 +08:00
dependency: 依赖函数默认为参数的类型注释
use_cache: 是否使用缓存默认为 `True`
2022-01-12 18:53:30 +08:00
用法:
```python
def depend_func() -> Any:
return ...
def depend_gen_func():
try:
yield ...
finally:
...
async def handler(param_name: Any = Depends(depend_func), gen: Any = Depends(depend_gen_func)):
...
2022-01-12 18:53:30 +08:00
```
"""
return DependsInner(dependency, use_cache=use_cache)
class DependParam(Param):
@classmethod
def _check_param(
cls,
dependent: Dependent,
name: str,
param: inspect.Parameter,
) -> Optional["DependParam"]:
if isinstance(param.default, DependsInner):
dependency: T_Handler
if param.default.dependency is None:
assert param.annotation is not param.empty, "Dependency cannot be empty"
dependency = param.annotation
else:
dependency = param.default.dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=dependent.allow_types,
)
dependent.pre_checkers.extend(sub_dependent.pre_checkers)
sub_dependent.pre_checkers.clear()
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
)
@classmethod
def _check_parameterless(
cls, dependent: "Dependent", value: Any
) -> Optional["Param"]:
if isinstance(value, DependsInner):
assert value.dependency, "Dependency cannot be empty"
dependent = Dependent[Any].parse(
call=value.dependency, allow_types=dependent.allow_types
)
return cls(Required, use_cache=value.use_cache, dependent=dependent)
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
2021-12-16 23:22:25 +08:00
dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.extra["use_cache"]
2021-12-16 23:22:25 +08:00
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
call = sub_dependent.call
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
# run dependency function
2021-12-16 23:22:25 +08:00
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
solved = await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
2021-12-16 23:22:25 +08:00
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
solved = await task
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
solved = await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
solved = await task
return solved
2021-11-15 21:44:24 +08:00
class _BotChecker(Param):
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(bot, {}, loc=("bot",))
if errs_:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, bot)
2021-11-15 21:44:24 +08:00
class BotParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["BotParam"]:
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Bot):
2021-12-31 23:58:59 +08:00
if param.annotation is not Bot:
dependent.pre_checkers.append(
_BotChecker(
Required,
field=ModelField(
2022-01-07 16:27:35 +08:00
name=name,
2021-12-31 23:58:59 +08:00
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "bot":
return cls(Required)
2021-11-15 21:44:24 +08:00
2021-12-13 00:37:07 +08:00
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
2021-11-15 21:44:24 +08:00
return bot
class _EventChecker(Param):
async def _solve(self, event: Event, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(event, {}, loc=("event",))
if errs_:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, event)
2021-11-15 21:44:24 +08:00
class EventParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["EventParam"]:
if param.default == param.empty:
if generic_check_issubclass(param.annotation, Event):
2021-12-31 23:58:59 +08:00
if param.annotation is not Event:
dependent.pre_checkers.append(
_EventChecker(
Required,
field=ModelField(
2022-01-07 16:27:35 +08:00
name=name,
2021-12-31 23:58:59 +08:00
type_=param.annotation,
class_validators=None,
model_config=CustomConfig,
default=None,
required=True,
),
)
)
return cls(Required)
elif param.annotation == param.empty and name == "event":
return cls(Required)
2021-11-15 21:44:24 +08:00
2021-12-13 00:37:07 +08:00
async def _solve(self, event: Event, **kwargs: Any) -> Any:
2021-11-15 21:44:24 +08:00
return event
2021-12-14 01:08:48 +08:00
async def _event_type(event: Event) -> str:
return event.get_type()
def EventType() -> str:
return Depends(_event_type)
async def _event_message(event: Event) -> Message:
return event.get_message()
2021-12-20 00:28:02 +08:00
def EventMessage() -> Any:
2021-12-14 01:08:48 +08:00
return Depends(_event_message)
async def _event_plain_text(event: Event) -> str:
return event.get_plaintext()
def EventPlainText() -> str:
return Depends(_event_plain_text)
async def _event_to_me(event: Event) -> bool:
return event.is_tome()
def EventToMe() -> bool:
return Depends(_event_to_me)
2022-01-10 22:24:45 +08:00
class StateInner(T_State):
...
def State() -> T_State:
warnings.warn("State() is deprecated, use T_State instead", DeprecationWarning)
return StateInner()
2021-11-15 21:44:24 +08:00
class StateParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["StateParam"]:
2022-01-10 22:24:45 +08:00
if isinstance(param.default, StateInner):
return cls(Required)
elif param.default == param.empty:
2022-01-10 11:20:06 +08:00
if param.annotation is T_State:
return cls(Required)
elif param.annotation == param.empty and name == "state":
return cls(Required)
2021-11-15 21:44:24 +08:00
2021-12-13 00:37:07 +08:00
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
2021-11-15 21:44:24 +08:00
return state
2022-01-10 11:20:06 +08:00
def _command(state: T_State) -> Message:
2021-12-14 01:08:48 +08:00
return state[PREFIX_KEY][CMD_KEY]
def Command() -> Tuple[str, ...]:
2021-12-16 23:22:25 +08:00
return Depends(_command, use_cache=False)
2021-12-14 01:08:48 +08:00
2022-01-10 11:20:06 +08:00
def _raw_command(state: T_State) -> Message:
2021-12-14 01:08:48 +08:00
return state[PREFIX_KEY][RAW_CMD_KEY]
def RawCommand() -> str:
2021-12-16 23:22:25 +08:00
return Depends(_raw_command, use_cache=False)
2021-12-14 01:08:48 +08:00
2022-01-10 11:20:06 +08:00
def _command_arg(state: T_State) -> Message:
2021-12-14 01:08:48 +08:00
return state[PREFIX_KEY][CMD_ARG_KEY]
2021-12-20 00:28:02 +08:00
def CommandArg() -> Any:
2021-12-16 23:22:25 +08:00
return Depends(_command_arg, use_cache=False)
2021-12-14 01:08:48 +08:00
2022-01-10 11:20:06 +08:00
def _shell_command_args(state: T_State) -> Any:
return state[SHELL_ARGS]
def ShellCommandArgs():
2021-12-16 23:22:25 +08:00
return Depends(_shell_command_args, use_cache=False)
2022-01-10 11:20:06 +08:00
def _shell_command_argv(state: T_State) -> List[str]:
return state[SHELL_ARGV]
def ShellCommandArgv() -> Any:
2021-12-16 23:22:25 +08:00
return Depends(_shell_command_argv, use_cache=False)
2022-01-10 11:20:06 +08:00
def _regex_matched(state: T_State) -> str:
return state[REGEX_MATCHED]
def RegexMatched() -> str:
2021-12-16 23:22:25 +08:00
return Depends(_regex_matched, use_cache=False)
2022-01-10 11:20:06 +08:00
def _regex_group(state: T_State):
return state[REGEX_GROUP]
def RegexGroup() -> Tuple[Any, ...]:
2021-12-16 23:22:25 +08:00
return Depends(_regex_group, use_cache=False)
2022-01-10 11:20:06 +08:00
def _regex_dict(state: T_State):
return state[REGEX_DICT]
def RegexDict() -> Dict[str, Any]:
2021-12-16 23:22:25 +08:00
return Depends(_regex_dict, use_cache=False)
2021-11-15 21:44:24 +08:00
class MatcherParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["MatcherParam"]:
if generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and name == "matcher"
):
return cls(Required)
2021-11-15 21:44:24 +08:00
2021-12-13 00:37:07 +08:00
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
2021-11-15 21:44:24 +08:00
return matcher
def Received(id: Optional[str] = None, default: Any = None) -> Any:
def _received(matcher: "Matcher"):
return matcher.get_receive(id or "", default)
2021-12-14 01:08:48 +08:00
2021-12-16 23:22:25 +08:00
return Depends(_received, use_cache=False)
2021-12-14 01:08:48 +08:00
def LastReceived(default: Any = None) -> Any:
def _last_received(matcher: "Matcher") -> Any:
return matcher.get_last_receive(default)
2021-12-16 23:22:25 +08:00
return Depends(_last_received, use_cache=False)
2021-12-14 01:08:48 +08:00
2021-12-20 00:28:02 +08:00
class ArgInner:
def __init__(
2021-12-23 22:16:55 +08:00
self, key: Optional[str], type: Literal["message", "str", "plaintext"]
2021-12-20 00:28:02 +08:00
) -> None:
self.key = key
self.type = type
def Arg(key: Optional[str] = None) -> Any:
return ArgInner(key, "message")
2021-12-23 22:16:55 +08:00
def ArgStr(key: Optional[str] = None) -> str:
return ArgInner(key, "str") # type: ignore
2021-12-20 00:28:02 +08:00
2021-12-23 22:16:55 +08:00
def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "plaintext") # type: ignore
2021-12-20 00:28:02 +08:00
class ArgParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
2021-12-23 22:16:55 +08:00
message = matcher.get_arg(self.extra["key"])
if message is None:
return message
if self.extra["type"] == "message":
return message
elif self.extra["type"] == "str":
return str(message)
2021-12-20 00:28:02 +08:00
else:
2021-12-23 22:16:55 +08:00
return message.extract_plain_text()
2021-12-20 00:28:02 +08:00
2021-11-15 21:44:24 +08:00
class ExceptionParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and name == "exception"
):
return cls(Required)
2021-11-15 21:44:24 +08:00
2021-12-13 00:37:07 +08:00
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
2021-11-15 21:44:24 +08:00
return exception
class DefaultParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["DefaultParam"]:
if param.default != param.empty:
return cls(param.default)
2021-12-13 00:37:07 +08:00
async def _solve(self, **kwargs: Any) -> Any:
return Undefined
from nonebot.matcher import Matcher