mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 09:05:04 +08:00
🐛 fix cache concurrency
This commit is contained in:
parent
d22630e768
commit
75d4cd9565
@ -21,8 +21,11 @@ from .models import Dependent as Dependent
|
|||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from .models import DependsWrapper as DependsWrapper
|
from .models import DependsWrapper as DependsWrapper
|
||||||
from nonebot.typing import T_Handler, T_DependencyCache
|
from nonebot.typing import T_Handler, T_DependencyCache
|
||||||
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
|
from nonebot.utils import (CacheLock, run_sync, is_gen_callable,
|
||||||
is_async_gen_callable, is_coroutine_callable)
|
run_sync_ctx_manager, is_async_gen_callable,
|
||||||
|
is_coroutine_callable)
|
||||||
|
|
||||||
|
cache_lock = CacheLock()
|
||||||
|
|
||||||
|
|
||||||
class CustomConfig(BaseConfig):
|
class CustomConfig(BaseConfig):
|
||||||
@ -93,7 +96,7 @@ def get_dependent(*,
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}"
|
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
||||||
)
|
)
|
||||||
|
|
||||||
annotation: Any = Any
|
annotation: Any = Any
|
||||||
@ -122,7 +125,7 @@ async def solve_dependencies(
|
|||||||
_dependency_cache: Optional[T_DependencyCache] = None,
|
_dependency_cache: Optional[T_DependencyCache] = None,
|
||||||
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
|
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
|
||||||
values: Dict[str, Any] = {}
|
values: Dict[str, Any] = {}
|
||||||
dependency_cache = _dependency_cache or {}
|
dependency_cache = {} if _dependency_cache is None else _dependency_cache
|
||||||
|
|
||||||
# solve sub dependencies
|
# solve sub dependencies
|
||||||
sub_dependent: Dependent
|
sub_dependent: Dependent
|
||||||
@ -151,35 +154,37 @@ async def solve_dependencies(
|
|||||||
solved_result = await solve_dependencies(
|
solved_result = await solve_dependencies(
|
||||||
_dependent=use_sub_dependant,
|
_dependent=use_sub_dependant,
|
||||||
_dependency_overrides_provider=_dependency_overrides_provider,
|
_dependency_overrides_provider=_dependency_overrides_provider,
|
||||||
dependency_cache=dependency_cache,
|
_dependency_cache=dependency_cache,
|
||||||
**params)
|
**params)
|
||||||
sub_values, sub_dependency_cache = solved_result
|
sub_values, sub_dependency_cache = solved_result
|
||||||
# update cache?
|
# update cache?
|
||||||
dependency_cache.update(sub_dependency_cache)
|
# dependency_cache.update(sub_dependency_cache)
|
||||||
|
|
||||||
# run dependency function
|
# run dependency function
|
||||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
async with cache_lock:
|
||||||
solved = dependency_cache[sub_dependent.cache_key]
|
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
solved = dependency_cache[sub_dependent.cache_key]
|
||||||
assert isinstance(
|
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||||
_stack, AsyncExitStack
|
assert isinstance(
|
||||||
), "Generator dependency should be called in context"
|
_stack, AsyncExitStack
|
||||||
if is_gen_callable(func):
|
), "Generator dependency should be called in context"
|
||||||
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
if is_gen_callable(func):
|
||||||
|
cm = run_sync_ctx_manager(
|
||||||
|
contextmanager(func)(**sub_values))
|
||||||
|
else:
|
||||||
|
cm = asynccontextmanager(func)(**sub_values)
|
||||||
|
solved = await _stack.enter_async_context(cm)
|
||||||
|
elif is_coroutine_callable(func):
|
||||||
|
solved = await func(**sub_values)
|
||||||
else:
|
else:
|
||||||
cm = asynccontextmanager(func)(**sub_values)
|
solved = await run_sync(func)(**sub_values)
|
||||||
solved = await _stack.enter_async_context(cm)
|
|
||||||
elif is_coroutine_callable(func):
|
|
||||||
solved = await func(**sub_values)
|
|
||||||
else:
|
|
||||||
solved = await run_sync(func)(**sub_values)
|
|
||||||
|
|
||||||
# parameter dependency
|
# parameter dependency
|
||||||
if sub_dependent.name is not None:
|
if sub_dependent.name is not None:
|
||||||
values[sub_dependent.name] = solved
|
values[sub_dependent.name] = solved
|
||||||
# save current dependency to cache
|
# save current dependency to cache
|
||||||
if sub_dependent.cache_key not in dependency_cache:
|
if sub_dependent.cache_key not in dependency_cache:
|
||||||
dependency_cache[sub_dependent.cache_key] = solved
|
dependency_cache[sub_dependent.cache_key] = solved
|
||||||
|
|
||||||
# usual dependency
|
# usual dependency
|
||||||
for field in _dependent.params:
|
for field in _dependent.params:
|
||||||
|
@ -80,7 +80,7 @@ class Handler:
|
|||||||
_dependency_cache: Optional[Dict[Callable[..., Any],
|
_dependency_cache: Optional[Dict[Callable[..., Any],
|
||||||
Any]] = None,
|
Any]] = None,
|
||||||
**params) -> Any:
|
**params) -> Any:
|
||||||
values, cache = await solve_dependencies(
|
values, _ = await solve_dependencies(
|
||||||
_dependent=self.dependent,
|
_dependent=self.dependent,
|
||||||
_stack=_stack,
|
_stack=_stack,
|
||||||
_sub_dependents=[
|
_sub_dependents=[
|
||||||
|
@ -163,7 +163,7 @@ async def _run_matcher(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Running matcher {matcher}")
|
logger.debug(f"Running matcher {matcher}")
|
||||||
await matcher.run(bot, event, state)
|
await matcher.run(bot, event, state, stack, dependency_cache)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.opt(colors=True, exception=e).error(
|
logger.opt(colors=True, exception=e).error(
|
||||||
f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>"
|
f"<r><bg #f8bbd0>Running matcher {matcher} failed.</bg #f8bbd0></r>"
|
||||||
@ -260,7 +260,8 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
|||||||
logger.debug(f"Checking for matchers in priority {priority}...")
|
logger.debug(f"Checking for matchers in priority {priority}...")
|
||||||
|
|
||||||
pending_tasks = [
|
pending_tasks = [
|
||||||
_check_matcher(priority, matcher, bot, event, state.copy())
|
_check_matcher(priority, matcher, bot, event, state.copy(),
|
||||||
|
stack, dependency_cache)
|
||||||
for matcher in matchers[priority]
|
for matcher in matchers[priority]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -42,17 +42,19 @@ class Permission:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
*checkers: T_PermissionChecker,
|
*checkers: Union[T_PermissionChecker, Handler],
|
||||||
dependency_overrides_provider: Optional[Any] = None) -> None:
|
dependency_overrides_provider: Optional[Any] = None) -> None:
|
||||||
"""
|
"""
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``*checkers: T_PermissionChecker``: PermissionChecker
|
* ``*checkers: Union[T_PermissionChecker, Handler]``: PermissionChecker
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.checkers = set(
|
self.checkers = set(
|
||||||
Handler(checker,
|
checker if isinstance(checker, Handler) else Handler(
|
||||||
allow_types=self.HANDLER_PARAM_TYPES,
|
checker,
|
||||||
dependency_overrides_provider=dependency_overrides_provider)
|
allow_types=self.HANDLER_PARAM_TYPES,
|
||||||
|
dependency_overrides_provider=dependency_overrides_provider)
|
||||||
for checker in checkers)
|
for checker in checkers)
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -90,11 +92,11 @@ class Permission:
|
|||||||
if not self.checkers:
|
if not self.checkers:
|
||||||
return True
|
return True
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
checker(bot=bot,
|
*(checker(bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
_stack=stack,
|
_stack=stack,
|
||||||
_dependency_cache=dependency_cache)
|
_dependency_cache=dependency_cache)
|
||||||
for checker in self.checkers)
|
for checker in self.checkers))
|
||||||
return any(results)
|
return any(results)
|
||||||
|
|
||||||
def __and__(self, other) -> NoReturn:
|
def __and__(self, other) -> NoReturn:
|
||||||
@ -111,19 +113,19 @@ class Permission:
|
|||||||
return Permission(*self.checkers, other)
|
return Permission(*self.checkers, other)
|
||||||
|
|
||||||
|
|
||||||
async def _message(bot: Bot, event: Event) -> bool:
|
async def _message(event: Event) -> bool:
|
||||||
return event.get_type() == "message"
|
return event.get_type() == "message"
|
||||||
|
|
||||||
|
|
||||||
async def _notice(bot: Bot, event: Event) -> bool:
|
async def _notice(event: Event) -> bool:
|
||||||
return event.get_type() == "notice"
|
return event.get_type() == "notice"
|
||||||
|
|
||||||
|
|
||||||
async def _request(bot: Bot, event: Event) -> bool:
|
async def _request(event: Event) -> bool:
|
||||||
return event.get_type() == "request"
|
return event.get_type() == "request"
|
||||||
|
|
||||||
|
|
||||||
async def _metaevent(bot: Bot, event: Event) -> bool:
|
async def _metaevent(event: Event) -> bool:
|
||||||
return event.get_type() == "meta_event"
|
return event.get_type() == "meta_event"
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,18 +69,19 @@ class Rule:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
*checkers: T_RuleChecker,
|
*checkers: Union[T_RuleChecker, Handler],
|
||||||
dependency_overrides_provider: Optional[Any] = None) -> None:
|
dependency_overrides_provider: Optional[Any] = None) -> None:
|
||||||
"""
|
"""
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``*checkers: T_RuleChecker``: RuleChecker
|
* ``*checkers: Union[T_RuleChecker, Handler]``: RuleChecker
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.checkers = set(
|
self.checkers = set(
|
||||||
Handler(checker,
|
checker if isinstance(checker, Handler) else Handler(
|
||||||
allow_types=self.HANDLER_PARAM_TYPES,
|
checker,
|
||||||
dependency_overrides_provider=dependency_overrides_provider)
|
allow_types=self.HANDLER_PARAM_TYPES,
|
||||||
|
dependency_overrides_provider=dependency_overrides_provider)
|
||||||
for checker in checkers)
|
for checker in checkers)
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -120,12 +121,12 @@ class Rule:
|
|||||||
if not self.checkers:
|
if not self.checkers:
|
||||||
return True
|
return True
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
checker(bot=bot,
|
*(checker(bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
state=state,
|
state=state,
|
||||||
_stack=stack,
|
_stack=stack,
|
||||||
_dependency_cache=dependency_cache)
|
_dependency_cache=dependency_cache)
|
||||||
for checker in self.checkers)
|
for checker in self.checkers))
|
||||||
return all(results)
|
return all(results)
|
||||||
|
|
||||||
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
||||||
|
@ -2,12 +2,13 @@ import re
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import collections
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import wraps, partial
|
from functools import wraps, partial
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing_extensions import GenericAlias # type: ignore
|
from typing_extensions import GenericAlias # type: ignore
|
||||||
from typing_extensions import ParamSpec, get_args, get_origin
|
from typing_extensions import ParamSpec, get_args, get_origin
|
||||||
from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional,
|
from typing import (Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional,
|
||||||
Awaitable, AsyncGenerator, ContextManager)
|
Awaitable, AsyncGenerator, ContextManager)
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
@ -120,6 +121,79 @@ def get_name(obj: Any) -> str:
|
|||||||
return obj.__class__.__name__
|
return obj.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
class CacheLock:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._waiters: Optional[Deque[asyncio.Future]] = None
|
||||||
|
self._locked = False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
extra = "locked" if self._locked else "unlocked"
|
||||||
|
if self._waiters:
|
||||||
|
extra = f"{extra}, waiters: {len(self._waiters)}"
|
||||||
|
return f"<{self.__class__.__name__} [{extra}]>"
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
await self.acquire()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
def locked(self):
|
||||||
|
return self._locked
|
||||||
|
|
||||||
|
async def acquire(self):
|
||||||
|
if (not self._locked and (self._waiters is None or
|
||||||
|
all(w.cancelled() for w in self._waiters))):
|
||||||
|
self._locked = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._waiters is None:
|
||||||
|
self._waiters = collections.deque()
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future = loop.create_future()
|
||||||
|
self._waiters.append(future)
|
||||||
|
|
||||||
|
# Finally block should be called before the CancelledError
|
||||||
|
# handling as we don't want CancelledError to call
|
||||||
|
# _wake_up_first() and attempt to wake up itself.
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
await future
|
||||||
|
finally:
|
||||||
|
self._waiters.remove(future)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if not self._locked:
|
||||||
|
self._wake_up_first()
|
||||||
|
raise
|
||||||
|
|
||||||
|
self._locked = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
if self._locked:
|
||||||
|
self._locked = False
|
||||||
|
self._wake_up_first()
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Lock is not acquired.")
|
||||||
|
|
||||||
|
def _wake_up_first(self):
|
||||||
|
if not self._waiters:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
future = next(iter(self._waiters))
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
|
||||||
|
# .done() necessarily means that a waiter will wake up later on and
|
||||||
|
# either take the lock, or, if it was cancelled and lock wasn't
|
||||||
|
# taken already, will hit this again and wake up a new waiter.
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(True)
|
||||||
|
|
||||||
|
|
||||||
class DataclassEncoder(json.JSONEncoder):
|
class DataclassEncoder(json.JSONEncoder):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
@ -1,26 +1,21 @@
|
|||||||
from typing import TYPE_CHECKING
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
from nonebot.permission import Permission
|
from nonebot.permission import Permission
|
||||||
|
from .event import GroupMessageEvent, PrivateMessageEvent
|
||||||
from .event import PrivateMessageEvent, GroupMessageEvent
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from nonebot.adapters import Bot, Event
|
|
||||||
|
|
||||||
|
|
||||||
async def _private(bot: "Bot", event: "Event") -> bool:
|
async def _private(event: Event) -> bool:
|
||||||
return isinstance(event, PrivateMessageEvent)
|
return isinstance(event, PrivateMessageEvent)
|
||||||
|
|
||||||
|
|
||||||
async def _private_friend(bot: "Bot", event: "Event") -> bool:
|
async def _private_friend(event: Event) -> bool:
|
||||||
return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend"
|
return isinstance(event, PrivateMessageEvent) and event.sub_type == "friend"
|
||||||
|
|
||||||
|
|
||||||
async def _private_group(bot: "Bot", event: "Event") -> bool:
|
async def _private_group(event: Event) -> bool:
|
||||||
return isinstance(event, PrivateMessageEvent) and event.sub_type == "group"
|
return isinstance(event, PrivateMessageEvent) and event.sub_type == "group"
|
||||||
|
|
||||||
|
|
||||||
async def _private_other(bot: "Bot", event: "Event") -> bool:
|
async def _private_other(event: Event) -> bool:
|
||||||
return isinstance(event, PrivateMessageEvent) and event.sub_type == "other"
|
return isinstance(event, PrivateMessageEvent) and event.sub_type == "other"
|
||||||
|
|
||||||
|
|
||||||
@ -42,20 +37,20 @@ PRIVATE_OTHER = Permission(_private_other)
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def _group(bot: "Bot", event: "Event") -> bool:
|
async def _group(event: Event) -> bool:
|
||||||
return isinstance(event, GroupMessageEvent)
|
return isinstance(event, GroupMessageEvent)
|
||||||
|
|
||||||
|
|
||||||
async def _group_member(bot: "Bot", event: "Event") -> bool:
|
async def _group_member(event: Event) -> bool:
|
||||||
return isinstance(event,
|
return isinstance(event,
|
||||||
GroupMessageEvent) and event.sender.role == "member"
|
GroupMessageEvent) and event.sender.role == "member"
|
||||||
|
|
||||||
|
|
||||||
async def _group_admin(bot: "Bot", event: "Event") -> bool:
|
async def _group_admin(event: Event) -> bool:
|
||||||
return isinstance(event, GroupMessageEvent) and event.sender.role == "admin"
|
return isinstance(event, GroupMessageEvent) and event.sender.role == "admin"
|
||||||
|
|
||||||
|
|
||||||
async def _group_owner(bot: "Bot", event: "Event") -> bool:
|
async def _group_owner(event: Event) -> bool:
|
||||||
return isinstance(event, GroupMessageEvent) and event.sender.role == "owner"
|
return isinstance(event, GroupMessageEvent) and event.sender.role == "owner"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from nonebot import on_command
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.dependencies import Depends
|
from nonebot.dependencies import Depends
|
||||||
|
from nonebot import on_command, on_message
|
||||||
|
|
||||||
test = on_command("123")
|
test = on_command("123")
|
||||||
|
|
||||||
|
|
||||||
def depend(state: dict):
|
def depend(state: dict):
|
||||||
|
print("==== depends running =====")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
@ -13,5 +14,15 @@ def depend(state: dict):
|
|||||||
@test.got("b", prompt="b")
|
@test.got("b", prompt="b")
|
||||||
@test.receive()
|
@test.receive()
|
||||||
@test.got("c", prompt="c")
|
@test.got("c", prompt="c")
|
||||||
async def _(state: dict = Depends(depend)):
|
async def _(x: dict = Depends(depend)):
|
||||||
logger.info(f"=======, {state}")
|
logger.info(f"=======, {x}")
|
||||||
|
|
||||||
|
|
||||||
|
test_cache1 = on_message()
|
||||||
|
test_cache2 = on_message()
|
||||||
|
|
||||||
|
|
||||||
|
@test_cache1.handle()
|
||||||
|
@test_cache2.handle()
|
||||||
|
async def _(x: dict = Depends(depend)):
|
||||||
|
logger.info(f"======= test, {x}")
|
||||||
|
Loading…
Reference in New Issue
Block a user