improve dependency cache

This commit is contained in:
yanyongyu 2021-12-16 23:22:25 +08:00
parent fe69735ca0
commit 3d762fcbab
15 changed files with 162 additions and 100 deletions

View File

@ -26,7 +26,6 @@ from typing import (
from nonebot import params
from nonebot.rule import Rule
from nonebot.log import logger
from nonebot.utils import CacheDict
from nonebot.dependencies import Dependent
from nonebot.permission import USER, Permission
from nonebot.adapters import (
@ -43,14 +42,6 @@ from nonebot.consts import (
REJECT_TARGET,
LAST_RECEIVE_KEY,
)
from nonebot.typing import (
Any,
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_PermissionUpdater,
)
from nonebot.exception import (
PausedException,
StopPropagation,
@ -58,6 +49,15 @@ from nonebot.exception import (
FinishedException,
RejectedException,
)
from nonebot.typing import (
Any,
T_State,
T_Handler,
T_ArgsParser,
T_TypeUpdater,
T_DependencyCache,
T_PermissionUpdater,
)
if TYPE_CHECKING:
from nonebot.plugin import Plugin
@ -296,7 +296,7 @@ class Matcher(metaclass=MatcherMeta):
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""
:说明:
@ -324,7 +324,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""
:说明:
@ -669,7 +669,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
@ -711,7 +711,7 @@ class Matcher(metaclass=MatcherMeta):
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
try:
await self.simple_run(bot, event, state, stack, dependency_cache)

View File

@ -22,9 +22,9 @@ from typing import (
from nonebot import params
from nonebot.log import logger
from nonebot.rule import TrieRule
from nonebot.utils import escape_tag
from nonebot.dependencies import Dependent
from nonebot.matcher import Matcher, matchers
from nonebot.utils import CacheDict, escape_tag
from nonebot.exception import (
NoLogException,
StopPropagation,
@ -34,6 +34,7 @@ from nonebot.exception import (
from nonebot.typing import (
T_State,
T_Handler,
T_DependencyCache,
T_RunPreProcessor,
T_RunPostProcessor,
T_EventPreProcessor,
@ -136,7 +137,7 @@ async def _check_matcher(
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> None:
if Matcher.expire_time and datetime.now() > Matcher.expire_time:
try:
@ -171,7 +172,7 @@ async def _run_matcher(
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> None:
logger.info(f"Event will be handled by {Matcher}")
@ -275,7 +276,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
logger.opt(colors=True).success(log_msg)
state: Dict[Any, Any] = {}
dependency_cache: CacheDict[T_Handler, Any] = CacheDict()
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
coros = list(

View File

@ -1,12 +1,13 @@
import asyncio
import inspect
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined
from nonebot.typing import T_State, T_Handler
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
@ -19,7 +20,6 @@ from nonebot.consts import (
REGEX_MATCHED,
)
from nonebot.utils import (
CacheDict,
get_name,
run_sync,
is_gen_callable,
@ -49,7 +49,7 @@ class DependsInner:
def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = False,
use_cache: bool = True,
) -> Any:
"""
:说明:
@ -114,11 +114,11 @@ class DependParam(Param):
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
**kwargs: Any,
) -> Any:
use_cache: bool = self.extra["use_cache"]
dependency_cache = CacheDict() if dependency_cache is None else dependency_cache
dependency_cache = {} if dependency_cache is None else dependency_cache
sub_dependent: Dependent = self.extra["dependent"]
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
@ -132,26 +132,28 @@ class DependParam(Param):
)
# run dependency function
async with dependency_cache:
if use_cache and call in dependency_cache:
solved = dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
solved = await stack.enter_async_context(cm)
elif is_coroutine_callable(call):
return await call(**sub_values)
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
solved = await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
return await run_sync(call)(**sub_values)
# save current dependency to cache
if call not in dependency_cache:
dependency_cache[call] = solved
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
solved = await task
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
solved = await task
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
solved = await task
return solved
@ -243,7 +245,7 @@ def _command(state=State()) -> Message:
def Command() -> Tuple[str, ...]:
return Depends(_command)
return Depends(_command, use_cache=False)
def _raw_command(state=State()) -> Message:
@ -251,7 +253,7 @@ def _raw_command(state=State()) -> Message:
def RawCommand() -> str:
return Depends(_raw_command)
return Depends(_raw_command, use_cache=False)
def _command_arg(state=State()) -> Message:
@ -259,7 +261,7 @@ def _command_arg(state=State()) -> Message:
def CommandArg() -> Message:
return Depends(_command_arg)
return Depends(_command_arg, use_cache=False)
def _shell_command_args(state=State()) -> Any:
@ -267,7 +269,7 @@ def _shell_command_args(state=State()) -> Any:
def ShellCommandArgs():
return Depends(_shell_command_args)
return Depends(_shell_command_args, use_cache=False)
def _shell_command_argv(state=State()) -> List[str]:
@ -275,7 +277,7 @@ def _shell_command_argv(state=State()) -> List[str]:
def ShellCommandArgv() -> Any:
return Depends(_shell_command_argv)
return Depends(_shell_command_argv, use_cache=False)
def _regex_matched(state=State()) -> str:
@ -283,7 +285,7 @@ def _regex_matched(state=State()) -> str:
def RegexMatched() -> str:
return Depends(_regex_matched)
return Depends(_regex_matched, use_cache=False)
def _regex_group(state=State()):
@ -291,7 +293,7 @@ def _regex_group(state=State()):
def RegexGroup() -> Tuple[Any, ...]:
return Depends(_regex_group)
return Depends(_regex_group, use_cache=False)
def _regex_dict(state=State()):
@ -299,7 +301,7 @@ def _regex_dict(state=State()):
def RegexDict() -> Dict[str, Any]:
return Depends(_regex_dict)
return Depends(_regex_dict, use_cache=False)
class MatcherParam(Param):
@ -320,14 +322,14 @@ def Received(id: str, default: Any = None) -> Any:
def _received(matcher: "Matcher"):
return matcher.get_receive(id, default)
return Depends(_received)
return Depends(_received, use_cache=False)
def LastReceived(default: Any = None) -> Any:
def _last_received(matcher: "Matcher") -> Any:
return matcher.get_receive(None, default)
return Depends(_last_received)
return Depends(_last_received, use_cache=False)
class ExceptionParam(Param):

View File

@ -24,11 +24,10 @@ from typing import (
)
from nonebot import params
from nonebot.utils import CacheDict
from nonebot.adapters import Bot, Event
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_Handler, T_PermissionChecker
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
@ -93,7 +92,7 @@ class Permission:
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""
:说明:

View File

@ -22,12 +22,11 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
from pygtrie import CharTrie
from nonebot.log import logger
from nonebot.utils import CacheDict
from nonebot import params, get_driver
from nonebot.dependencies import Dependent
from nonebot.exception import ParserExit, SkippedException
from nonebot.typing import T_State, T_Handler, T_RuleChecker
from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
@ -105,7 +104,7 @@ class Rule:
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""
:说明:

View File

@ -17,7 +17,7 @@
.. _typing:
https://docs.python.org/3/library/typing.html
"""
from asyncio import Task
from typing import (
TYPE_CHECKING,
Any,
@ -32,7 +32,6 @@ from typing import (
)
if TYPE_CHECKING:
from nonebot.utils import CacheDict
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission
@ -250,3 +249,9 @@ T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]]
PermissionUpdater Matcher.pause, Matcher.reject 时被运行用于更新会话对象权限默认会更新为当前事件的触发对象
"""
T_DependencyCache = Dict[Callable[..., Any], Task[Any]]
"""
:类型: ``Dict[Callable[..., Any], Task[Any]]``
:说明:
依赖缓存, 用于存储依赖函数的返回值
"""

View File

@ -135,33 +135,6 @@ def get_name(obj: Any) -> str:
return obj.__class__.__name__
class CacheDict(Dict[K, V], Generic[K, V]):
def __init__(self, *args, **kwargs):
super(CacheDict, self).__init__(*args, **kwargs)
self._lock = asyncio.Lock()
@property
def locked(self):
return self._lock.locked()
def __repr__(self):
extra = "locked" if self.locked else "unlocked"
return f"<{self.__class__.__name__} [{extra}]>"
async def __aenter__(self) -> None:
await self.acquire()
return None
async def __aexit__(self, exc_type, exc, tb):
self.release()
async def acquire(self):
return await self._lock.acquire()
def release(self):
self._lock.release()
class DataclassEncoder(json.JSONEncoder):
"""
:说明:

2
poetry.lock generated
View File

@ -543,7 +543,7 @@ pytest-order = "^1.0.0"
type = "git"
url = "https://github.com/nonebot/nonebug.git"
reference = "master"
resolved_reference = "4584d5a4bc95cd1bafcec08599ab7d72815e268e"
resolved_reference = "9c4f21373701ac25bc152cbad5f5527edc5e4c19"
[[package]]
name = "packaging"

View File

@ -1,6 +1,8 @@
[report]
exclude_lines =
def __repr__
pragma: no cover
if TYPE_CHECKING:
@(abc\.)?abstractmethod
raise NotImplementedError
if __name__ == .__main__.:

8
tests/.isort.cfg Normal file
View File

@ -0,0 +1,8 @@
[settings]
profile=black
line_length=80
length_sort=true
skip_gitignore=true
force_sort_within_sections=true
known_local_folder=plugins
extra_standard_library=typing_extensions

View File

@ -2,22 +2,23 @@ from nonebot import on_message
from nonebot.adapters import Event
from nonebot.params import Depends
test = on_message()
test2 = on_message()
test_depends = on_message()
runned = False
runned = []
def dependency(event: Event):
# test cache
global runned
assert not runned
runned = True
runned.append(event)
return event
@test.handle()
@test2.handle()
async def handle(x: Event = Depends(dependency, use_cache=True)):
@test_depends.handle()
async def depends(x: Event = Depends(dependency)):
# test dependency
return x
@test_depends.handle()
async def depends_cache(y: Event = Depends(dependency, use_cache=True)):
return y

View File

@ -1,2 +0,0 @@
[tool.isort]
known_local_folder = ["plugins"]

View File

@ -1,6 +1,5 @@
import os
import sys
from re import A
from typing import TYPE_CHECKING, Set
import pytest

29
tests/test_param.py Normal file
View File

@ -0,0 +1,29 @@
import pytest
from nonebug import App
from utils import load_plugin, make_fake_event
@pytest.mark.asyncio
async def test_depends(app: App, load_plugin):
from nonebot.params import EventParam, DependParam
from plugins.depends import runned, depends, test_depends
async with app.test_dependent(
depends, allow_types=[EventParam, DependParam]
) as ctx:
event = make_fake_event()()
ctx.pass_params(event=event)
ctx.should_return(event)
assert len(runned) == 1 and runned[0] == event
runned.clear()
async with app.test_matcher(test_depends) as ctx:
bot = ctx.create_bot()
event_next = make_fake_event()()
ctx.receive_event(bot, event_next)
assert len(runned) == 1 and runned[0] == event_next

View File

@ -1,10 +1,56 @@
from pathlib import Path
from typing import TYPE_CHECKING, Set
from typing import TYPE_CHECKING, Set, Type, Optional
import pytest
from pydantic import create_model
if TYPE_CHECKING:
from nonebot.plugin import Plugin
from nonebot.adapters import Event, Message
def make_fake_event(
_type: str = "message",
_name: str = "test",
_description: str = "test",
_user_id: str = "test",
_session_id: str = "test",
_message: Optional["Message"] = None,
_to_me: bool = True,
**fields,
) -> Type["Event"]:
from nonebot.adapters import Event
_Fake = create_model("_Fake", __base__=Event, **fields)
class FakeEvent(_Fake):
def get_type(self) -> str:
return _type
def get_event_name(self) -> str:
return _name
def get_event_description(self) -> str:
return _description
def get_user_id(self) -> str:
return _user_id
def get_session_id(self) -> str:
return _session_id
def get_message(self) -> "Message":
if _message is not None:
return _message
raise NotImplementedError
def is_tome(self) -> bool:
return _to_me
class Config:
extra = "forbid"
return FakeEvent
@pytest.fixture