improve dependency cache

This commit is contained in:
yanyongyu 2021-12-16 23:22:25 +08:00
parent fe69735ca0
commit 3d762fcbab
15 changed files with 162 additions and 100 deletions

View File

@ -26,7 +26,6 @@ from typing import (
from nonebot import params from nonebot import params
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import CacheDict
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.adapters import ( from nonebot.adapters import (
@ -43,14 +42,6 @@ from nonebot.consts import (
REJECT_TARGET, REJECT_TARGET,
LAST_RECEIVE_KEY, LAST_RECEIVE_KEY,
) )
from nonebot.typing import (
Any,
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_PermissionUpdater,
)
from nonebot.exception import ( from nonebot.exception import (
PausedException, PausedException,
StopPropagation, StopPropagation,
@ -58,6 +49,15 @@ from nonebot.exception import (
FinishedException, FinishedException,
RejectedException, RejectedException,
) )
from nonebot.typing import (
Any,
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_DependencyCache,
T_PermissionUpdater,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.plugin import Plugin from nonebot.plugin import Plugin
@ -296,7 +296,7 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot, bot: Bot,
event: Event, event: Event,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
) -> bool: ) -> bool:
""" """
:说明: :说明:
@ -324,7 +324,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
) -> bool: ) -> bool:
""" """
:说明: :说明:
@ -669,7 +669,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
): ):
b_t = current_bot.set(bot) b_t = current_bot.set(bot)
e_t = current_event.set(event) e_t = current_event.set(event)
@ -711,7 +711,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
): ):
try: try:
await self.simple_run(bot, event, state, stack, dependency_cache) await self.simple_run(bot, event, state, stack, dependency_cache)

View File

