improve dependency injection params (#2034)

This commit is contained in:
Ju4tCode 2023-05-21 16:01:55 +08:00 committed by GitHub
parent dd80191761
commit aa48299d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 54 deletions

View File

@ -71,7 +71,12 @@ def Depends(
class DependParam(Param): class DependParam(Param):
"""子依赖参数""" """子依赖注入参数。
本注入解析所有子依赖注入然后将它们的返回值作为参数值传递给父依赖
本注入应该具有最高优先级因此应该在其他参数之前检查
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Depends({self.extra['dependent']})" return f"Depends({self.extra['dependent']})"
@ -168,7 +173,12 @@ class DependParam(Param):
class BotParam(Param): class BotParam(Param):
"""{ref}`nonebot.adapters.Bot` 参数""" """{ref}`nonebot.adapters.Bot` 注入参数。
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Bot` 及其子类或 `None` 的参数
为保证兼容性本注入还会解析名为 `bot` 且没有类型注解的参数
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -187,21 +197,22 @@ class BotParam(Param):
) -> Optional["BotParam"]: ) -> Optional["BotParam"]:
from nonebot.adapters import Bot from nonebot.adapters import Bot
if param.default == param.empty: # param type is Bot(s) or subclass(es) of Bot or None
if generic_check_issubclass(param.annotation, Bot): if generic_check_issubclass(param.annotation, Bot):
checker: Optional[ModelField] = None checker: Optional[ModelField] = None
if param.annotation is not Bot: if param.annotation is not Bot:
checker = ModelField( checker = ModelField(
name=param.name, name=param.name,
type_=param.annotation, type_=param.annotation,
class_validators=None, class_validators=None,
model_config=CustomConfig, model_config=CustomConfig,
default=None, default=None,
required=True, required=True,
) )
return cls(Required, checker=checker) return cls(Required, checker=checker)
elif param.annotation == param.empty and param.name == "bot": # legacy: param is named "bot" and has no type annotation
return cls(Required) elif param.annotation == param.empty and param.name == "bot":
return cls(Required)
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot return bot
@ -212,7 +223,12 @@ class BotParam(Param):
class EventParam(Param): class EventParam(Param):
"""{ref}`nonebot.adapters.Event` 参数""" """{ref}`nonebot.adapters.Event` 注入参数
本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Event` 及其子类或 `None` 的参数
为保证兼容性本注入还会解析名为 `event` 且没有类型注解的参数
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -231,21 +247,22 @@ class EventParam(Param):
) -> Optional["EventParam"]: ) -> Optional["EventParam"]:
from nonebot.adapters import Event from nonebot.adapters import Event
if param.default == param.empty: # param type is Event(s) or subclass(es) of Event or None
if generic_check_issubclass(param.annotation, Event): if generic_check_issubclass(param.annotation, Event):
checker: Optional[ModelField] = None checker: Optional[ModelField] = None
if param.annotation is not Event: if param.annotation is not Event:
checker = ModelField( checker = ModelField(
name=param.name, name=param.name,
type_=param.annotation, type_=param.annotation,
class_validators=None, class_validators=None,
model_config=CustomConfig, model_config=CustomConfig,
default=None, default=None,
required=True, required=True,
) )
return cls(Required, checker=checker) return cls(Required, checker=checker)
elif param.annotation == param.empty and param.name == "event": # legacy: param is named "event" and has no type annotation
return cls(Required) elif param.annotation == param.empty and param.name == "event":
return cls(Required)
async def _solve(self, event: "Event", **kwargs: Any) -> Any: async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event return event
@ -256,7 +273,12 @@ class EventParam(Param):
class StateParam(Param): class StateParam(Param):
"""事件处理状态参数""" """事件处理状态注入参数
本注入解析所有类型为 `T_State` 的参数
为保证兼容性本注入还会解析名为 `state` 且没有类型注解的参数
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return "StateParam()" return "StateParam()"
@ -265,18 +287,24 @@ class StateParam(Param):
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]: ) -> Optional["StateParam"]:
if param.default == param.empty: # param type is T_State
if param.annotation is T_State: if param.annotation is T_State:
return cls(Required) return cls(Required)
elif param.annotation == param.empty and param.name == "state": # legacy: param is named "state" and has no type annotation
return cls(Required) elif param.annotation == param.empty and param.name == "state":
return cls(Required)
async def _solve(self, state: T_State, **kwargs: Any) -> Any: async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state return state
class MatcherParam(Param): class MatcherParam(Param):
"""事件响应器实例参数""" """事件响应器实例注入参数
本注入解析所有类型为且仅为 {ref}`nonebot.matcher.Matcher` 及其子类或 `None` 的参数
为保证兼容性本注入还会解析名为 `matcher` 且没有类型注解的参数
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return "MatcherParam()" return "MatcherParam()"
@ -287,9 +315,11 @@ class MatcherParam(Param):
) -> Optional["MatcherParam"]: ) -> Optional["MatcherParam"]:
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
if generic_check_issubclass(param.annotation, Matcher) or ( # param type is Matcher(s) or subclass(es) of Matcher or None
param.annotation == param.empty and param.name == "matcher" if generic_check_issubclass(param.annotation, Matcher):
): return cls(Required)
# legacy: param is named "matcher" and has no type annotation
elif param.annotation == param.empty and param.name == "matcher":
return cls(Required) return cls(Required)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
@ -308,22 +338,28 @@ class ArgInner:
def Arg(key: Optional[str] = None) -> Any: def Arg(key: Optional[str] = None) -> Any:
"""`got` 的 Arg 参数消息""" """Arg 参数消息"""
return ArgInner(key, "message") return ArgInner(key, "message")
def ArgStr(key: Optional[str] = None) -> str: def ArgStr(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息文本""" """Arg 参数消息文本"""
return ArgInner(key, "str") # type: ignore return ArgInner(key, "str") # type: ignore
def ArgPlainText(key: Optional[str] = None) -> str: def ArgPlainText(key: Optional[str] = None) -> str:
"""`got` 的 Arg 参数消息纯文本""" """Arg 参数消息纯文本"""
return ArgInner(key, "plaintext") # type: ignore return ArgInner(key, "plaintext") # type: ignore
class ArgParam(Param): class ArgParam(Param):
"""`got` 的 Arg 参数""" """Arg 注入参数
本注入解析事件响应器操作 `got` 所获取的参数
可以通过 `Arg``ArgStr``ArgPlainText` 等函数参数 `key` 指定获取的参数
留空则会根据参数名称获取
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
@ -338,7 +374,8 @@ class ArgParam(Param):
) )
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
message = matcher.get_arg(self.extra["key"]) key: str = self.extra["key"]
message = matcher.get_arg(key)
if message is None: if message is None:
return message return message
if self.extra["type"] == "message": if self.extra["type"] == "message":
@ -350,7 +387,12 @@ class ArgParam(Param):
class ExceptionParam(Param): class ExceptionParam(Param):
"""`run_postprocessor` 的异常参数""" """{ref}`nonebot.message.run_postprocessor` 的异常注入参数
本注入解析所有类型为 `Exception` `None` 的参数
为保证兼容性本注入还会解析名为 `exception` 且没有类型注解的参数
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return "ExceptionParam()" return "ExceptionParam()"
@ -359,9 +401,11 @@ class ExceptionParam(Param):
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]: ) -> Optional["ExceptionParam"]:
if generic_check_issubclass(param.annotation, Exception) or ( # param type is Exception(s) or subclass(es) of Exception or None
param.annotation == param.empty and param.name == "exception" if generic_check_issubclass(param.annotation, Exception):
): return cls(Required)
# legacy: param is named "exception" and has no type annotation
elif param.annotation == param.empty and param.name == "exception":
return cls(Required) return cls(Required)
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
@ -369,7 +413,12 @@ class ExceptionParam(Param):
class DefaultParam(Param): class DefaultParam(Param):
"""默认值参数""" """默认值注入参数
本注入解析所有剩余未能解析且具有默认值的参数
本注入参数应该具有最低优先级因此应该在所有其他注入参数之后使用
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return f"DefaultParam(default={self.default!r})" return f"DefaultParam(default={self.default!r})"

