🚧 add generator dependency support

This commit is contained in:
yanyongyu 2021-11-15 01:28:47 +08:00
parent 0a1ae75b70
commit cafe5c9af0
5 changed files with 75 additions and 17 deletions

View File

@ -40,6 +40,11 @@ class Driver(abc.ABC):
:类型: ``Set[T_BotDisconnectionHook]`` :类型: ``Set[T_BotDisconnectionHook]``
:说明: Bot 连接断开时执行的函数 :说明: Bot 连接断开时执行的函数
""" """
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
"""
:类型: ``Dict[Callable[..., Any], Callable[..., Any]]``
:说明: Depends 函数的替换表
"""
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
""" """

View File

@ -1,15 +1,17 @@
import inspect import inspect
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from .models import Dependent from .models import Dependent
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from .models import Depends as DependsClass from .models import Depends as DependsClass
from nonebot.utils import run_sync, is_coroutine_callable
from .utils import (generic_get_types, get_typed_signature, from .utils import (generic_get_types, get_typed_signature,
generic_check_issubclass) generic_check_issubclass)
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
@ -95,11 +97,12 @@ async def solve_dependencies(
bot: Bot, bot: Bot,
event: Event, event: Event,
state: T_State, state: T_State,
matcher: "Matcher", matcher: Optional["Matcher"],
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None, sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None, dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]: ) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
@ -108,7 +111,7 @@ async def solve_dependencies(
for sub_dependent in chain(sub_dependents or tuple(), for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies): dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]], sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key) sub_dependent.cache_key)
func = sub_dependent.func func = sub_dependent.func
@ -158,6 +161,15 @@ async def solve_dependencies(
# run dependency function # run dependency function
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key] solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await stack.enter_async_context(cm)
elif is_coroutine_callable(func): elif is_coroutine_callable(func):
solved = await func(**sub_values) solved = await func(**sub_values)
else: else:

View File

@ -6,6 +6,7 @@
""" """
import asyncio import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
from nonebot.log import logger from nonebot.log import logger
@ -37,15 +38,15 @@ class Handler:
self.name = get_name(func) if name is None else name self.name = get_name(func) if name is None else name
self.dependencies = dependencies or [] self.dependencies = dependencies or []
self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {} self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
if dependencies: if dependencies:
for depends in dependencies: for depends in dependencies:
if not depends.dependency: if not depends.dependency:
raise ValueError(f"{depends} has no dependency") raise ValueError(f"{depends} has no dependency")
if (depends.dependency,) in self.sub_dependents: if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies") raise ValueError(f"{depends} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=depends) sub_dependant = get_parameterless_sub_dependant(depends=depends)
self.sub_dependents[(depends.dependency,)] = sub_dependant self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func) self.dependent = get_dependent(func=func)
@ -60,19 +61,29 @@ class Handler:
def __str__(self) -> str: def __str__(self) -> str:
return repr(self) return repr(self)
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event", async def __call__(
state: T_State): self,
matcher: "Matcher",
bot: "Bot",
event: "Event",
state: T_State,
*,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> Any:
values, _, ignored = await solve_dependencies( values, _, ignored = await solve_dependencies(
dependent=self.dependent, dependent=self.dependent,
bot=bot, bot=bot,
event=event, event=event,
state=state, state=state,
matcher=matcher, matcher=matcher,
stack=stack,
sub_dependents=[ sub_dependents=[
self.sub_dependents[(dependency.dependency,)] # type: ignore self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies for dependency in self.dependencies
], ],
dependency_overrides_provider=self.dependency_overrides_provider) dependency_overrides_provider=self.dependency_overrides_provider,
dependency_cache=dependency_cache)
if ignored: if ignored:
return return
@ -101,7 +112,7 @@ class Handler:
if (dependency.dependency,) in self.sub_dependents: if (dependency.dependency,) in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies") raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(depends=dependency) sub_dependant = get_parameterless_sub_dependant(depends=dependency)
self.sub_dependents[(dependency.dependency,)] = sub_dependant self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends): def prepend_dependency(self, dependency: Depends):
self.cache_dependent(dependency) self.cache_dependent(dependency)
@ -114,7 +125,7 @@ class Handler:
def remove_dependency(self, dependency: Depends): def remove_dependency(self, dependency: Depends):
if not dependency.dependency: if not dependency.dependency:
raise ValueError(f"{dependency} has no dependency") raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents: if dependency.dependency in self.sub_dependents:
del self.sub_dependents[(dependency.dependency,)] del self.sub_dependents[dependency.dependency]
if dependency in self.dependencies: if dependency in self.dependencies:
self.dependencies.remove(dependency) self.dependencies.remove(dependency)

View File

@ -45,4 +45,4 @@ class Dependent:
self.matcher_param_name = matcher_param_name self.matcher_param_name = matcher_param_name
self.dependencies = dependencies or [] self.dependencies = dependencies or []
self.use_cache = use_cache self.use_cache = use_cache
self.cache_key = (self.func,) self.cache_key = self.func

View File

@ -5,13 +5,16 @@ import inspect
import dataclasses import dataclasses
from functools import wraps, partial from functools import wraps, partial
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from typing import Any, TypeVar, Callable, Optional, Awaitable from contextlib import asynccontextmanager
from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator,
ContextManager)
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T")
def escape_tag(s: str) -> str: def escape_tag(s: str) -> str:
@ -40,6 +43,20 @@ def is_coroutine_callable(func: Callable[..., Any]) -> bool:
return inspect.iscoroutinefunction(func_) return inspect.iscoroutinefunction(func_)
def is_gen_callable(func: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(func):
return True
func_ = getattr(func, "__call__", None)
return inspect.isgeneratorfunction(func_)
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(func):
return True
func_ = getattr(func, "__call__", None)
return inspect.isasyncgenfunction(func_)
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
""" """
:说明: :说明:
@ -65,6 +82,19 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
return _wrapper return _wrapper
@asynccontextmanager
async def run_sync_ctx_manager(
cm: ContextManager[T],) -> AsyncGenerator[T, None]:
try:
yield await run_sync(cm.__enter__)()
except Exception as e:
ok = await run_sync(cm.__exit__)(type(e), e, None)
if not ok:
raise e
else:
await run_sync(cm.__exit__)(None, None, None)
def get_name(obj: Any) -> str: def get_name(obj: Any) -> str:
if inspect.isfunction(obj) or inspect.isclass(obj): if inspect.isfunction(obj) or inspect.isclass(obj):
return obj.__name__ return obj.__name__