@ -22,9 +22,9 @@ from typing import (
from nonebot import params from nonebot import params
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import TrieRule from nonebot.rule import TrieRule
from nonebot.utils import escape_tag
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.matcher import Matcher, matchers from nonebot.matcher import Matcher, matchers
from nonebot.utils import CacheDict, escape_tag
from nonebot.exception import ( from nonebot.exception import (
NoLogException, NoLogException,
StopPropagation, StopPropagation,
@ -34,6 +34,7 @@ from nonebot.exception import (
from nonebot.typing import ( from nonebot.typing import (
T_State, T_State,
T_Handler, T_Handler,
T_DependencyCache,
T_RunPreProcessor, T_RunPreProcessor,
T_RunPostProcessor, T_RunPostProcessor,
T_EventPreProcessor, T_EventPreProcessor,
@ -136,7 +137,7 @@ async def _check_matcher(
event: "Event", event: "Event",
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
) -> None: ) -> None:
if Matcher.expire_time and datetime.now() > Matcher.expire_time: if Matcher.expire_time and datetime.now() > Matcher.expire_time:
try: try:
@ -171,7 +172,7 @@ async def _run_matcher(
event: "Event", event: "Event",
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
) -> None: ) -> None:
logger.info(f"Event will be handled by {Matcher}") logger.info(f"Event will be handled by {Matcher}")
@ -275,7 +276,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.opt(colors=True).success(log_msg) logger.opt(colors=True).success(log_msg)
state: Dict[Any, Any] = {} state: Dict[Any, Any] = {}
dependency_cache: CacheDict[T_Handler, Any] = CacheDict() dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
coros = list( coros = list(

View File

@ -1,12 +1,13 @@
import asyncio
import inspect import inspect
from typing import Any, Dict, List, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined from pydantic.fields import Required, Undefined
from nonebot.typing import T_State, T_Handler
from nonebot.adapters import Bot, Event, Message from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent from nonebot.dependencies import Param, Dependent
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
PREFIX_KEY, PREFIX_KEY,
@ -19,7 +20,6 @@ from nonebot.consts import (
REGEX_MATCHED, REGEX_MATCHED,
) )
from nonebot.utils import ( from nonebot.utils import (
CacheDict,
get_name, get_name,
run_sync, run_sync,
is_gen_callable, is_gen_callable,
@ -49,7 +49,7 @@ class DependsInner:
def Depends( def Depends(
dependency: Optional[T_Handler] = None, dependency: Optional[T_Handler] = None,
*, *,
use_cache: bool = False, use_cache: bool = True,
) -> Any: ) -> Any:
""" """
:说明: :说明:
@ -114,11 +114,11 @@ class DependParam(Param):
async def _solve( async def _solve(
self, self,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
use_cache: bool = self.extra["use_cache"] use_cache: bool = self.extra["use_cache"]
dependency_cache = CacheDict() if dependency_cache is None else dependency_cache dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"] sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
@ -132,26 +132,28 @@ class DependParam(Param):
) )
# run dependency function # run dependency function
async with dependency_cache: task: asyncio.Task[Any]
if use_cache and call in dependency_cache: if use_cache and call in dependency_cache:
solved = dependency_cache[call] solved = await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call): elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance( assert isinstance(
stack, AsyncExitStack stack, AsyncExitStack
), "Generator dependency should be called in context" ), "Generator dependency should be called in context"
if is_gen_callable(call): if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
solved = await stack.enter_async_context(cm)
elif is_coroutine_callable(call):
return await call(**sub_values)
else: else:
return await run_sync(call)(**sub_values) cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
# save current dependency to cache dependency_cache[call] = task
if call not in dependency_cache: solved = await task
dependency_cache[call] = solved elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
solved = await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
solved = await task
return solved return solved
@ -243,7 +245,7 @@ def _command(state=State()) -> Message:
def Command() -> Tuple[str, ...]: def Command() -> Tuple[str, ...]:
return Depends(_command) return Depends(_command, use_cache=False)
def _raw_command(state=State()) -> Message: def _raw_command(state=State()) -> Message:
@ -251,7 +253,7 @@ def _raw_command(state=State()) -> Message:
def RawCommand() -> str: def RawCommand() -> str:
return Depends(_raw_command) return Depends(_raw_command, use_cache=False)
def _command_arg(state=State()) -> Message: def _command_arg(state=State()) -> Message:
@ -259,7 +261,7 @@ def _command_arg(state=State()) -> Message:
def CommandArg() -> Message: def CommandArg() -> Message:
return Depends(_command_arg) return Depends(_command_arg, use_cache=False)
def _shell_command_args(state=State()) -> Any: def _shell_command_args(state=State()) -> Any:
@ -267,7 +269,7 @@ def _shell_command_args(state=State()) -> Any:
def ShellCommandArgs(): def ShellCommandArgs():
return Depends(_shell_command_args) return Depends(_shell_command_args, use_cache=False)
def _shell_command_argv(state=State()) -> List[str]: def _shell_command_argv(state=State()) -> List[str]:
@ -275,7 +277,7 @@ def _shell_command_argv(state=State()) -> List[str]:
def ShellCommandArgv() -> Any: def ShellCommandArgv() -> Any:
return Depends(_shell_command_argv) return Depends(_shell_command_argv, use_cache=False)
def _regex_matched(state=State()) -> str: def _regex_matched(state=State()) -> str:
@ -283,7 +285,7 @@ def _regex_matched(state=State()) -> str:
def RegexMatched() -> str: def RegexMatched() -> str:
return Depends(_regex_matched) return Depends(_regex_matched, use_cache=False)
def _regex_group(state=State()): def _regex_group(state=State()):
@ -291,7 +293,7 @@ def _regex_group(state=State()):
def RegexGroup() -> Tuple[Any, ...]: def RegexGroup() -> Tuple[Any, ...]:
return Depends(_regex_group) return Depends(_regex_group, use_cache=False)
def _regex_dict(state=State()): def _regex_dict(state=State()):
@ -299,7 +301,7 @@ def _regex_dict(state=State()):
def RegexDict() -> Dict[str, Any]: def RegexDict() -> Dict[str, Any]:
return Depends(_regex_dict) return Depends(_regex_dict, use_cache=False)
class MatcherParam(Param): class MatcherParam(Param):
@ -320,14 +322,14 @@ def Received(id: str, default: Any = None) -> Any:
def _received(matcher: "Matcher"): def _received(matcher: "Matcher"):
return matcher.get_receive(id, default) return matcher.get_receive(id, default)
return Depends(_received) return Depends(_received, use_cache=False)
def LastReceived(default: Any = None) -> Any: def LastReceived(default: Any = None) -> Any:
def _last_received(matcher: "Matcher") -> Any: def _last_received(matcher: "Matcher") -> Any:
return matcher.get_receive(None, default) return matcher.get_receive(None, default)
return Depends(_last_received) return Depends(_last_received, use_cache=False)
class ExceptionParam(Param): class ExceptionParam(Param):

View File

@ -24,11 +24,10 @@ from typing import (
) )
from nonebot import params from nonebot import params
from nonebot.utils import CacheDict
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from nonebot.typing import T_Handler, T_PermissionChecker from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]): async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
@ -93,7 +92,7 @@ class Permission:
bot: Bot, bot: Bot,
event: Event, event: Event,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
) -> bool: ) -> bool:
""" """
:说明: :说明:

View File

@ -22,12 +22,11 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
from pygtrie import CharTrie from pygtrie import CharTrie
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import CacheDict
from nonebot import params, get_driver from nonebot import params, get_driver
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import ParserExit, SkippedException from nonebot.exception import ParserExit, SkippedException
from nonebot.typing import T_State, T_Handler, T_RuleChecker
from nonebot.adapters import Bot, Event, Message, MessageSegment from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
PREFIX_KEY, PREFIX_KEY,
@ -105,7 +104,7 @@ class Rule:
event: Event, event: Event,
state: T_State, state: T_State,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None, dependency_cache: Optional[T_DependencyCache] = None,
) -> bool: ) -> bool:
""" """
:说明: :说明:

