diff --git a/nonebot/matcher.py b/nonebot/matcher.py index a90ac63b..f52f8190 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -71,7 +71,7 @@ matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) """ current_bot: ContextVar[Bot] = ContextVar("current_bot") current_event: ContextVar[Event] = ContextVar("current_event") -current_state: ContextVar[T_State] = ContextVar("current_state") +current_matcher: ContextVar["Matcher"] = ContextVar("current_matcher") current_handler: ContextVar[Dependent] = ContextVar("current_handler") @@ -424,7 +424,7 @@ class Matcher(metaclass=MatcherMeta): @classmethod def receive( - cls, id: Optional[str] = None, parameterless: Optional[List[Any]] = None + cls, id: str = "", parameterless: Optional[List[Any]] = None ) -> Callable[[T_Handler], T_Handler]: """ :说明: @@ -433,18 +433,17 @@ class Matcher(metaclass=MatcherMeta): :参数: + * ``id: str``: 消息 ID * ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表 """ - _id = id or "" - async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]: - if matcher.get_receive(_id): + if matcher.get_target() == RECEIVE_KEY.format(id=id): + matcher.set_receive(id, event) return - if matcher.get_target() == RECEIVE_KEY.format(id=_id): - matcher.set_receive(_id, event) + if matcher.get_receive(id): return - matcher.set_target(RECEIVE_KEY.format(id=_id)) + matcher.set_target(RECEIVE_KEY.format(id=id)) raise RejectedException _parameterless = [params.Depends(_receive), *(parameterless or [])] @@ -483,11 +482,11 @@ class Matcher(metaclass=MatcherMeta): """ async def _key_getter(event: Event, matcher: "Matcher"): - if matcher.get_arg(key): - return if matcher.get_target() == ARG_KEY.format(key=key): matcher.set_arg(key, event) return + if matcher.get_arg(key): + return matcher.set_target(ARG_KEY.format(key=key)) if prompt is not None: await matcher.send(prompt) @@ -529,7 +528,7 @@ class Matcher(metaclass=MatcherMeta): """ bot = current_bot.get() event = current_event.get() - state = current_state.get() + state = current_matcher.get().state if isinstance(message, MessageTemplate): _message = message.format(**state) else: @@ -583,7 +582,8 @@ class Matcher(metaclass=MatcherMeta): """ :说明: - 发送一条消息给当前交互用户并暂停事件响应器,在接收用户新的一条消息后重新运行当前处理函数 + 最近使用 ``got`` / ``receive`` 接收的消息不符合预期,发送一条消息给当前交互用户并暂停事件响应器, + 在接收用户新的一条消息后继续当前处理函数 :参数: @@ -594,6 +594,56 @@ class Matcher(metaclass=MatcherMeta): await cls.send(prompt, **kwargs) raise RejectedException + @classmethod + async def reject_arg( + cls, + key: str, + prompt: Optional[Union[str, Message, MessageSegment]] = None, + **kwargs, + ) -> NoReturn: + """ + :说明: + + 最近使用 ``got`` 接收的消息不符合预期,发送一条消息给当前交互用户并暂停事件响应器, + 在接收用户新的一条消息后继续当前处理函数 + + :参数: + + * ``key: str``: 参数名 + * ``prompt: Union[str, Message, MessageSegment]``: 消息内容 + * ``**kwargs``: 其他传递给 ``bot.send`` 的参数,请参考对应 adapter 的 bot 对象 api + """ + matcher = current_matcher.get() + matcher.set_target(ARG_KEY.format(key=key)) + if prompt is not None: + await cls.send(prompt, **kwargs) + raise RejectedException + + @classmethod + async def reject_receive( + cls, + id: str = "", + prompt: Optional[Union[str, Message, MessageSegment]] = None, + **kwargs, + ) -> NoReturn: + """ + :说明: + + 最近使用 ``got`` 接收的消息不符合预期,发送一条消息给当前交互用户并暂停事件响应器, + 在接收用户新的一条消息后继续当前处理函数 + + :参数: + + * ``id: str``: 消息 id + * ``prompt: Union[str, Message, MessageSegment]``: 消息内容 + * ``**kwargs``: 其他传递给 ``bot.send`` 的参数,请参考对应 adapter 的 bot 对象 api + """ + matcher = current_matcher.get() + matcher.set_target(RECEIVE_KEY.format(id=id)) + if prompt is not None: + await cls.send(prompt, **kwargs) + raise RejectedException + def get_receive(self, id: str, default: T = None) -> Union[Event, T]: return self.state.get(RECEIVE_KEY.format(id=id), default) @@ -650,7 +700,7 @@ class Matcher(metaclass=MatcherMeta): ): b_t = current_bot.set(bot) e_t = current_event.set(event) - s_t = current_state.set(self.state) + m_t = current_matcher.set(self) try: # Refresh preprocess state self.state.update(state) @@ -679,7 +729,7 @@ class Matcher(metaclass=MatcherMeta): logger.info(f"Matcher {self} running complete") current_bot.reset(b_t) current_event.reset(e_t) - current_state.reset(s_t) + current_matcher.reset(m_t) # 运行handlers async def run( diff --git a/poetry.lock b/poetry.lock index 5a8b1d57..32507fb9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -95,14 +95,14 @@ tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"] [[package]] name = "async-timeout" -version = "4.0.1" +version = "4.0.2" description = "Timeout context manager for asyncio programs" category = "main" optional = true python-versions = ">=3.6" [package.dependencies] -typing-extensions = ">=3.6.5" +typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""} [[package]] name = "asynctest" @@ -543,7 +543,7 @@ pytest-order = "^1.0.0" type = "git" url = "https://github.com/nonebot/nonebug.git" reference = "master" -resolved_reference = "5cb87c36aac56da3c71dd8c5aaa9b16396afc528" +resolved_reference = "0a1132e9dc1803517ded0d485bfbe8c47a1d8585" [[package]] name = "packaging" @@ -1282,8 +1282,8 @@ asgiref = [ {file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"}, ] async-timeout = [ - {file = "async-timeout-4.0.1.tar.gz", hash = "sha256:b930cb161a39042f9222f6efb7301399c87eeab394727ec5437924a36d6eef51"}, - {file = "async_timeout-4.0.1-py3-none-any.whl", hash = "sha256:a22c0b311af23337eb05fcf05a8b51c3ea53729d46fb5460af62bee033cec690"}, + {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, + {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, ] asynctest = [ {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"}, diff --git a/tests/plugins/matcher.py b/tests/plugins/matcher.py index a15c423a..e9b13c1d 100644 --- a/tests/plugins/matcher.py +++ b/tests/plugins/matcher.py @@ -16,9 +16,11 @@ test_got = on_message() @test_got.got("key1", "prompt key1") @test_got.got("key2", "prompt key2") async def got(key1: str = ArgStr(), key2: str = ArgStr()): + if key2 == "text": + await test_got.reject("reject", at_sender=True) + assert key1 == "text" - assert key2 == "text" - await test_got.reject("reject", at_sender=True) + assert key2 == "text_next" test_receive = on_message() @@ -33,3 +35,22 @@ async def receive( assert str(z.get_message()) == "text" assert x is y await test_receive.pause("pause", at_sender=True) + + +test_combine = on_message() + + +@test_combine.got("a") +@test_combine.receive() +@test_combine.got("b") +async def combine(a: str = ArgStr(), b: str = ArgStr(), r: Event = Received()): + if a == "text": + await test_combine.reject_arg("a") + elif b == "text": + await test_combine.reject_arg("b") + elif str(r.get_message()) == "text": + await test_combine.reject_receive() + + assert a == "text_next" + assert b == "text_next" + assert str(r.get_message()) == "text_next" diff --git a/tests/test_matcher.py b/tests/test_matcher.py index df2488c0..7df74850 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -8,10 +8,17 @@ from utils import load_plugin, make_fake_event, make_fake_message @pytest.mark.asyncio async def test_matcher(app: App, load_plugin): - from plugins.matcher import test_got, test_handle, test_receive + from plugins.matcher import ( + test_got, + test_handle, + test_combine, + test_receive, + ) message = make_fake_message()("text") event = make_fake_event(_message=message)() + message_next = make_fake_message()("text_next") + event_next = make_fake_event(_message=message_next)() assert len(test_handle.handlers) == 1 async with app.test_matcher(test_handle) as ctx: @@ -30,6 +37,7 @@ async def test_matcher(app: App, load_plugin): ctx.receive_event(bot, event) ctx.should_call_send(event, "reject", "result3", at_sender=True) ctx.should_rejected() + ctx.receive_event(bot, event_next) assert len(test_receive.handlers) == 1 async with app.test_matcher(test_receive) as ctx: @@ -39,3 +47,16 @@ async def test_matcher(app: App, load_plugin): ctx.receive_event(bot, event) ctx.should_call_send(event, "pause", "result", at_sender=True) ctx.should_paused() + + assert len(test_receive.handlers) == 1 + async with app.test_matcher(test_combine) as ctx: + bot = ctx.create_bot() + ctx.receive_event(bot, event) + ctx.receive_event(bot, event) + ctx.receive_event(bot, event) + ctx.should_rejected() + ctx.receive_event(bot, event_next) + ctx.should_rejected() + ctx.receive_event(bot, event_next) + ctx.should_rejected() + ctx.receive_event(bot, event_next)