🚧 add generator dependency support

This commit is contained in:
yanyongyu 2021-11-15 01:28:47 +08:00
parent 0a1ae75b70
commit cafe5c9af0
5 changed files with 75 additions and 17 deletions

View File

@ -40,6 +40,11 @@ class Driver(abc.ABC):
:类型: ``Set[T_BotDisconnectionHook]``
:说明: Bot 连接断开时执行的函数
"""
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
"""
:类型: ``Dict[Callable[..., Any], Callable[..., Any]]``
:说明: Depends 函数的替换表
"""
def __init__(self, env: Env, config: Config):
"""

View File

@ -1,15 +1,17 @@
import inspect
from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from .models import Dependent
from nonebot.log import logger
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass
from nonebot.utils import run_sync, is_coroutine_callable
from .utils import (generic_get_types, get_typed_signature,
generic_check_issubclass)
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
@ -95,11 +97,12 @@ async def solve_dependencies(
bot: Bot,
event: Event,
state: T_State,
matcher: "Matcher",
matcher: Optional["Matcher"],
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {}
@ -108,7 +111,7 @@ async def solve_dependencies(
for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key)
func = sub_dependent.func
@ -158,6 +161,15 @@ async def solve_dependencies(
# run dependency function
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:

View File

@ -6,6 +6,7 @@
"""
import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger
@ -37,15 +38,15 @@ class Handler:
self.name = get_name(func) if name is None else name
self.dependencies = dependencies or []
self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {}
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
if dependencies:
for depends in dependencies:
if not depends.dependency:
raise ValueError(f"{depends} has no dependency")
if (depends.dependency,) in self.sub_dependents:
if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=depends)
self.sub_dependents[(depends.dependency,)] = sub_dependant
self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func)
@ -60,19 +61,29 @@ class Handler:
def __str__(self) -> str:
return repr(self)
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
state: T_State):
async def __call__(
self,
matcher: "Matcher",
bot: "Bot",
event: "Event",
state: T_State,
*,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> Any:
values, _, ignored = await solve_dependencies(
dependent=self.dependent,
bot=bot,
event=event,
state=state,
matcher=matcher,
stack=stack,
sub_dependents=[
self.sub_dependents[(dependency.dependency,)] # type: ignore
self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies
],
dependency_overrides_provider=self.dependency_overrides_provider)
dependency_overrides_provider=self.dependency_overrides_provider,
dependency_cache=dependency_cache)
if ignored:
return
@ -101,7 +112,7 @@ class Handler:
if (dependency.dependency,) in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=dependency)
self.sub_dependents[(dependency.dependency,)] = sub_dependant
self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends):
self.cache_dependent(dependency)
@ -114,7 +125,7 @@ class Handler:
def remove_dependency(self, dependency: Depends):
if not dependency.dependency:
raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents:
del self.sub_dependents[(dependency.dependency,)]
if dependency.dependency in self.sub_dependents:
del self.sub_dependents[dependency.dependency]
if dependency in self.dependencies:
self.dependencies.remove(dependency)

View File

@ -45,4 +45,4 @@ class Dependent:
self.matcher_param_name = matcher_param_name
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = (self.func,)
self.cache_key = self.func

View File

@ -5,13 +5,16 @@ import inspect
import dataclasses
from functools import wraps, partial
from typing_extensions import ParamSpec
from typing import Any, TypeVar, Callable, Optional, Awaitable
from contextlib import asynccontextmanager
from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator,
ContextManager)
from nonebot.log import logger
from nonebot.typing import overrides
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def escape_tag(s: str) -> str:
@ -40,6 +43,20 @@ def is_coroutine_callable(func: Callable[..., Any]) -> bool:
return inspect.iscoroutinefunction(func_)
def is_gen_callable(func: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(func):
return True
func_ = getattr(func, "__call__", None)
return inspect.isgeneratorfunction(func_)
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(func):
return True
func_ = getattr(func, "__call__", None)
return inspect.isasyncgenfunction(func_)
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
"""
:说明:
@ -65,6 +82,19 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
return _wrapper
@asynccontextmanager
async def run_sync_ctx_manager(
cm: ContextManager[T],) -> AsyncGenerator[T, None]:
try:
yield await run_sync(cm.__enter__)()
except Exception as e:
ok = await run_sync(cm.__exit__)(type(e), e, None)
if not ok:
raise e
else:
await run_sync(cm.__exit__)(None, None, None)
def get_name(obj: Any) -> str:
if inspect.isfunction(obj) or inspect.isclass(obj):
return obj.__name__