mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
⚡ improve dependency cache
This commit is contained in:
parent
fe69735ca0
commit
3d762fcbab
@ -26,7 +26,6 @@ from typing import (
|
||||
from nonebot import params
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import CacheDict
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.permission import USER, Permission
|
||||
from nonebot.adapters import (
|
||||
@ -43,14 +42,6 @@ from nonebot.consts import (
|
||||
REJECT_TARGET,
|
||||
LAST_RECEIVE_KEY,
|
||||
)
|
||||
from nonebot.typing import (
|
||||
Any,
|
||||
T_State,
|
||||
T_Handler,
|
||||
T_ArgsParser,
|
||||
T_TypeUpdater,
|
||||
T_PermissionUpdater,
|
||||
)
|
||||
from nonebot.exception import (
|
||||
PausedException,
|
||||
StopPropagation,
|
||||
@ -58,6 +49,15 @@ from nonebot.exception import (
|
||||
FinishedException,
|
||||
RejectedException,
|
||||
)
|
||||
from nonebot.typing import (
|
||||
Any,
|
||||
T_State,
|
||||
T_Handler,
|
||||
T_ArgsParser,
|
||||
T_TypeUpdater,
|
||||
T_DependencyCache,
|
||||
T_PermissionUpdater,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.plugin import Plugin
|
||||
@ -296,7 +296,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
@ -324,7 +324,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
@ -669,7 +669,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
):
|
||||
b_t = current_bot.set(bot)
|
||||
e_t = current_event.set(event)
|
||||
@ -711,7 +711,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
):
|
||||
try:
|
||||
await self.simple_run(bot, event, state, stack, dependency_cache)
|
||||
|
@ -22,9 +22,9 @@ from typing import (
|
||||
from nonebot import params
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import TrieRule
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.matcher import Matcher, matchers
|
||||
from nonebot.utils import CacheDict, escape_tag
|
||||
from nonebot.exception import (
|
||||
NoLogException,
|
||||
StopPropagation,
|
||||
@ -34,6 +34,7 @@ from nonebot.exception import (
|
||||
from nonebot.typing import (
|
||||
T_State,
|
||||
T_Handler,
|
||||
T_DependencyCache,
|
||||
T_RunPreProcessor,
|
||||
T_RunPostProcessor,
|
||||
T_EventPreProcessor,
|
||||
@ -136,7 +137,7 @@ async def _check_matcher(
|
||||
event: "Event",
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
) -> None:
|
||||
if Matcher.expire_time and datetime.now() > Matcher.expire_time:
|
||||
try:
|
||||
@ -171,7 +172,7 @@ async def _run_matcher(
|
||||
event: "Event",
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
) -> None:
|
||||
logger.info(f"Event will be handled by {Matcher}")
|
||||
|
||||
@ -275,7 +276,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
||||
logger.opt(colors=True).success(log_msg)
|
||||
|
||||
state: Dict[Any, Any] = {}
|
||||
dependency_cache: CacheDict[T_Handler, Any] = CacheDict()
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
coros = list(
|
||||
|
@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional, cast
|
||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||
|
||||
from pydantic.fields import Required, Undefined
|
||||
|
||||
from nonebot.typing import T_State, T_Handler
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
from nonebot.dependencies import Param, Dependent
|
||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
||||
from nonebot.consts import (
|
||||
CMD_KEY,
|
||||
PREFIX_KEY,
|
||||
@ -19,7 +20,6 @@ from nonebot.consts import (
|
||||
REGEX_MATCHED,
|
||||
)
|
||||
from nonebot.utils import (
|
||||
CacheDict,
|
||||
get_name,
|
||||
run_sync,
|
||||
is_gen_callable,
|
||||
@ -49,7 +49,7 @@ class DependsInner:
|
||||
def Depends(
|
||||
dependency: Optional[T_Handler] = None,
|
||||
*,
|
||||
use_cache: bool = False,
|
||||
use_cache: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
@ -114,11 +114,11 @@ class DependParam(Param):
|
||||
async def _solve(
|
||||
self,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
use_cache: bool = self.extra["use_cache"]
|
||||
dependency_cache = CacheDict() if dependency_cache is None else dependency_cache
|
||||
dependency_cache = {} if dependency_cache is None else dependency_cache
|
||||
|
||||
sub_dependent: Dependent = self.extra["dependent"]
|
||||
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
||||
@ -132,26 +132,28 @@ class DependParam(Param):
|
||||
)
|
||||
|
||||
# run dependency function
|
||||
async with dependency_cache:
|
||||
if use_cache and call in dependency_cache:
|
||||
solved = dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
solved = await stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(call):
|
||||
return await call(**sub_values)
|
||||
task: asyncio.Task[Any]
|
||||
if use_cache and call in dependency_cache:
|
||||
solved = await dependency_cache[call]
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
return await run_sync(call)(**sub_values)
|
||||
|
||||
# save current dependency to cache
|
||||
if call not in dependency_cache:
|
||||
dependency_cache[call] = solved
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
elif is_coroutine_callable(call):
|
||||
task = asyncio.create_task(call(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
else:
|
||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
||||
dependency_cache[call] = task
|
||||
solved = await task
|
||||
|
||||
return solved
|
||||
|
||||
@ -243,7 +245,7 @@ def _command(state=State()) -> Message:
|
||||
|
||||
|
||||
def Command() -> Tuple[str, ...]:
|
||||
return Depends(_command)
|
||||
return Depends(_command, use_cache=False)
|
||||
|
||||
|
||||
def _raw_command(state=State()) -> Message:
|
||||
@ -251,7 +253,7 @@ def _raw_command(state=State()) -> Message:
|
||||
|
||||
|
||||
def RawCommand() -> str:
|
||||
return Depends(_raw_command)
|
||||
return Depends(_raw_command, use_cache=False)
|
||||
|
||||
|
||||
def _command_arg(state=State()) -> Message:
|
||||
@ -259,7 +261,7 @@ def _command_arg(state=State()) -> Message:
|
||||
|
||||
|
||||
def CommandArg() -> Message:
|
||||
return Depends(_command_arg)
|
||||
return Depends(_command_arg, use_cache=False)
|
||||
|
||||
|
||||
def _shell_command_args(state=State()) -> Any:
|
||||
@ -267,7 +269,7 @@ def _shell_command_args(state=State()) -> Any:
|
||||
|
||||
|
||||
def ShellCommandArgs():
|
||||
return Depends(_shell_command_args)
|
||||
return Depends(_shell_command_args, use_cache=False)
|
||||
|
||||
|
||||
def _shell_command_argv(state=State()) -> List[str]:
|
||||
@ -275,7 +277,7 @@ def _shell_command_argv(state=State()) -> List[str]:
|
||||
|
||||
|
||||
def ShellCommandArgv() -> Any:
|
||||
return Depends(_shell_command_argv)
|
||||
return Depends(_shell_command_argv, use_cache=False)
|
||||
|
||||
|
||||
def _regex_matched(state=State()) -> str:
|
||||
@ -283,7 +285,7 @@ def _regex_matched(state=State()) -> str:
|
||||
|
||||
|
||||
def RegexMatched() -> str:
|
||||
return Depends(_regex_matched)
|
||||
return Depends(_regex_matched, use_cache=False)
|
||||
|
||||
|
||||
def _regex_group(state=State()):
|
||||
@ -291,7 +293,7 @@ def _regex_group(state=State()):
|
||||
|
||||
|
||||
def RegexGroup() -> Tuple[Any, ...]:
|
||||
return Depends(_regex_group)
|
||||
return Depends(_regex_group, use_cache=False)
|
||||
|
||||
|
||||
def _regex_dict(state=State()):
|
||||
@ -299,7 +301,7 @@ def _regex_dict(state=State()):
|
||||
|
||||
|
||||
def RegexDict() -> Dict[str, Any]:
|
||||
return Depends(_regex_dict)
|
||||
return Depends(_regex_dict, use_cache=False)
|
||||
|
||||
|
||||
class MatcherParam(Param):
|
||||
@ -320,14 +322,14 @@ def Received(id: str, default: Any = None) -> Any:
|
||||
def _received(matcher: "Matcher"):
|
||||
return matcher.get_receive(id, default)
|
||||
|
||||
return Depends(_received)
|
||||
return Depends(_received, use_cache=False)
|
||||
|
||||
|
||||
def LastReceived(default: Any = None) -> Any:
|
||||
def _last_received(matcher: "Matcher") -> Any:
|
||||
return matcher.get_receive(None, default)
|
||||
|
||||
return Depends(_last_received)
|
||||
return Depends(_last_received, use_cache=False)
|
||||
|
||||
|
||||
class ExceptionParam(Param):
|
||||
|
@ -24,11 +24,10 @@ from typing import (
|
||||
)
|
||||
|
||||
from nonebot import params
|
||||
from nonebot.utils import CacheDict
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.typing import T_Handler, T_PermissionChecker
|
||||
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
|
||||
|
||||
|
||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
||||
@ -93,7 +92,7 @@ class Permission:
|
||||
bot: Bot,
|
||||
event: Event,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
|
@ -22,12 +22,11 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
|
||||
from pygtrie import CharTrie
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import CacheDict
|
||||
from nonebot import params, get_driver
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import ParserExit, SkippedException
|
||||
from nonebot.typing import T_State, T_Handler, T_RuleChecker
|
||||
from nonebot.adapters import Bot, Event, Message, MessageSegment
|
||||
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache
|
||||
from nonebot.consts import (
|
||||
CMD_KEY,
|
||||
PREFIX_KEY,
|
||||
@ -105,7 +104,7 @@ class Rule:
|
||||
event: Event,
|
||||
state: T_State,
|
||||
stack: Optional[AsyncExitStack] = None,
|
||||
dependency_cache: Optional[CacheDict[T_Handler, Any]] = None,
|
||||
dependency_cache: Optional[T_DependencyCache] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
:说明:
|
||||
|
@ -17,7 +17,7 @@
|
||||
.. _typing:
|
||||
https://docs.python.org/3/library/typing.html
|
||||
"""
|
||||
|
||||
from asyncio import Task
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -32,7 +32,6 @@ from typing import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.utils import CacheDict
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.permission import Permission
|
||||
|
||||
@ -250,3 +249,9 @@ T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]]
|
||||
|
||||
PermissionUpdater 在 Matcher.pause, Matcher.reject 时被运行,用于更新会话对象权限。默认会更新为当前事件的触发对象。
|
||||
"""
|
||||
T_DependencyCache = Dict[Callable[..., Any], Task[Any]]
|
||||
"""
|
||||
:类型: ``Dict[Callable[..., Any], Task[Any]]``
|
||||
:说明:
|
||||
依赖缓存, 用于存储依赖函数的返回值
|
||||
"""
|
||||
|
@ -135,33 +135,6 @@ def get_name(obj: Any) -> str:
|
||||
return obj.__class__.__name__
|
||||
|
||||
|
||||
class CacheDict(Dict[K, V], Generic[K, V]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CacheDict, self).__init__(*args, **kwargs)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
return self._lock.locked()
|
||||
|
||||
def __repr__(self):
|
||||
extra = "locked" if self.locked else "unlocked"
|
||||
return f"<{self.__class__.__name__} [{extra}]>"
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire()
|
||||
return None
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
self.release()
|
||||
|
||||
async def acquire(self):
|
||||
return await self._lock.acquire()
|
||||
|
||||
def release(self):
|
||||
self._lock.release()
|
||||
|
||||
|
||||
class DataclassEncoder(json.JSONEncoder):
|
||||
"""
|
||||
:说明:
|
||||
|
2
poetry.lock
generated
2
poetry.lock
generated
@ -543,7 +543,7 @@ pytest-order = "^1.0.0"
|
||||
type = "git"
|
||||
url = "https://github.com/nonebot/nonebug.git"
|
||||
reference = "master"
|
||||
resolved_reference = "4584d5a4bc95cd1bafcec08599ab7d72815e268e"
|
||||
resolved_reference = "9c4f21373701ac25bc152cbad5f5527edc5e4c19"
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
|
@ -1,6 +1,8 @@
|
||||
[report]
|
||||
exclude_lines =
|
||||
def __repr__
|
||||
pragma: no cover
|
||||
if TYPE_CHECKING:
|
||||
@(abc\.)?abstractmethod
|
||||
raise NotImplementedError
|
||||
if __name__ == .__main__.:
|
||||
|
8
tests/.isort.cfg
Normal file
8
tests/.isort.cfg
Normal file
@ -0,0 +1,8 @@
|
||||
[settings]
|
||||
profile=black
|
||||
line_length=80
|
||||
length_sort=true
|
||||
skip_gitignore=true
|
||||
force_sort_within_sections=true
|
||||
known_local_folder=plugins
|
||||
extra_standard_library=typing_extensions
|
@ -2,22 +2,23 @@ from nonebot import on_message
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.params import Depends
|
||||
|
||||
test = on_message()
|
||||
test2 = on_message()
|
||||
test_depends = on_message()
|
||||
|
||||
runned = False
|
||||
runned = []
|
||||
|
||||
|
||||
def dependency(event: Event):
|
||||
# test cache
|
||||
global runned
|
||||
assert not runned
|
||||
runned = True
|
||||
runned.append(event)
|
||||
return event
|
||||
|
||||
|
||||
@test.handle()
|
||||
@test2.handle()
|
||||
async def handle(x: Event = Depends(dependency, use_cache=True)):
|
||||
@test_depends.handle()
|
||||
async def depends(x: Event = Depends(dependency)):
|
||||
# test dependency
|
||||
return x
|
||||
|
||||
|
||||
@test_depends.handle()
|
||||
async def depends_cache(y: Event = Depends(dependency, use_cache=True)):
|
||||
return y
|
||||
|
@ -1,2 +0,0 @@
|
||||
[tool.isort]
|
||||
known_local_folder = ["plugins"]
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
from re import A
|
||||
from typing import TYPE_CHECKING, Set
|
||||
|
||||
import pytest
|
||||
|
29
tests/test_param.py
Normal file
29
tests/test_param.py
Normal file
@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
from nonebug import App
|
||||
|
||||
from utils import load_plugin, make_fake_event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_depends(app: App, load_plugin):
|
||||
from nonebot.params import EventParam, DependParam
|
||||
|
||||
from plugins.depends import runned, depends, test_depends
|
||||
|
||||
async with app.test_dependent(
|
||||
depends, allow_types=[EventParam, DependParam]
|
||||
) as ctx:
|
||||
event = make_fake_event()()
|
||||
ctx.pass_params(event=event)
|
||||
ctx.should_return(event)
|
||||
|
||||
assert len(runned) == 1 and runned[0] == event
|
||||
|
||||
runned.clear()
|
||||
|
||||
async with app.test_matcher(test_depends) as ctx:
|
||||
bot = ctx.create_bot()
|
||||
event_next = make_fake_event()()
|
||||
ctx.receive_event(bot, event_next)
|
||||
|
||||
assert len(runned) == 1 and runned[0] == event_next
|
@ -1,10 +1,56 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Set
|
||||
from typing import TYPE_CHECKING, Set, Type, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.plugin import Plugin
|
||||
from nonebot.adapters import Event, Message
|
||||
|
||||
|
||||
def make_fake_event(
|
||||
_type: str = "message",
|
||||
_name: str = "test",
|
||||
_description: str = "test",
|
||||
_user_id: str = "test",
|
||||
_session_id: str = "test",
|
||||
_message: Optional["Message"] = None,
|
||||
_to_me: bool = True,
|
||||
**fields,
|
||||
) -> Type["Event"]:
|
||||
from nonebot.adapters import Event
|
||||
|
||||
_Fake = create_model("_Fake", __base__=Event, **fields)
|
||||
|
||||
class FakeEvent(_Fake):
|
||||
def get_type(self) -> str:
|
||||
return _type
|
||||
|
||||
def get_event_name(self) -> str:
|
||||
return _name
|
||||
|
||||
def get_event_description(self) -> str:
|
||||
return _description
|
||||
|
||||
def get_user_id(self) -> str:
|
||||
return _user_id
|
||||
|
||||
def get_session_id(self) -> str:
|
||||
return _session_id
|
||||
|
||||
def get_message(self) -> "Message":
|
||||
if _message is not None:
|
||||
return _message
|
||||
raise NotImplementedError
|
||||
|
||||
def is_tome(self) -> bool:
|
||||
return _to_me
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
return FakeEvent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
Reference in New Issue
Block a user