Feature: 存储 matcher 发送 prompt 的结果 (#3155)

This commit is contained in:
Ju4tCode 2024-12-05 20:55:24 +08:00 committed by GitHub
parent ab8dea5a02
commit 32bc2c314a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 271 additions and 22 deletions

View File

@ -22,6 +22,10 @@ REJECT_TARGET: Literal["_current_target"] = "_current_target"
"""当前 `reject` 目标存储 key""" """当前 `reject` 目标存储 key"""
REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target" REJECT_CACHE_TARGET: Literal["_next_target"] = "_next_target"
"""下一个 `reject` 目标存储 key""" """下一个 `reject` 目标存储 key"""
PAUSE_PROMPT_RESULT_KEY: Literal["_pause_result"] = "_pause_result"
"""`pause` prompt 发送结果存储 key"""
REJECT_PROMPT_RESULT_KEY: Literal["_reject_{key}_result"] = "_reject_{key}_result"
"""`reject` prompt 发送结果存储 key"""
# used by Rule # used by Rule
PREFIX_KEY: Literal["_prefix"] = "_prefix" PREFIX_KEY: Literal["_prefix"] = "_prefix"

View File

@ -27,8 +27,10 @@ from exceptiongroup import BaseExceptionGroup, catch
from nonebot.consts import ( from nonebot.consts import (
ARG_KEY, ARG_KEY,
LAST_RECEIVE_KEY, LAST_RECEIVE_KEY,
PAUSE_PROMPT_RESULT_KEY,
RECEIVE_KEY, RECEIVE_KEY,
REJECT_CACHE_TARGET, REJECT_CACHE_TARGET,
REJECT_PROMPT_RESULT_KEY,
REJECT_TARGET, REJECT_TARGET,
) )
from nonebot.dependencies import Dependent, Param from nonebot.dependencies import Dependent, Param
@ -560,8 +562,8 @@ class Matcher(metaclass=MatcherMeta):
""" """
bot = current_bot.get() bot = current_bot.get()
event = current_event.get() event = current_event.get()
state = current_matcher.get().state
if isinstance(message, MessageTemplate): if isinstance(message, MessageTemplate):
state = current_matcher.get().state
_message = message.format(**state) _message = message.format(**state)
else: else:
_message = message _message = message
@ -597,8 +599,15 @@ class Matcher(metaclass=MatcherMeta):
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数 kwargs: {ref}`nonebot.adapters.Bot.send` 的参数
请参考对应 adapter bot 对象 api 请参考对应 adapter bot 对象 api
""" """
try:
matcher = current_matcher.get()
except Exception:
matcher = None
if prompt is not None: if prompt is not None:
await cls.send(prompt, **kwargs) result = await cls.send(prompt, **kwargs)
if matcher is not None:
matcher.state[PAUSE_PROMPT_RESULT_KEY] = result
raise PausedException raise PausedException
@classmethod @classmethod
@ -615,8 +624,19 @@ class Matcher(metaclass=MatcherMeta):
kwargs: {ref}`nonebot.adapters.Bot.send` 的参数 kwargs: {ref}`nonebot.adapters.Bot.send` 的参数
请参考对应 adapter bot 对象 api 请参考对应 adapter bot 对象 api
""" """
try:
matcher = current_matcher.get()
key = matcher.get_target()
except Exception:
matcher = None
key = None
key = REJECT_PROMPT_RESULT_KEY.format(key=key) if key is not None else None
if prompt is not None: if prompt is not None:
await cls.send(prompt, **kwargs) result = await cls.send(prompt, **kwargs)
if key is not None and matcher:
matcher.state[key] = result
raise RejectedException raise RejectedException
@classmethod @classmethod
@ -636,9 +656,12 @@ class Matcher(metaclass=MatcherMeta):
请参考对应 adapter bot 对象 api 请参考对应 adapter bot 对象 api
""" """
matcher = current_matcher.get() matcher = current_matcher.get()
matcher.set_target(ARG_KEY.format(key=key)) arg_key = ARG_KEY.format(key=key)
matcher.set_target(arg_key)
if prompt is not None: if prompt is not None:
await cls.send(prompt, **kwargs) result = await cls.send(prompt, **kwargs)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=arg_key)] = result
raise RejectedException raise RejectedException
@classmethod @classmethod
@ -658,9 +681,12 @@ class Matcher(metaclass=MatcherMeta):
请参考对应 adapter bot 对象 api 请参考对应 adapter bot 对象 api
""" """
matcher = current_matcher.get() matcher = current_matcher.get()
matcher.set_target(RECEIVE_KEY.format(id=id)) receive_key = RECEIVE_KEY.format(id=id)
matcher.set_target(receive_key)
if prompt is not None: if prompt is not None:
await cls.send(prompt, **kwargs) result = await cls.send(prompt, **kwargs)
matcher.state[REJECT_PROMPT_RESULT_KEY.format(key=receive_key)] = result
raise RejectedException raise RejectedException
@classmethod @classmethod

