🐛 fix cannot reject preset arg

This commit is contained in:
yanyongyu 2021-12-24 14:09:43 +08:00
parent 17f3c8fd09
commit 6643f951ef
4 changed files with 49 additions and 10 deletions

View File

@ -3,6 +3,7 @@ RECEIVE_KEY = "_receive_{id}"
LAST_RECEIVE_KEY = "_last_receive"
ARG_KEY = "{key}"
REJECT_TARGET = "_current_target"
REJECT_CACHE_TARGET = "_next_target"
# used by Rule
PREFIX_KEY = "_prefix"

View File

@ -28,7 +28,6 @@ from nonebot.rule import Rule
from nonebot.log import logger
from nonebot.dependencies import Dependent
from nonebot.permission import USER, Permission
from nonebot.consts import ARG_KEY, RECEIVE_KEY, REJECT_TARGET, LAST_RECEIVE_KEY
from nonebot.adapters import (
Bot,
Event,
@ -36,6 +35,13 @@ from nonebot.adapters import (
MessageSegment,
MessageTemplate,
)
from nonebot.consts import (
ARG_KEY,
RECEIVE_KEY,
REJECT_TARGET,
LAST_RECEIVE_KEY,
REJECT_CACHE_TARGET,
)
from nonebot.exception import (
PausedException,
StopPropagation,
@ -432,12 +438,12 @@ class Matcher(metaclass=MatcherMeta):
"""
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
matcher.set_target(RECEIVE_KEY.format(id=id))
if matcher.get_target() == RECEIVE_KEY.format(id=id):
matcher.set_receive(id, event)
return
if matcher.get_receive(id):
return
matcher.set_target(RECEIVE_KEY.format(id=id))
raise RejectedException
_parameterless = [params.Depends(_receive), *(parameterless or [])]
@ -476,12 +482,13 @@ class Matcher(metaclass=MatcherMeta):
"""
async def _key_getter(event: Event, matcher: "Matcher"):
print(key, matcher.state)
matcher.set_target(ARG_KEY.format(key=key))
if matcher.get_target() == ARG_KEY.format(key=key):
matcher.set_arg(key, event.get_message())
return
if matcher.get_arg(key):
return
matcher.set_target(ARG_KEY.format(key=key))
if prompt is not None:
await matcher.send(prompt)
raise RejectedException
@ -654,8 +661,11 @@ class Matcher(metaclass=MatcherMeta):
def set_arg(self, key: str, message: Message) -> None:
self.state[ARG_KEY.format(key=key)] = message
def set_target(self, target: str) -> None:
self.state[REJECT_TARGET] = target
def set_target(self, target: str, cache: bool = True) -> None:
if cache:
self.state[REJECT_CACHE_TARGET] = target
else:
self.state[REJECT_TARGET] = target
def get_target(self, default: T = None) -> Union[str, T]:
return self.state.get(REJECT_TARGET, default)
@ -680,6 +690,11 @@ class Matcher(metaclass=MatcherMeta):
return USER(event.get_session_id(), perm=self.permission)
return await updater(bot=bot, event=event, state=self.state, matcher=self)
async def resolve_reject(self):
handler = current_handler.get()
self.handlers.insert(0, handler)
self.state[REJECT_TARGET] = self.state[REJECT_CACHE_TARGET]
async def simple_run(
self,
bot: Bot,
@ -734,9 +749,7 @@ class Matcher(metaclass=MatcherMeta):
await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException:
handler = current_handler.get()
self.handlers.insert(0, handler)
await self.resolve_reject()
type_ = await self.update_type(bot, event)
permission = await self.update_permission(bot, event)

View File

@ -1,6 +1,7 @@
from nonebot import on_message
from nonebot.adapters import Event
from nonebot.params import ArgStr, Received, LastReceived
from nonebot.matcher import Matcher
from nonebot.adapters import Event, Message
from nonebot.params import ArgStr, Received, EventMessage, LastReceived
test_handle = on_message()
@ -54,3 +55,19 @@ async def combine(a: str = ArgStr(), b: str = ArgStr(), r: Event = Received()):
assert a == "text_next"
assert b == "text_next"
assert str(r.get_message()) == "text_next"
test_preset = on_message()
@test_preset.handle()
async def preset(matcher: Matcher, message: Message = EventMessage()):
matcher.set_arg("a", message)
@test_preset.got("a")
async def reject_preset(a: str = ArgStr()):
if a == "text":
await test_preset.reject_arg("a")
assert a == "text_next"

View File

@ -9,6 +9,7 @@ async def test_matcher(app: App, load_plugin):
from plugins.matcher import (
test_got,
test_handle,
test_preset,
test_combine,
test_receive,
)
@ -58,3 +59,10 @@ async def test_matcher(app: App, load_plugin):
ctx.receive_event(bot, event_next)
ctx.should_rejected()
ctx.receive_event(bot, event_next)
assert len(test_preset.handlers) == 2
async with app.test_matcher(test_preset) as ctx:
bot = ctx.create_bot()
ctx.receive_event(bot, event)
ctx.should_rejected()
ctx.receive_event(bot, event_next)