2021-11-12 20:55:59 +08:00
|
|
|
import inspect
|
2021-11-13 19:38:01 +08:00
|
|
|
from itertools import chain
|
|
|
|
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
2021-11-15 01:28:47 +08:00
|
|
|
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
2021-11-12 20:55:59 +08:00
|
|
|
|
|
|
|
from .models import Dependent
|
2021-11-14 18:51:23 +08:00
|
|
|
from nonebot.log import logger
|
2021-11-13 19:38:01 +08:00
|
|
|
from nonebot.typing import T_State
|
2021-11-12 20:55:59 +08:00
|
|
|
from nonebot.adapters import Bot, Event
|
2021-11-13 19:38:01 +08:00
|
|
|
from .models import Depends as DependsClass
|
|
|
|
from .utils import (generic_get_types, get_typed_signature,
|
|
|
|
generic_check_issubclass)
|
2021-11-15 01:28:47 +08:00
|
|
|
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
|
|
|
|
is_async_gen_callable, is_coroutine_callable)
|
2021-11-12 20:55:59 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent:
|
2021-11-13 19:38:01 +08:00
|
|
|
depends: DependsClass = param.default
|
2021-11-12 20:55:59 +08:00
|
|
|
if depends.dependency:
|
|
|
|
dependency = depends.dependency
|
|
|
|
else:
|
|
|
|
dependency = param.annotation
|
|
|
|
return get_sub_dependant(
|
|
|
|
depends=depends,
|
|
|
|
dependency=dependency,
|
|
|
|
name=param.name,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-11-13 19:38:01 +08:00
|
|
|
def get_parameterless_sub_dependant(*, depends: DependsClass) -> Dependent:
|
2021-11-12 20:55:59 +08:00
|
|
|
assert callable(
|
|
|
|
depends.dependency
|
|
|
|
), "A parameter-less dependency must have a callable dependency"
|
|
|
|
return get_sub_dependant(depends=depends, dependency=depends.dependency)
|
|
|
|
|
|
|
|
|
|
|
|
def get_sub_dependant(
|
|
|
|
*,
|
2021-11-13 19:38:01 +08:00
|
|
|
depends: DependsClass,
|
2021-11-12 20:55:59 +08:00
|
|
|
dependency: Callable[..., Any],
|
|
|
|
name: Optional[str] = None,
|
|
|
|
) -> Dependent:
|
|
|
|
sub_dependant = get_dependent(
|
|
|
|
func=dependency,
|
|
|
|
name=name,
|
|
|
|
use_cache=depends.use_cache,
|
|
|
|
)
|
|
|
|
return sub_dependant
|
|
|
|
|
|
|
|
|
|
|
|
def get_dependent(*,
|
|
|
|
func: Callable[..., Any],
|
|
|
|
name: Optional[str] = None,
|
|
|
|
use_cache: bool = True) -> Dependent:
|
|
|
|
signature = get_typed_signature(func)
|
|
|
|
params = signature.parameters
|
|
|
|
dependent = Dependent(func=func, name=name, use_cache=use_cache)
|
|
|
|
for param_name, param in params.items():
|
2021-11-13 19:38:01 +08:00
|
|
|
if isinstance(param.default, DependsClass):
|
2021-11-12 20:55:59 +08:00
|
|
|
sub_dependent = get_param_sub_dependent(param=param)
|
|
|
|
dependent.dependencies.append(sub_dependent)
|
|
|
|
continue
|
|
|
|
|
|
|
|
if generic_check_issubclass(param.annotation, Bot):
|
2021-11-13 19:38:01 +08:00
|
|
|
if dependent.bot_param_name is not None:
|
|
|
|
raise ValueError(f"{func} has more than one Bot parameter: "
|
|
|
|
f"{dependent.bot_param_name} / {param_name}")
|
2021-11-12 20:55:59 +08:00
|
|
|
dependent.bot_param_name = param_name
|
2021-11-13 19:38:01 +08:00
|
|
|
dependent.bot_param_type = generic_get_types(param.annotation)
|
2021-11-12 20:55:59 +08:00
|
|
|
elif generic_check_issubclass(param.annotation, Event):
|
2021-11-13 19:38:01 +08:00
|
|
|
if dependent.event_param_name is not None:
|
|
|
|
raise ValueError(f"{func} has more than one Event parameter: "
|
|
|
|
f"{dependent.event_param_name} / {param_name}")
|
2021-11-12 20:55:59 +08:00
|
|
|
dependent.event_param_name = param_name
|
2021-11-13 19:38:01 +08:00
|
|
|
dependent.event_param_type = generic_get_types(param.annotation)
|
2021-11-14 18:51:23 +08:00
|
|
|
elif generic_check_issubclass(param.annotation, Dict):
|
2021-11-13 19:38:01 +08:00
|
|
|
if dependent.state_param_name is not None:
|
|
|
|
raise ValueError(f"{func} has more than one State parameter: "
|
|
|
|
f"{dependent.state_param_name} / {param_name}")
|
2021-11-12 20:55:59 +08:00
|
|
|
dependent.state_param_name = param_name
|
|
|
|
elif generic_check_issubclass(param.annotation, Matcher):
|
2021-11-13 19:38:01 +08:00
|
|
|
if dependent.matcher_param_name is not None:
|
|
|
|
raise ValueError(
|
|
|
|
f"{func} has more than one Matcher parameter: "
|
|
|
|
f"{dependent.matcher_param_name} / {param_name}")
|
2021-11-12 20:55:59 +08:00
|
|
|
dependent.matcher_param_name = param_name
|
2021-11-13 19:38:01 +08:00
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown parameter {param_name} with type {param.annotation}")
|
2021-11-12 20:55:59 +08:00
|
|
|
|
|
|
|
return dependent
|
|
|
|
|
|
|
|
|
2021-11-13 19:38:01 +08:00
|
|
|
async def solve_dependencies(
|
|
|
|
*,
|
|
|
|
dependent: Dependent,
|
|
|
|
bot: Bot,
|
|
|
|
event: Event,
|
|
|
|
state: T_State,
|
2021-11-15 01:28:47 +08:00
|
|
|
matcher: Optional["Matcher"],
|
|
|
|
stack: Optional[AsyncExitStack] = None,
|
2021-11-13 19:38:01 +08:00
|
|
|
sub_dependents: Optional[List[Dependent]] = None,
|
|
|
|
dependency_overrides_provider: Optional[Any] = None,
|
2021-11-15 01:28:47 +08:00
|
|
|
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
|
|
|
|
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
|
2021-11-13 19:38:01 +08:00
|
|
|
values: Dict[str, Any] = {}
|
|
|
|
dependency_cache = dependency_cache or {}
|
|
|
|
|
|
|
|
# solve sub dependencies
|
2021-11-14 01:34:25 +08:00
|
|
|
sub_dependent: Dependent
|
|
|
|
for sub_dependent in chain(sub_dependents or tuple(),
|
2021-11-13 19:38:01 +08:00
|
|
|
dependent.dependencies):
|
2021-11-14 01:34:25 +08:00
|
|
|
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
2021-11-15 01:28:47 +08:00
|
|
|
sub_dependent.cache_key = cast(Callable[..., Any],
|
2021-11-14 01:34:25 +08:00
|
|
|
sub_dependent.cache_key)
|
|
|
|
func = sub_dependent.func
|
|
|
|
|
|
|
|
# check bot and event type
|
|
|
|
if sub_dependent.bot_param_type and not isinstance(
|
|
|
|
bot, sub_dependent.bot_param_type):
|
2021-11-14 18:51:23 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Matcher {matcher} bot type {type(bot)} not match depends {func} "
|
|
|
|
f"annotation {sub_dependent.bot_param_type}, ignored")
|
2021-11-14 01:34:25 +08:00
|
|
|
return values, dependency_cache, True
|
|
|
|
elif sub_dependent.event_param_type and not isinstance(
|
|
|
|
event, sub_dependent.event_param_type):
|
2021-11-14 18:51:23 +08:00
|
|
|
logger.debug(
|
|
|
|
f"Matcher {matcher} event type {type(event)} not match depends {func} "
|
|
|
|
f"annotation {sub_dependent.event_param_type}, ignored")
|
2021-11-14 01:34:25 +08:00
|
|
|
return values, dependency_cache, True
|
2021-11-13 19:38:01 +08:00
|
|
|
|
|
|
|
# dependency overrides
|
2021-11-14 01:34:25 +08:00
|
|
|
use_sub_dependant = sub_dependent
|
2021-11-13 19:38:01 +08:00
|
|
|
if (dependency_overrides_provider and
|
|
|
|
hasattr(dependency_overrides_provider, "dependency_overrides")):
|
2021-11-14 01:34:25 +08:00
|
|
|
original_call = sub_dependent.func
|
2021-11-13 19:38:01 +08:00
|
|
|
func = getattr(dependency_overrides_provider,
|
|
|
|
"dependency_overrides",
|
|
|
|
{}).get(original_call, original_call)
|
|
|
|
use_sub_dependant = get_dependent(
|
|
|
|
func=func,
|
2021-11-14 01:34:25 +08:00
|
|
|
name=sub_dependent.name,
|
2021-11-13 19:38:01 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
# solve sub dependency with current cache
|
|
|
|
solved_result = await solve_dependencies(
|
|
|
|
dependent=use_sub_dependant,
|
|
|
|
bot=bot,
|
|
|
|
event=event,
|
|
|
|
state=state,
|
|
|
|
matcher=matcher,
|
|
|
|
dependency_overrides_provider=dependency_overrides_provider,
|
|
|
|
dependency_cache=dependency_cache,
|
|
|
|
)
|
2021-11-14 01:34:25 +08:00
|
|
|
sub_values, sub_dependency_cache, ignored = solved_result
|
|
|
|
if ignored:
|
|
|
|
return values, dependency_cache, True
|
2021-11-13 19:38:01 +08:00
|
|
|
# update cache?
|
|
|
|
dependency_cache.update(sub_dependency_cache)
|
|
|
|
|
|
|
|
# run dependency function
|
2021-11-14 01:34:25 +08:00
|
|
|
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
|
|
|
solved = dependency_cache[sub_dependent.cache_key]
|
2021-11-15 01:28:47 +08:00
|
|
|
elif is_gen_callable(func) or is_async_gen_callable(func):
|
|
|
|
assert isinstance(
|
|
|
|
stack, AsyncExitStack
|
|
|
|
), "Generator dependency should be called in context"
|
|
|
|
if is_gen_callable(func):
|
|
|
|
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
|
|
|
else:
|
|
|
|
cm = asynccontextmanager(func)(**sub_values)
|
|
|
|
solved = await stack.enter_async_context(cm)
|
2021-11-13 19:38:01 +08:00
|
|
|
elif is_coroutine_callable(func):
|
|
|
|
solved = await func(**sub_values)
|
|
|
|
else:
|
|
|
|
solved = await run_sync(func)(**sub_values)
|
|
|
|
|
|
|
|
# parameter dependency
|
2021-11-14 01:34:25 +08:00
|
|
|
if sub_dependent.name is not None:
|
|
|
|
values[sub_dependent.name] = solved
|
2021-11-13 19:38:01 +08:00
|
|
|
# save current dependency to cache
|
2021-11-14 01:34:25 +08:00
|
|
|
if sub_dependent.cache_key not in dependency_cache:
|
|
|
|
dependency_cache[sub_dependent.cache_key] = solved
|
2021-11-13 19:38:01 +08:00
|
|
|
|
|
|
|
# usual dependency
|
|
|
|
if dependent.bot_param_name is not None:
|
|
|
|
values[dependent.bot_param_name] = bot
|
|
|
|
if dependent.event_param_name is not None:
|
|
|
|
values[dependent.event_param_name] = event
|
|
|
|
if dependent.state_param_name is not None:
|
|
|
|
values[dependent.state_param_name] = state
|
|
|
|
if dependent.matcher_param_name is not None:
|
|
|
|
values[dependent.matcher_param_name] = matcher
|
2021-11-14 01:34:25 +08:00
|
|
|
return values, dependency_cache, False
|
2021-11-13 19:38:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
def Depends(dependency: Optional[Callable[..., Any]] = None,
|
|
|
|
*,
|
|
|
|
use_cache: bool = True) -> Any:
|
|
|
|
return DependsClass(dependency=dependency, use_cache=use_cache)
|
|
|
|
|
|
|
|
|
2021-11-12 20:55:59 +08:00
|
|
|
from .handler import Handler as Handler
|
|
|
|
from .matcher import Matcher as Matcher
|
2021-11-13 19:38:01 +08:00
|
|
|
from .matcher import matchers as matchers
|