mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
🚧 add generator dependency support
This commit is contained in:
parent
0a1ae75b70
commit
cafe5c9af0
@ -40,6 +40,11 @@ class Driver(abc.ABC):
|
||||
:类型: ``Set[T_BotDisconnectionHook]``
|
||||
:说明: Bot 连接断开时执行的函数
|
||||
"""
|
||||
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||
"""
|
||||
:类型: ``Dict[Callable[..., Any], Callable[..., Any]]``
|
||||
:说明: Depends 函数的替换表
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, config: Config):
|
||||
"""
|
||||
|
@ -1,15 +1,17 @@
|
||||
import inspect
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
|
||||
from .models import Dependent
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.adapters import Bot, Event
|
||||
from .models import Depends as DependsClass
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
from .utils import (generic_get_types, get_typed_signature,
|
||||
generic_check_issubclass)
|
||||
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
|
||||
is_async_gen_callable, is_coroutine_callable)
|
||||
|
||||
|
||||
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
|
||||
@ -95,11 +97,12 @@ async def solve_dependencies(
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
state: T_State,
|
||||
matcher: "Matcher",
|
||||
matcher: Optional["Matcher"],
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
sub_dependents: Optional[List[Dependent]] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
|
||||
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
|
||||
values: Dict[str, Any] = {}
|
||||
dependency_cache = dependency_cache or {}
|
||||
|
||||
@ -108,7 +111,7 @@ async def solve_dependencies(
|
||||
for sub_dependent in chain(sub_dependents or tuple(),
|
||||
dependent.dependencies):
|
||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
|
||||
sub_dependent.cache_key = cast(Callable[..., Any],
|
||||
sub_dependent.cache_key)
|
||||
func = sub_dependent.func
|
||||
|
||||
@ -158,6 +161,15 @@ async def solve_dependencies(
|
||||
# run dependency function
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(func):
|
||||
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(func)(**sub_values)
|
||||
solved = await stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
else:
|
||||
|
@ -6,6 +6,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
|
||||
|
||||
from nonebot.log import logger
|
||||
@ -37,15 +38,15 @@ class Handler:
|
||||
self.name = get_name(func) if name is None else name
|
||||
|
||||
self.dependencies = dependencies or []
|
||||
self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {}
|
||||
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
|
||||
if dependencies:
|
||||
for depends in dependencies:
|
||||
if not depends.dependency:
|
||||
raise ValueError(f"{depends} has no dependency")
|
||||
if (depends.dependency,) in self.sub_dependents:
|
||||
if depends.dependency in self.sub_dependents:
|
||||
raise ValueError(f"{depends} is already in dependencies")
|
||||
sub_dependant = get_parameterless_sub_dependant(depends=depends)
|
||||
self.sub_dependents[(depends.dependency,)] = sub_dependant
|
||||
self.sub_dependents[depends.dependency] = sub_dependant
|
||||
self.dependency_overrides_provider = dependency_overrides_provider
|
||||
self.dependent = get_dependent(func=func)
|
||||
|
||||
@ -60,19 +61,29 @@ class Handler:
|
||||
def __str__(self) -> str:
|
||||
return repr(self)
|
||||
|
||||
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
|
||||
state: T_State):
|
||||
async def __call__(
|
||||
self,
|
||||
matcher: "Matcher",
|
||||
bot: "Bot",
|
||||
event: "Event",
|
||||
state: T_State,
|
||||
*,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None) -> Any:
|
||||
values, _, ignored = await solve_dependencies(
|
||||
dependent=self.dependent,
|
||||
bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
matcher=matcher,
|
||||
stack=stack,
|
||||
sub_dependents=[
|
||||
self.sub_dependents[(dependency.dependency,)] # type: ignore
|
||||
self.sub_dependents[dependency.dependency] # type: ignore
|
||||
for dependency in self.dependencies
|
||||
],
|
||||
dependency_overrides_provider=self.dependency_overrides_provider)
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache)
|
||||
|
||||
if ignored:
|
||||
return
|
||||
@ -101,7 +112,7 @@ class Handler:
|
||||
if (dependency.dependency,) in self.sub_dependents:
|
||||
raise ValueError(f"{dependency} is already in dependencies")
|
||||
sub_dependant = get_parameterless_sub_dependant(depends=dependency)
|
||||
self.sub_dependents[(dependency.dependency,)] = sub_dependant
|
||||
self.sub_dependents[dependency.dependency] = sub_dependant
|
||||
|
||||
def prepend_dependency(self, dependency: Depends):
|
||||
self.cache_dependent(dependency)
|
||||
@ -114,7 +125,7 @@ class Handler:
|
||||
def remove_dependency(self, dependency: Depends):
|
||||
if not dependency.dependency:
|
||||
raise ValueError(f"{dependency} has no dependency")
|
||||
if (dependency.dependency,) in self.sub_dependents:
|
||||
del self.sub_dependents[(dependency.dependency,)]
|
||||
if dependency.dependency in self.sub_dependents:
|
||||
del self.sub_dependents[dependency.dependency]
|
||||
if dependency in self.dependencies:
|
||||
self.dependencies.remove(dependency)
|
||||
|
@ -45,4 +45,4 @@ class Dependent:
|
||||
self.matcher_param_name = matcher_param_name
|
||||
self.dependencies = dependencies or []
|
||||
self.use_cache = use_cache
|
||||
self.cache_key = (self.func,)
|
||||
self.cache_key = self.func
|
||||
|
@ -5,13 +5,16 @@ import inspect
|
||||
import dataclasses
|
||||
from functools import wraps, partial
|
||||
from typing_extensions import ParamSpec
|
||||
from typing import Any, TypeVar, Callable, Optional, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator,
|
||||
ContextManager)
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def escape_tag(s: str) -> str:
|
||||
@ -40,6 +43,20 @@ def is_coroutine_callable(func: Callable[..., Any]) -> bool:
|
||||
return inspect.iscoroutinefunction(func_)
|
||||
|
||||
|
||||
def is_gen_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(func):
|
||||
return True
|
||||
func_ = getattr(func, "__call__", None)
|
||||
return inspect.isgeneratorfunction(func_)
|
||||
|
||||
|
||||
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(func):
|
||||
return True
|
||||
func_ = getattr(func, "__call__", None)
|
||||
return inspect.isasyncgenfunction(func_)
|
||||
|
||||
|
||||
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -65,6 +82,19 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
return _wrapper
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_sync_ctx_manager(
|
||||
cm: ContextManager[T],) -> AsyncGenerator[T, None]:
|
||||
try:
|
||||
yield await run_sync(cm.__enter__)()
|
||||
except Exception as e:
|
||||
ok = await run_sync(cm.__exit__)(type(e), e, None)
|
||||
if not ok:
|
||||
raise e
|
||||
else:
|
||||
await run_sync(cm.__exit__)(None, None, None)
|
||||
|
||||
|
||||
def get_name(obj: Any) -> str:
|
||||
if inspect.isfunction(obj) or inspect.isclass(obj):
|
||||
return obj.__name__
|
||||
|
Loading…
Reference in New Issue
Block a user