add di functions

This commit is contained in:
yanyongyu 2021-12-14 01:08:48 +08:00
parent e942f4076c
commit 329a1fd226
6 changed files with 147 additions and 74 deletions

View File

@ -1,4 +1,29 @@
# used by Params
WRAPPER_ASSIGNMENTS = (
"__module__",
"__name__",
"__qualname__",
"__doc__",
"__annotations__",
"__globals__",
)
# used by Matcher
RECEIVE_KEY = "_receive_{id}" RECEIVE_KEY = "_receive_{id}"
ARG_KEY = "_arg_{key}" ARG_KEY = "_arg_{key}"
ARG_STR_KEY = "{key}" ARG_STR_KEY = "{key}"
REJECT_TARGET = "_current_target" REJECT_TARGET = "_current_target"
# used by Rule
PREFIX_KEY = "_prefix"
CMD_KEY = "command"
RAW_CMD_KEY = "raw_command"
CMD_ARG_KEY = "command_arg"
SHELL_ARGS = "_args"
SHELL_ARGV = "_argv"
REGEX_MATCHED = "_matched"
REGEX_GROUP = "_matched_groups"
REGEX_DICT = "_matched_dict"

View File

@ -8,6 +8,7 @@ from pydantic.typing import ForwardRef, evaluate_forwardref
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call) signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {}) globalns = getattr(call, "__globals__", {})
print(signature.parameters)
typed_params = [ typed_params = [
inspect.Parameter( inspect.Parameter(
name=param.name, name=param.name,

View File

@ -1,12 +1,25 @@
import inspect import inspect
from typing import Any, List, Type, Callable, Optional, cast from functools import wraps, partial
from typing import Any, Tuple, Union, TypeVar, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined from pydantic.fields import Required, Undefined
from nonebot.adapters import Bot, Event
from nonebot.typing import T_State, T_Handler from nonebot.typing import T_State, T_Handler
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent from nonebot.dependencies import Param, Dependent
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
REGEX_DICT,
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
WRAPPER_ASSIGNMENTS,
)
from nonebot.utils import ( from nonebot.utils import (
CacheDict, CacheDict,
get_name, get_name,
@ -18,6 +31,8 @@ from nonebot.utils import (
generic_check_issubclass, generic_check_issubclass,
) )
T = TypeVar("T")
class DependsInner: class DependsInner:
def __init__( def __init__(
@ -175,12 +190,44 @@ class EventParam(Param):
return event return event
async def _event_type(event: Event) -> str:
return event.get_type()
def EventType() -> str:
return Depends(_event_type)
async def _event_message(event: Event) -> Message:
return event.get_message()
def EventMessage() -> Message:
return Depends(_event_message)
async def _event_plain_text(event: Event) -> str:
return event.get_plaintext()
def EventPlainText() -> str:
return Depends(_event_plain_text)
async def _event_to_me(event: Event) -> bool:
return event.is_tome()
def EventToMe() -> bool:
return Depends(_event_to_me)
class StateInner: class StateInner:
... ...
def State() -> Any: def State() -> T_State:
return StateInner() return StateInner() # type: ignore
class StateParam(Param): class StateParam(Param):
@ -195,6 +242,30 @@ class StateParam(Param):
return state return state
def _command(state=State()) -> Message:
return state[PREFIX_KEY][CMD_KEY]
def Command() -> Tuple[str, ...]:
return Depends(_command)
def _raw_command(state=State()) -> Message:
return state[PREFIX_KEY][RAW_CMD_KEY]
def RawCommand() -> str:
return Depends(_raw_command)
def _command_arg(state=State()) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY]
def CommandArg() -> Message:
return Depends(_command_arg)
class MatcherParam(Param): class MatcherParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
@ -209,6 +280,18 @@ class MatcherParam(Param):
return matcher return matcher
def _received(matcher: "Matcher", id: str = "", default: T = None) -> Union[Event, T]:
return matcher.get_receive(id, default)
def Received(id: str = "", default: Any = None) -> Any:
return Depends(
wraps(_received, assigned=WRAPPER_ASSIGNMENTS)(
partial(_received, id=id, default=default)
)
)
class ExceptionParam(Param): class ExceptionParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(

View File

@ -394,27 +394,8 @@ def on_command(
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
async def _strip_cmd(event: Event, state: T_State = State()):
message = event.get_message()
if len(message) < 1:
return
segment = message.pop(0)
segment_text = str(segment).lstrip()
if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
return
new_message = message.__class__(
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
)
for new_segment in reversed(new_message):
message.insert(0, new_segment)
handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set()) commands = set([cmd]) | (aliases or set())
return on_message( return on_message(command(*commands) & rule, **kwargs, _depth=_depth + 1)
command(*commands) & rule, handlers=handlers, **kwargs, _depth=_depth + 1
)
def on_shell_command( def on_shell_command(
@ -452,22 +433,9 @@ def on_shell_command(
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
async def _strip_cmd(event: Event, state: T_State = State()):
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].strip()
)
for new_segment in reversed(new_message):
message.insert(0, new_segment)
handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set()) commands = set([cmd]) | (aliases or set())
return on_message( return on_message(
shell_command(*commands, parser=parser) & rule, shell_command(*commands, parser=parser) & rule,
handlers=handlers,
**kwargs, **kwargs,
_depth=_depth + 1, _depth=_depth + 1,
) )

View File

@ -25,24 +25,29 @@ from nonebot.log import logger
from nonebot.utils import CacheDict from nonebot.utils import CacheDict
from nonebot import params, get_driver from nonebot import params, get_driver
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.adapters import Bot, Event, MessageSegment
from nonebot.exception import ParserExit, SkippedException from nonebot.exception import ParserExit, SkippedException
from nonebot.typing import T_State, T_Handler, T_RuleChecker from nonebot.typing import T_State, T_Handler, T_RuleChecker
from nonebot.adapters import Bot, Event, Message, MessageSegment
PREFIX_KEY = "_prefix" from nonebot.consts import (
SUFFIX_KEY = "_suffix" CMD_KEY,
CMD_KEY = "command" PREFIX_KEY,
RAW_CMD_KEY = "raw_command" REGEX_DICT,
CMD_RESULT = TypedDict( SHELL_ARGS,
"CMD_RESULT", {"command": Optional[Tuple[str, ...]], "raw_command": Optional[str]} SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
) )
SHELL_ARGS = "_args" CMD_RESULT = TypedDict(
SHELL_ARGV = "_argv" "CMD_RESULT",
{
REGEX_MATCHED = "_matched" "command": Optional[Tuple[str, ...]],
REGEX_GROUP = "_matched_groups" "raw_command": Optional[str],
REGEX_DICT = "_matched_dict" "command_arg": Optional[Message[MessageSegment]],
},
)
class Rule: class Rule:
@ -152,7 +157,6 @@ class Rule:
class TrieRule: class TrieRule:
prefix: CharTrie = CharTrie() prefix: CharTrie = CharTrie()
suffix: CharTrie = CharTrie()
@classmethod @classmethod
def add_prefix(cls, prefix: str, value: Any): def add_prefix(cls, prefix: str, value: Any):
@ -162,36 +166,28 @@ class TrieRule:
cls.prefix[prefix] = value cls.prefix[prefix] = value
@classmethod @classmethod
def add_suffix(cls, suffix: str, value: Any): def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
if suffix[::-1] in cls.suffix: prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None)
logger.warning(f'Duplicated suffix rule "{suffix}"')
return
cls.suffix[suffix[::-1]] = value
@classmethod
def get_value(
cls, bot: Bot, event: Event, state: T_State
) -> Tuple[CMD_RESULT, CMD_RESULT]:
prefix = CMD_RESULT(command=None, raw_command=None)
suffix = CMD_RESULT(command=None, raw_command=None)
state[PREFIX_KEY] = prefix state[PREFIX_KEY] = prefix
state[SUFFIX_KEY] = suffix
if event.get_type() != "message": if event.get_type() != "message":
return prefix, suffix return prefix
message = event.get_message() message = event.get_message()
message_seg: MessageSegment = message[0] message_seg: MessageSegment = message[0]
if message_seg.is_text(): if message_seg.is_text():
pf = cls.prefix.longest_prefix(str(message_seg).lstrip()) segment_text = str(message_seg).lstrip()
pf = cls.prefix.longest_prefix(segment_text)
prefix[RAW_CMD_KEY] = pf.key prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_KEY] = pf.value prefix[CMD_KEY] = pf.value
message_seg_r: MessageSegment = message[-1] if pf.key:
if message_seg_r.is_text(): msg = message.copy()
sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1]) msg.pop(0)
suffix[RAW_CMD_KEY] = sf.key new_message = msg.__class__(segment_text[len(pf.key) :].lstrip())
suffix[CMD_KEY] = sf.value for new_segment in reversed(new_message):
msg.insert(0, new_segment)
prefix[CMD_ARG_KEY] = msg
return prefix, suffix return prefix
class Startswith: class Startswith:

View File

@ -64,7 +64,7 @@ def generic_check_issubclass(
return True return True
elif origin: elif origin:
return issubclass(origin, class_or_tuple) return issubclass(origin, class_or_tuple)
raise return False
def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_coroutine_callable(call: Callable[..., Any]) -> bool: