🐛 fix cache concurrency

This commit is contained in:
yanyongyu 2021-11-21 15:46:48 +08:00
parent d22630e768
commit 75d4cd9565
8 changed files with 162 additions and 73 deletions

View File

@ -21,8 +21,11 @@ from .models import Dependent as Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, from nonebot.utils import (CacheLock, run_sync, is_gen_callable,
is_async_gen_callable, is_coroutine_callable) run_sync_ctx_manager, is_async_gen_callable,
is_coroutine_callable)
cache_lock = CacheLock()
class CustomConfig(BaseConfig): class CustomConfig(BaseConfig):
@ -93,7 +96,7 @@ def get_dependent(*,
break break
else: else:
raise ValueError( raise ValueError(
f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}" f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
) )
annotation: Any = Any annotation: Any = Any
@ -122,7 +125,7 @@ async def solve_dependencies(
_dependency_cache: Optional[T_DependencyCache] = None, _dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]: **params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
dependency_cache = _dependency_cache or {} dependency_cache = {} if _dependency_cache is None else _dependency_cache
# solve sub dependencies # solve sub dependencies
sub_dependent: Dependent sub_dependent: Dependent
@ -151,35 +154,37 @@ async def solve_dependencies(
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
_dependent=use_sub_dependant, _dependent=use_sub_dependant,
_dependency_overrides_provider=_dependency_overrides_provider, _dependency_overrides_provider=_dependency_overrides_provider,
dependency_cache=dependency_cache, _dependency_cache=dependency_cache,
**params) **params)
sub_values, sub_dependency_cache = solved_result sub_values, sub_dependency_cache = solved_result
# update cache? # update cache?
dependency_cache.update(sub_dependency_cache) # dependency_cache.update(sub_dependency_cache)
# run dependency function # run dependency function
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: async with cache_lock:
solved = dependency_cache[sub_dependent.cache_key] if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
elif is_gen_callable(func) or is_async_gen_callable(func): solved = dependency_cache[sub_dependent.cache_key]
assert isinstance( elif is_gen_callable(func) or is_async_gen_callable(func):
_stack, AsyncExitStack assert isinstance(
), "Generator dependency should be called in context" _stack, AsyncExitStack
if is_gen_callable(func): ), "Generator dependency should be called in context"
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) if is_gen_callable(func):
cm = run_sync_ctx_manager(
contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else: else:
cm = asynccontextmanager(func)(**sub_values) solved = await run_sync(func)(**sub_values)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:
solved = await run_sync(func)(**sub_values)
# parameter dependency # parameter dependency
if sub_dependent.name is not None: if sub_dependent.name is not None:
values[sub_dependent.name] = solved values[sub_dependent.name] = solved
# save current dependency to cache # save current dependency to cache
if sub_dependent.cache_key not in dependency_cache: if sub_dependent.cache_key not in dependency_cache:
dependency_cache[sub_dependent.cache_key] = solved dependency_cache[sub_dependent.cache_key] = solved
# usual dependency # usual dependency
for field in _dependent.params: for field in _dependent.params:

View File

@ -80,7 +80,7 @@ class Handler:
_dependency_cache: Optional[Dict[Callable[..., Any], _dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None, Any]] = None,
**params) -> Any: **params) -> Any:
values, cache = await solve_dependencies( values, _ = await solve_dependencies(
_dependent=self.dependent, _dependent=self.dependent,
_stack=_stack, _stack=_stack,
_sub_dependents=[ _sub_dependents=[

View File

@ -163,7 +163,7 @@ async def _run_matcher(
try: try:
logger.debug(f"Running matcher {matcher}") logger.debug(f"Running matcher {matcher}")
await matcher.run(bot, event, state) await matcher.run(bot, event, state, stack, dependency_cache)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>" f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>"
@ -260,7 +260,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.debug(f"Checking for matchers in priority {priority}...") logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [ pending_tasks = [
_check_matcher(priority, matcher, bot, event, state.copy()) _check_matcher(priority, matcher, bot, event, state.copy(),
stack, dependency_cache)
for matcher in matchers[priority] for matcher in matchers[priority]
] ]

View File

@ -42,17 +42,19 @@ class Permission:
] ]
def __init__(self, def __init__(self,
*checkers: T_PermissionChecker, *checkers: Union[T_PermissionChecker, Handler],
dependency_overrides_provider: Optional[Any] = None) -> None: dependency_overrides_provider: Optional[Any] = None) -> None:
""" """
:参数: :参数:
* ``*checkers: T_PermissionChecker``: PermissionChecker * ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker
""" """
self.checkers = set( self.checkers = set(
Handler(checker, checker if isinstance(checker, Handler) else Handler(
allow_types=self.HANDLER_PARAM_TYPES, checker,
dependency_overrides_provider=dependency_overrides_provider) allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers) for checker in checkers)
""" """
:说明: :说明:
@ -90,11 +92,11 @@ class Permission:
if not self.checkers: if not self.checkers:
return True return True
results = await asyncio.gather( results = await asyncio.gather(
checker(bot=bot, *(checker(bot=bot,
event=event, event=event,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache) _dependency_cache=dependency_cache)
for checker in self.checkers) for checker in self.checkers))
return any(results) return any(results)
def __and__(self, other) -> NoReturn: def __and__(self, other) -> NoReturn:
@ -111,19 +113,19 @@ class Permission:
return Permission(*self.checkers, other) return Permission(*self.checkers, other)
async def _message(bot: Bot, event: Event) -> bool: async def _message(event: Event) -> bool:
return event.get_type() == "message" return event.get_type() == "message"
async def _notice(bot: Bot, event: Event) -> bool: async def _notice(event: Event) -> bool:
return event.get_type() == "notice" return event.get_type() == "notice"
async def _request(bot: Bot, event: Event) -> bool: async def _request(event: Event) -> bool:
return event.get_type() == "request" return event.get_type() == "request"
async def _metaevent(bot: Bot, event: Event) -> bool: async def _metaevent(event: Event) -> bool:
return event.get_type() == "meta_event" return event.get_type() == "meta_event"

View File

@ -69,18 +69,19 @@ class Rule:
] ]
def __init__(self, def __init__(self,
*checkers: T_RuleChecker, *checkers: Union[T_RuleChecker, Handler],
dependency_overrides_provider: Optional[Any] = None) -> None: dependency_overrides_provider: Optional[Any] = None) -> None:
""" """
:参数: :参数:
* ``*checkers: T_RuleChecker``: RuleChecker * ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker
""" """
self.checkers = set( self.checkers = set(
Handler(checker, checker if isinstance(checker, Handler) else Handler(
allow_types=self.HANDLER_PARAM_TYPES, checker,
dependency_overrides_provider=dependency_overrides_provider) allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers) for checker in checkers)
""" """
:说明: :说明:
@ -120,12 +121,12 @@ class Rule:
if not self.checkers: if not self.checkers:
return True return True
results = await asyncio.gather( results = await asyncio.gather(
checker(bot=bot, *(checker(bot=bot,
event=event, event=event,
state=state, state=state,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache) _dependency_cache=dependency_cache)
for checker in self.checkers) for checker in self.checkers))
return all(results) return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule": def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":

View File

@ -2,12 +2,13 @@ import re
import json import json
import asyncio import asyncio
import inspect import inspect
import collections
import dataclasses import dataclasses
from functools import wraps, partial from functools import wraps, partial
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing_extensions import GenericAlias # type: ignore from typing_extensions import GenericAlias # type: ignore
from typing_extensions import ParamSpec, get_args, get_origin from typing_extensions import ParamSpec, get_args, get_origin
from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional, from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional,
Awaitable, AsyncGenerator, ContextManager) Awaitable, AsyncGenerator, ContextManager)
from nonebot.log import logger from nonebot.log import logger
@ -120,6 +121,79 @@ def get_name(obj: Any) -> str:
return obj.__class__.__name__ return obj.__class__.__name__
class CacheLock:
def __init__(self):
self._waiters: Optional[Deque[asyncio.Future]] = None
self._locked = False
def __repr__(self):
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters: {len(self._waiters)}"
return f"<{self.__class__.__name__} [{extra}]>"
async def __aenter__(self):
await self.acquire()
return None
async def __aexit__(self, exc_type, exc, tb):
self.release()
def locked(self):
return self._locked
async def acquire(self):
if (not self._locked and (self._waiters is None or
all(w.cancelled() for w in self._waiters))):
self._locked = True
return True
if self._waiters is None:
self._waiters = collections.deque()
loop = asyncio.get_running_loop()
future = loop.create_future()
self._waiters.append(future)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await future
finally:
self._waiters.remove(future)
except asyncio.CancelledError:
if not self._locked:
self._wake_up_first()
raise
self._locked = True
return True
def release(self):
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
def _wake_up_first(self):
if not self._waiters:
return
try:
future = next(iter(self._waiters))
except StopIteration:
return
# .done() necessarily means that a waiter will wake up later on and
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
if not future.done():
future.set_result(True)
class DataclassEncoder(json.JSONEncoder): class DataclassEncoder(json.JSONEncoder):
""" """
:说明: :说明:

View File

@ -1,26 +1,21 @@
from typing import TYPE_CHECKING from nonebot.adapters import Event
from nonebot.permission import Permission from nonebot.permission import Permission
from .event import GroupMessageEvent, PrivateMessageEvent
from .event import PrivateMessageEvent, GroupMessageEvent
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
async def _private(bot: "Bot", event: "Event") -> bool: async def _private(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) return isinstance(event, PrivateMessageEvent)
async def _private_friend(bot: "Bot", event: "Event") -> bool: async def _private_friend(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend" return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend"
async def _private_group(bot: "Bot", event: "Event") -> bool: async def _private_group(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) and event.sub_type == "group" return isinstance(event, PrivateMessageEvent) and event.sub_type == "group"
async def _private_other(bot: "Bot", event: "Event") -> bool: async def _private_other(event: Event) -> bool:
return isinstance(event, PrivateMessageEvent) and event.sub_type == "other" return isinstance(event, PrivateMessageEvent) and event.sub_type == "other"
@ -42,20 +37,20 @@ PRIVATE_OTHER = Permission(_private_other)
""" """
async def _group(bot: "Bot", event: "Event") -> bool: async def _group(event: Event) -> bool:
return isinstance(event, GroupMessageEvent) return isinstance(event, GroupMessageEvent)
async def _group_member(bot: "Bot", event: "Event") -> bool: async def _group_member(event: Event) -> bool:
return isinstance(event, return isinstance(event,
GroupMessageEvent) and event.sender.role == "member" GroupMessageEvent) and event.sender.role == "member"
async def _group_admin(bot: "Bot", event: "Event") -> bool: async def _group_admin(event: Event) -> bool:
return isinstance(event, GroupMessageEvent) and event.sender.role == "admin" return isinstance(event, GroupMessageEvent) and event.sender.role == "admin"
async def _group_owner(bot: "Bot", event: "Event") -> bool: async def _group_owner(event: Event) -> bool:
return isinstance(event, GroupMessageEvent) and event.sender.role == "owner" return isinstance(event, GroupMessageEvent) and event.sender.role == "owner"

View File

@ -1,11 +1,12 @@
from nonebot import on_command
from nonebot.log import logger from nonebot.log import logger
from nonebot.dependencies import Depends from nonebot.dependencies import Depends
from nonebot import on_command, on_message
test = on_command("123") test = on_command("123")
def depend(state: dict): def depend(state: dict):
print("==== depends running =====")
return state return state
@ -13,5 +14,15 @@ def depend(state: dict):
@test.got("b", prompt="b") @test.got("b", prompt="b")
@test.receive() @test.receive()
@test.got("c", prompt="c") @test.got("c", prompt="c")
async def _(state: dict = Depends(depend)): async def _(x: dict = Depends(depend)):
logger.info(f"=======, {state}") logger.info(f"=======, {x}")
test_cache1 = on_message()
test_cache2 = on_message()
@test_cache1.handle()
@test_cache2.handle()
async def _(x: dict = Depends(depend)):
logger.info(f"======= test, {x}")