View File

@ -59,7 +59,7 @@ def generic_check_issubclass(
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。 """检查 cls 是否是 class_or_tuple 中的一个类型子类。
特别的如果 cls `typing.Union` `types.UnionType` 类型 特别的如果 cls `typing.Union` `types.UnionType` 类型
则会检查其中的类型是否是 class_or_tuple 的一个类型子类None 会被忽略 则会检查其中的所有类型是否是 class_or_tuple 一个类型的子类或 None
""" """
try: try:
return issubclass(cls, class_or_tuple) return issubclass(cls, class_or_tuple)

View File

@ -0,0 +1,23 @@
from typing import Optional
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, Depends
from nonebot.adapters import Bot, Event, Message
def dependency():
return 1
async def complex_priority(
sub: int = Depends(dependency),
bot: Optional[Bot] = None,
event: Optional[Event] = None,
state: T_State = {},
matcher: Optional[Matcher] = None,
arg: Message = Arg(),
exception: Optional[Exception] = None,
default: int = 1,
):
...

View File

@ -4,6 +4,7 @@ import pytest
from nonebug import App from nonebug import App
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.dependencies import Dependent
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from utils import make_fake_event, make_fake_message from utils import make_fake_event, make_fake_message
from nonebot.params import ( from nonebot.params import (
@ -413,3 +414,41 @@ async def test_default(app: App):
async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx: async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx:
ctx.should_return(1) ctx.should_return(1)
@pytest.mark.asyncio
async def test_priority():
from plugins.param.priority import complex_priority
dependent = Dependent.parse(
call=complex_priority,
allow_types=[
DependParam,
BotParam,
EventParam,
StateParam,
MatcherParam,
ArgParam,
ExceptionParam,
DefaultParam,
],
)
for param in dependent.params:
if param.name == "sub":
assert isinstance(param.field_info, DependParam)
elif param.name == "bot":
assert isinstance(param.field_info, BotParam)
elif param.name == "event":
assert isinstance(param.field_info, EventParam)
elif param.name == "state":
assert isinstance(param.field_info, StateParam)
elif param.name == "matcher":
assert isinstance(param.field_info, MatcherParam)
elif param.name == "arg":
assert isinstance(param.field_info, ArgParam)
elif param.name == "exception":
assert isinstance(param.field_info, ExceptionParam)
elif param.name == "default":
assert isinstance(param.field_info, DefaultParam)
else:
raise ValueError(f"unknown param {param.name}")