diff --git a/nonebot/message.py b/nonebot/message.py index b853e481..f03a7c6d 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Set, Type, Optional from nonebot.log import logger from nonebot.rule import TrieRule from nonebot.utils import escape_tag -from nonebot.matcher import Matcher, matchers +from nonebot.processor import Matcher, matchers from nonebot.exception import NoLogException, StopPropagation, IgnoredException from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor, T_EventPreProcessor, T_EventPostProcessor) diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index e4da8345..46a5ecfe 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -5,10 +5,9 @@ from types import ModuleType from typing import (TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union, Optional) -from nonebot.handler import Handler -from nonebot.matcher import Matcher from .manager import _current_plugin from nonebot.permission import Permission +from nonebot.processor import Handler, Matcher from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, endswith, startswith, shell_command) diff --git a/nonebot/plugin/on.pyi b/nonebot/plugin/on.pyi index 68e8ad62..fbbf9c90 100644 --- a/nonebot/plugin/on.pyi +++ b/nonebot/plugin/on.pyi @@ -1,10 +1,9 @@ import re from typing import Set, List, Type, Tuple, Union, Optional -from nonebot.handler import Handler -from nonebot.matcher import Matcher from nonebot.permission import Permission from nonebot.rule import Rule, ArgumentParser +from nonebot.processor import Handler, Matcher from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory diff --git a/nonebot/plugin/plugin.py b/nonebot/plugin/plugin.py index ee5a7c51..bf7773ac 100644 --- a/nonebot/plugin/plugin.py +++ b/nonebot/plugin/plugin.py @@ -3,7 +3,7 @@ from dataclasses import field, dataclass from typing import Set, Dict, Type, Optional from .export import Export -from nonebot.matcher import Matcher +from nonebot.processor import Matcher plugins: Dict[str, "Plugin"] = {} """ diff --git a/nonebot/plugins/single_session.py b/nonebot/plugins/single_session.py index 8a8b3cfb..1b6d4116 100644 --- a/nonebot/plugins/single_session.py +++ b/nonebot/plugins/single_session.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from nonebot.typing import T_State -from nonebot.matcher import Matcher +from nonebot.processor import Matcher from nonebot.adapters import Bot, Event from nonebot.message import (IgnoredException, run_preprocessor, run_postprocessor) diff --git a/nonebot/processor/__init__.py b/nonebot/processor/__init__.py index 73b86bbc..e4f0a09c 100644 --- a/nonebot/processor/__init__.py +++ b/nonebot/processor/__init__.py @@ -1,14 +1,18 @@ import inspect -from typing import Any, Callable, Optional +from itertools import chain +from typing import Any, Dict, List, Tuple, Callable, Optional, cast from .models import Dependent -from .models import Depends as Depends +from nonebot.typing import T_State from nonebot.adapters import Bot, Event -from .utils import get_typed_signature, generic_check_issubclass +from .models import Depends as DependsClass +from nonebot.utils import run_sync, is_coroutine_callable +from .utils import (generic_get_types, get_typed_signature, + generic_check_issubclass) def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: - depends: Depends = param.default + depends: DependsClass = param.default if depends.dependency: dependency = depends.dependency else: @@ -20,7 +24,7 @@ def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: ) -def get_parameterless_sub_dependant(*, depends: Depends) -> Dependent: +def get_parameterless_sub_dependant(*, depends: DependsClass) -> Dependent: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" @@ -29,7 +33,7 @@ def get_parameterless_sub_dependant(*, depends: Depends) -> Dependent: def get_sub_dependant( *, - depends: Depends, + depends: DependsClass, dependency: Callable[..., Any], name: Optional[str] = None, ) -> Dependent: @@ -49,29 +53,124 @@ def get_dependent(*, params = signature.parameters dependent = Dependent(func=func, name=name, use_cache=use_cache) for param_name, param in params.items(): - if isinstance(param.default, Depends): + if isinstance(param.default, DependsClass): sub_dependent = get_param_sub_dependent(param=param) dependent.dependencies.append(sub_dependent) continue if generic_check_issubclass(param.annotation, Bot): + if dependent.bot_param_name is not None: + raise ValueError(f"{func} has more than one Bot parameter: " + f"{dependent.bot_param_name} / {param_name}") dependent.bot_param_name = param_name - continue + dependent.bot_param_type = generic_get_types(param.annotation) elif generic_check_issubclass(param.annotation, Event): + if dependent.event_param_name is not None: + raise ValueError(f"{func} has more than one Event parameter: " + f"{dependent.event_param_name} / {param_name}") dependent.event_param_name = param_name - continue + dependent.event_param_type = generic_get_types(param.annotation) elif generic_check_issubclass(param.annotation, dict): + if dependent.state_param_name is not None: + raise ValueError(f"{func} has more than one State parameter: " + f"{dependent.state_param_name} / {param_name}") dependent.state_param_name = param_name - continue elif generic_check_issubclass(param.annotation, Matcher): + if dependent.matcher_param_name is not None: + raise ValueError( + f"{func} has more than one Matcher parameter: " + f"{dependent.matcher_param_name} / {param_name}") dependent.matcher_param_name = param_name - continue - - raise ValueError( - f"Unknown parameter {param_name} with type {param.annotation}") + else: + raise ValueError( + f"Unknown parameter {param_name} with type {param.annotation}") return dependent +async def solve_dependencies( + *, + dependent: Dependent, + bot: Bot, + event: Event, + state: T_State, + matcher: "Matcher", + sub_dependents: Optional[List[Dependent]] = None, + dependency_overrides_provider: Optional[Any] = None, + dependency_cache: Optional[Dict[Tuple[Callable[..., Any]], Any]] = None, +) -> Tuple[Dict[str, Any], Dict[Tuple[Callable[..., Any]], Any]]: + values: Dict[str, Any] = {} + dependency_cache = dependency_cache or {} + + # solve sub dependencies + sub_dependant: Dependent + for sub_dependant in chain(sub_dependents or tuple(), + dependent.dependencies): + sub_dependant.func = cast(Callable[..., Any], sub_dependant.func) + sub_dependant.cache_key = cast(Tuple[Callable[..., Any]], + sub_dependant.cache_key) + func = sub_dependant.func + + # dependency overrides + use_sub_dependant = sub_dependant + if (dependency_overrides_provider and + hasattr(dependency_overrides_provider, "dependency_overrides")): + original_call = sub_dependant.func + func = getattr(dependency_overrides_provider, + "dependency_overrides", + {}).get(original_call, original_call) + use_sub_dependant = get_dependent( + func=func, + name=sub_dependant.name, + ) + + # solve sub dependency with current cache + solved_result = await solve_dependencies( + dependent=use_sub_dependant, + bot=bot, + event=event, + state=state, + matcher=matcher, + dependency_overrides_provider=dependency_overrides_provider, + dependency_cache=dependency_cache, + ) + sub_values, sub_dependency_cache = solved_result + # update cache? + dependency_cache.update(sub_dependency_cache) + + # run dependency function + if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: + solved = dependency_cache[sub_dependant.cache_key] + elif is_coroutine_callable(func): + solved = await func(**sub_values) + else: + solved = await run_sync(func)(**sub_values) + + # parameter dependency + if sub_dependant.name is not None: + values[sub_dependant.name] = solved + # save current dependency to cache + if sub_dependant.cache_key not in dependency_cache: + dependency_cache[sub_dependant.cache_key] = solved + + # usual dependency + if dependent.bot_param_name is not None: + values[dependent.bot_param_name] = bot + if dependent.event_param_name is not None: + values[dependent.event_param_name] = event + if dependent.state_param_name is not None: + values[dependent.state_param_name] = state + if dependent.matcher_param_name is not None: + values[dependent.matcher_param_name] = matcher + return values, dependency_cache + + +def Depends(dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True) -> Any: + return DependsClass(dependency=dependency, use_cache=use_cache) + + from .handler import Handler as Handler from .matcher import Matcher as Matcher +from .matcher import matchers as matchers diff --git a/nonebot/processor/handler.py b/nonebot/processor/handler.py index 66366cb7..03e8685b 100644 --- a/nonebot/processor/handler.py +++ b/nonebot/processor/handler.py @@ -4,12 +4,14 @@ 该模块实现事件处理函数的封装,以实现动态参数等功能。 """ -from typing import TYPE_CHECKING, List, Optional -from .models import Depends -from nonebot.utils import get_name +import asyncio +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional + +from .models import Depends, Dependent +from nonebot.utils import get_name, run_sync from nonebot.typing import T_State, T_Handler -from . import get_dependent, get_parameterless_sub_dependant +from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant if TYPE_CHECKING: from .matcher import Matcher @@ -23,7 +25,8 @@ class Handler: func: T_Handler, *, name: Optional[str] = None, - dependencies: Optional[List[Depends]] = None): + dependencies: Optional[List[Depends]] = None, + dependency_overrides_provider: Optional[Any] = None): """装饰事件处理函数以便根据动态参数运行""" self.func: T_Handler = func """ @@ -33,11 +36,57 @@ class Handler: self.name = get_name(func) if name is None else name self.dependencies = dependencies or [] + self.sub_dependents: Dict[Tuple[Callable[..., Any]], Dependent] = {} + if dependencies: + for depends in dependencies: + if not depends.dependency: + raise ValueError(f"{depends} has no dependency") + if (depends.dependency,) in self.sub_dependents: + raise ValueError(f"{depends} is already in dependencies") + sub_dependant = get_parameterless_sub_dependant(depends=depends) + self.sub_dependents[(depends.dependency,)] = sub_dependant + self.dependency_overrides_provider = dependency_overrides_provider self.dependent = get_dependent(func=func) - for depends in self.dependencies[::-1]: - self.dependent.dependencies.insert( - 0, get_parameterless_sub_dependant(depends=depends)) - def __call__(self, bot: Bot, event: Event, state: T_State, - matcher: "Matcher"): - ... + async def __call__(self, matcher: "Matcher", bot: Bot, event: Event, + state: T_State): + values, _ = await solve_dependencies( + dependent=self.dependent, + bot=bot, + event=event, + state=state, + matcher=matcher, + sub_dependents=[ + self.sub_dependents[(dependency.dependency,)] # type: ignore + for dependency in self.dependencies + ], + dependency_overrides_provider=self.dependency_overrides_provider) + + if asyncio.iscoroutinefunction(self.func): + await self.func(**values) + else: + await run_sync(self.func)(**values) + + def cache_dependent(self, dependency: Depends): + if not dependency.dependency: + raise ValueError(f"{dependency} has no dependency") + if (dependency.dependency,) in self.sub_dependents: + raise ValueError(f"{dependency} is already in dependencies") + sub_dependant = get_parameterless_sub_dependant(depends=dependency) + self.sub_dependents[(dependency.dependency,)] = sub_dependant + + def prepend_dependency(self, dependency: Depends): + self.cache_dependent(dependency) + self.dependencies.insert(0, dependency) + + def append_dependency(self, dependency: Depends): + self.cache_dependent(dependency) + self.dependencies.append(dependency) + + def remove_dependency(self, dependency: Depends): + if not dependency.dependency: + raise ValueError(f"{dependency} has no dependency") + if (dependency.dependency,) in self.sub_dependents: + del self.sub_dependents[(dependency.dependency,)] + if dependency in self.dependencies: + self.dependencies.remove(dependency) diff --git a/nonebot/processor/matcher.py b/nonebot/processor/matcher.py index 256c787d..612b6d36 100644 --- a/nonebot/processor/matcher.py +++ b/nonebot/processor/matcher.py @@ -5,7 +5,6 @@ 该模块实现事件响应器的创建与运行,并提供一些快捷方法来帮助用户更好的与机器人进行对话 。 """ -from functools import wraps from types import ModuleType from datetime import datetime from contextvars import ContextVar @@ -13,8 +12,10 @@ from collections import defaultdict from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, NoReturn, Optional) +from .models import Depends from .handler import Handler from nonebot.rule import Rule +from nonebot import get_driver from nonebot.log import logger from nonebot.adapters import MessageTemplate from nonebot.permission import USER, Permission @@ -228,8 +229,8 @@ class Matcher(metaclass=MatcherMeta): "permission": permission or Permission(), "handlers": [ - handler - if isinstance(handler, Handler) else Handler(handler) + handler if isinstance(handler, Handler) else Handler( + handler, dependency_overrides_provider=get_driver()) for handler in handlers ] if handlers else [], "temp": @@ -343,8 +344,12 @@ class Matcher(metaclass=MatcherMeta): return func @classmethod - def append_handler(cls, handler: T_Handler) -> Handler: - handler_ = Handler(handler) + def append_handler(cls, + handler: T_Handler, + dependencies: Optional[List[Depends]] = None) -> Handler: + handler_ = Handler(handler, + dependencies=dependencies, + dependency_overrides_provider=get_driver()) cls.handlers.append(handler_) return handler_ @@ -378,22 +383,19 @@ class Matcher(metaclass=MatcherMeta): * 无 """ - async def _receive(bot: "Bot", event: "Event") -> NoReturn: - raise PausedException - - if cls.handlers: - # 已有前置handlers则接受一条新的消息,否则视为接收初始消息 - receive_handler = cls.append_handler(_receive) - else: - receive_handler = None - def _decorator(func: T_Handler) -> T_Handler: - if not cls.handlers or cls.handlers[-1] is not func: - func_handler = cls.append_handler(func) - if receive_handler: - receive_handler.update_signature( - bot=func_handler.bot_type, - event=func_handler.event_type) + + async def _receive() -> NoReturn: + func_handler.remove_dependency(depend) + raise PausedException + + depend = Depends(_receive) + if cls.handlers and cls.handlers[-1].func is func: + func_handler = cls.handlers[-1] + func_handler.prepend_dependency(depend) + else: + func_handler = cls.append_handler( + func, dependencies=[depend] if cls.handlers else []) return func @@ -419,54 +421,42 @@ class Matcher(metaclass=MatcherMeta): * ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数,空则使用默认解析函数 """ - async def _key_getter(bot: "Bot", event: "Event", state: T_State): - state["_current_key"] = key - if key not in state: - if prompt is not None: - if isinstance(prompt, MessageTemplate): - _prompt = prompt.format(**state) - else: - _prompt = prompt - await bot.send(event=event, message=_prompt) - raise PausedException - else: - state["_skip_key"] = True - - async def _key_parser(bot: "Bot", event: "Event", state: T_State): - if key in state and state.get("_skip_key"): - del state["_skip_key"] - return - parser = args_parser or cls._default_parser - if parser: - # parser = cast(T_ArgsParser["Bot", "Event"], parser) - await parser(bot, event, state) - else: - state[state["_current_key"]] = str(event.get_message()) - - getter_handler = cls.append_handler(_key_getter) - parser_handler = cls.append_handler(_key_parser) - def _decorator(func: T_Handler) -> T_Handler: - if not hasattr(cls.handlers[-1].func, "__wrapped__"): - parser = cls.handlers.pop() - func_handler = Handler(func) - @wraps(func) - async def wrapper(bot: "Bot", event: "Event", state: T_State, - matcher: Matcher): - await parser(matcher, bot, event, state) - await func_handler(matcher, bot, event, state) - if "_current_key" in state: - del state["_current_key"] + async def _key_getter(bot: "Bot", event: "Event", state: T_State): + func_handler.remove_dependency(get_depend) + state["_current_key"] = key + if key not in state: + if prompt is not None: + if isinstance(prompt, MessageTemplate): + _prompt = prompt.format(**state) + else: + _prompt = prompt + await bot.send(event=event, message=_prompt) + raise PausedException + else: + state["_skip_key"] = True - wrapper_handler = cls.append_handler(wrapper) + async def _key_parser(bot: "Bot", event: "Event", state: T_State): + if key in state and state.get("_skip_key"): + del state["_skip_key"] + return + parser = args_parser or cls._default_parser + if parser: + await parser(bot, event, state) + else: + state[state["_current_key"]] = str(event.get_message()) - getter_handler.update_signature( - bot=wrapper_handler.bot_type, - event=wrapper_handler.event_type) - parser_handler.update_signature( - bot=wrapper_handler.bot_type, - event=wrapper_handler.event_type) + get_depend = Depends(_key_getter) + parser_depend = Depends(_key_parser) + + if cls.handlers and cls.handlers[-1].func is func: + func_handler = cls.handlers[-1] + func_handler.prepend_dependency(parser_depend) + func_handler.prepend_dependency(get_depend) + else: + func_handler = cls.append_handler( + func, dependencies=[get_depend, parser_depend]) return func diff --git a/nonebot/processor/models.py b/nonebot/processor/models.py index 14af1699..996c3be7 100644 --- a/nonebot/processor/models.py +++ b/nonebot/processor/models.py @@ -1,7 +1,10 @@ -from typing import Any, List, Callable, Optional +from typing import TYPE_CHECKING, Any, List, Tuple, Callable, Optional from nonebot.utils import get_name +if TYPE_CHECKING: + from nonebot.adapters import Bot, Event + class Depends: @@ -25,7 +28,9 @@ class Dependent: func: Optional[Callable[..., Any]] = None, name: Optional[str] = None, bot_param_name: Optional[str] = None, + bot_param_type: Optional[Tuple["Bot", ...]] = None, event_param_name: Optional[str] = None, + event_param_type: Optional[Tuple["Event", ...]] = None, state_param_name: Optional[str] = None, matcher_param_name: Optional[str] = None, dependencies: Optional[List["Dependent"]] = None, @@ -33,7 +38,9 @@ class Dependent: self.func = func self.name = name self.bot_param_name = bot_param_name + self.bot_param_type = bot_param_type self.event_param_name = event_param_name + self.event_param_type = event_param_type self.state_param_name = state_param_name self.matcher_param_name = matcher_param_name self.dependencies = dependencies or [] diff --git a/nonebot/processor/utils.py b/nonebot/processor/utils.py index 7c3729da..bdae3aea 100644 --- a/nonebot/processor/utils.py +++ b/nonebot/processor/utils.py @@ -1,7 +1,7 @@ import inspect from typing import Any, Dict, Type, Tuple, Union, Callable -from pydantic.typing import (ForwardRef, GenericAlias, get_args, get_origin, +from pydantic.typing import (ForwardRef, get_args, get_origin, evaluate_forwardref) @@ -41,3 +41,9 @@ def generic_check_issubclass( return False return True raise + + +def generic_get_types(cls: Any) -> Tuple[Type[Any], ...]: + if get_origin(cls) is Union: + return get_args(cls) + return (cls,) diff --git a/nonebot/typing.py b/nonebot/typing.py index 970fbcc7..053fcff7 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -22,7 +22,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable, NoReturn, Optional, Awaitable) if TYPE_CHECKING: - from nonebot.matcher import Matcher + from nonebot.processor import Matcher from nonebot.adapters import Bot, Event from nonebot.permission import Permission diff --git a/nonebot/utils.py b/nonebot/utils.py index 71df0d4b..f7d78457 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -4,11 +4,15 @@ import asyncio import inspect import dataclasses from functools import wraps, partial -from typing import Any, Callable, Optional, Awaitable +from typing_extensions import ParamSpec +from typing import Any, TypeVar, Callable, Optional, Awaitable from nonebot.log import logger from nonebot.typing import overrides +P = ParamSpec("P") +R = TypeVar("R") + def escape_tag(s: str) -> str: """ @@ -27,7 +31,16 @@ def escape_tag(s: str) -> str: return re.sub(r"\s]*)>", r"\\\g<0>", s) -def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: +def is_coroutine_callable(func: Callable[..., Any]) -> bool: + if inspect.isroutine(func): + return inspect.iscoroutinefunction(func) + if inspect.isclass(func): + return False + func_ = getattr(func, "__call__", None) + return inspect.iscoroutinefunction(func_) + + +def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: """ :说明: @@ -35,15 +48,15 @@ def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: :参数: - * ``func: Callable[..., Any]``: 被装饰的同步函数 + * ``func: Callable[P, R]``: 被装饰的同步函数 :返回: - - ``Callable[..., Awaitable[Any]]`` + - ``Callable[P, Awaitable[R]]`` """ @wraps(func) - async def _wrapper(*args: Any, **kwargs: Any) -> Any: + async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: loop = asyncio.get_running_loop() pfunc = partial(func, *args, **kwargs) result = await loop.run_in_executor(None, pfunc) diff --git a/poetry.lock b/poetry.lock index 13e3204d..91c140fc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1099,7 +1099,7 @@ quart = ["Quart"] [metadata] lock-version = "1.1" python-versions = "^3.7.3" -content-hash = "51f4f0ce5ced234a65cae790c4f57486e42d7120972657a3f51e733cb4e7c639" +content-hash = "81edd95f4289e55d7cfe632664c930846bde723cb8fa0359fa1e18474853f454" [metadata.files] aiocache = [ diff --git a/pyproject.toml b/pyproject.toml index 6d9a8330..4153931a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pygtrie = "^2.4.1" tomlkit = "^0.7.0" fastapi = "^0.70.0" websockets = ">=9.1" +typing-extensions = "^3.10.0" Quart = { version = "^0.15.0", optional = true } httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] } pydantic = { version = "~1.8.0", extras = ["dotenv"] } diff --git a/tests/test_plugins/test_processor.py b/tests/test_plugins/test_processor.py index b5aaa298..5eb6b70e 100644 --- a/tests/test_plugins/test_processor.py +++ b/tests/test_plugins/test_processor.py @@ -1,7 +1,7 @@ from nonebot.typing import T_State -from nonebot.matcher import Matcher +from nonebot.processor import Matcher from nonebot.adapters import Bot, Event -from nonebot.message import event_preprocessor, run_preprocessor +from nonebot.message import run_preprocessor, event_preprocessor @event_preprocessor