mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
🚧 add generator dependency support
This commit is contained in:
parent
0a1ae75b70
commit
cafe5c9af0
@ -40,6 +40,11 @@ class Driver(abc.ABC):
|
|||||||
:类型: ``Set[T_BotDisconnectionHook]``
|
:类型: ``Set[T_BotDisconnectionHook]``
|
||||||
:说明: Bot 连接断开时执行的函数
|
:说明: Bot 连接断开时执行的函数
|
||||||
"""
|
"""
|
||||||
|
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
|
||||||
|
"""
|
||||||
|
:类型: ``Dict[Callable[..., Any], Callable[..., Any]]``
|
||||||
|
:说明: Depends 函数的替换表
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
"""
|
"""
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
||||||
|
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||||
|
|
||||||
from .models import Dependent
|
from .models import Dependent
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from .models import Depends as DependsClass
|
from .models import Depends as DependsClass
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable
|
|
||||||
from .utils import (generic_get_types, get_typed_signature,
|
from .utils import (generic_get_types, get_typed_signature,
|
||||||
generic_check_issubclass)
|
generic_check_issubclass)
|
||||||
|
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
|
||||||
|
is_async_gen_callable, is_coroutine_callable)
|
||||||
|
|
||||||
|
|
||||||
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
|
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
|
||||||
@ -95,11 +97,12 @@ async def solve_dependencies(
|
|||||||
bot: Bot,
|
bot: Bot,
|
||||||
event: Event,
|
event: Event,
|
||||||
state: T_State,
|
state: T_State,
|
||||||
matcher: "Matcher",
|
matcher: Optional["Matcher"],
|
||||||
|
stack: Optional[AsyncExitStack] = None,
|
||||||
sub_dependents: Optional[List[Dependent]] = None,
|
sub_dependents: Optional[List[Dependent]] = None,
|
||||||
dependency_overrides_provider: Optional[Any] = None,
|
dependency_overrides_provider: Optional[Any] = None,
|
||||||
dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None,
|
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
|
||||||
) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any], bool]:
|
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
dependency_cache = dependency_cache or {}
|
dependency_cache = dependency_cache or {}
|
||||||
|
|
||||||
@ -108,7 +111,7 @@ async def solve_dependencies(
|
|||||||
for sub_dependent in chain(sub_dependents or tuple(),
|
for sub_dependent in chain(sub_dependents or tuple(),
|
||||||
dependent.dependencies):
|
dependent.dependencies):
|
||||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||||
sub_dependent.cache_key = cast(Tuple[Callable[..., Any]],
|
sub_dependent.cache_key = cast(Callable[..., Any],
|
||||||
sub_dependent.cache_key)
|
sub_dependent.cache_key)
|
||||||
func = sub_dependent.func
|
func = sub_dependent.func
|
||||||
|
|
||||||
@ -158,6 +161,15 @@ async def solve_dependencies(
|
|||||||
# run dependency function
|
# run dependency function
|
||||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependent.cache_key]
|
solved = dependency_cache[sub_dependent.cache_key]
|
||||||
|
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||||
|
assert isinstance(
|
||||||
|
stack, AsyncExitStack
|
||||||
|
), "Generator dependency should be called in context"
|
||||||
|
if is_gen_callable(func):
|
||||||
|
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
||||||
|
else:
|
||||||
|
cm = asynccontextmanager(func)(**sub_values)
|
||||||
|
solved = await stack.enter_async_context(cm)
|
||||||
elif is_coroutine_callable(func):
|
elif is_coroutine_callable(func):
|
||||||
solved = await func(**sub_values)
|
solved = await func(**sub_values)
|
||||||
else:
|
else:
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
@ -37,15 +38,15 @@ class Handler:
|
|||||||
self.name = get_name(func) if name is None else name
|
self.name = get_name(func) if name is None else name
|
||||||
|
|
||||||
self.dependencies = dependencies or []
|
self.dependencies = dependencies or []
|
||||||
self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {}
|
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
|
||||||
if dependencies:
|
if dependencies:
|
||||||
for depends in dependencies:
|
for depends in dependencies:
|
||||||
if not depends.dependency:
|
if not depends.dependency:
|
||||||
raise ValueError(f"{depends} has no dependency")
|
raise ValueError(f"{depends} has no dependency")
|
||||||
if (depends.dependency,) in self.sub_dependents:
|
if depends.dependency in self.sub_dependents:
|
||||||
raise ValueError(f"{depends} is already in dependencies")
|
raise ValueError(f"{depends} is already in dependencies")
|
||||||
sub_dependant = get_parameterless_sub_dependant(depends=depends)
|
sub_dependant = get_parameterless_sub_dependant(depends=depends)
|
||||||
self.sub_dependents[(depends.dependency,)] = sub_dependant
|
self.sub_dependents[depends.dependency] = sub_dependant
|
||||||
self.dependency_overrides_provider = dependency_overrides_provider
|
self.dependency_overrides_provider = dependency_overrides_provider
|
||||||
self.dependent = get_dependent(func=func)
|
self.dependent = get_dependent(func=func)
|
||||||
|
|
||||||
@ -60,19 +61,29 @@ class Handler:
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return repr(self)
|
return repr(self)
|
||||||
|
|
||||||
async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event",
|
async def __call__(
|
||||||
state: T_State):
|
self,
|
||||||
|
matcher: "Matcher",
|
||||||
|
bot: "Bot",
|
||||||
|
event: "Event",
|
||||||
|
state: T_State,
|
||||||
|
*,
|
||||||
|
stack: Optional[AsyncExitStack] = None,
|
||||||
|
dependency_cache: Optional[Dict[Callable[..., Any],
|
||||||
|
Any]] = None) -> Any:
|
||||||
values, _, ignored = await solve_dependencies(
|
values, _, ignored = await solve_dependencies(
|
||||||
dependent=self.dependent,
|
dependent=self.dependent,
|
||||||
bot=bot,
|
bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
state=state,
|
state=state,
|
||||||
matcher=matcher,
|
matcher=matcher,
|
||||||
|
stack=stack,
|
||||||
sub_dependents=[
|
sub_dependents=[
|
||||||
self.sub_dependents[(dependency.dependency,)] # type: ignore
|
self.sub_dependents[dependency.dependency] # type: ignore
|
||||||
for dependency in self.dependencies
|
for dependency in self.dependencies
|
||||||
],
|
],
|
||||||
dependency_overrides_provider=self.dependency_overrides_provider)
|
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||||
|
dependency_cache=dependency_cache)
|
||||||
|
|
||||||
if ignored:
|
if ignored:
|
||||||
return
|
return
|
||||||
@ -101,7 +112,7 @@ class Handler:
|
|||||||
if (dependency.dependency,) in self.sub_dependents:
|
if (dependency.dependency,) in self.sub_dependents:
|
||||||
raise ValueError(f"{dependency} is already in dependencies")
|
raise ValueError(f"{dependency} is already in dependencies")
|
||||||
sub_dependant = get_parameterless_sub_dependant(depends=dependency)
|
sub_dependant = get_parameterless_sub_dependant(depends=dependency)
|
||||||
self.sub_dependents[(dependency.dependency,)] = sub_dependant
|
self.sub_dependents[dependency.dependency] = sub_dependant
|
||||||
|
|
||||||
def prepend_dependency(self, dependency: Depends):
|
def prepend_dependency(self, dependency: Depends):
|
||||||
self.cache_dependent(dependency)
|
self.cache_dependent(dependency)
|
||||||
@ -114,7 +125,7 @@ class Handler:
|
|||||||
def remove_dependency(self, dependency: Depends):
|
def remove_dependency(self, dependency: Depends):
|
||||||
if not dependency.dependency:
|
if not dependency.dependency:
|
||||||
raise ValueError(f"{dependency} has no dependency")
|
raise ValueError(f"{dependency} has no dependency")
|
||||||
if (dependency.dependency,) in self.sub_dependents:
|
if dependency.dependency in self.sub_dependents:
|
||||||
del self.sub_dependents[(dependency.dependency,)]
|
del self.sub_dependents[dependency.dependency]
|
||||||
if dependency in self.dependencies:
|
if dependency in self.dependencies:
|
||||||
self.dependencies.remove(dependency)
|
self.dependencies.remove(dependency)
|
||||||
|
@ -45,4 +45,4 @@ class Dependent:
|
|||||||
self.matcher_param_name = matcher_param_name
|
self.matcher_param_name = matcher_param_name
|
||||||
self.dependencies = dependencies or []
|
self.dependencies = dependencies or []
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.cache_key = (self.func,)
|
self.cache_key = self.func
|
||||||
|
@ -5,13 +5,16 @@ import inspect
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import wraps, partial
|
from functools import wraps, partial
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
from typing import Any, TypeVar, Callable, Optional, Awaitable
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator,
|
||||||
|
ContextManager)
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def escape_tag(s: str) -> str:
|
def escape_tag(s: str) -> str:
|
||||||
@ -40,6 +43,20 @@ def is_coroutine_callable(func: Callable[..., Any]) -> bool:
|
|||||||
return inspect.iscoroutinefunction(func_)
|
return inspect.iscoroutinefunction(func_)
|
||||||
|
|
||||||
|
|
||||||
|
def is_gen_callable(func: Callable[..., Any]) -> bool:
|
||||||
|
if inspect.isgeneratorfunction(func):
|
||||||
|
return True
|
||||||
|
func_ = getattr(func, "__call__", None)
|
||||||
|
return inspect.isgeneratorfunction(func_)
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
|
||||||
|
if inspect.isasyncgenfunction(func):
|
||||||
|
return True
|
||||||
|
func_ = getattr(func, "__call__", None)
|
||||||
|
return inspect.isasyncgenfunction(func_)
|
||||||
|
|
||||||
|
|
||||||
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -65,6 +82,19 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
|||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def run_sync_ctx_manager(
|
||||||
|
cm: ContextManager[T],) -> AsyncGenerator[T, None]:
|
||||||
|
try:
|
||||||
|
yield await run_sync(cm.__enter__)()
|
||||||
|
except Exception as e:
|
||||||
|
ok = await run_sync(cm.__exit__)(type(e), e, None)
|
||||||
|
if not ok:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_sync(cm.__exit__)(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def get_name(obj: Any) -> str:
|
def get_name(obj: Any) -> str:
|
||||||
if inspect.isfunction(obj) or inspect.isclass(obj):
|
if inspect.isfunction(obj) or inspect.isclass(obj):
|
||||||
return obj.__name__
|
return obj.__name__
|
||||||
|
Loading…
Reference in New Issue
Block a user