mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🐛 fix cannot reject preset arg
This commit is contained in:
parent
17f3c8fd09
commit
6643f951ef
@ -3,6 +3,7 @@ RECEIVE_KEY = "_receive_{id}"
|
||||
LAST_RECEIVE_KEY = "_last_receive"
|
||||
ARG_KEY = "{key}"
|
||||
REJECT_TARGET = "_current_target"
|
||||
REJECT_CACHE_TARGET = "_next_target"
|
||||
|
||||
# used by Rule
|
||||
PREFIX_KEY = "_prefix"
|
||||
|
@ -28,7 +28,6 @@ from nonebot.rule import Rule
|
||||
from nonebot.log import logger
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.permission import USER, Permission
|
||||
from nonebot.consts import ARG_KEY, RECEIVE_KEY, REJECT_TARGET, LAST_RECEIVE_KEY
|
||||
from nonebot.adapters import (
|
||||
Bot,
|
||||
Event,
|
||||
@ -36,6 +35,13 @@ from nonebot.adapters import (
|
||||
MessageSegment,
|
||||
MessageTemplate,
|
||||
)
|
||||
from nonebot.consts import (
|
||||
ARG_KEY,
|
||||
RECEIVE_KEY,
|
||||
REJECT_TARGET,
|
||||
LAST_RECEIVE_KEY,
|
||||
REJECT_CACHE_TARGET,
|
||||
)
|
||||
from nonebot.exception import (
|
||||
PausedException,
|
||||
StopPropagation,
|
||||
@ -432,12 +438,12 @@ class Matcher(metaclass=MatcherMeta):
|
||||
"""
|
||||
|
||||
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
|
||||
matcher.set_target(RECEIVE_KEY.format(id=id))
|
||||
if matcher.get_target() == RECEIVE_KEY.format(id=id):
|
||||
matcher.set_receive(id, event)
|
||||
return
|
||||
if matcher.get_receive(id):
|
||||
return
|
||||
matcher.set_target(RECEIVE_KEY.format(id=id))
|
||||
raise RejectedException
|
||||
|
||||
_parameterless = [params.Depends(_receive), *(parameterless or [])]
|
||||
@ -476,12 +482,13 @@ class Matcher(metaclass=MatcherMeta):
|
||||
"""
|
||||
|
||||
async def _key_getter(event: Event, matcher: "Matcher"):
|
||||
print(key, matcher.state)
|
||||
matcher.set_target(ARG_KEY.format(key=key))
|
||||
if matcher.get_target() == ARG_KEY.format(key=key):
|
||||
matcher.set_arg(key, event.get_message())
|
||||
return
|
||||
if matcher.get_arg(key):
|
||||
return
|
||||
matcher.set_target(ARG_KEY.format(key=key))
|
||||
if prompt is not None:
|
||||
await matcher.send(prompt)
|
||||
raise RejectedException
|
||||
@ -654,8 +661,11 @@ class Matcher(metaclass=MatcherMeta):
|
||||
def set_arg(self, key: str, message: Message) -> None:
|
||||
self.state[ARG_KEY.format(key=key)] = message
|
||||
|
||||
def set_target(self, target: str) -> None:
|
||||
self.state[REJECT_TARGET] = target
|
||||
def set_target(self, target: str, cache: bool = True) -> None:
|
||||
if cache:
|
||||
self.state[REJECT_CACHE_TARGET] = target
|
||||
else:
|
||||
self.state[REJECT_TARGET] = target
|
||||
|
||||
def get_target(self, default: T = None) -> Union[str, T]:
|
||||
return self.state.get(REJECT_TARGET, default)
|
||||
@ -680,6 +690,11 @@ class Matcher(metaclass=MatcherMeta):
|
||||
return USER(event.get_session_id(), perm=self.permission)
|
||||
return await updater(bot=bot, event=event, state=self.state, matcher=self)
|
||||
|
||||
async def resolve_reject(self):
|
||||
handler = current_handler.get()
|
||||
self.handlers.insert(0, handler)
|
||||
self.state[REJECT_TARGET] = self.state[REJECT_CACHE_TARGET]
|
||||
|
||||
async def simple_run(
|
||||
self,
|
||||
bot: Bot,
|
||||
@ -734,9 +749,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
await self.simple_run(bot, event, state, stack, dependency_cache)
|
||||
|
||||
except RejectedException:
|
||||
handler = current_handler.get()
|
||||
self.handlers.insert(0, handler)
|
||||
|
||||
await self.resolve_reject()
|
||||
type_ = await self.update_type(bot, event)
|
||||
permission = await self.update_permission(bot, event)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from nonebot import on_message
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.params import ArgStr, Received, LastReceived
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.adapters import Event, Message
|
||||
from nonebot.params import ArgStr, Received, EventMessage, LastReceived
|
||||
|
||||
test_handle = on_message()
|
||||
|
||||
@ -54,3 +55,19 @@ async def combine(a: str = ArgStr(), b: str = ArgStr(), r: Event = Received()):
|
||||
assert a == "text_next"
|
||||
assert b == "text_next"
|
||||
assert str(r.get_message()) == "text_next"
|
||||
|
||||
|
||||
test_preset = on_message()
|
||||
|
||||
|
||||
@test_preset.handle()
|
||||
async def preset(matcher: Matcher, message: Message = EventMessage()):
|
||||
matcher.set_arg("a", message)
|
||||
|
||||
|
||||
@test_preset.got("a")
|
||||
async def reject_preset(a: str = ArgStr()):
|
||||
if a == "text":
|
||||
await test_preset.reject_arg("a")
|
||||
|
||||
assert a == "text_next"
|
||||
|
@ -9,6 +9,7 @@ async def test_matcher(app: App, load_plugin):
|
||||
from plugins.matcher import (
|
||||
test_got,
|
||||
test_handle,
|
||||
test_preset,
|
||||
test_combine,
|
||||
test_receive,
|
||||
)
|
||||
@ -58,3 +59,10 @@ async def test_matcher(app: App, load_plugin):
|
||||
ctx.receive_event(bot, event_next)
|
||||
ctx.should_rejected()
|
||||
ctx.receive_event(bot, event_next)
|
||||
|
||||
assert len(test_preset.handlers) == 2
|
||||
async with app.test_matcher(test_preset) as ctx:
|
||||
bot = ctx.create_bot()
|
||||
ctx.receive_event(bot, event)
|
||||
ctx.should_rejected()
|
||||
ctx.receive_event(bot, event_next)
|
||||
|
Loading…
Reference in New Issue
Block a user