View File

@ -17,7 +17,7 @@
.. _typing: .. _typing:
https://docs.python.org/3/library/typing.html https://docs.python.org/3/library/typing.html
""" """
from asyncio import Task
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -32,7 +32,6 @@ from typing import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.utils import CacheDict
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
@ -250,3 +249,9 @@ T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]]
PermissionUpdater Matcher.pause, Matcher.reject 时被运行用于更新会话对象权限默认会更新为当前事件的触发对象 PermissionUpdater Matcher.pause, Matcher.reject 时被运行用于更新会话对象权限默认会更新为当前事件的触发对象
""" """
T_DependencyCache = Dict[Callable[..., Any], Task[Any]]
"""
:类型: ``Dict[Callable[..., Any], Task[Any]]``
:说明:
依赖缓存, 用于存储依赖函数的返回值
"""

View File

@ -135,33 +135,6 @@ def get_name(obj: Any) -> str:
return obj.__class__.__name__ return obj.__class__.__name__
class CacheDict(Dict[K, V], Generic[K, V]):
def __init__(self, *args, **kwargs):
super(CacheDict, self).__init__(*args, **kwargs)
self._lock = asyncio.Lock()
@property
def locked(self):
return self._lock.locked()
def __repr__(self):
extra = "locked" if self.locked else "unlocked"
return f"<{self.__class__.__name__} [{extra}]>"
async def __aenter__(self) -> None:
await self.acquire()
return None
async def __aexit__(self, exc_type, exc, tb):
self.release()
async def acquire(self):
return await self._lock.acquire()
def release(self):
self._lock.release()
class DataclassEncoder(json.JSONEncoder): class DataclassEncoder(json.JSONEncoder):
""" """
:说明: :说明:

2
poetry.lock generated
View File

