diff --git a/nonebot/consts.py b/nonebot/consts.py index 8e0cf903..11e3feaa 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -3,6 +3,7 @@ RECEIVE_KEY = "_receive_{id}" LAST_RECEIVE_KEY = "_last_receive" ARG_KEY = "{key}" REJECT_TARGET = "_current_target" +REJECT_CACHE_TARGET = "_next_target" # used by Rule PREFIX_KEY = "_prefix" diff --git a/nonebot/matcher.py b/nonebot/matcher.py index b0e69998..d33fed71 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -28,7 +28,6 @@ from nonebot.rule import Rule from nonebot.log import logger from nonebot.dependencies import Dependent from nonebot.permission import USER, Permission -from nonebot.consts import ARG_KEY, RECEIVE_KEY, REJECT_TARGET, LAST_RECEIVE_KEY from nonebot.adapters import ( Bot, Event, @@ -36,6 +35,13 @@ from nonebot.adapters import ( MessageSegment, MessageTemplate, ) +from nonebot.consts import ( + ARG_KEY, + RECEIVE_KEY, + REJECT_TARGET, + LAST_RECEIVE_KEY, + REJECT_CACHE_TARGET, +) from nonebot.exception import ( PausedException, StopPropagation, @@ -432,12 +438,12 @@ class Matcher(metaclass=MatcherMeta): """ async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]: + matcher.set_target(RECEIVE_KEY.format(id=id)) if matcher.get_target() == RECEIVE_KEY.format(id=id): matcher.set_receive(id, event) return if matcher.get_receive(id): return - matcher.set_target(RECEIVE_KEY.format(id=id)) raise RejectedException _parameterless = [params.Depends(_receive), *(parameterless or [])] @@ -476,12 +482,13 @@ class Matcher(metaclass=MatcherMeta): """ async def _key_getter(event: Event, matcher: "Matcher"): + print(key, matcher.state) + matcher.set_target(ARG_KEY.format(key=key)) if matcher.get_target() == ARG_KEY.format(key=key): matcher.set_arg(key, event.get_message()) return if matcher.get_arg(key): return - matcher.set_target(ARG_KEY.format(key=key)) if prompt is not None: await matcher.send(prompt) raise RejectedException @@ -654,8 +661,11 @@ class Matcher(metaclass=MatcherMeta): def set_arg(self, key: str, message: Message) -> None: self.state[ARG_KEY.format(key=key)] = message - def set_target(self, target: str) -> None: - self.state[REJECT_TARGET] = target + def set_target(self, target: str, cache: bool = True) -> None: + if cache: + self.state[REJECT_CACHE_TARGET] = target + else: + self.state[REJECT_TARGET] = target def get_target(self, default: T = None) -> Union[str, T]: return self.state.get(REJECT_TARGET, default) @@ -680,6 +690,11 @@ class Matcher(metaclass=MatcherMeta): return USER(event.get_session_id(), perm=self.permission) return await updater(bot=bot, event=event, state=self.state, matcher=self) + async def resolve_reject(self): + handler = current_handler.get() + self.handlers.insert(0, handler) + self.state[REJECT_TARGET] = self.state[REJECT_CACHE_TARGET] + async def simple_run( self, bot: Bot, @@ -734,9 +749,7 @@ class Matcher(metaclass=MatcherMeta): await self.simple_run(bot, event, state, stack, dependency_cache) except RejectedException: - handler = current_handler.get() - self.handlers.insert(0, handler) - + await self.resolve_reject() type_ = await self.update_type(bot, event) permission = await self.update_permission(bot, event) diff --git a/tests/plugins/matcher.py b/tests/plugins/matcher.py index e9b13c1d..3621b850 100644 --- a/tests/plugins/matcher.py +++ b/tests/plugins/matcher.py @@ -1,6 +1,7 @@ from nonebot import on_message -from nonebot.adapters import Event -from nonebot.params import ArgStr, Received, LastReceived +from nonebot.matcher import Matcher +from nonebot.adapters import Event, Message +from nonebot.params import ArgStr, Received, EventMessage, LastReceived test_handle = on_message() @@ -54,3 +55,19 @@ async def combine(a: str = ArgStr(), b: str = ArgStr(), r: Event = Received()): assert a == "text_next" assert b == "text_next" assert str(r.get_message()) == "text_next" + + +test_preset = on_message() + + +@test_preset.handle() +async def preset(matcher: Matcher, message: Message = EventMessage()): + matcher.set_arg("a", message) + + +@test_preset.got("a") +async def reject_preset(a: str = ArgStr()): + if a == "text": + await test_preset.reject_arg("a") + + assert a == "text_next" diff --git a/tests/test_matcher.py b/tests/test_matcher.py index 3340afb8..403aeab1 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -9,6 +9,7 @@ async def test_matcher(app: App, load_plugin): from plugins.matcher import ( test_got, test_handle, + test_preset, test_combine, test_receive, ) @@ -58,3 +59,10 @@ async def test_matcher(app: App, load_plugin): ctx.receive_event(bot, event_next) ctx.should_rejected() ctx.receive_event(bot, event_next) + + assert len(test_preset.handlers) == 2 + async with app.test_matcher(test_preset) as ctx: + bot = ctx.create_bot() + ctx.receive_event(bot, event) + ctx.should_rejected() + ctx.receive_event(bot, event_next)