nonebot2/tests/plugins/param/param_depend.py

132 lines
2.4 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from typing import Annotated
import anyio
from pydantic import Field
2021-12-20 00:28:02 +08:00
from nonebot import on_message
from nonebot.adapters import Bot
2021-12-20 00:28:02 +08:00
from nonebot.params import Depends
test_depends = on_message()
runned = []
def dependency():
runned.append(1)
return 1
def parameterless():
assert len(runned) == 0
runned.append(1)
def gen_sync():
yield 1
async def gen_async():
yield 2
@dataclass
class ClassDependency:
x: int = Depends(gen_sync) # noqa: RUF009
y: int = Depends(gen_async) # noqa: RUF009
class FooBot(Bot): ...
async def sub_bot(b: FooBot) -> FooBot:
return b
2021-12-20 00:28:02 +08:00
# test parameterless
@test_depends.handle(parameterless=[Depends(parameterless)])
async def depends(x: int = Depends(dependency)):
# test dependency
return x
@test_depends.handle()
async def depends_cache(y: int = Depends(dependency, use_cache=True)):
# test cache
return y
# test class dependency
async def class_depend(c: ClassDependency = Depends()):
return c
# test annotated dependency
async def annotated_depend(x: Annotated[int, Depends(dependency)]):
return x
# test annotated class dependency
async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]):
return c
# test dependency priority
async def annotated_prior_depend(
x: Annotated[int, Depends(lambda: 2)] = Depends(dependency),
):
return x
async def annotated_multi_depend(
x: Annotated[Annotated[int, Depends(lambda: 2)], Depends(dependency)],
):
return x
# test sub dependency type mismatch
async def sub_type_mismatch(b: FooBot = Depends(sub_bot)):
return b
# test type validate
async def validate(x: int = Depends(lambda: "1", validate=True)):
return x
async def validate_fail(x: int = Depends(lambda: "not_number", validate=True)):
return x
# test FieldInfo validate
async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))):
return x
async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
return x
async def _dep():
await anyio.sleep(1)
return 1
def _dep_mismatch():
return 1
async def cache_exception_func1(
dep: int = Depends(_dep),
mismatch: dict = Depends(_dep_mismatch),
):
raise RuntimeError("Never reach here")
async def cache_exception_func2(
dep: int = Depends(_dep),
match: int = Depends(_dep_mismatch),
):
return dep