improve dependency injection params (#2034)

This commit is contained in:
Ju4tCode 2023-05-21 16:01:55 +08:00 committed by GitHub
parent dd80191761
commit aa48299d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 54 deletions

View File

@ -71,7 +71,12 @@ def Depends(
class DependParam(Param):
"""子依赖参数"""
"""子依赖注入参数。
本注入解析所有子依赖注入然后将它们的返回值作为参数值传递给父依赖
本注入应该具有最高优先级因此应该在其他参数之前检查
"""
def __repr__(self) -> str:
return f"Depends({self.extra['dependent']})"
@ -168,7 +173,12 @@ class DependParam(Param):
class BotParam(Param):
"""{ref}`nonebot.adapters.Bot` 参数"""
"""{ref}`nonebot.adapters.Bot` 注入参数。
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Bot` 及其子类或 `None` 的参数
为保证兼容性本注入还会解析名为 `bot` 且没有类型注解的参数
"""
def __repr__(self) -> str:
return (
@ -187,7 +197,7 @@ class BotParam(Param):
) -> Optional["BotParam"]:
from nonebot.adapters import Bot
if param.default == param.empty:
# param type is Bot(s) or subclass(es) of Bot or None
if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None
if param.annotation is not Bot:
@ -200,6 +210,7 @@ class BotParam(Param):
required=True,
)
return cls(Required, checker=checker)
# legacy: param is named "bot" and has no type annotation
elif param.annotation == param.empty and param.name == "bot":
return cls(Required)
@ -212,7 +223,12 @@ class BotParam(Param):
class EventParam(Param):
"""{ref}`nonebot.adapters.Event` 参数"""
"""{ref}`nonebot.adapters.Event` 注入参数
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Event` 及其子类或 `None` 的参数
为保证兼容性本注入还会解析名为 `event` 且没有类型注解的参数
"""
def __repr__(self) -> str:
return (
@ -231,7 +247,7 @@ class EventParam(Param):
) -> Optional["EventParam"]:
from nonebot.adapters import Event
if param.default == param.empty:
# param type is Event(s) or subclass(es) of Event or None
if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None
if param.annotation is not Event:
@ -244,6 +260,7 @@ class EventParam(Param):
required=True,
)
return cls(Required, checker=checker)
# legacy: param is named "event" and has no type annotation
elif param.annotation == param.empty and param.name == "event":
return cls(Required)
@ -256,7 +273,12 @@ class EventParam(Param):
class StateParam(Param):
"""事件处理状态参数"""
"""事件处理状态注入参数
本注入解析所有类型为 `T_State` 的参数
为保证兼容性本注入还会解析名为 `state` 且没有类型注解的参数
"""
def __repr__(self) -> str:
return "StateParam()"
@ -265,9 +287,10 @@ class StateParam(Param):
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]:
if param.default == param.empty:
# param type is T_State
if param.annotation is T_State:
return cls(Required)
# legacy: param is named "state" and has no type annotation
elif param.annotation == param.empty and param.name == "state":
return cls(Required)
@ -276,7 +299,12 @@ class StateParam(Param):
class MatcherParam(Param):
"""事件响应器实例参数"""
"""事件响应器实例注入参数
本注入解析所有类型为且仅为 {ref}`nonebot.matcher.Matcher` 及其子类或 `None` 的参数
为保证兼容性本注入还会解析名为 `matcher` 且没有类型注解的参数
"""
def __repr__(self) -> str:
return "MatcherParam()"
@ -287,9 +315,11 @@ class MatcherParam(Param):
) -> Optional["MatcherParam"]:
from nonebot.matcher import Matcher
if generic_check_issubclass(param.annotation, Matcher) or (
param.annotation == param.empty and param.name == "matcher"
):
# param type is Matcher(s) or subclass(es) of Matcher or None
if generic_check_issubclass(param.annotation, Matcher):
return cls(Required)
# legacy: param is named "matcher" and has no type annotation
elif param.annotation == param.empty and param.name == "matcher":
return cls(Required)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
@ -308,22 +338,28 @@ class ArgInner:
def Arg(key: Optional[str] = None) -> Any:
"""`got` 的 Arg 参数消息"""
"""Arg 参数消息"""
return ArgInner(key, "message")
def ArgStr(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息文本"""
"""Arg 参数消息文本"""
return ArgInner(key, "str") # type: ignore
def ArgPlainText(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息纯文本"""
"""Arg 参数消息纯文本"""
return ArgInner(key, "plaintext") # type: ignore
class ArgParam(Param):
"""`got` 的 Arg 参数"""
"""Arg 注入参数
本注入解析事件响应器操作 `got` 所获取的参数
可以通过 `Arg``ArgStr``ArgPlainText` 等函数参数 `key` 指定获取的参数
留空则会根据参数名称获取
"""
def __repr__(self) -> str:
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
@ -338,7 +374,8 @@ class ArgParam(Param):
)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
message = matcher.get_arg(self.extra["key"])
key: str = self.extra["key"]
message = matcher.get_arg(key)
if message is None:
return message
if self.extra["type"] == "message":
@ -350,7 +387,12 @@ class ArgParam(Param):
class ExceptionParam(Param):
"""`run_postprocessor` 的异常参数"""
"""{ref}`nonebot.message.run_postprocessor` 的异常注入参数
本注入解析所有类型为 `Exception` `None` 的参数
为保证兼容性本注入还会解析名为 `exception` 且没有类型注解的参数
"""
def __repr__(self) -> str:
return "ExceptionParam()"
@ -359,9 +401,11 @@ class ExceptionParam(Param):
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or (
param.annotation == param.empty and param.name == "exception"
):
# param type is Exception(s) or subclass(es) of Exception or None
if generic_check_issubclass(param.annotation, Exception):
return cls(Required)
# legacy: param is named "exception" and has no type annotation
elif param.annotation == param.empty and param.name == "exception":
return cls(Required)
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
@ -369,7 +413,12 @@ class ExceptionParam(Param):
class DefaultParam(Param):
"""默认值参数"""
"""默认值注入参数
本注入解析所有剩余未能解析且具有默认值的参数
本注入参数应该具有最低优先级因此应该在所有其他注入参数之后使用
"""
def __repr__(self) -> str:
return f"DefaultParam(default={self.default!r})"

View File

@ -59,7 +59,7 @@ def generic_check_issubclass(
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。
特别的如果 cls `typing.Union` `types.UnionType` 类型
则会检查其中的类型是否是 class_or_tuple 的一个类型子类None 会被忽略
则会检查其中的所有类型是否是 class_or_tuple 一个类型的子类或 None
"""
try:
return issubclass(cls, class_or_tuple)

View File

@ -0,0 +1,23 @@
from typing import Optional
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, Depends
from nonebot.adapters import Bot, Event, Message
def dependency():
return 1
async def complex_priority(
sub: int = Depends(dependency),
bot: Optional[Bot] = None,
event: Optional[Event] = None,
state: T_State = {},
matcher: Optional[Matcher] = None,
arg: Message = Arg(),
exception: Optional[Exception] = None,
default: int = 1,
):
...

View File

@ -4,6 +4,7 @@ import pytest
from nonebug import App
from nonebot.matcher import Matcher
from nonebot.dependencies import Dependent
from nonebot.exception import TypeMisMatch
from utils import make_fake_event, make_fake_message
from nonebot.params import (
@ -413,3 +414,41 @@ async def test_default(app: App):
async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx:
ctx.should_return(1)
@pytest.mark.asyncio
async def test_priority():
from plugins.param.priority import complex_priority
dependent = Dependent.parse(
call=complex_priority,
allow_types=[
DependParam,
BotParam,
EventParam,
StateParam,
MatcherParam,
ArgParam,
ExceptionParam,
DefaultParam,
],
)
for param in dependent.params:
if param.name == "sub":
assert isinstance(param.field_info, DependParam)
elif param.name == "bot":
assert isinstance(param.field_info, BotParam)
elif param.name == "event":
assert isinstance(param.field_info, EventParam)
elif param.name == "state":
assert isinstance(param.field_info, StateParam)
elif param.name == "matcher":
assert isinstance(param.field_info, MatcherParam)
elif param.name == "arg":
assert isinstance(param.field_info, ArgParam)
elif param.name == "exception":
assert isinstance(param.field_info, ExceptionParam)
elif param.name == "default":
assert isinstance(param.field_info, DefaultParam)
else:
raise ValueError(f"unknown param {param.name}")