@ -543,7 +543,7 @@ pytest-order = "^1.0.0"
type = "git" type = "git"
url = "https://github.com/nonebot/nonebug.git" url = "https://github.com/nonebot/nonebug.git"
reference = "master" reference = "master"
resolved_reference = "4584d5a4bc95cd1bafcec08599ab7d72815e268e" resolved_reference = "9c4f21373701ac25bc152cbad5f5527edc5e4c19"
[[package]] [[package]]
name = "packaging" name = "packaging"

View File

@ -1,6 +1,8 @@
[report] [report]
exclude_lines = exclude_lines =
def __repr__
pragma: no cover pragma: no cover
if TYPE_CHECKING: if TYPE_CHECKING:
@(abc\.)?abstractmethod @(abc\.)?abstractmethod
raise NotImplementedError raise NotImplementedError
if __name__ == .__main__.:

8
tests/.isort.cfg Normal file
View File

@ -0,0 +1,8 @@
[settings]
profile=black
line_length=80
length_sort=true
skip_gitignore=true
force_sort_within_sections=true
known_local_folder=plugins
extra_standard_library=typing_extensions

View File

@ -2,22 +2,23 @@ from nonebot import on_message
from nonebot.adapters import Event from nonebot.adapters import Event
from nonebot.params import Depends from nonebot.params import Depends
test = on_message() test_depends = on_message()
test2 = on_message()
runned = False runned = []
def dependency(event: Event): def dependency(event: Event):
# test cache # test cache
global runned runned.append(event)
assert not runned
runned = True
return event return event
@test.handle() @test_depends.handle()
@test2.handle() async def depends(x: Event = Depends(dependency)):
async def handle(x: Event = Depends(dependency, use_cache=True)):
# test dependency # test dependency
return x return x
@test_depends.handle()
async def depends_cache(y: Event = Depends(dependency, use_cache=True)):
return y

View File

@ -1,2 +0,0 @@
[tool.isort]
known_local_folder = ["plugins"]

View File

@ -1,6 +1,5 @@
import os import os
import sys import sys
from re import A
from typing import TYPE_CHECKING, Set from typing import TYPE_CHECKING, Set
import pytest import pytest

29
tests/test_param.py Normal file
View File

@ -0,0 +1,29 @@
import pytest
from nonebug import App
from utils import load_plugin, make_fake_event
@pytest.mark.asyncio
async def test_depends(app: App, load_plugin):
from nonebot.params import EventParam, DependParam
from plugins.depends import runned, depends, test_depends
async with app.test_dependent(
depends, allow_types=[EventParam, DependParam]
) as ctx:
event = make_fake_event()()
ctx.pass_params(event=event)
ctx.should_return(event)
assert len(runned) == 1 and runned[0] == event
runned.clear()
async with app.test_matcher(test_depends) as ctx:
bot = ctx.create_bot()
event_next = make_fake_event()()
ctx.receive_event(bot, event_next)
assert len(runned) == 1 and runned[0] == event_next

View File

@ -1,10 +1,56 @@
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Set from typing import TYPE_CHECKING, Set, Type, Optional
import pytest import pytest
from pydantic import create_model
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.plugin import Plugin from nonebot.plugin import Plugin
from nonebot.adapters import Event, Message
def make_fake_event(
_type: str = "message",
_name: str = "test",
_description: str = "test",
_user_id: str = "test",
_session_id: str = "test",
_message: Optional["Message"] = None,
_to_me: bool = True,
**fields,
) -> Type["Event"]:
from nonebot.adapters import Event
_Fake = create_model("_Fake", __base__=Event, **fields)
class FakeEvent(_Fake):
def get_type(self) -> str:
return _type
def get_event_name(self) -> str:
return _name
def get_event_description(self) -> str:
return _description
def get_user_id(self) -> str:
return _user_id
def get_session_id(self) -> str:
return _session_id
def get_message(self) -> "Message":
if _message is not None:
return _message
raise NotImplementedError
def is_tome(self) -> bool:
return _to_me
class Config:
extra = "forbid"
return FakeEvent
@pytest.fixture @pytest.fixture