From 75d4cd95658053f6b5cf3a7499489e44afde40aa Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Sun, 21 Nov 2021 15:46:48 +0800 Subject: [PATCH] :bug: fix cache concurrency --- nonebot/dependencies/__init__.py | 57 +++++++------- nonebot/handler.py | 2 +- nonebot/message.py | 5 +- nonebot/permission.py | 30 ++++---- nonebot/rule.py | 23 +++--- nonebot/utils.py | 76 ++++++++++++++++++- .../nonebot/adapters/cqhttp/permission.py | 25 +++--- tests/test_plugins/test_depends.py | 17 ++++- 8 files changed, 162 insertions(+), 73 deletions(-) diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index dd27ce9f..f863d5d9 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -21,8 +21,11 @@ from .models import Dependent as Dependent from nonebot.exception import SkippedException from .models import DependsWrapper as DependsWrapper from nonebot.typing import T_Handler, T_DependencyCache -from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, - is_async_gen_callable, is_coroutine_callable) +from nonebot.utils import (CacheLock, run_sync, is_gen_callable, + run_sync_ctx_manager, is_async_gen_callable, + is_coroutine_callable) + +cache_lock = CacheLock() class CustomConfig(BaseConfig): @@ -93,7 +96,7 @@ def get_dependent(*, break else: raise ValueError( - f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}" + f"Unknown parameter {param_name} for function {func} with type {param.annotation}" ) annotation: Any = Any @@ -122,7 +125,7 @@ async def solve_dependencies( _dependency_cache: Optional[T_DependencyCache] = None, **params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]: values: Dict[str, Any] = {} - dependency_cache = _dependency_cache or {} + dependency_cache = {} if _dependency_cache is None else _dependency_cache # solve sub dependencies sub_dependent: Dependent @@ -151,35 +154,37 @@ async def solve_dependencies( solved_result = await solve_dependencies( _dependent=use_sub_dependant, _dependency_overrides_provider=_dependency_overrides_provider, - dependency_cache=dependency_cache, + _dependency_cache=dependency_cache, **params) sub_values, sub_dependency_cache = solved_result # update cache? - dependency_cache.update(sub_dependency_cache) + # dependency_cache.update(sub_dependency_cache) # run dependency function - if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: - solved = dependency_cache[sub_dependent.cache_key] - elif is_gen_callable(func) or is_async_gen_callable(func): - assert isinstance( - _stack, AsyncExitStack - ), "Generator dependency should be called in context" - if is_gen_callable(func): - cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) + async with cache_lock: + if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: + solved = dependency_cache[sub_dependent.cache_key] + elif is_gen_callable(func) or is_async_gen_callable(func): + assert isinstance( + _stack, AsyncExitStack + ), "Generator dependency should be called in context" + if is_gen_callable(func): + cm = run_sync_ctx_manager( + contextmanager(func)(**sub_values)) + else: + cm = asynccontextmanager(func)(**sub_values) + solved = await _stack.enter_async_context(cm) + elif is_coroutine_callable(func): + solved = await func(**sub_values) else: - cm = asynccontextmanager(func)(**sub_values) - solved = await _stack.enter_async_context(cm) - elif is_coroutine_callable(func): - solved = await func(**sub_values) - else: - solved = await run_sync(func)(**sub_values) + solved = await run_sync(func)(**sub_values) - # parameter dependency - if sub_dependent.name is not None: - values[sub_dependent.name] = solved - # save current dependency to cache - if sub_dependent.cache_key not in dependency_cache: - dependency_cache[sub_dependent.cache_key] = solved + # parameter dependency + if sub_dependent.name is not None: + values[sub_dependent.name] = solved + # save current dependency to cache + if sub_dependent.cache_key not in dependency_cache: + dependency_cache[sub_dependent.cache_key] = solved # usual dependency for field in _dependent.params: diff --git a/nonebot/handler.py b/nonebot/handler.py index d6df70dd..a39003b1 100644 --- a/nonebot/handler.py +++ b/nonebot/handler.py @@ -80,7 +80,7 @@ class Handler: _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, **params) -> Any: - values, cache = await solve_dependencies( + values, _ = await solve_dependencies( _dependent=self.dependent, _stack=_stack, _sub_dependents=[ diff --git a/nonebot/message.py b/nonebot/message.py index 66379e77..3c00423b 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -163,7 +163,7 @@ async def _run_matcher( try: logger.debug(f"Running matcher {matcher}") - await matcher.run(bot, event, state) + await matcher.run(bot, event, state, stack, dependency_cache) except Exception as e: logger.opt(colors=True, exception=e).error( f"Running matcher {matcher} failed." @@ -260,7 +260,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None: logger.debug(f"Checking for matchers in priority {priority}...") pending_tasks = [ - _check_matcher(priority, matcher, bot, event, state.copy()) + _check_matcher(priority, matcher, bot, event, state.copy(), + stack, dependency_cache) for matcher in matchers[priority] ] diff --git a/nonebot/permission.py b/nonebot/permission.py index b83be58a..48781e32 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -42,17 +42,19 @@ class Permission: ] def __init__(self, - *checkers: T_PermissionChecker, + *checkers: Union[T_PermissionChecker, Handler], dependency_overrides_provider: Optional[Any] = None) -> None: """ :参数: - * ``*checkers: T_PermissionChecker``: PermissionChecker + * ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker """ + self.checkers = set( - Handler(checker, - allow_types=self.HANDLER_PARAM_TYPES, - dependency_overrides_provider=dependency_overrides_provider) + checker if isinstance(checker, Handler) else Handler( + checker, + allow_types=self.HANDLER_PARAM_TYPES, + dependency_overrides_provider=dependency_overrides_provider) for checker in checkers) """ :说明: @@ -90,11 +92,11 @@ class Permission: if not self.checkers: return True results = await asyncio.gather( - checker(bot=bot, - event=event, - _stack=stack, - _dependency_cache=dependency_cache) - for checker in self.checkers) + *(checker(bot=bot, + event=event, + _stack=stack, + _dependency_cache=dependency_cache) + for checker in self.checkers)) return any(results) def __and__(self, other) -> NoReturn: @@ -111,19 +113,19 @@ class Permission: return Permission(*self.checkers, other) -async def _message(bot: Bot, event: Event) -> bool: +async def _message(event: Event) -> bool: return event.get_type() == "message" -async def _notice(bot: Bot, event: Event) -> bool: +async def _notice(event: Event) -> bool: return event.get_type() == "notice" -async def _request(bot: Bot, event: Event) -> bool: +async def _request(event: Event) -> bool: return event.get_type() == "request" -async def _metaevent(bot: Bot, event: Event) -> bool: +async def _metaevent(event: Event) -> bool: return event.get_type() == "meta_event" diff --git a/nonebot/rule.py b/nonebot/rule.py index c2ac4401..177b59d7 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -69,18 +69,19 @@ class Rule: ] def __init__(self, - *checkers: T_RuleChecker, + *checkers: Union[T_RuleChecker, Handler], dependency_overrides_provider: Optional[Any] = None) -> None: """ :参数: - * ``*checkers: T_RuleChecker``: RuleChecker + * ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker """ self.checkers = set( - Handler(checker, - allow_types=self.HANDLER_PARAM_TYPES, - dependency_overrides_provider=dependency_overrides_provider) + checker if isinstance(checker, Handler) else Handler( + checker, + allow_types=self.HANDLER_PARAM_TYPES, + dependency_overrides_provider=dependency_overrides_provider) for checker in checkers) """ :说明: @@ -120,12 +121,12 @@ class Rule: if not self.checkers: return True results = await asyncio.gather( - checker(bot=bot, - event=event, - state=state, - _stack=stack, - _dependency_cache=dependency_cache) - for checker in self.checkers) + *(checker(bot=bot, + event=event, + state=state, + _stack=stack, + _dependency_cache=dependency_cache) + for checker in self.checkers)) return all(results) def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule": diff --git a/nonebot/utils.py b/nonebot/utils.py index 66e905e6..32eca0db 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -2,12 +2,13 @@ import re import json import asyncio import inspect +import collections import dataclasses from functools import wraps, partial from contextlib import asynccontextmanager from typing_extensions import GenericAlias # type: ignore from typing_extensions import ParamSpec, get_args, get_origin -from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional, +from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional, Awaitable, AsyncGenerator, ContextManager) from nonebot.log import logger @@ -120,6 +121,79 @@ def get_name(obj: Any) -> str: return obj.__class__.__name__ +class CacheLock: + + def __init__(self): + self._waiters: Optional[Deque[asyncio.Future]] = None + self._locked = False + + def __repr__(self): + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters: {len(self._waiters)}" + return f"<{self.__class__.__name__} [{extra}]>" + + async def __aenter__(self): + await self.acquire() + return None + + async def __aexit__(self, exc_type, exc, tb): + self.release() + + def locked(self): + return self._locked + + async def acquire(self): + if (not self._locked and (self._waiters is None or + all(w.cancelled() for w in self._waiters))): + self._locked = True + return True + + if self._waiters is None: + self._waiters = collections.deque() + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._waiters.append(future) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. + try: + try: + await future + finally: + self._waiters.remove(future) + except asyncio.CancelledError: + if not self._locked: + self._wake_up_first() + raise + + self._locked = True + return True + + def release(self): + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock is not acquired.") + + def _wake_up_first(self): + if not self._waiters: + return + try: + future = next(iter(self._waiters)) + except StopIteration: + return + + # .done() necessarily means that a waiter will wake up later on and + # either take the lock, or, if it was cancelled and lock wasn't + # taken already, will hit this again and wake up a new waiter. + if not future.done(): + future.set_result(True) + + class DataclassEncoder(json.JSONEncoder): """ :说明: diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py index 1d3b3f36..09ea7b7a 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py @@ -1,26 +1,21 @@ -from typing import TYPE_CHECKING - +from nonebot.adapters import Event from nonebot.permission import Permission - -from .event import PrivateMessageEvent, GroupMessageEvent - -if TYPE_CHECKING: - from nonebot.adapters import Bot, Event +from .event import GroupMessageEvent, PrivateMessageEvent -async def _private(bot: "Bot", event: "Event") -> bool: +async def _private(event: Event) -> bool: return isinstance(event, PrivateMessageEvent) -async def _private_friend(bot: "Bot", event: "Event") -> bool: +async def _private_friend(event: Event) -> bool: return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend" -async def _private_group(bot: "Bot", event: "Event") -> bool: +async def _private_group(event: Event) -> bool: return isinstance(event, PrivateMessageEvent) and event.sub_type == "group" -async def _private_other(bot: "Bot", event: "Event") -> bool: +async def _private_other(event: Event) -> bool: return isinstance(event, PrivateMessageEvent) and event.sub_type == "other" @@ -42,20 +37,20 @@ PRIVATE_OTHER = Permission(_private_other) """ -async def _group(bot: "Bot", event: "Event") -> bool: +async def _group(event: Event) -> bool: return isinstance(event, GroupMessageEvent) -async def _group_member(bot: "Bot", event: "Event") -> bool: +async def _group_member(event: Event) -> bool: return isinstance(event, GroupMessageEvent) and event.sender.role == "member" -async def _group_admin(bot: "Bot", event: "Event") -> bool: +async def _group_admin(event: Event) -> bool: return isinstance(event, GroupMessageEvent) and event.sender.role == "admin" -async def _group_owner(bot: "Bot", event: "Event") -> bool: +async def _group_owner(event: Event) -> bool: return isinstance(event, GroupMessageEvent) and event.sender.role == "owner" diff --git a/tests/test_plugins/test_depends.py b/tests/test_plugins/test_depends.py index c604169b..e23c9c4e 100644 --- a/tests/test_plugins/test_depends.py +++ b/tests/test_plugins/test_depends.py @@ -1,11 +1,12 @@ -from nonebot import on_command from nonebot.log import logger from nonebot.dependencies import Depends +from nonebot import on_command, on_message test = on_command("123") def depend(state: dict): + print("==== depends running =====") return state @@ -13,5 +14,15 @@ def depend(state: dict): @test.got("b", prompt="b") @test.receive() @test.got("c", prompt="c") -async def _(state: dict = Depends(depend)): - logger.info(f"=======, {state}") +async def _(x: dict = Depends(depend)): + logger.info(f"=======, {x}") + + +test_cache1 = on_message() +test_cache2 = on_message() + + +@test_cache1.handle() +@test_cache2.handle() +async def _(x: dict = Depends(depend)): + logger.info(f"======= test, {x}")