diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 5f06f309..382825bd 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -26,7 +26,6 @@ from typing import ( from nonebot import params from nonebot.rule import Rule from nonebot.log import logger -from nonebot.utils import CacheDict from nonebot.dependencies import Dependent from nonebot.permission import USER, Permission from nonebot.adapters import ( @@ -43,14 +42,6 @@ from nonebot.consts import ( REJECT_TARGET, LAST_RECEIVE_KEY, ) -from nonebot.typing import ( - Any, - T_State, - T_Handler, - T_ArgsParser, - T_TypeUpdater, - T_PermissionUpdater, -) from nonebot.exception import ( PausedException, StopPropagation, @@ -58,6 +49,15 @@ from nonebot.exception import ( FinishedException, RejectedException, ) +from nonebot.typing import ( + Any, + T_State, + T_Handler, + T_ArgsParser, + T_TypeUpdater, + T_DependencyCache, + T_PermissionUpdater, +) if TYPE_CHECKING: from nonebot.plugin import Plugin @@ -296,7 +296,7 @@ class Matcher(metaclass=MatcherMeta): bot: Bot, event: Event, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ) -> bool: """ :说明: @@ -324,7 +324,7 @@ class Matcher(metaclass=MatcherMeta): event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ) -> bool: """ :说明: @@ -669,7 +669,7 @@ class Matcher(metaclass=MatcherMeta): event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ): b_t = current_bot.set(bot) e_t = current_event.set(event) @@ -711,7 +711,7 @@ class Matcher(metaclass=MatcherMeta): event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ): try: await self.simple_run(bot, event, state, stack, dependency_cache) diff --git a/nonebot/message.py b/nonebot/message.py index 9b45c62a..a58c0222 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -22,9 +22,9 @@ from typing import ( from nonebot import params from nonebot.log import logger from nonebot.rule import TrieRule +from nonebot.utils import escape_tag from nonebot.dependencies import Dependent from nonebot.matcher import Matcher, matchers -from nonebot.utils import CacheDict, escape_tag from nonebot.exception import ( NoLogException, StopPropagation, @@ -34,6 +34,7 @@ from nonebot.exception import ( from nonebot.typing import ( T_State, T_Handler, + T_DependencyCache, T_RunPreProcessor, T_RunPostProcessor, T_EventPreProcessor, @@ -136,7 +137,7 @@ async def _check_matcher( event: "Event", state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ) -> None: if Matcher.expire_time and datetime.now() > Matcher.expire_time: try: @@ -171,7 +172,7 @@ async def _run_matcher( event: "Event", state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ) -> None: logger.info(f"Event will be handled by {Matcher}") @@ -275,7 +276,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None: logger.opt(colors=True).success(log_msg) state: Dict[Any, Any] = {} - dependency_cache: CacheDict[T_Handler, Any] = CacheDict() + dependency_cache: T_DependencyCache = {} async with AsyncExitStack() as stack: coros = list( diff --git a/nonebot/params.py b/nonebot/params.py index 50358736..822fd2a8 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -1,12 +1,13 @@ +import asyncio import inspect from typing import Any, Dict, List, Tuple, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from pydantic.fields import Required, Undefined -from nonebot.typing import T_State, T_Handler from nonebot.adapters import Bot, Event, Message from nonebot.dependencies import Param, Dependent +from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.consts import ( CMD_KEY, PREFIX_KEY, @@ -19,7 +20,6 @@ from nonebot.consts import ( REGEX_MATCHED, ) from nonebot.utils import ( - CacheDict, get_name, run_sync, is_gen_callable, @@ -49,7 +49,7 @@ class DependsInner: def Depends( dependency: Optional[T_Handler] = None, *, - use_cache: bool = False, + use_cache: bool = True, ) -> Any: """ :说明: @@ -114,11 +114,11 @@ class DependParam(Param): async def _solve( self, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, **kwargs: Any, ) -> Any: use_cache: bool = self.extra["use_cache"] - dependency_cache = CacheDict() if dependency_cache is None else dependency_cache + dependency_cache = {} if dependency_cache is None else dependency_cache sub_dependent: Dependent = self.extra["dependent"] sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) @@ -132,26 +132,28 @@ class DependParam(Param): ) # run dependency function - async with dependency_cache: - if use_cache and call in dependency_cache: - solved = dependency_cache[call] - elif is_gen_callable(call) or is_async_gen_callable(call): - assert isinstance( - stack, AsyncExitStack - ), "Generator dependency should be called in context" - if is_gen_callable(call): - cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) - else: - cm = asynccontextmanager(call)(**sub_values) - solved = await stack.enter_async_context(cm) - elif is_coroutine_callable(call): - return await call(**sub_values) + task: asyncio.Task[Any] + if use_cache and call in dependency_cache: + solved = await dependency_cache[call] + elif is_gen_callable(call) or is_async_gen_callable(call): + assert isinstance( + stack, AsyncExitStack + ), "Generator dependency should be called in context" + if is_gen_callable(call): + cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) else: - return await run_sync(call)(**sub_values) - - # save current dependency to cache - if call not in dependency_cache: - dependency_cache[call] = solved + cm = asynccontextmanager(call)(**sub_values) + task = asyncio.create_task(stack.enter_async_context(cm)) + dependency_cache[call] = task + solved = await task + elif is_coroutine_callable(call): + task = asyncio.create_task(call(**sub_values)) + dependency_cache[call] = task + solved = await task + else: + task = asyncio.create_task(run_sync(call)(**sub_values)) + dependency_cache[call] = task + solved = await task return solved @@ -243,7 +245,7 @@ def _command(state=State()) -> Message: def Command() -> Tuple[str, ...]: - return Depends(_command) + return Depends(_command, use_cache=False) def _raw_command(state=State()) -> Message: @@ -251,7 +253,7 @@ def _raw_command(state=State()) -> Message: def RawCommand() -> str: - return Depends(_raw_command) + return Depends(_raw_command, use_cache=False) def _command_arg(state=State()) -> Message: @@ -259,7 +261,7 @@ def _command_arg(state=State()) -> Message: def CommandArg() -> Message: - return Depends(_command_arg) + return Depends(_command_arg, use_cache=False) def _shell_command_args(state=State()) -> Any: @@ -267,7 +269,7 @@ def _shell_command_args(state=State()) -> Any: def ShellCommandArgs(): - return Depends(_shell_command_args) + return Depends(_shell_command_args, use_cache=False) def _shell_command_argv(state=State()) -> List[str]: @@ -275,7 +277,7 @@ def _shell_command_argv(state=State()) -> List[str]: def ShellCommandArgv() -> Any: - return Depends(_shell_command_argv) + return Depends(_shell_command_argv, use_cache=False) def _regex_matched(state=State()) -> str: @@ -283,7 +285,7 @@ def _regex_matched(state=State()) -> str: def RegexMatched() -> str: - return Depends(_regex_matched) + return Depends(_regex_matched, use_cache=False) def _regex_group(state=State()): @@ -291,7 +293,7 @@ def _regex_group(state=State()): def RegexGroup() -> Tuple[Any, ...]: - return Depends(_regex_group) + return Depends(_regex_group, use_cache=False) def _regex_dict(state=State()): @@ -299,7 +301,7 @@ def _regex_dict(state=State()): def RegexDict() -> Dict[str, Any]: - return Depends(_regex_dict) + return Depends(_regex_dict, use_cache=False) class MatcherParam(Param): @@ -320,14 +322,14 @@ def Received(id: str, default: Any = None) -> Any: def _received(matcher: "Matcher"): return matcher.get_receive(id, default) - return Depends(_received) + return Depends(_received, use_cache=False) def LastReceived(default: Any = None) -> Any: def _last_received(matcher: "Matcher") -> Any: return matcher.get_receive(None, default) - return Depends(_last_received) + return Depends(_last_received, use_cache=False) class ExceptionParam(Param): diff --git a/nonebot/permission.py b/nonebot/permission.py index 9ca25f64..a2ce1778 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -24,11 +24,10 @@ from typing import ( ) from nonebot import params -from nonebot.utils import CacheDict from nonebot.adapters import Bot, Event from nonebot.dependencies import Dependent from nonebot.exception import SkippedException -from nonebot.typing import T_Handler, T_PermissionChecker +from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]): @@ -93,7 +92,7 @@ class Permission: bot: Bot, event: Event, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ) -> bool: """ :说明: diff --git a/nonebot/rule.py b/nonebot/rule.py index 627f18c6..48c60963 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -22,12 +22,11 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence from pygtrie import CharTrie from nonebot.log import logger -from nonebot.utils import CacheDict from nonebot import params, get_driver from nonebot.dependencies import Dependent from nonebot.exception import ParserExit, SkippedException -from nonebot.typing import T_State, T_Handler, T_RuleChecker from nonebot.adapters import Bot, Event, Message, MessageSegment +from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache from nonebot.consts import ( CMD_KEY, PREFIX_KEY, @@ -105,7 +104,7 @@ class Rule: event: Event, state: T_State, stack: Optional[AsyncExitStack] = None, - dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, + dependency_cache: Optional[T_DependencyCache] = None, ) -> bool: """ :说明: diff --git a/nonebot/typing.py b/nonebot/typing.py index bf9d1042..e5a275f8 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -17,7 +17,7 @@ .. _typing: https://docs.python.org/3/library/typing.html """ - +from asyncio import Task from typing import ( TYPE_CHECKING, Any, @@ -32,7 +32,6 @@ from typing import ( ) if TYPE_CHECKING: - from nonebot.utils import CacheDict from nonebot.adapters import Bot, Event from nonebot.permission import Permission @@ -250,3 +249,9 @@ T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]] PermissionUpdater 在 Matcher.pause, Matcher.reject 时被运行,用于更新会话对象权限。默认会更新为当前事件的触发对象。 """ +T_DependencyCache = Dict[Callable[..., Any], Task[Any]] +""" +:类型: ``Dict[Callable[..., Any], Task[Any]]`` +:说明: + 依赖缓存, 用于存储依赖函数的返回值 +""" diff --git a/nonebot/utils.py b/nonebot/utils.py index 57756f58..bd4a92a0 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -135,33 +135,6 @@ def get_name(obj: Any) -> str: return obj.__class__.__name__ -class CacheDict(Dict[K, V], Generic[K, V]): - def __init__(self, *args, **kwargs): - super(CacheDict, self).__init__(*args, **kwargs) - self._lock = asyncio.Lock() - - @property - def locked(self): - return self._lock.locked() - - def __repr__(self): - extra = "locked" if self.locked else "unlocked" - return f"<{self.__class__.__name__} [{extra}]>" - - async def __aenter__(self) -> None: - await self.acquire() - return None - - async def __aexit__(self, exc_type, exc, tb): - self.release() - - async def acquire(self): - return await self._lock.acquire() - - def release(self): - self._lock.release() - - class DataclassEncoder(json.JSONEncoder): """ :说明: diff --git a/poetry.lock b/poetry.lock index bfcc66d2..fcb6dc22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -543,7 +543,7 @@ pytest-order = "^1.0.0" type = "git" url = "https://github.com/nonebot/nonebug.git" reference = "master" -resolved_reference = "4584d5a4bc95cd1bafcec08599ab7d72815e268e" +resolved_reference = "9c4f21373701ac25bc152cbad5f5527edc5e4c19" [[package]] name = "packaging" diff --git a/tests/.coveragerc b/tests/.coveragerc index 97998b7f..4bf3ebe2 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -1,6 +1,8 @@ [report] exclude_lines = + def __repr__ pragma: no cover if TYPE_CHECKING: @(abc\.)?abstractmethod raise NotImplementedError + if __name__ == .__main__.: diff --git a/tests/.isort.cfg b/tests/.isort.cfg new file mode 100644 index 00000000..68396bdd --- /dev/null +++ b/tests/.isort.cfg @@ -0,0 +1,8 @@ +[settings] +profile=black +line_length=80 +length_sort=true +skip_gitignore=true +force_sort_within_sections=true +known_local_folder=plugins +extra_standard_library=typing_extensions diff --git a/tests/plugins/depends.py b/tests/plugins/depends.py index e9690785..51226d20 100644 --- a/tests/plugins/depends.py +++ b/tests/plugins/depends.py @@ -2,22 +2,23 @@ from nonebot import on_message from nonebot.adapters import Event from nonebot.params import Depends -test = on_message() -test2 = on_message() +test_depends = on_message() -runned = False +runned = [] def dependency(event: Event): # test cache - global runned - assert not runned - runned = True + runned.append(event) return event -@test.handle() -@test2.handle() -async def handle(x: Event = Depends(dependency, use_cache=True)): +@test_depends.handle() +async def depends(x: Event = Depends(dependency)): # test dependency return x + + +@test_depends.handle() +async def depends_cache(y: Event = Depends(dependency, use_cache=True)): + return y diff --git a/tests/pyproject.toml b/tests/pyproject.toml deleted file mode 100644 index 8f7f477a..00000000 --- a/tests/pyproject.toml +++ /dev/null @@ -1,2 +0,0 @@ -[tool.isort] -known_local_folder = ["plugins"] diff --git a/tests/test_init.py b/tests/test_init.py index 8038be9d..ef76676c 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,6 +1,5 @@ import os import sys -from re import A from typing import TYPE_CHECKING, Set import pytest diff --git a/tests/test_param.py b/tests/test_param.py new file mode 100644 index 00000000..35a67daa --- /dev/null +++ b/tests/test_param.py @@ -0,0 +1,29 @@ +import pytest +from nonebug import App + +from utils import load_plugin, make_fake_event + + +@pytest.mark.asyncio +async def test_depends(app: App, load_plugin): + from nonebot.params import EventParam, DependParam + + from plugins.depends import runned, depends, test_depends + + async with app.test_dependent( + depends, allow_types=[EventParam, DependParam] + ) as ctx: + event = make_fake_event()() + ctx.pass_params(event=event) + ctx.should_return(event) + + assert len(runned) == 1 and runned[0] == event + + runned.clear() + + async with app.test_matcher(test_depends) as ctx: + bot = ctx.create_bot() + event_next = make_fake_event()() + ctx.receive_event(bot, event_next) + + assert len(runned) == 1 and runned[0] == event_next diff --git a/tests/utils.py b/tests/utils.py index 0223b365..2e90ca7e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,56 @@ from pathlib import Path -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING, Set, Type, Optional import pytest +from pydantic import create_model if TYPE_CHECKING: from nonebot.plugin import Plugin + from nonebot.adapters import Event, Message + + +def make_fake_event( + _type: str = "message", + _name: str = "test", + _description: str = "test", + _user_id: str = "test", + _session_id: str = "test", + _message: Optional["Message"] = None, + _to_me: bool = True, + **fields, +) -> Type["Event"]: + from nonebot.adapters import Event + + _Fake = create_model("_Fake", __base__=Event, **fields) + + class FakeEvent(_Fake): + def get_type(self) -> str: + return _type + + def get_event_name(self) -> str: + return _name + + def get_event_description(self) -> str: + return _description + + def get_user_id(self) -> str: + return _user_id + + def get_session_id(self) -> str: + return _session_id + + def get_message(self) -> "Message": + if _message is not None: + return _message + raise NotImplementedError + + def is_tome(self) -> bool: + return _to_me + + class Config: + extra = "forbid" + + return FakeEvent @pytest.fixture