mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
✨ improve dependency injection params (#2034)
This commit is contained in:
parent
dd80191761
commit
aa48299d5d
@ -71,7 +71,12 @@ def Depends(
|
||||
|
||||
|
||||
class DependParam(Param):
|
||||
"""子依赖参数"""
|
||||
"""子依赖注入参数。
|
||||
|
||||
本注入解析所有子依赖注入,然后将它们的返回值作为参数值传递给父依赖。
|
||||
|
||||
本注入应该具有最高优先级,因此应该在其他参数之前检查。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Depends({self.extra['dependent']})"
|
||||
@ -168,7 +173,12 @@ class DependParam(Param):
|
||||
|
||||
|
||||
class BotParam(Param):
|
||||
"""{ref}`nonebot.adapters.Bot` 参数"""
|
||||
"""{ref}`nonebot.adapters.Bot` 注入参数。
|
||||
|
||||
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Bot` 及其子类或 `None` 的参数。
|
||||
|
||||
为保证兼容性,本注入还会解析名为 `bot` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@ -187,21 +197,22 @@ class BotParam(Param):
|
||||
) -> Optional["BotParam"]:
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
if param.default == param.empty:
|
||||
if generic_check_issubclass(param.annotation, Bot):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Bot:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
elif param.annotation == param.empty and param.name == "bot":
|
||||
return cls(Required)
|
||||
# param type is Bot(s) or subclass(es) of Bot or None
|
||||
if generic_check_issubclass(param.annotation, Bot):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Bot:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
# legacy: param is named "bot" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "bot":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||
return bot
|
||||
@ -212,7 +223,12 @@ class BotParam(Param):
|
||||
|
||||
|
||||
class EventParam(Param):
|
||||
"""{ref}`nonebot.adapters.Event` 参数"""
|
||||
"""{ref}`nonebot.adapters.Event` 注入参数
|
||||
|
||||
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Event` 及其子类或 `None` 的参数。
|
||||
|
||||
为保证兼容性,本注入还会解析名为 `event` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@ -231,21 +247,22 @@ class EventParam(Param):
|
||||
) -> Optional["EventParam"]:
|
||||
from nonebot.adapters import Event
|
||||
|
||||
if param.default == param.empty:
|
||||
if generic_check_issubclass(param.annotation, Event):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Event:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
elif param.annotation == param.empty and param.name == "event":
|
||||
return cls(Required)
|
||||
# param type is Event(s) or subclass(es) of Event or None
|
||||
if generic_check_issubclass(param.annotation, Event):
|
||||
checker: Optional[ModelField] = None
|
||||
if param.annotation is not Event:
|
||||
checker = ModelField(
|
||||
name=param.name,
|
||||
type_=param.annotation,
|
||||
class_validators=None,
|
||||
model_config=CustomConfig,
|
||||
default=None,
|
||||
required=True,
|
||||
)
|
||||
return cls(Required, checker=checker)
|
||||
# legacy: param is named "event" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "event":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||
return event
|
||||
@ -256,7 +273,12 @@ class EventParam(Param):
|
||||
|
||||
|
||||
class StateParam(Param):
|
||||
"""事件处理状态参数"""
|
||||
"""事件处理状态注入参数
|
||||
|
||||
本注入解析所有类型为 `T_State` 的参数。
|
||||
|
||||
为保证兼容性,本注入还会解析名为 `state` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "StateParam()"
|
||||
@ -265,18 +287,24 @@ class StateParam(Param):
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["StateParam"]:
|
||||
if param.default == param.empty:
|
||||
if param.annotation is T_State:
|
||||
return cls(Required)
|
||||
elif param.annotation == param.empty and param.name == "state":
|
||||
return cls(Required)
|
||||
# param type is T_State
|
||||
if param.annotation is T_State:
|
||||
return cls(Required)
|
||||
# legacy: param is named "state" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "state":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||
return state
|
||||
|
||||
|
||||
class MatcherParam(Param):
|
||||
"""事件响应器实例参数"""
|
||||
"""事件响应器实例注入参数
|
||||
|
||||
本注入解析所有类型为且仅为 {ref}`nonebot.matcher.Matcher` 及其子类或 `None` 的参数。
|
||||
|
||||
为保证兼容性,本注入还会解析名为 `matcher` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "MatcherParam()"
|
||||
@ -287,9 +315,11 @@ class MatcherParam(Param):
|
||||
) -> Optional["MatcherParam"]:
|
||||
from nonebot.matcher import Matcher
|
||||
|
||||
if generic_check_issubclass(param.annotation, Matcher) or (
|
||||
param.annotation == param.empty and param.name == "matcher"
|
||||
):
|
||||
# param type is Matcher(s) or subclass(es) of Matcher or None
|
||||
if generic_check_issubclass(param.annotation, Matcher):
|
||||
return cls(Required)
|
||||
# legacy: param is named "matcher" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "matcher":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
@ -308,22 +338,28 @@ class ArgInner:
|
||||
|
||||
|
||||
def Arg(key: Optional[str] = None) -> Any:
|
||||
"""`got` 的 Arg 参数消息"""
|
||||
"""Arg 参数消息"""
|
||||
return ArgInner(key, "message")
|
||||
|
||||
|
||||
def ArgStr(key: Optional[str] = None) -> str:
|
||||
"""`got` 的 Arg 参数消息文本"""
|
||||
"""Arg 参数消息文本"""
|
||||
return ArgInner(key, "str") # type: ignore
|
||||
|
||||
|
||||
def ArgPlainText(key: Optional[str] = None) -> str:
|
||||
"""`got` 的 Arg 参数消息纯文本"""
|
||||
"""Arg 参数消息纯文本"""
|
||||
return ArgInner(key, "plaintext") # type: ignore
|
||||
|
||||
|
||||
class ArgParam(Param):
|
||||
"""`got` 的 Arg 参数"""
|
||||
"""Arg 注入参数
|
||||
|
||||
本注入解析事件响应器操作 `got` 所获取的参数。
|
||||
|
||||
可以通过 `Arg`、`ArgStr`、`ArgPlainText` 等函数参数 `key` 指定获取的参数,
|
||||
留空则会根据参数名称获取。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
||||
@ -338,7 +374,8 @@ class ArgParam(Param):
|
||||
)
|
||||
|
||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||
message = matcher.get_arg(self.extra["key"])
|
||||
key: str = self.extra["key"]
|
||||
message = matcher.get_arg(key)
|
||||
if message is None:
|
||||
return message
|
||||
if self.extra["type"] == "message":
|
||||
@ -350,7 +387,12 @@ class ArgParam(Param):
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
"""`run_postprocessor` 的异常参数"""
|
||||
"""{ref}`nonebot.message.run_postprocessor` 的异常注入参数
|
||||
|
||||
本注入解析所有类型为 `Exception` 或 `None` 的参数。
|
||||
|
||||
为保证兼容性,本注入还会解析名为 `exception` 且没有类型注解的参数。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ExceptionParam()"
|
||||
@ -359,9 +401,11 @@ class ExceptionParam(Param):
|
||||
def _check_param(
|
||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||
) -> Optional["ExceptionParam"]:
|
||||
if generic_check_issubclass(param.annotation, Exception) or (
|
||||
param.annotation == param.empty and param.name == "exception"
|
||||
):
|
||||
# param type is Exception(s) or subclass(es) of Exception or None
|
||||
if generic_check_issubclass(param.annotation, Exception):
|
||||
return cls(Required)
|
||||
# legacy: param is named "exception" and has no type annotation
|
||||
elif param.annotation == param.empty and param.name == "exception":
|
||||
return cls(Required)
|
||||
|
||||
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
||||
@ -369,7 +413,12 @@ class ExceptionParam(Param):
|
||||
|
||||
|
||||
class DefaultParam(Param):
|
||||
"""默认值参数"""
|
||||
"""默认值注入参数
|
||||
|
||||
本注入解析所有剩余未能解析且具有默认值的参数。
|
||||
|
||||
本注入参数应该具有最低优先级,因此应该在所有其他注入参数之后使用。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DefaultParam(default={self.default!r})"
|
||||
|
@ -59,7 +59,7 @@ def generic_check_issubclass(
|
||||
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。
|
||||
|
||||
特别的,如果 cls 是 `typing.Union` 或 `types.UnionType` 类型,
|
||||
则会检查其中的类型是否是 class_or_tuple 中的一个类型子类。(None 会被忽略)
|
||||
则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。
|
||||
"""
|
||||
try:
|
||||
return issubclass(cls, class_or_tuple)
|
||||
|
23
tests/plugins/param/priority.py
Normal file
23
tests/plugins/param/priority.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Optional
|
||||
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Arg, Depends
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
|
||||
|
||||
def dependency():
|
||||
return 1
|
||||
|
||||
|
||||
async def complex_priority(
|
||||
sub: int = Depends(dependency),
|
||||
bot: Optional[Bot] = None,
|
||||
event: Optional[Event] = None,
|
||||
state: T_State = {},
|
||||
matcher: Optional[Matcher] = None,
|
||||
arg: Message = Arg(),
|
||||
exception: Optional[Exception] = None,
|
||||
default: int = 1,
|
||||
):
|
||||
...
|
@ -4,6 +4,7 @@ import pytest
|
||||
from nonebug import App
|
||||
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from utils import make_fake_event, make_fake_message
|
||||
from nonebot.params import (
|
||||
@ -413,3 +414,41 @@ async def test_default(app: App):
|
||||
|
||||
async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx:
|
||||
ctx.should_return(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_priority():
|
||||
from plugins.param.priority import complex_priority
|
||||
|
||||
dependent = Dependent.parse(
|
||||
call=complex_priority,
|
||||
allow_types=[
|
||||
DependParam,
|
||||
BotParam,
|
||||
EventParam,
|
||||
StateParam,
|
||||
MatcherParam,
|
||||
ArgParam,
|
||||
ExceptionParam,
|
||||
DefaultParam,
|
||||
],
|
||||
)
|
||||
for param in dependent.params:
|
||||
if param.name == "sub":
|
||||
assert isinstance(param.field_info, DependParam)
|
||||
elif param.name == "bot":
|
||||
assert isinstance(param.field_info, BotParam)
|
||||
elif param.name == "event":
|
||||
assert isinstance(param.field_info, EventParam)
|
||||
elif param.name == "state":
|
||||
assert isinstance(param.field_info, StateParam)
|
||||
elif param.name == "matcher":
|
||||
assert isinstance(param.field_info, MatcherParam)
|
||||
elif param.name == "arg":
|
||||
assert isinstance(param.field_info, ArgParam)
|
||||
elif param.name == "exception":
|
||||
assert isinstance(param.field_info, ExceptionParam)
|
||||
elif param.name == "default":
|
||||
assert isinstance(param.field_info, DefaultParam)
|
||||
else:
|
||||
raise ValueError(f"unknown param {param.name}")
|
||||
|
Loading…
Reference in New Issue
Block a user