View File

@ -18,6 +18,7 @@ from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
from nonebot.consts import ARG_KEY, REJECT_PROMPT_RESULT_KEY
from nonebot.dependencies import Dependent, Param from nonebot.dependencies import Dependent, Param
from nonebot.dependencies.utils import check_field_type from nonebot.dependencies.utils import check_field_type
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
@ -39,7 +40,7 @@ from nonebot.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
@ -522,10 +523,10 @@ class MatcherParam(Param):
class ArgInner: class ArgInner:
def __init__( def __init__(
self, key: Optional[str], type: Literal["message", "str", "plaintext"] self, key: Optional[str], type: Literal["message", "str", "plaintext", "prompt"]
) -> None: ) -> None:
self.key: Optional[str] = key self.key: Optional[str] = key
self.type: Literal["message", "str", "plaintext"] = type self.type: Literal["message", "str", "plaintext", "prompt"] = type
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ArgInner(key={self.key!r}, type={self.type!r})" return f"ArgInner(key={self.key!r}, type={self.type!r})"
@ -546,6 +547,11 @@ def ArgPlainText(key: Optional[str] = None) -> str:
return ArgInner(key, "plaintext") # type: ignore return ArgInner(key, "plaintext") # type: ignore
def ArgPromptResult(key: Optional[str] = None) -> Any:
"""`arg` prompt 发送结果"""
return ArgInner(key, "prompt")
class ArgParam(Param): class ArgParam(Param):
"""Arg 注入参数 """Arg 注入参数
@ -559,7 +565,7 @@ class ArgParam(Param):
self, self,
*args, *args,
key: str, key: str,
type: Literal["message", "str", "plaintext"], type: Literal["message", "str", "plaintext", "prompt"],
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -584,15 +590,32 @@ class ArgParam(Param):
async def _solve( # pyright: ignore[reportIncompatibleMethodOverride] async def _solve( # pyright: ignore[reportIncompatibleMethodOverride]
self, matcher: "Matcher", **kwargs: Any self, matcher: "Matcher", **kwargs: Any
) -> Any: ) -> Any:
message = matcher.get_arg(self.key)
if message is None:
return message
if self.type == "message": if self.type == "message":
return message return self._solve_message(matcher)
elif self.type == "str": elif self.type == "str":
return str(message) return self._solve_str(matcher)
elif self.type == "plaintext":
return self._solve_plaintext(matcher)
elif self.type == "prompt":
return self._solve_prompt(matcher)
else: else:
return message.extract_plain_text() raise ValueError(f"Unknown Arg type: {self.type}")
def _solve_message(self, matcher: "Matcher") -> Optional["Message"]:
return matcher.get_arg(self.key)
def _solve_str(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return str(message) if message is not None else None
def _solve_plaintext(self, matcher: "Matcher") -> Optional[str]:
message = matcher.get_arg(self.key)
return message.extract_plain_text() if message is not None else None
def _solve_prompt(self, matcher: "Matcher") -> Optional[Any]:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=ARG_KEY.format(key=self.key))
)
class ExceptionParam(Param): class ExceptionParam(Param):

View File

@ -19,9 +19,12 @@ from nonebot.consts import (
ENDSWITH_KEY, ENDSWITH_KEY,
FULLMATCH_KEY, FULLMATCH_KEY,
KEYWORD_KEY, KEYWORD_KEY,
PAUSE_PROMPT_RESULT_KEY,
PREFIX_KEY, PREFIX_KEY,
RAW_CMD_KEY, RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED, REGEX_MATCHED,
REJECT_PROMPT_RESULT_KEY,
SHELL_ARGS, SHELL_ARGS,
SHELL_ARGV, SHELL_ARGV,
STARTSWITH_KEY, STARTSWITH_KEY,
@ -29,6 +32,7 @@ from nonebot.consts import (
from nonebot.internal.params import Arg as Arg from nonebot.internal.params import Arg as Arg
from nonebot.internal.params import ArgParam as ArgParam from nonebot.internal.params import ArgParam as ArgParam
from nonebot.internal.params import ArgPlainText as ArgPlainText from nonebot.internal.params import ArgPlainText as ArgPlainText
from nonebot.internal.params import ArgPromptResult as ArgPromptResult
from nonebot.internal.params import ArgStr as ArgStr from nonebot.internal.params import ArgStr as ArgStr
from nonebot.internal.params import BotParam as BotParam from nonebot.internal.params import BotParam as BotParam
from nonebot.internal.params import DefaultParam as DefaultParam from nonebot.internal.params import DefaultParam as DefaultParam
@ -252,6 +256,26 @@ def LastReceived(default: Any = None) -> Any:
return Depends(_last_received, use_cache=False) return Depends(_last_received, use_cache=False)
def ReceivePromptResult(id: Optional[str] = None) -> Any:
"""`receive` prompt 发送结果"""
def _receive_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(
REJECT_PROMPT_RESULT_KEY.format(key=RECEIVE_KEY.format(id=id))
)
return Depends(_receive_prompt_result, use_cache=False)
def PausePromptResult() -> Any:
"""`pause` prompt 发送结果"""
def _pause_prompt_result(matcher: "Matcher") -> Any:
return matcher.state.get(PAUSE_PROMPT_RESULT_KEY)
return Depends(_pause_prompt_result, use_cache=False)
__autodoc__ = { __autodoc__ = {
"Arg": True, "Arg": True,
"ArgStr": True, "ArgStr": True,
@ -265,4 +289,5 @@ __autodoc__ = {
"DefaultParam": True, "DefaultParam": True,
"MatcherParam": True, "MatcherParam": True,
"ExceptionParam": True, "ExceptionParam": True,
"ArgPromptResult": True,
} }

View File

@ -1,7 +1,7 @@
from typing import Annotated from typing import Annotated, Any
from nonebot.adapters import Message from nonebot.adapters import Message
from nonebot.params import Arg, ArgPlainText, ArgStr from nonebot.params import Arg, ArgPlainText, ArgPromptResult, ArgStr
async def arg(key: Message = Arg()) -> Message: async def arg(key: Message = Arg()) -> Message:
@ -28,6 +28,10 @@ async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key return key
async def annotated_arg_prompt_result(key: Annotated[Any, ArgPromptResult()]) -> Any:
return key
# test dependency priority # test dependency priority
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()): async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
return key return key

View File

@ -1,8 +1,13 @@
from typing import TypeVar, Union from typing import Any, TypeVar, Union
from nonebot.adapters import Event from nonebot.adapters import Event
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import LastReceived, Received from nonebot.params import (
LastReceived,
PausePromptResult,
Received,
ReceivePromptResult,
)
async def matcher(m: Matcher) -> Matcher: async def matcher(m: Matcher) -> Matcher:
@ -59,3 +64,11 @@ async def receive(e: Event = Received("test")) -> Event:
async def last_receive(e: Event = LastReceived()) -> Event: async def last_receive(e: Event = LastReceived()) -> Event:
return e return e
async def receive_prompt_result(result: Any = ReceivePromptResult("test")) -> Any:
return result
async def pause_prompt_result(result: Any = PausePromptResult()) -> Any:
return result

View File

@ -1,3 +1,4 @@
from contextlib import suppress
import re import re
from exceptiongroup import BaseExceptionGroup from exceptiongroup import BaseExceptionGroup
@ -5,6 +6,7 @@ from nonebug import App
import pytest import pytest
from nonebot.consts import ( from nonebot.consts import (
ARG_KEY,
CMD_ARG_KEY, CMD_ARG_KEY,
CMD_KEY, CMD_KEY,
CMD_START_KEY, CMD_START_KEY,
@ -14,13 +16,14 @@ from nonebot.consts import (
KEYWORD_KEY, KEYWORD_KEY,
PREFIX_KEY, PREFIX_KEY,
RAW_CMD_KEY, RAW_CMD_KEY,
RECEIVE_KEY,
REGEX_MATCHED, REGEX_MATCHED,
SHELL_ARGS, SHELL_ARGS,
SHELL_ARGV, SHELL_ARGV,
STARTSWITH_KEY, STARTSWITH_KEY,
) )
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import TypeMisMatch from nonebot.exception import PausedException, RejectedException, TypeMisMatch
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import ( from nonebot.params import (
ArgParam, ArgParam,
@ -469,8 +472,10 @@ async def test_matcher(app: App):
matcher, matcher,
not_legacy_matcher, not_legacy_matcher,
not_matcher, not_matcher,
pause_prompt_result,
postpone_matcher, postpone_matcher,
receive, receive,
receive_prompt_result,
sub_matcher, sub_matcher,
union_matcher, union_matcher,
) )
@ -538,12 +543,42 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=fake_matcher) ctx.pass_params(matcher=fake_matcher)
ctx.should_return(event_next) ctx.should_return(event_next)
fake_matcher.set_target(RECEIVE_KEY.format(id="test"), cache=False)
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_send(event, "test", result=True, bot=bot)
with fake_matcher.ensure_context(bot, event):
with suppress(RejectedException):
await fake_matcher.reject("test")
async with app.test_dependent(
receive_prompt_result, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(True)
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_send(event, "test", result=False, bot=bot)
with fake_matcher.ensure_context(bot, event):
fake_matcher.set_target("test")
with suppress(PausedException):
await fake_matcher.pause("test")
async with app.test_dependent(
pause_prompt_result, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(False)
@pytest.mark.anyio @pytest.mark.anyio
async def test_arg(app: App): async def test_arg(app: App):
from plugins.param.param_arg import ( from plugins.param.param_arg import (
annotated_arg, annotated_arg,
annotated_arg_plain_text, annotated_arg_plain_text,
annotated_arg_prompt_result,
annotated_arg_str, annotated_arg_str,
annotated_multi_arg, annotated_multi_arg,
annotated_prior_arg, annotated_prior_arg,
@ -553,6 +588,7 @@ async def test_arg(app: App):
) )
matcher = Matcher() matcher = Matcher()
event = make_fake_event()()
message = FakeMessage("text") message = FakeMessage("text")
matcher.set_arg("key", message) matcher.set_arg("key", message)
@ -582,6 +618,21 @@ async def test_arg(app: App):
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text()) ctx.should_return(message.extract_plain_text())
matcher.set_target(ARG_KEY.format(key="key"), cache=False)
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_send(event, "test", result="arg", bot=bot)
with matcher.ensure_context(bot, event):
with suppress(RejectedException):
await matcher.reject("test")
async with app.test_dependent(
annotated_arg_prompt_result, allow_types=[ArgParam]
) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return("arg")
async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx: async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text()) ctx.should_return(message.extract_plain_text())

View File

@ -1224,6 +1224,37 @@ async def _(foo: Event = LastReceived()): ...
</TabItem> </TabItem>
</Tabs> </Tabs>
### ReceivePromptResult
获取某次 `receive` 发送提示消息的结果。
<Tabs groupId="annotated">
<TabItem value="annotated" label="Use Annotated" default>
```python {6}
from typing import Any, Annotated
from nonebot.params import ReceivePromptResult
@matcher.receive("id", prompt="prompt")
async def _(result: Annotated[Any, ReceivePromptResult("id")]): ...
```
</TabItem>
<TabItem value="no-annotated" label="Without Annotated">
```python {6}
from typing import Any
from nonebot.params import ReceivePromptResult
@matcher.receive("id", prompt="prompt")
async def _(result: Any = ReceivePromptResult("id")): ...
```
</TabItem>
</Tabs>
### Arg ### Arg
获取某次 `got` 接收的参数。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。 获取某次 `got` 接收的参数。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。
@ -1318,3 +1349,75 @@ async def _(foo: str = ArgPlainText("key")): ...
</TabItem> </TabItem>
</Tabs> </Tabs>
### ArgPromptResult
获取某次 `got` 发送提示消息的结果。如果 `Arg` 参数留空,则使用函数的参数名作为要获取的参数。
<Tabs groupId="annotated">
<TabItem value="annotated" label="Use Annotated" default>
```python {6,7}
from typing import Any, Annotated
from nonebot.params import ArgPromptResult
@matcher.got("key", prompt="prompt")
async def _(result: Annotated[Any, ArgPromptResult()]): ...
async def _(result: Annotated[Any, ArgPromptResult("key")]): ...
```
</TabItem>
<TabItem value="no-annotated" label="Without Annotated">
```python {6,7}
from typing import Any
from nonebot.params import ArgPromptResult
@matcher.got("key", prompt="prompt")
async def _(result: Any = ArgPromptResult()): ...
async def _(result: Any = ArgPromptResult("key")): ...
```
</TabItem>
</Tabs>
### PausePromptResult
获取最近一次 `pause` 发送提示消息的结果。
<Tabs groupId="annotated">
<TabItem value="annotated" label="Use Annotated" default>
```python {6}
from typing import Any, Annotated
from nonebot.params import PausePromptResult
@matcher.handle()
async def _():
await matcher.pause(prompt="prompt")
@matcher.handle()
async def _(result: Annotated[Any, PausePromptResult()]): ...
```
</TabItem>
<TabItem value="no-annotated" label="Without Annotated">
```python {6}
from typing import Any
from nonebot.params import PausePromptResult
@matcher.handle()
async def _():
await matcher.pause(prompt="prompt")
@matcher.handle()
async def _(result: Any = PausePromptResult()): ...
```
</TabItem>
</Tabs>