diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index d08f6b0c..74c42ff6 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -40,6 +40,11 @@ class Driver(abc.ABC): :类型: ``Set[T_BotDisconnectionHook]`` :说明: Bot 连接断开时执行的函数 """ + dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {} + """ + :类型: ``Dict[Callable[..., Any], Callable[..., Any]]`` + :说明: Depends 函数的替换表 + """ def __init__(self, env: Env, config: Config): """ diff --git a/nonebot/processor/__init__.py b/nonebot/processor/__init__.py index 70689631..8b344b01 100644 --- a/nonebot/processor/__init__.py +++ b/nonebot/processor/__init__.py @@ -1,15 +1,17 @@ import inspect from itertools import chain from typing import Any, Dict, List, Tuple, Callable, Optional, cast +from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from .models import Dependent from nonebot.log import logger from nonebot.typing import T_State from nonebot.adapters import Bot, Event from .models import Depends as DependsClass -from nonebot.utils import run_sync, is_coroutine_callable from .utils import (generic_get_types, get_typed_signature, generic_check_issubclass) +from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, + is_async_gen_callable, is_coroutine_callable) def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: @@ -95,11 +97,12 @@ async def solve_dependencies( bot: Bot, event: Event, state: T_State, - matcher: "Matcher", + matcher: Optional["Matcher"], + stack: Optional[AsyncExitStack] = None, sub_dependents: Optional[List[Dependent]] = None, dependency_overrides_provider: Optional[Any] = None, - dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None, -) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]: + dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, +) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]: values: Dict[str, Any] = {} dependency_cache = dependency_cache or {} @@ -108,7 +111,7 @@ async def solve_dependencies( for sub_dependent in chain(sub_dependents or tuple(), dependent.dependencies): sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) - sub_dependent.cache_key = cast(Tuple[Callable[..., Any]], + sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key) func = sub_dependent.func @@ -158,6 +161,15 @@ async def solve_dependencies( # run dependency function if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: solved = dependency_cache[sub_dependent.cache_key] + elif is_gen_callable(func) or is_async_gen_callable(func): + assert isinstance( + stack, AsyncExitStack + ), "Generator dependency should be called in context" + if is_gen_callable(func): + cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) + else: + cm = asynccontextmanager(func)(**sub_values) + solved = await stack.enter_async_context(cm) elif is_coroutine_callable(func): solved = await func(**sub_values) else: diff --git a/nonebot/processor/handler.py b/nonebot/processor/handler.py index 0aafc8c9..57e38af5 100644 --- a/nonebot/processor/handler.py +++ b/nonebot/processor/handler.py @@ -6,6 +6,7 @@ """ import asyncio +from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional from nonebot.log import logger @@ -37,15 +38,15 @@ class Handler: self.name = get_name(func) if name is None else name self.dependencies = dependencies or [] - self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {} + self.sub_dependents: Dict[Callable[..., Any], Dependent] = {} if dependencies: for depends in dependencies: if not depends.dependency: raise ValueError(f"{depends} has no dependency") - if (depends.dependency,) in self.sub_dependents: + if depends.dependency in self.sub_dependents: raise ValueError(f"{depends} is already in dependencies") sub_dependant = get_parameterless_sub_dependant(depends=depends) - self.sub_dependents[(depends.dependency,)] = sub_dependant + self.sub_dependents[depends.dependency] = sub_dependant self.dependency_overrides_provider = dependency_overrides_provider self.dependent = get_dependent(func=func) @@ -60,19 +61,29 @@ class Handler: def __str__(self) -> str: return repr(self) - async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event", - state: T_State): + async def __call__( + self, + matcher: "Matcher", + bot: "Bot", + event: "Event", + state: T_State, + *, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[Dict[Callable[..., Any], + Any]] = None) -> Any: values, _, ignored = await solve_dependencies( dependent=self.dependent, bot=bot, event=event, state=state, matcher=matcher, + stack=stack, sub_dependents=[ - self.sub_dependents[(dependency.dependency,)] # type: ignore + self.sub_dependents[dependency.dependency] # type: ignore for dependency in self.dependencies ], - dependency_overrides_provider=self.dependency_overrides_provider) + dependency_overrides_provider=self.dependency_overrides_provider, + dependency_cache=dependency_cache) if ignored: return @@ -101,7 +112,7 @@ class Handler: if (dependency.dependency,) in self.sub_dependents: raise ValueError(f"{dependency} is already in dependencies") sub_dependant = get_parameterless_sub_dependant(depends=dependency) - self.sub_dependents[(dependency.dependency,)] = sub_dependant + self.sub_dependents[dependency.dependency] = sub_dependant def prepend_dependency(self, dependency: Depends): self.cache_dependent(dependency) @@ -114,7 +125,7 @@ class Handler: def remove_dependency(self, dependency: Depends): if not dependency.dependency: raise ValueError(f"{dependency} has no dependency") - if (dependency.dependency,) in self.sub_dependents: - del self.sub_dependents[(dependency.dependency,)] + if dependency.dependency in self.sub_dependents: + del self.sub_dependents[dependency.dependency] if dependency in self.dependencies: self.dependencies.remove(dependency) diff --git a/nonebot/processor/models.py b/nonebot/processor/models.py index 0e3cd7e4..9413fb8a 100644 --- a/nonebot/processor/models.py +++ b/nonebot/processor/models.py @@ -45,4 +45,4 @@ class Dependent: self.matcher_param_name = matcher_param_name self.dependencies = dependencies or [] self.use_cache = use_cache - self.cache_key = (self.func,) + self.cache_key = self.func diff --git a/nonebot/utils.py b/nonebot/utils.py index f7d78457..2786ef1d 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -5,13 +5,16 @@ import inspect import dataclasses from functools import wraps, partial from typing_extensions import ParamSpec -from typing import Any, TypeVar, Callable, Optional, Awaitable +from contextlib import asynccontextmanager +from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator, + ContextManager) from nonebot.log import logger from nonebot.typing import overrides P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") def escape_tag(s: str) -> str: @@ -40,6 +43,20 @@ def is_coroutine_callable(func: Callable[..., Any]) -> bool: return inspect.iscoroutinefunction(func_) +def is_gen_callable(func: Callable[..., Any]) -> bool: + if inspect.isgeneratorfunction(func): + return True + func_ = getattr(func, "__call__", None) + return inspect.isgeneratorfunction(func_) + + +def is_async_gen_callable(func: Callable[..., Any]) -> bool: + if inspect.isasyncgenfunction(func): + return True + func_ = getattr(func, "__call__", None) + return inspect.isasyncgenfunction(func_) + + def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: """ :说明: @@ -65,6 +82,19 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: return _wrapper +@asynccontextmanager +async def run_sync_ctx_manager( + cm: ContextManager[T],) -> AsyncGenerator[T, None]: + try: + yield await run_sync(cm.__enter__)() + except Exception as e: + ok = await run_sync(cm.__exit__)(type(e), e, None) + if not ok: + raise e + else: + await run_sync(cm.__exit__)(None, None, None) + + def get_name(obj: Any) -> str: if inspect.isfunction(obj) or inspect.isclass(obj): return obj.__name__