mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 04:26:24 +08:00
⚗️ add more reject case
This commit is contained in:
parent
e9b8515cf1
commit
cf8670c167
@ -71,7 +71,7 @@ matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
|
|||||||
"""
|
"""
|
||||||
current_bot: ContextVar[Bot] = ContextVar("current_bot")
|
current_bot: ContextVar[Bot] = ContextVar("current_bot")
|
||||||
current_event: ContextVar[Event] = ContextVar("current_event")
|
current_event: ContextVar[Event] = ContextVar("current_event")
|
||||||
current_state: ContextVar[T_State] = ContextVar("current_state")
|
current_matcher: ContextVar["Matcher"] = ContextVar("current_matcher")
|
||||||
current_handler: ContextVar[Dependent] = ContextVar("current_handler")
|
current_handler: ContextVar[Dependent] = ContextVar("current_handler")
|
||||||
|
|
||||||
|
|
||||||
@ -424,7 +424,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def receive(
|
def receive(
|
||||||
cls, id: Optional[str] = None, parameterless: Optional[List[Any]] = None
|
cls, id: str = "", parameterless: Optional[List[Any]] = None
|
||||||
) -> Callable[[T_Handler], T_Handler]:
|
) -> Callable[[T_Handler], T_Handler]:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -433,18 +433,17 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
|
* ``id: str``: 消息 ID
|
||||||
* ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表
|
* ``parameterless: Optional[List[Any]]``: 非参数类型依赖列表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_id = id or ""
|
|
||||||
|
|
||||||
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
|
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
|
||||||
if matcher.get_receive(_id):
|
if matcher.get_target() == RECEIVE_KEY.format(id=id):
|
||||||
|
matcher.set_receive(id, event)
|
||||||
return
|
return
|
||||||
if matcher.get_target() == RECEIVE_KEY.format(id=_id):
|
if matcher.get_receive(id):
|
||||||
matcher.set_receive(_id, event)
|
|
||||||
return
|
return
|
||||||
matcher.set_target(RECEIVE_KEY.format(id=_id))
|
matcher.set_target(RECEIVE_KEY.format(id=id))
|
||||||
raise RejectedException
|
raise RejectedException
|
||||||
|
|
||||||
_parameterless = [params.Depends(_receive), *(parameterless or [])]
|
_parameterless = [params.Depends(_receive), *(parameterless or [])]
|
||||||
@ -483,11 +482,11 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def _key_getter(event: Event, matcher: "Matcher"):
|
async def _key_getter(event: Event, matcher: "Matcher"):
|
||||||
if matcher.get_arg(key):
|
|
||||||
return
|
|
||||||
if matcher.get_target() == ARG_KEY.format(key=key):
|
if matcher.get_target() == ARG_KEY.format(key=key):
|
||||||
matcher.set_arg(key, event)
|
matcher.set_arg(key, event)
|
||||||
return
|
return
|
||||||
|
if matcher.get_arg(key):
|
||||||
|
return
|
||||||
matcher.set_target(ARG_KEY.format(key=key))
|
matcher.set_target(ARG_KEY.format(key=key))
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
await matcher.send(prompt)
|
await matcher.send(prompt)
|
||||||
@ -529,7 +528,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
"""
|
"""
|
||||||
bot = current_bot.get()
|
bot = current_bot.get()
|
||||||
event = current_event.get()
|
event = current_event.get()
|
||||||
state = current_state.get()
|
state = current_matcher.get().state
|
||||||
if isinstance(message, MessageTemplate):
|
if isinstance(message, MessageTemplate):
|
||||||
_message = message.format(**state)
|
_message = message.format(**state)
|
||||||
else:
|
else:
|
||||||
@ -583,7 +582,8 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
发送一条消息给当前交互用户并暂停事件响应器,在接收用户新的一条消息后重新运行当前处理函数
|
最近使用 ``got`` / ``receive`` 接收的消息不符合预期,发送一条消息给当前交互用户并暂停事件响应器,
|
||||||
|
在接收用户新的一条消息后继续当前处理函数
|
||||||
|
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
@ -594,6 +594,56 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
await cls.send(prompt, **kwargs)
|
await cls.send(prompt, **kwargs)
|
||||||
raise RejectedException
|
raise RejectedException
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def reject_arg(
|
||||||
|
cls,
|
||||||
|
key: str,
|
||||||
|
prompt: Optional[Union[str, Message, MessageSegment]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> NoReturn:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
最近使用 ``got`` 接收的消息不符合预期,发送一条消息给当前交互用户并暂停事件响应器,
|
||||||
|
在接收用户新的一条消息后继续当前处理函数
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``key: str``: 参数名
|
||||||
|
* ``prompt: Union[str, Message, MessageSegment]``: 消息内容
|
||||||
|
* ``**kwargs``: 其他传递给 ``bot.send`` 的参数,请参考对应 adapter 的 bot 对象 api
|
||||||
|
"""
|
||||||
|
matcher = current_matcher.get()
|
||||||
|
matcher.set_target(ARG_KEY.format(key=key))
|
||||||
|
if prompt is not None:
|
||||||
|
await cls.send(prompt, **kwargs)
|
||||||
|
raise RejectedException
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def reject_receive(
|
||||||
|
cls,
|
||||||
|
id: str = "",
|
||||||
|
prompt: Optional[Union[str, Message, MessageSegment]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> NoReturn:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
最近使用 ``got`` 接收的消息不符合预期,发送一条消息给当前交互用户并暂停事件响应器,
|
||||||
|
在接收用户新的一条消息后继续当前处理函数
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``id: str``: 消息 id
|
||||||
|
* ``prompt: Union[str, Message, MessageSegment]``: 消息内容
|
||||||
|
* ``**kwargs``: 其他传递给 ``bot.send`` 的参数,请参考对应 adapter 的 bot 对象 api
|
||||||
|
"""
|
||||||
|
matcher = current_matcher.get()
|
||||||
|
matcher.set_target(RECEIVE_KEY.format(id=id))
|
||||||
|
if prompt is not None:
|
||||||
|
await cls.send(prompt, **kwargs)
|
||||||
|
raise RejectedException
|
||||||
|
|
||||||
def get_receive(self, id: str, default: T = None) -> Union[Event, T]:
|
def get_receive(self, id: str, default: T = None) -> Union[Event, T]:
|
||||||
return self.state.get(RECEIVE_KEY.format(id=id), default)
|
return self.state.get(RECEIVE_KEY.format(id=id), default)
|
||||||
|
|
||||||
@ -650,7 +700,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
):
|
):
|
||||||
b_t = current_bot.set(bot)
|
b_t = current_bot.set(bot)
|
||||||
e_t = current_event.set(event)
|
e_t = current_event.set(event)
|
||||||
s_t = current_state.set(self.state)
|
m_t = current_matcher.set(self)
|
||||||
try:
|
try:
|
||||||
# Refresh preprocess state
|
# Refresh preprocess state
|
||||||
self.state.update(state)
|
self.state.update(state)
|
||||||
@ -679,7 +729,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
logger.info(f"Matcher {self} running complete")
|
logger.info(f"Matcher {self} running complete")
|
||||||
current_bot.reset(b_t)
|
current_bot.reset(b_t)
|
||||||
current_event.reset(e_t)
|
current_event.reset(e_t)
|
||||||
current_state.reset(s_t)
|
current_matcher.reset(m_t)
|
||||||
|
|
||||||
# 运行handlers
|
# 运行handlers
|
||||||
async def run(
|
async def run(
|
||||||
|
10
poetry.lock
generated
10
poetry.lock
generated
@ -95,14 +95,14 @@ tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-timeout"
|
name = "async-timeout"
|
||||||
version = "4.0.1"
|
version = "4.0.2"
|
||||||
description = "Timeout context manager for asyncio programs"
|
description = "Timeout context manager for asyncio programs"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
typing-extensions = ">=3.6.5"
|
typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "asynctest"
|
name = "asynctest"
|
||||||
@ -543,7 +543,7 @@ pytest-order = "^1.0.0"
|
|||||||
type = "git"
|
type = "git"
|
||||||
url = "https://github.com/nonebot/nonebug.git"
|
url = "https://github.com/nonebot/nonebug.git"
|
||||||
reference = "master"
|
reference = "master"
|
||||||
resolved_reference = "5cb87c36aac56da3c71dd8c5aaa9b16396afc528"
|
resolved_reference = "0a1132e9dc1803517ded0d485bfbe8c47a1d8585"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "packaging"
|
name = "packaging"
|
||||||
@ -1282,8 +1282,8 @@ asgiref = [
|
|||||||
{file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"},
|
{file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"},
|
||||||
]
|
]
|
||||||
async-timeout = [
|
async-timeout = [
|
||||||
{file = "async-timeout-4.0.1.tar.gz", hash = "sha256:b930cb161a39042f9222f6efb7301399c87eeab394727ec5437924a36d6eef51"},
|
{file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
|
||||||
{file = "async_timeout-4.0.1-py3-none-any.whl", hash = "sha256:a22c0b311af23337eb05fcf05a8b51c3ea53729d46fb5460af62bee033cec690"},
|
{file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"},
|
||||||
]
|
]
|
||||||
asynctest = [
|
asynctest = [
|
||||||
{file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"},
|
{file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"},
|
||||||
|
@ -16,10 +16,12 @@ test_got = on_message()
|
|||||||
@test_got.got("key1", "prompt key1")
|
@test_got.got("key1", "prompt key1")
|
||||||
@test_got.got("key2", "prompt key2")
|
@test_got.got("key2", "prompt key2")
|
||||||
async def got(key1: str = ArgStr(), key2: str = ArgStr()):
|
async def got(key1: str = ArgStr(), key2: str = ArgStr()):
|
||||||
assert key1 == "text"
|
if key2 == "text":
|
||||||
assert key2 == "text"
|
|
||||||
await test_got.reject("reject", at_sender=True)
|
await test_got.reject("reject", at_sender=True)
|
||||||
|
|
||||||
|
assert key1 == "text"
|
||||||
|
assert key2 == "text_next"
|
||||||
|
|
||||||
|
|
||||||
test_receive = on_message()
|
test_receive = on_message()
|
||||||
|
|
||||||
@ -33,3 +35,22 @@ async def receive(
|
|||||||
assert str(z.get_message()) == "text"
|
assert str(z.get_message()) == "text"
|
||||||
assert x is y
|
assert x is y
|
||||||
await test_receive.pause("pause", at_sender=True)
|
await test_receive.pause("pause", at_sender=True)
|
||||||
|
|
||||||
|
|
||||||
|
test_combine = on_message()
|
||||||
|
|
||||||
|
|
||||||
|
@test_combine.got("a")
|
||||||
|
@test_combine.receive()
|
||||||
|
@test_combine.got("b")
|
||||||
|
async def combine(a: str = ArgStr(), b: str = ArgStr(), r: Event = Received()):
|
||||||
|
if a == "text":
|
||||||
|
await test_combine.reject_arg("a")
|
||||||
|
elif b == "text":
|
||||||
|
await test_combine.reject_arg("b")
|
||||||
|
elif str(r.get_message()) == "text":
|
||||||
|
await test_combine.reject_receive()
|
||||||
|
|
||||||
|
assert a == "text_next"
|
||||||
|
assert b == "text_next"
|
||||||
|
assert str(r.get_message()) == "text_next"
|
||||||
|
@ -8,10 +8,17 @@ from utils import load_plugin, make_fake_event, make_fake_message
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_matcher(app: App, load_plugin):
|
async def test_matcher(app: App, load_plugin):
|
||||||
from plugins.matcher import test_got, test_handle, test_receive
|
from plugins.matcher import (
|
||||||
|
test_got,
|
||||||
|
test_handle,
|
||||||
|
test_combine,
|
||||||
|
test_receive,
|
||||||
|
)
|
||||||
|
|
||||||
message = make_fake_message()("text")
|
message = make_fake_message()("text")
|
||||||
event = make_fake_event(_message=message)()
|
event = make_fake_event(_message=message)()
|
||||||
|
message_next = make_fake_message()("text_next")
|
||||||
|
event_next = make_fake_event(_message=message_next)()
|
||||||
|
|
||||||
assert len(test_handle.handlers) == 1
|
assert len(test_handle.handlers) == 1
|
||||||
async with app.test_matcher(test_handle) as ctx:
|
async with app.test_matcher(test_handle) as ctx:
|
||||||
@ -30,6 +37,7 @@ async def test_matcher(app: App, load_plugin):
|
|||||||
ctx.receive_event(bot, event)
|
ctx.receive_event(bot, event)
|
||||||
ctx.should_call_send(event, "reject", "result3", at_sender=True)
|
ctx.should_call_send(event, "reject", "result3", at_sender=True)
|
||||||
ctx.should_rejected()
|
ctx.should_rejected()
|
||||||
|
ctx.receive_event(bot, event_next)
|
||||||
|
|
||||||
assert len(test_receive.handlers) == 1
|
assert len(test_receive.handlers) == 1
|
||||||
async with app.test_matcher(test_receive) as ctx:
|
async with app.test_matcher(test_receive) as ctx:
|
||||||
@ -39,3 +47,16 @@ async def test_matcher(app: App, load_plugin):
|
|||||||
ctx.receive_event(bot, event)
|
ctx.receive_event(bot, event)
|
||||||
ctx.should_call_send(event, "pause", "result", at_sender=True)
|
ctx.should_call_send(event, "pause", "result", at_sender=True)
|
||||||
ctx.should_paused()
|
ctx.should_paused()
|
||||||
|
|
||||||
|
assert len(test_receive.handlers) == 1
|
||||||
|
async with app.test_matcher(test_combine) as ctx:
|
||||||
|
bot = ctx.create_bot()
|
||||||
|
ctx.receive_event(bot, event)
|
||||||
|
ctx.receive_event(bot, event)
|
||||||
|
ctx.receive_event(bot, event)
|
||||||
|
ctx.should_rejected()
|
||||||
|
ctx.receive_event(bot, event_next)
|
||||||
|
ctx.should_rejected()
|
||||||
|
ctx.receive_event(bot, event_next)
|
||||||
|
ctx.should_rejected()
|
||||||
|
ctx.receive_event(bot, event_next)
|
||||||
|
Loading…
Reference in New Issue
Block a user