diff --git a/nonebot/consts.py b/nonebot/consts.py index a8867395..ae70ea7f 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -1,4 +1,29 @@ +# used by Params +WRAPPER_ASSIGNMENTS = ( + "__module__", + "__name__", + "__qualname__", + "__doc__", + "__annotations__", + "__globals__", +) + +# used by Matcher RECEIVE_KEY = "_receive_{id}" ARG_KEY = "_arg_{key}" ARG_STR_KEY = "{key}" REJECT_TARGET = "_current_target" + +# used by Rule +PREFIX_KEY = "_prefix" + +CMD_KEY = "command" +RAW_CMD_KEY = "raw_command" +CMD_ARG_KEY = "command_arg" + +SHELL_ARGS = "_args" +SHELL_ARGV = "_argv" + +REGEX_MATCHED = "_matched" +REGEX_GROUP = "_matched_groups" +REGEX_DICT = "_matched_dict" diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index 56a815ff..4f587803 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -8,6 +8,7 @@ from pydantic.typing import ForwardRef, evaluate_forwardref def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) + print(signature.parameters) typed_params = [ inspect.Parameter( name=param.name, diff --git a/nonebot/params.py b/nonebot/params.py index 4023f03d..d21c633e 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -1,12 +1,25 @@ import inspect -from typing import Any, List, Type, Callable, Optional, cast +from functools import wraps, partial +from typing import Any, Tuple, Union, TypeVar, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from pydantic.fields import Required, Undefined -from nonebot.adapters import Bot, Event from nonebot.typing import T_State, T_Handler +from nonebot.adapters import Bot, Event, Message from nonebot.dependencies import Param, Dependent +from nonebot.consts import ( + CMD_KEY, + PREFIX_KEY, + REGEX_DICT, + SHELL_ARGS, + SHELL_ARGV, + CMD_ARG_KEY, + RAW_CMD_KEY, + REGEX_GROUP, + REGEX_MATCHED, + WRAPPER_ASSIGNMENTS, +) from nonebot.utils import ( CacheDict, get_name, @@ -18,6 +31,8 @@ from nonebot.utils import ( generic_check_issubclass, ) +T = TypeVar("T") + class DependsInner: def __init__( @@ -175,12 +190,44 @@ class EventParam(Param): return event +async def _event_type(event: Event) -> str: + return event.get_type() + + +def EventType() -> str: + return Depends(_event_type) + + +async def _event_message(event: Event) -> Message: + return event.get_message() + + +def EventMessage() -> Message: + return Depends(_event_message) + + +async def _event_plain_text(event: Event) -> str: + return event.get_plaintext() + + +def EventPlainText() -> str: + return Depends(_event_plain_text) + + +async def _event_to_me(event: Event) -> bool: + return event.is_tome() + + +def EventToMe() -> bool: + return Depends(_event_to_me) + + class StateInner: ... -def State() -> Any: - return StateInner() +def State() -> T_State: + return StateInner() # type: ignore class StateParam(Param): @@ -195,6 +242,30 @@ class StateParam(Param): return state +def _command(state=State()) -> Message: + return state[PREFIX_KEY][CMD_KEY] + + +def Command() -> Tuple[str, ...]: + return Depends(_command) + + +def _raw_command(state=State()) -> Message: + return state[PREFIX_KEY][RAW_CMD_KEY] + + +def RawCommand() -> str: + return Depends(_raw_command) + + +def _command_arg(state=State()) -> Message: + return state[PREFIX_KEY][CMD_ARG_KEY] + + +def CommandArg() -> Message: + return Depends(_command_arg) + + class MatcherParam(Param): @classmethod def _check_param( @@ -209,6 +280,18 @@ class MatcherParam(Param): return matcher +def _received(matcher: "Matcher", id: str = "", default: T = None) -> Union[Event, T]: + return matcher.get_receive(id, default) + + +def Received(id: str = "", default: Any = None) -> Any: + return Depends( + wraps(_received, assigned=WRAPPER_ASSIGNMENTS)( + partial(_received, id=id, default=default) + ) + ) + + class ExceptionParam(Param): @classmethod def _check_param( diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 5f65a21b..d302cfdf 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -394,27 +394,8 @@ def on_command( - ``Type[Matcher]`` """ - async def _strip_cmd(event: Event, state: T_State = State()): - message = event.get_message() - if len(message) < 1: - return - segment = message.pop(0) - segment_text = str(segment).lstrip() - if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]): - return - new_message = message.__class__( - segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() - ) - for new_segment in reversed(new_message): - message.insert(0, new_segment) - - handlers = kwargs.pop("handlers", []) - handlers.insert(0, _strip_cmd) - commands = set([cmd]) | (aliases or set()) - return on_message( - command(*commands) & rule, handlers=handlers, **kwargs, _depth=_depth + 1 - ) + return on_message(command(*commands) & rule, **kwargs, _depth=_depth + 1) def on_shell_command( @@ -452,22 +433,9 @@ def on_shell_command( - ``Type[Matcher]`` """ - async def _strip_cmd(event: Event, state: T_State = State()): - message = event.get_message() - segment = message.pop(0) - new_message = message.__class__( - str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].strip() - ) - for new_segment in reversed(new_message): - message.insert(0, new_segment) - - handlers = kwargs.pop("handlers", []) - handlers.insert(0, _strip_cmd) - commands = set([cmd]) | (aliases or set()) return on_message( shell_command(*commands, parser=parser) & rule, - handlers=handlers, **kwargs, _depth=_depth + 1, ) diff --git a/nonebot/rule.py b/nonebot/rule.py index a58d3125..627f18c6 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -25,24 +25,29 @@ from nonebot.log import logger from nonebot.utils import CacheDict from nonebot import params, get_driver from nonebot.dependencies import Dependent -from nonebot.adapters import Bot, Event, MessageSegment from nonebot.exception import ParserExit, SkippedException from nonebot.typing import T_State, T_Handler, T_RuleChecker - -PREFIX_KEY = "_prefix" -SUFFIX_KEY = "_suffix" -CMD_KEY = "command" -RAW_CMD_KEY = "raw_command" -CMD_RESULT = TypedDict( - "CMD_RESULT", {"command": Optional[Tuple[str, ...]], "raw_command": Optional[str]} +from nonebot.adapters import Bot, Event, Message, MessageSegment +from nonebot.consts import ( + CMD_KEY, + PREFIX_KEY, + REGEX_DICT, + SHELL_ARGS, + SHELL_ARGV, + CMD_ARG_KEY, + RAW_CMD_KEY, + REGEX_GROUP, + REGEX_MATCHED, ) -SHELL_ARGS = "_args" -SHELL_ARGV = "_argv" - -REGEX_MATCHED = "_matched" -REGEX_GROUP = "_matched_groups" -REGEX_DICT = "_matched_dict" +CMD_RESULT = TypedDict( + "CMD_RESULT", + { + "command": Optional[Tuple[str, ...]], + "raw_command": Optional[str], + "command_arg": Optional[Message[MessageSegment]], + }, +) class Rule: @@ -152,7 +157,6 @@ class Rule: class TrieRule: prefix: CharTrie = CharTrie() - suffix: CharTrie = CharTrie() @classmethod def add_prefix(cls, prefix: str, value: Any): @@ -162,36 +166,28 @@ class TrieRule: cls.prefix[prefix] = value @classmethod - def add_suffix(cls, suffix: str, value: Any): - if suffix[::-1] in cls.suffix: - logger.warning(f'Duplicated suffix rule "{suffix}"') - return - cls.suffix[suffix[::-1]] = value - - @classmethod - def get_value( - cls, bot: Bot, event: Event, state: T_State - ) -> Tuple[CMD_RESULT, CMD_RESULT]: - prefix = CMD_RESULT(command=None, raw_command=None) - suffix = CMD_RESULT(command=None, raw_command=None) + def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT: + prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None) state[PREFIX_KEY] = prefix - state[SUFFIX_KEY] = suffix if event.get_type() != "message": - return prefix, suffix + return prefix message = event.get_message() message_seg: MessageSegment = message[0] if message_seg.is_text(): - pf = cls.prefix.longest_prefix(str(message_seg).lstrip()) + segment_text = str(message_seg).lstrip() + pf = cls.prefix.longest_prefix(segment_text) prefix[RAW_CMD_KEY] = pf.key prefix[CMD_KEY] = pf.value - message_seg_r: MessageSegment = message[-1] - if message_seg_r.is_text(): - sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1]) - suffix[RAW_CMD_KEY] = sf.key - suffix[CMD_KEY] = sf.value + if pf.key: + msg = message.copy() + msg.pop(0) + new_message = msg.__class__(segment_text[len(pf.key) :].lstrip()) + for new_segment in reversed(new_message): + msg.insert(0, new_segment) + prefix[CMD_ARG_KEY] = msg - return prefix, suffix + return prefix class Startswith: diff --git a/nonebot/utils.py b/nonebot/utils.py index 075ba667..57756f58 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -64,7 +64,7 @@ def generic_check_issubclass( return True elif origin: return issubclass(origin, class_or_tuple) - raise + return False def is_coroutine_callable(call: Callable[..., Any]) -> bool: