mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
🐛 fix cache concurrency
This commit is contained in:
parent
d22630e768
commit
75d4cd9565
@ -21,8 +21,11 @@ from .models import Dependent as Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from .models import DependsWrapper as DependsWrapper
|
||||
from nonebot.typing import T_Handler, T_DependencyCache
|
||||
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
|
||||
is_async_gen_callable, is_coroutine_callable)
|
||||
from nonebot.utils import (CacheLock, run_sync, is_gen_callable,
|
||||
run_sync_ctx_manager, is_async_gen_callable,
|
||||
is_coroutine_callable)
|
||||
|
||||
cache_lock = CacheLock()
|
||||
|
||||
|
||||
class CustomConfig(BaseConfig):
|
||||
@ -93,7 +96,7 @@ def get_dependent(*,
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}"
|
||||
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
||||
)
|
||||
|
||||
annotation: Any = Any
|
||||
@ -122,7 +125,7 @@ async def solve_dependencies(
|
||||
_dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
|
||||
values: Dict[str, Any] = {}
|
||||
dependency_cache = _dependency_cache or {}
|
||||
dependency_cache = {} if _dependency_cache is None else _dependency_cache
|
||||
|
||||
# solve sub dependencies
|
||||
sub_dependent: Dependent
|
||||
@ -151,35 +154,37 @@ async def solve_dependencies(
|
||||
solved_result = await solve_dependencies(
|
||||
_dependent=use_sub_dependant,
|
||||
_dependency_overrides_provider=_dependency_overrides_provider,
|
||||
dependency_cache=dependency_cache,
|
||||
_dependency_cache=dependency_cache,
|
||||
**params)
|
||||
sub_values, sub_dependency_cache = solved_result
|
||||
# update cache?
|
||||
dependency_cache.update(sub_dependency_cache)
|
||||
# dependency_cache.update(sub_dependency_cache)
|
||||
|
||||
# run dependency function
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||
assert isinstance(
|
||||
_stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(func):
|
||||
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
||||
async with cache_lock:
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||
assert isinstance(
|
||||
_stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(func):
|
||||
cm = run_sync_ctx_manager(
|
||||
contextmanager(func)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(func)(**sub_values)
|
||||
solved = await _stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
else:
|
||||
cm = asynccontextmanager(func)(**sub_values)
|
||||
solved = await _stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
else:
|
||||
solved = await run_sync(func)(**sub_values)
|
||||
solved = await run_sync(func)(**sub_values)
|
||||
|
||||
# parameter dependency
|
||||
if sub_dependent.name is not None:
|
||||
values[sub_dependent.name] = solved
|
||||
# save current dependency to cache
|
||||
if sub_dependent.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependent.cache_key] = solved
|
||||
# parameter dependency
|
||||
if sub_dependent.name is not None:
|
||||
values[sub_dependent.name] = solved
|
||||
# save current dependency to cache
|
||||
if sub_dependent.cache_key not in dependency_cache:
|
||||
dependency_cache[sub_dependent.cache_key] = solved
|
||||
|
||||
# usual dependency
|
||||
for field in _dependent.params:
|
||||
|
@ -80,7 +80,7 @@ class Handler:
|
||||
_dependency_cache: Optional[Dict[Callable[..., Any],
|
||||
Any]] = None,
|
||||
**params) -> Any:
|
||||
values, cache = await solve_dependencies(
|
||||
values, _ = await solve_dependencies(
|
||||
_dependent=self.dependent,
|
||||
_stack=_stack,
|
||||
_sub_dependents=[
|
||||
|
@ -163,7 +163,7 @@ async def _run_matcher(
|
||||
|
||||
try:
|
||||
logger.debug(f"Running matcher {matcher}")
|
||||
await matcher.run(bot, event, state)
|
||||
await matcher.run(bot, event, state, stack, dependency_cache)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>"
|
||||
@ -260,7 +260,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
logger.debug(f"Checking for matchers in priority {priority}...")
|
||||
|
||||
pending_tasks = [
|
||||
_check_matcher(priority, matcher, bot, event, state.copy())
|
||||
_check_matcher(priority, matcher, bot, event, state.copy(),
|
||||
stack, dependency_cache)
|
||||
for matcher in matchers[priority]
|
||||
]
|
||||
|
||||
|
@ -42,17 +42,19 @@ class Permission:
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
*checkers: T_PermissionChecker,
|
||||
*checkers: Union[T_PermissionChecker, Handler],
|
||||
dependency_overrides_provider: Optional[Any] = None) -> None:
|
||||
"""
|
||||
:参数:
|
||||
|
||||
* ``*checkers: T_PermissionChecker``: PermissionChecker
|
||||
* ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker
|
||||
"""
|
||||
|
||||
self.checkers = set(
|
||||
Handler(checker,
|
||||
allow_types=self.HANDLER_PARAM_TYPES,
|
||||
dependency_overrides_provider=dependency_overrides_provider)
|
||||
checker if isinstance(checker, Handler) else Handler(
|
||||
checker,
|
||||
allow_types=self.HANDLER_PARAM_TYPES,
|
||||
dependency_overrides_provider=dependency_overrides_provider)
|
||||
for checker in checkers)
|
||||
"""
|
||||
:说明:
|
||||
@ -90,11 +92,11 @@ class Permission:
|
||||
if not self.checkers:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
checker(bot=bot,
|
||||
event=event,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
for checker in self.checkers)
|
||||
*(checker(bot=bot,
|
||||
event=event,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
for checker in self.checkers))
|
||||
return any(results)
|
||||
|
||||
def __and__(self, other) -> NoReturn:
|
||||
@ -111,19 +113,19 @@ class Permission:
|
||||
return Permission(*self.checkers, other)
|
||||
|
||||
|
||||
async def _message(bot: Bot, event: Event) -> bool:
|
||||
async def _message(event: Event) -> bool:
|
||||
return event.get_type() == "message"
|
||||
|
||||
|
||||
async def _notice(bot: Bot, event: Event) -> bool:
|
||||
async def _notice(event: Event) -> bool:
|
||||
return event.get_type() == "notice"
|
||||
|
||||
|
||||
async def _request(bot: Bot, event: Event) -> bool:
|
||||
async def _request(event: Event) -> bool:
|
||||
return event.get_type() == "request"
|
||||
|
||||
|
||||
async def _metaevent(bot: Bot, event: Event) -> bool:
|
||||
async def _metaevent(event: Event) -> bool:
|
||||
return event.get_type() == "meta_event"
|
||||
|
||||
|
||||
|
@ -69,18 +69,19 @@ class Rule:
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
*checkers: T_RuleChecker,
|
||||
*checkers: Union[T_RuleChecker, Handler],
|
||||
dependency_overrides_provider: Optional[Any] = None) -> None:
|
||||
"""
|
||||
:参数:
|
||||
|
||||
* ``*checkers: T_RuleChecker``: RuleChecker
|
||||
* ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker
|
||||
|
||||
"""
|
||||
self.checkers = set(
|
||||
Handler(checker,
|
||||
allow_types=self.HANDLER_PARAM_TYPES,
|
||||
dependency_overrides_provider=dependency_overrides_provider)
|
||||
checker if isinstance(checker, Handler) else Handler(
|
||||
checker,
|
||||
allow_types=self.HANDLER_PARAM_TYPES,
|
||||
dependency_overrides_provider=dependency_overrides_provider)
|
||||
for checker in checkers)
|
||||
"""
|
||||
:说明:
|
||||
@ -120,12 +121,12 @@ class Rule:
|
||||
if not self.checkers:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
checker(bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
for checker in self.checkers)
|
||||
*(checker(bot=bot,
|
||||
event=event,
|
||||
state=state,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache)
|
||||
for checker in self.checkers))
|
||||
return all(results)
|
||||
|
||||
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
||||
|
@ -2,12 +2,13 @@ import re
|
||||
import json
|
||||
import asyncio
|
||||
import inspect
|
||||
import collections
|
||||
import dataclasses
|
||||
from functools import wraps, partial
|
||||
from contextlib import asynccontextmanager
|
||||
from typing_extensions import GenericAlias # type: ignore
|
||||
from typing_extensions import ParamSpec, get_args, get_origin
|
||||
from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional,
|
||||
from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional,
|
||||
Awaitable, AsyncGenerator, ContextManager)
|
||||
|
||||
from nonebot.log import logger
|
||||
@ -120,6 +121,79 @@ def get_name(obj: Any) -> str:
|
||||
return obj.__class__.__name__
|
||||
|
||||
|
||||
class CacheLock:
|
||||
|
||||
def __init__(self):
|
||||
self._waiters: Optional[Deque[asyncio.Future]] = None
|
||||
self._locked = False
|
||||
|
||||
def __repr__(self):
|
||||
extra = "locked" if self._locked else "unlocked"
|
||||
if self._waiters:
|
||||
extra = f"{extra}, waiters: {len(self._waiters)}"
|
||||
return f"<{self.__class__.__name__} [{extra}]>"
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
return None
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.release()
|
||||
|
||||
def locked(self):
|
||||
return self._locked
|
||||
|
||||
async def acquire(self):
|
||||
if (not self._locked and (self._waiters is None or
|
||||
all(w.cancelled() for w in self._waiters))):
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._waiters.append(future)
|
||||
|
||||
# Finally block should be called before the CancelledError
|
||||
# handling as we don't want CancelledError to call
|
||||
# _wake_up_first() and attempt to wake up itself.
|
||||
try:
|
||||
try:
|
||||
await future
|
||||
finally:
|
||||
self._waiters.remove(future)
|
||||
except asyncio.CancelledError:
|
||||
if not self._locked:
|
||||
self._wake_up_first()
|
||||
raise
|
||||
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
def release(self):
|
||||
if self._locked:
|
||||
self._locked = False
|
||||
self._wake_up_first()
|
||||
else:
|
||||
raise RuntimeError("Lock is not acquired.")
|
||||
|
||||
def _wake_up_first(self):
|
||||
if not self._waiters:
|
||||
return
|
||||
try:
|
||||
future = next(iter(self._waiters))
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# .done() necessarily means that a waiter will wake up later on and
|
||||
# either take the lock, or, if it was cancelled and lock wasn't
|
||||
# taken already, will hit this again and wake up a new waiter.
|
||||
if not future.done():
|
||||
future.set_result(True)
|
||||
|
||||
|
||||
class DataclassEncoder(json.JSONEncoder):
|
||||
"""
|
||||
:说明:
|
||||
|
@ -1,26 +1,21 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.permission import Permission
|
||||
|
||||
from .event import PrivateMessageEvent, GroupMessageEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot, Event
|
||||
from .event import GroupMessageEvent, PrivateMessageEvent
|
||||
|
||||
|
||||
async def _private(bot: "Bot", event: "Event") -> bool:
|
||||
async def _private(event: Event) -> bool:
|
||||
return isinstance(event, PrivateMessageEvent)
|
||||
|
||||
|
||||
async def _private_friend(bot: "Bot", event: "Event") -> bool:
|
||||
async def _private_friend(event: Event) -> bool:
|
||||
return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend"
|
||||
|
||||
|
||||
async def _private_group(bot: "Bot", event: "Event") -> bool:
|
||||
async def _private_group(event: Event) -> bool:
|
||||
return isinstance(event, PrivateMessageEvent) and event.sub_type == "group"
|
||||
|
||||
|
||||
async def _private_other(bot: "Bot", event: "Event") -> bool:
|
||||
async def _private_other(event: Event) -> bool:
|
||||
return isinstance(event, PrivateMessageEvent) and event.sub_type == "other"
|
||||
|
||||
|
||||
@ -42,20 +37,20 @@ PRIVATE_OTHER = Permission(_private_other)
|
||||
"""
|
||||
|
||||
|
||||
async def _group(bot: "Bot", event: "Event") -> bool:
|
||||
async def _group(event: Event) -> bool:
|
||||
return isinstance(event, GroupMessageEvent)
|
||||
|
||||
|
||||
async def _group_member(bot: "Bot", event: "Event") -> bool:
|
||||
async def _group_member(event: Event) -> bool:
|
||||
return isinstance(event,
|
||||
GroupMessageEvent) and event.sender.role == "member"
|
||||
|
||||
|
||||
async def _group_admin(bot: "Bot", event: "Event") -> bool:
|
||||
async def _group_admin(event: Event) -> bool:
|
||||
return isinstance(event, GroupMessageEvent) and event.sender.role == "admin"
|
||||
|
||||
|
||||
async def _group_owner(bot: "Bot", event: "Event") -> bool:
|
||||
async def _group_owner(event: Event) -> bool:
|
||||
return isinstance(event, GroupMessageEvent) and event.sender.role == "owner"
|
||||
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
from nonebot import on_command
|
||||
from nonebot.log import logger
|
||||
from nonebot.dependencies import Depends
|
||||
from nonebot import on_command, on_message
|
||||
|
||||
test = on_command("123")
|
||||
|
||||
|
||||
def depend(state: dict):
|
||||
print("==== depends running =====")
|
||||
return state
|
||||
|
||||
|
||||
@ -13,5 +14,15 @@ def depend(state: dict):
|
||||
@test.got("b", prompt="b")
|
||||
@test.receive()
|
||||
@test.got("c", prompt="c")
|
||||
async def _(state: dict = Depends(depend)):
|
||||
logger.info(f"=======, {state}")
|
||||
async def _(x: dict = Depends(depend)):
|
||||
logger.info(f"=======, {x}")
|
||||
|
||||
|
||||
test_cache1 = on_message()
|
||||
test_cache2 = on_message()
|
||||
|
||||
|
||||
@test_cache1.handle()
|
||||
@test_cache2.handle()
|
||||
async def _(x: dict = Depends(depend)):
|
||||
logger.info(f"======= test, {x}")
|
||||
|
Loading…
Reference in New Issue
Block a user