add di functions

This commit is contained in:
yanyongyu 2021-12-14 01:08:48 +08:00
parent e942f4076c
commit 329a1fd226
6 changed files with 147 additions and 74 deletions

View File

@ -1,4 +1,29 @@
# used by Params
WRAPPER_ASSIGNMENTS = (
"__module__",
"__name__",
"__qualname__",
"__doc__",
"__annotations__",
"__globals__",
)
# used by Matcher
RECEIVE_KEY = "_receive_{id}"
ARG_KEY = "_arg_{key}"
ARG_STR_KEY = "{key}"
REJECT_TARGET = "_current_target"
# used by Rule
PREFIX_KEY = "_prefix"
CMD_KEY = "command"
RAW_CMD_KEY = "raw_command"
CMD_ARG_KEY = "command_arg"
SHELL_ARGS = "_args"
SHELL_ARGV = "_argv"
REGEX_MATCHED = "_matched"
REGEX_GROUP = "_matched_groups"
REGEX_DICT = "_matched_dict"

View File

@ -8,6 +8,7 @@ from pydantic.typing import ForwardRef, evaluate_forwardref
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
print(signature.parameters)
typed_params = [
inspect.Parameter(
name=param.name,

View File

@ -1,12 +1,25 @@
import inspect
from typing import Any, List, Type, Callable, Optional, cast
from functools import wraps, partial
from typing import Any, Tuple, Union, TypeVar, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic.fields import Required, Undefined
from nonebot.adapters import Bot, Event
from nonebot.typing import T_State, T_Handler
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies import Param, Dependent
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
REGEX_DICT,
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
WRAPPER_ASSIGNMENTS,
)
from nonebot.utils import (
CacheDict,
get_name,
@ -18,6 +31,8 @@ from nonebot.utils import (
generic_check_issubclass,
)
T = TypeVar("T")
class DependsInner:
def __init__(
@ -175,12 +190,44 @@ class EventParam(Param):
return event
async def _event_type(event: Event) -> str:
return event.get_type()
def EventType() -> str:
return Depends(_event_type)
async def _event_message(event: Event) -> Message:
return event.get_message()
def EventMessage() -> Message:
return Depends(_event_message)
async def _event_plain_text(event: Event) -> str:
return event.get_plaintext()
def EventPlainText() -> str:
return Depends(_event_plain_text)
async def _event_to_me(event: Event) -> bool:
return event.is_tome()
def EventToMe() -> bool:
return Depends(_event_to_me)
class StateInner:
...
def State() -> Any:
return StateInner()
def State() -> T_State:
return StateInner() # type: ignore
class StateParam(Param):
@ -195,6 +242,30 @@ class StateParam(Param):
return state
def _command(state=State()) -> Message:
return state[PREFIX_KEY][CMD_KEY]
def Command() -> Tuple[str, ...]:
return Depends(_command)
def _raw_command(state=State()) -> Message:
return state[PREFIX_KEY][RAW_CMD_KEY]
def RawCommand() -> str:
return Depends(_raw_command)
def _command_arg(state=State()) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY]
def CommandArg() -> Message:
return Depends(_command_arg)
class MatcherParam(Param):
@classmethod
def _check_param(
@ -209,6 +280,18 @@ class MatcherParam(Param):
return matcher
def _received(matcher: "Matcher", id: str = "", default: T = None) -> Union[Event, T]:
return matcher.get_receive(id, default)
def Received(id: str = "", default: Any = None) -> Any:
return Depends(
wraps(_received, assigned=WRAPPER_ASSIGNMENTS)(
partial(_received, id=id, default=default)
)
)
class ExceptionParam(Param):
@classmethod
def _check_param(

View File

@ -394,27 +394,8 @@ def on_command(
- ``Type[Matcher]``
"""
async def _strip_cmd(event: Event, state: T_State = State()):
message = event.get_message()
if len(message) < 1:
return
segment = message.pop(0)
segment_text = str(segment).lstrip()
if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
return
new_message = message.__class__(
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
)
for new_segment in reversed(new_message):
message.insert(0, new_segment)
handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set())
return on_message(
command(*commands) & rule, handlers=handlers, **kwargs, _depth=_depth + 1
)
return on_message(command(*commands) & rule, **kwargs, _depth=_depth + 1)
def on_shell_command(
@ -452,22 +433,9 @@ def on_shell_command(
- ``Type[Matcher]``
"""
async def _strip_cmd(event: Event, state: T_State = State()):
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].strip()
)
for new_segment in reversed(new_message):
message.insert(0, new_segment)
handlers = kwargs.pop("handlers", [])
handlers.insert(0, _strip_cmd)
commands = set([cmd]) | (aliases or set())
return on_message(
shell_command(*commands, parser=parser) & rule,
handlers=handlers,
**kwargs,
_depth=_depth + 1,
)

View File

@ -25,24 +25,29 @@ from nonebot.log import logger
from nonebot.utils import CacheDict
from nonebot import params, get_driver
from nonebot.dependencies import Dependent
from nonebot.adapters import Bot, Event, MessageSegment
from nonebot.exception import ParserExit, SkippedException
from nonebot.typing import T_State, T_Handler, T_RuleChecker
PREFIX_KEY = "_prefix"
SUFFIX_KEY = "_suffix"
CMD_KEY = "command"
RAW_CMD_KEY = "raw_command"
CMD_RESULT = TypedDict(
"CMD_RESULT", {"command": Optional[Tuple[str, ...]], "raw_command": Optional[str]}
from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
REGEX_DICT,
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
)
SHELL_ARGS = "_args"
SHELL_ARGV = "_argv"
REGEX_MATCHED = "_matched"
REGEX_GROUP = "_matched_groups"
REGEX_DICT = "_matched_dict"
CMD_RESULT = TypedDict(
"CMD_RESULT",
{
"command": Optional[Tuple[str, ...]],
"raw_command": Optional[str],
"command_arg": Optional[Message[MessageSegment]],
},
)
class Rule:
@ -152,7 +157,6 @@ class Rule:
class TrieRule:
prefix: CharTrie = CharTrie()
suffix: CharTrie = CharTrie()
@classmethod
def add_prefix(cls, prefix: str, value: Any):
@ -162,36 +166,28 @@ class TrieRule:
cls.prefix[prefix] = value
@classmethod
def add_suffix(cls, suffix: str, value: Any):
if suffix[::-1] in cls.suffix:
logger.warning(f'Duplicated suffix rule "{suffix}"')
return
cls.suffix[suffix[::-1]] = value
@classmethod
def get_value(
cls, bot: Bot, event: Event, state: T_State
) -> Tuple[CMD_RESULT, CMD_RESULT]:
prefix = CMD_RESULT(command=None, raw_command=None)
suffix = CMD_RESULT(command=None, raw_command=None)
def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None)
state[PREFIX_KEY] = prefix
state[SUFFIX_KEY] = suffix
if event.get_type() != "message":
return prefix, suffix
return prefix
message = event.get_message()
message_seg: MessageSegment = message[0]
if message_seg.is_text():
pf = cls.prefix.longest_prefix(str(message_seg).lstrip())
segment_text = str(message_seg).lstrip()
pf = cls.prefix.longest_prefix(segment_text)
prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_KEY] = pf.value
message_seg_r: MessageSegment = message[-1]
if message_seg_r.is_text():
sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1])
suffix[RAW_CMD_KEY] = sf.key
suffix[CMD_KEY] = sf.value
if pf.key:
msg = message.copy()
msg.pop(0)
new_message = msg.__class__(segment_text[len(pf.key) :].lstrip())
for new_segment in reversed(new_message):
msg.insert(0, new_segment)
prefix[CMD_ARG_KEY] = msg
return prefix, suffix
return prefix
class Startswith:

View File

@ -64,7 +64,7 @@ def generic_check_issubclass(
return True
elif origin:
return issubclass(origin, class_or_tuple)
raise
return False
def is_coroutine_callable(call: Callable[..., Any]) -> bool: