mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ improve dependency injection params (#2034)
This commit is contained in:
parent
dd80191761
commit
aa48299d5d
@ -71,7 +71,12 @@ def Depends(
|
|||||||
|
|
||||||
|
|
||||||
class DependParam(Param):
|
class DependParam(Param):
|
||||||
"""子依赖参数"""
|
"""子依赖注入参数。
|
||||||
|
|
||||||
|
本注入解析所有子依赖注入,然后将它们的返回值作为参数值传递给父依赖。
|
||||||
|
|
||||||
|
本注入应该具有最高优先级,因此应该在其他参数之前检查。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Depends({self.extra['dependent']})"
|
return f"Depends({self.extra['dependent']})"
|
||||||
@ -168,7 +173,12 @@ class DependParam(Param):
|
|||||||
|
|
||||||
|
|
||||||
class BotParam(Param):
|
class BotParam(Param):
|
||||||
"""{ref}`nonebot.adapters.Bot` 参数"""
|
"""{ref}`nonebot.adapters.Bot` 注入参数。
|
||||||
|
|
||||||
|
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Bot` 及其子类或 `None` 的参数。
|
||||||
|
|
||||||
|
为保证兼容性,本注入还会解析名为 `bot` 且没有类型注解的参数。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
@ -187,21 +197,22 @@ class BotParam(Param):
|
|||||||
) -> Optional["BotParam"]:
|
) -> Optional["BotParam"]:
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
|
|
||||||
if param.default == param.empty:
|
# param type is Bot(s) or subclass(es) of Bot or None
|
||||||
if generic_check_issubclass(param.annotation, Bot):
|
if generic_check_issubclass(param.annotation, Bot):
|
||||||
checker: Optional[ModelField] = None
|
checker: Optional[ModelField] = None
|
||||||
if param.annotation is not Bot:
|
if param.annotation is not Bot:
|
||||||
checker = ModelField(
|
checker = ModelField(
|
||||||
name=param.name,
|
name=param.name,
|
||||||
type_=param.annotation,
|
type_=param.annotation,
|
||||||
class_validators=None,
|
class_validators=None,
|
||||||
model_config=CustomConfig,
|
model_config=CustomConfig,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
return cls(Required, checker=checker)
|
return cls(Required, checker=checker)
|
||||||
elif param.annotation == param.empty and param.name == "bot":
|
# legacy: param is named "bot" and has no type annotation
|
||||||
return cls(Required)
|
elif param.annotation == param.empty and param.name == "bot":
|
||||||
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||||
return bot
|
return bot
|
||||||
@ -212,7 +223,12 @@ class BotParam(Param):
|
|||||||
|
|
||||||
|
|
||||||
class EventParam(Param):
|
class EventParam(Param):
|
||||||
"""{ref}`nonebot.adapters.Event` 参数"""
|
"""{ref}`nonebot.adapters.Event` 注入参数
|
||||||
|
|
||||||
|
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Event` 及其子类或 `None` 的参数。
|
||||||
|
|
||||||
|
为保证兼容性,本注入还会解析名为 `event` 且没有类型注解的参数。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
@ -231,21 +247,22 @@ class EventParam(Param):
|
|||||||
) -> Optional["EventParam"]:
|
) -> Optional["EventParam"]:
|
||||||
from nonebot.adapters import Event
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
if param.default == param.empty:
|
# param type is Event(s) or subclass(es) of Event or None
|
||||||
if generic_check_issubclass(param.annotation, Event):
|
if generic_check_issubclass(param.annotation, Event):
|
||||||
checker: Optional[ModelField] = None
|
checker: Optional[ModelField] = None
|
||||||
if param.annotation is not Event:
|
if param.annotation is not Event:
|
||||||
checker = ModelField(
|
checker = ModelField(
|
||||||
name=param.name,
|
name=param.name,
|
||||||
type_=param.annotation,
|
type_=param.annotation,
|
||||||
class_validators=None,
|
class_validators=None,
|
||||||
model_config=CustomConfig,
|
model_config=CustomConfig,
|
||||||
default=None,
|
default=None,
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
return cls(Required, checker=checker)
|
return cls(Required, checker=checker)
|
||||||
elif param.annotation == param.empty and param.name == "event":
|
# legacy: param is named "event" and has no type annotation
|
||||||
return cls(Required)
|
elif param.annotation == param.empty and param.name == "event":
|
||||||
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||||
return event
|
return event
|
||||||
@ -256,7 +273,12 @@ class EventParam(Param):
|
|||||||
|
|
||||||
|
|
||||||
class StateParam(Param):
|
class StateParam(Param):
|
||||||
"""事件处理状态参数"""
|
"""事件处理状态注入参数
|
||||||
|
|
||||||
|
本注入解析所有类型为 `T_State` 的参数。
|
||||||
|
|
||||||
|
为保证兼容性,本注入还会解析名为 `state` 且没有类型注解的参数。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "StateParam()"
|
return "StateParam()"
|
||||||
@ -265,18 +287,24 @@ class StateParam(Param):
|
|||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["StateParam"]:
|
) -> Optional["StateParam"]:
|
||||||
if param.default == param.empty:
|
# param type is T_State
|
||||||
if param.annotation is T_State:
|
if param.annotation is T_State:
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
elif param.annotation == param.empty and param.name == "state":
|
# legacy: param is named "state" and has no type annotation
|
||||||
return cls(Required)
|
elif param.annotation == param.empty and param.name == "state":
|
||||||
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
class MatcherParam(Param):
|
class MatcherParam(Param):
|
||||||
"""事件响应器实例参数"""
|
"""事件响应器实例注入参数
|
||||||
|
|
||||||
|
本注入解析所有类型为且仅为 {ref}`nonebot.matcher.Matcher` 及其子类或 `None` 的参数。
|
||||||
|
|
||||||
|
为保证兼容性,本注入还会解析名为 `matcher` 且没有类型注解的参数。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "MatcherParam()"
|
return "MatcherParam()"
|
||||||
@ -287,9 +315,11 @@ class MatcherParam(Param):
|
|||||||
) -> Optional["MatcherParam"]:
|
) -> Optional["MatcherParam"]:
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
|
|
||||||
if generic_check_issubclass(param.annotation, Matcher) or (
|
# param type is Matcher(s) or subclass(es) of Matcher or None
|
||||||
param.annotation == param.empty and param.name == "matcher"
|
if generic_check_issubclass(param.annotation, Matcher):
|
||||||
):
|
return cls(Required)
|
||||||
|
# legacy: param is named "matcher" and has no type annotation
|
||||||
|
elif param.annotation == param.empty and param.name == "matcher":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||||
@ -308,22 +338,28 @@ class ArgInner:
|
|||||||
|
|
||||||
|
|
||||||
def Arg(key: Optional[str] = None) -> Any:
|
def Arg(key: Optional[str] = None) -> Any:
|
||||||
"""`got` 的 Arg 参数消息"""
|
"""Arg 参数消息"""
|
||||||
return ArgInner(key, "message")
|
return ArgInner(key, "message")
|
||||||
|
|
||||||
|
|
||||||
def ArgStr(key: Optional[str] = None) -> str:
|
def ArgStr(key: Optional[str] = None) -> str:
|
||||||
"""`got` 的 Arg 参数消息文本"""
|
"""Arg 参数消息文本"""
|
||||||
return ArgInner(key, "str") # type: ignore
|
return ArgInner(key, "str") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def ArgPlainText(key: Optional[str] = None) -> str:
|
def ArgPlainText(key: Optional[str] = None) -> str:
|
||||||
"""`got` 的 Arg 参数消息纯文本"""
|
"""Arg 参数消息纯文本"""
|
||||||
return ArgInner(key, "plaintext") # type: ignore
|
return ArgInner(key, "plaintext") # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ArgParam(Param):
|
class ArgParam(Param):
|
||||||
"""`got` 的 Arg 参数"""
|
"""Arg 注入参数
|
||||||
|
|
||||||
|
本注入解析事件响应器操作 `got` 所获取的参数。
|
||||||
|
|
||||||
|
可以通过 `Arg`、`ArgStr`、`ArgPlainText` 等函数参数 `key` 指定获取的参数,
|
||||||
|
留空则会根据参数名称获取。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
||||||
@ -338,7 +374,8 @@ class ArgParam(Param):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||||
message = matcher.get_arg(self.extra["key"])
|
key: str = self.extra["key"]
|
||||||
|
message = matcher.get_arg(key)
|
||||||
if message is None:
|
if message is None:
|
||||||
return message
|
return message
|
||||||
if self.extra["type"] == "message":
|
if self.extra["type"] == "message":
|
||||||
@ -350,7 +387,12 @@ class ArgParam(Param):
|
|||||||
|
|
||||||
|
|
||||||
class ExceptionParam(Param):
|
class ExceptionParam(Param):
|
||||||
"""`run_postprocessor` 的异常参数"""
|
"""{ref}`nonebot.message.run_postprocessor` 的异常注入参数
|
||||||
|
|
||||||
|
本注入解析所有类型为 `Exception` 或 `None` 的参数。
|
||||||
|
|
||||||
|
为保证兼容性,本注入还会解析名为 `exception` 且没有类型注解的参数。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "ExceptionParam()"
|
return "ExceptionParam()"
|
||||||
@ -359,9 +401,11 @@ class ExceptionParam(Param):
|
|||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["ExceptionParam"]:
|
) -> Optional["ExceptionParam"]:
|
||||||
if generic_check_issubclass(param.annotation, Exception) or (
|
# param type is Exception(s) or subclass(es) of Exception or None
|
||||||
param.annotation == param.empty and param.name == "exception"
|
if generic_check_issubclass(param.annotation, Exception):
|
||||||
):
|
return cls(Required)
|
||||||
|
# legacy: param is named "exception" and has no type annotation
|
||||||
|
elif param.annotation == param.empty and param.name == "exception":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
||||||
@ -369,7 +413,12 @@ class ExceptionParam(Param):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultParam(Param):
|
class DefaultParam(Param):
|
||||||
"""默认值参数"""
|
"""默认值注入参数
|
||||||
|
|
||||||
|
本注入解析所有剩余未能解析且具有默认值的参数。
|
||||||
|
|
||||||
|
本注入参数应该具有最低优先级,因此应该在所有其他注入参数之后使用。
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"DefaultParam(default={self.default!r})"
|
return f"DefaultParam(default={self.default!r})"
|
||||||
|
@ -59,7 +59,7 @@ def generic_check_issubclass(
|
|||||||
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。
|
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。
|
||||||
|
|
||||||
特别的,如果 cls 是 `typing.Union` 或 `types.UnionType` 类型,
|
特别的,如果 cls 是 `typing.Union` 或 `types.UnionType` 类型,
|
||||||
则会检查其中的类型是否是 class_or_tuple 中的一个类型子类。(None 会被忽略)
|
则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return issubclass(cls, class_or_tuple)
|
return issubclass(cls, class_or_tuple)
|
||||||
|
23
tests/plugins/param/priority.py
Normal file
23
tests/plugins/param/priority.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nonebot.typing import T_State
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot.params import Arg, Depends
|
||||||
|
from nonebot.adapters import Bot, Event, Message
|
||||||
|
|
||||||
|
|
||||||
|
def dependency():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
async def complex_priority(
|
||||||
|
sub: int = Depends(dependency),
|
||||||
|
bot: Optional[Bot] = None,
|
||||||
|
event: Optional[Event] = None,
|
||||||
|
state: T_State = {},
|
||||||
|
matcher: Optional[Matcher] = None,
|
||||||
|
arg: Message = Arg(),
|
||||||
|
exception: Optional[Exception] = None,
|
||||||
|
default: int = 1,
|
||||||
|
):
|
||||||
|
...
|
@ -4,6 +4,7 @@ import pytest
|
|||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import TypeMisMatch
|
from nonebot.exception import TypeMisMatch
|
||||||
from utils import make_fake_event, make_fake_message
|
from utils import make_fake_event, make_fake_message
|
||||||
from nonebot.params import (
|
from nonebot.params import (
|
||||||
@ -413,3 +414,41 @@ async def test_default(app: App):
|
|||||||
|
|
||||||
async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx:
|
async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx:
|
||||||
ctx.should_return(1)
|
ctx.should_return(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_priority():
|
||||||
|
from plugins.param.priority import complex_priority
|
||||||
|
|
||||||
|
dependent = Dependent.parse(
|
||||||
|
call=complex_priority,
|
||||||
|
allow_types=[
|
||||||
|
DependParam,
|
||||||
|
BotParam,
|
||||||
|
EventParam,
|
||||||
|
StateParam,
|
||||||
|
MatcherParam,
|
||||||
|
ArgParam,
|
||||||
|
ExceptionParam,
|
||||||
|
DefaultParam,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for param in dependent.params:
|
||||||
|
if param.name == "sub":
|
||||||
|
assert isinstance(param.field_info, DependParam)
|
||||||
|
elif param.name == "bot":
|
||||||
|
assert isinstance(param.field_info, BotParam)
|
||||||
|
elif param.name == "event":
|
||||||
|
assert isinstance(param.field_info, EventParam)
|
||||||
|
elif param.name == "state":
|
||||||
|
assert isinstance(param.field_info, StateParam)
|
||||||
|
elif param.name == "matcher":
|
||||||
|
assert isinstance(param.field_info, MatcherParam)
|
||||||
|
elif param.name == "arg":
|
||||||
|
assert isinstance(param.field_info, ArgParam)
|
||||||
|
elif param.name == "exception":
|
||||||
|
assert isinstance(param.field_info, ExceptionParam)
|
||||||
|
elif param.name == "default":
|
||||||
|
assert isinstance(param.field_info, DefaultParam)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown param {param.name}")
|
||||||
|
Loading…
Reference in New Issue
Block a user