🐛 fix arg message store

This commit is contained in:
yanyongyu 2021-12-23 22:16:55 +08:00
parent 76104d3237
commit 17f3c8fd09
6 changed files with 28 additions and 38 deletions

View File

@ -1,8 +1,7 @@
# used by Matcher # used by Matcher
RECEIVE_KEY = "_receive_{id}" RECEIVE_KEY = "_receive_{id}"
LAST_RECEIVE_KEY = "_last_receive" LAST_RECEIVE_KEY = "_last_receive"
ARG_KEY = "_arg_{key}" ARG_KEY = "{key}"
ARG_STR_KEY = "{key}"
REJECT_TARGET = "_current_target" REJECT_TARGET = "_current_target"
# used by Rule # used by Rule

View File

@ -28,6 +28,7 @@ from nonebot.rule import Rule
from nonebot.log import logger from nonebot.log import logger
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.consts import ARG_KEY, RECEIVE_KEY, REJECT_TARGET, LAST_RECEIVE_KEY
from nonebot.adapters import ( from nonebot.adapters import (
Bot, Bot,
Event, Event,
@ -35,13 +36,6 @@ from nonebot.adapters import (
MessageSegment, MessageSegment,
MessageTemplate, MessageTemplate,
) )
from nonebot.consts import (
ARG_KEY,
ARG_STR_KEY,
RECEIVE_KEY,
REJECT_TARGET,
LAST_RECEIVE_KEY,
)
from nonebot.exception import ( from nonebot.exception import (
PausedException, PausedException,
StopPropagation, StopPropagation,
@ -483,7 +477,7 @@ class Matcher(metaclass=MatcherMeta):
async def _key_getter(event: Event, matcher: "Matcher"): async def _key_getter(event: Event, matcher: "Matcher"):
if matcher.get_target() == ARG_KEY.format(key=key): if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event) matcher.set_arg(key, event.get_message())
return return
if matcher.get_arg(key): if matcher.get_arg(key):
return return
@ -654,15 +648,11 @@ class Matcher(metaclass=MatcherMeta):
def get_last_receive(self, default: T = None) -> Union[Event, T]: def get_last_receive(self, default: T = None) -> Union[Event, T]:
return self.state.get(LAST_RECEIVE_KEY, default) return self.state.get(LAST_RECEIVE_KEY, default)
def get_arg(self, key: str, default: T = None) -> Union[Event, T]: def get_arg(self, key: str, default: T = None) -> Union[Message, T]:
return self.state.get(ARG_KEY.format(key=key), default) return self.state.get(ARG_KEY.format(key=key), default)
def get_arg_str(self, key: str, default: T = None) -> Union[str, T]: def set_arg(self, key: str, message: Message) -> None:
return self.state.get(ARG_STR_KEY.format(key=key), default) self.state[ARG_KEY.format(key=key)] = message
def set_arg(self, key: str, event: Event) -> None:
self.state[ARG_KEY.format(key=key)] = event
self.state[ARG_STR_KEY.format(key=key)] = str(event.get_message())
def set_target(self, target: str) -> None: def set_target(self, target: str) -> None:
self.state[REJECT_TARGET] = target self.state[REJECT_TARGET] = target

View File

@ -334,7 +334,7 @@ def LastReceived(default: Any = None) -> Any:
class ArgInner: class ArgInner:
def __init__( def __init__(
self, key: Optional[str], type: Literal["event", "message", "str"] self, key: Optional[str], type: Literal["message", "str", "plaintext"]
) -> None: ) -> None:
self.key = key self.key = key
self.type = type self.type = type
@ -344,12 +344,12 @@ def Arg(key: Optional[str] = None) -> Any:
return ArgInner(key, "message") return ArgInner(key, "message")
def ArgEvent(key: Optional[str] = None) -> Any: def ArgStr(key: Optional[str] = None) -> str:
return ArgInner(key, "event") return ArgInner(key, "str") # type: ignore
def ArgStr(key: Optional[str] = None) -> Any: def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "str") return ArgInner(key, "plaintext") # type: ignore
class ArgParam(Param): class ArgParam(Param):
@ -361,13 +361,15 @@ class ArgParam(Param):
return cls(Required, key=param.default.key or name, type=param.default.type) return cls(Required, key=param.default.key or name, type=param.default.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
event = matcher.get_arg(self.extra["key"]) message = matcher.get_arg(self.extra["key"])
if self.extra["type"] == "event": if message is None:
return event return message
elif self.extra["type"] == "message": if self.extra["type"] == "message":
return event.get_message() return message
elif self.extra["type"] == "str":
return str(message)
else: else:
return matcher.get_arg_str(self.extra["key"]) return message.extract_plain_text()
class ExceptionParam(Param): class ExceptionParam(Param):

View File

@ -17,8 +17,6 @@ from nonebot.typing import (
T_PermissionChecker, T_PermissionChecker,
) )
from nonebot.rule import ( from nonebot.rule import (
PREFIX_KEY,
RAW_CMD_KEY,
Rule, Rule,
ArgumentParser, ArgumentParser,
regex, regex,
@ -395,7 +393,9 @@ def on_command(
""" """
commands = set([cmd]) | (aliases or set()) commands = set([cmd]) | (aliases or set())
return on_message(command(*commands) & rule, **kwargs, _depth=_depth + 1) return on_message(
command(*commands) & rule, block=False, **kwargs, _depth=_depth + 1
)
def on_shell_command( def on_shell_command(

View File

@ -1,5 +1,5 @@
from nonebot.adapters import Event, Message from nonebot.adapters import Event, Message
from nonebot.params import Arg, ArgStr, ArgEvent from nonebot.params import Arg, ArgStr, ArgPlainText
async def arg(key: Message = Arg()) -> Message: async def arg(key: Message = Arg()) -> Message:
@ -10,5 +10,5 @@ async def arg_str(key: str = ArgStr()) -> str:
return key return key
async def arg_event(key: Event = ArgEvent()) -> Event: async def arg_plain_text(key: str = ArgPlainText()) -> str:
return key return key

View File

@ -216,12 +216,11 @@ async def test_arg(app: App, load_plugin):
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import ArgParam from nonebot.params import ArgParam
from plugins.param.param_arg import arg, arg_str, arg_event from plugins.param.param_arg import arg, arg_str, arg_plain_text
matcher = Matcher() matcher = Matcher()
message = make_fake_message()("text") message = make_fake_message()("text")
event = make_fake_event(_message=message)() matcher.set_arg("key", message)
matcher.set_arg("key", event)
async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
@ -231,9 +230,9 @@ async def test_arg(app: App, load_plugin):
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
ctx.should_return(str(message)) ctx.should_return(str(message))
async with app.test_dependent(arg_event, allow_types=[ArgParam]) as ctx: async with app.test_dependent(arg_plain_text, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
ctx.should_return(event) ctx.should_return(message.extract_plain_text())
@pytest.mark.asyncio @pytest.mark.asyncio