From 17f3c8fd09a4c7f629efb94c2b1738fd83dd39b1 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Thu, 23 Dec 2021 22:16:55 +0800 Subject: [PATCH] :bug: fix arg message store --- nonebot/consts.py | 3 +-- nonebot/matcher.py | 20 +++++--------------- nonebot/params.py | 24 +++++++++++++----------- nonebot/plugin/on.py | 6 +++--- tests/plugins/param/param_arg.py | 4 ++-- tests/test_param.py | 9 ++++----- 6 files changed, 28 insertions(+), 38 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index 47ee68ab..8e0cf903 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -1,8 +1,7 @@ # used by Matcher RECEIVE_KEY = "_receive_{id}" LAST_RECEIVE_KEY = "_last_receive" -ARG_KEY = "_arg_{key}" -ARG_STR_KEY = "{key}" +ARG_KEY = "{key}" REJECT_TARGET = "_current_target" # used by Rule diff --git a/nonebot/matcher.py b/nonebot/matcher.py index f52f8190..b0e69998 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -28,6 +28,7 @@ from nonebot.rule import Rule from nonebot.log import logger from nonebot.dependencies import Dependent from nonebot.permission import USER, Permission +from nonebot.consts import ARG_KEY, RECEIVE_KEY, REJECT_TARGET, LAST_RECEIVE_KEY from nonebot.adapters import ( Bot, Event, @@ -35,13 +36,6 @@ from nonebot.adapters import ( MessageSegment, MessageTemplate, ) -from nonebot.consts import ( - ARG_KEY, - ARG_STR_KEY, - RECEIVE_KEY, - REJECT_TARGET, - LAST_RECEIVE_KEY, -) from nonebot.exception import ( PausedException, StopPropagation, @@ -483,7 +477,7 @@ class Matcher(metaclass=MatcherMeta): async def _key_getter(event: Event, matcher: "Matcher"): if matcher.get_target() == ARG_KEY.format(key=key): - matcher.set_arg(key, event) + matcher.set_arg(key, event.get_message()) return if matcher.get_arg(key): return @@ -654,15 +648,11 @@ class Matcher(metaclass=MatcherMeta): def get_last_receive(self, default: T = None) -> Union[Event, T]: return self.state.get(LAST_RECEIVE_KEY, default) - def get_arg(self, key: str, default: T = None) -> Union[Event, T]: + def get_arg(self, key: str, default: T = None) -> Union[Message, T]: return self.state.get(ARG_KEY.format(key=key), default) - def get_arg_str(self, key: str, default: T = None) -> Union[str, T]: - return self.state.get(ARG_STR_KEY.format(key=key), default) - - def set_arg(self, key: str, event: Event) -> None: - self.state[ARG_KEY.format(key=key)] = event - self.state[ARG_STR_KEY.format(key=key)] = str(event.get_message()) + def set_arg(self, key: str, message: Message) -> None: + self.state[ARG_KEY.format(key=key)] = message def set_target(self, target: str) -> None: self.state[REJECT_TARGET] = target diff --git a/nonebot/params.py b/nonebot/params.py index c067fb13..1e62d039 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -334,7 +334,7 @@ def LastReceived(default: Any = None) -> Any: class ArgInner: def __init__( - self, key: Optional[str], type: Literal["event", "message", "str"] + self, key: Optional[str], type: Literal["message", "str", "plaintext"] ) -> None: self.key = key self.type = type @@ -344,12 +344,12 @@ def Arg(key: Optional[str] = None) -> Any: return ArgInner(key, "message") -def ArgEvent(key: Optional[str] = None) -> Any: - return ArgInner(key, "event") +def ArgStr(key: Optional[str] = None) -> str: + return ArgInner(key, "str") # type: ignore -def ArgStr(key: Optional[str] = None) -> Any: - return ArgInner(key, "str") +def ArgPlainText(key: Optional[str] = None) -> str: + return ArgInner(key, "plaintext") # type: ignore class ArgParam(Param): @@ -361,13 +361,15 @@ class ArgParam(Param): return cls(Required, key=param.default.key or name, type=param.default.type) async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: - event = matcher.get_arg(self.extra["key"]) - if self.extra["type"] == "event": - return event - elif self.extra["type"] == "message": - return event.get_message() + message = matcher.get_arg(self.extra["key"]) + if message is None: + return message + if self.extra["type"] == "message": + return message + elif self.extra["type"] == "str": + return str(message) else: - return matcher.get_arg_str(self.extra["key"]) + return message.extract_plain_text() class ExceptionParam(Param): diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index d302cfdf..0520488a 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -17,8 +17,6 @@ from nonebot.typing import ( T_PermissionChecker, ) from nonebot.rule import ( - PREFIX_KEY, - RAW_CMD_KEY, Rule, ArgumentParser, regex, @@ -395,7 +393,9 @@ def on_command( """ commands = set([cmd]) | (aliases or set()) - return on_message(command(*commands) & rule, **kwargs, _depth=_depth + 1) + return on_message( + command(*commands) & rule, block=False, **kwargs, _depth=_depth + 1 + ) def on_shell_command( diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index bd3e2bcf..b2c7c4b0 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -1,5 +1,5 @@ from nonebot.adapters import Event, Message -from nonebot.params import Arg, ArgStr, ArgEvent +from nonebot.params import Arg, ArgStr, ArgPlainText async def arg(key: Message = Arg()) -> Message: @@ -10,5 +10,5 @@ async def arg_str(key: str = ArgStr()) -> str: return key -async def arg_event(key: Event = ArgEvent()) -> Event: +async def arg_plain_text(key: str = ArgPlainText()) -> str: return key diff --git a/tests/test_param.py b/tests/test_param.py index 606ea08a..d23f61ff 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -216,12 +216,11 @@ async def test_arg(app: App, load_plugin): from nonebot.matcher import Matcher from nonebot.params import ArgParam - from plugins.param.param_arg import arg, arg_str, arg_event + from plugins.param.param_arg import arg, arg_str, arg_plain_text matcher = Matcher() message = make_fake_message()("text") - event = make_fake_event(_message=message)() - matcher.set_arg("key", event) + matcher.set_arg("key", message) async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) @@ -231,9 +230,9 @@ async def test_arg(app: App, load_plugin): ctx.pass_params(matcher=matcher) ctx.should_return(str(message)) - async with app.test_dependent(arg_event, allow_types=[ArgParam]) as ctx: + async with app.test_dependent(arg_plain_text, allow_types=[ArgParam]) as ctx: ctx.pass_params(matcher=matcher) - ctx.should_return(event) + ctx.should_return(message.extract_plain_text()) @pytest.mark.asyncio