diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index 04d91d15..776a25e5 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -69,22 +69,22 @@ def get_sub_dependant( allow_types: Optional[List[Type[Param]]] = None, ) -> Dependent: sub_dependant = get_dependent( - func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types + call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types ) return sub_dependant def get_dependent( *, - func: T_Handler, + call: T_Handler, name: Optional[str] = None, use_cache: bool = True, allow_types: Optional[List[Type[Param]]] = None, ) -> Dependent: - signature = get_typed_signature(func) + signature = get_typed_signature(call) params = signature.parameters dependent = Dependent( - func=func, name=name, allow_types=allow_types, use_cache=use_cache + call=call, name=name, allow_types=allow_types, use_cache=use_cache ) for param_name, param in params.items(): if isinstance(param.default, DependsWrapper): @@ -108,7 +108,7 @@ def get_dependent( break else: raise ValueError( - f"Unknown parameter {param_name} for function {func} with type {param.annotation}" + f"Unknown parameter {param_name} for function {call} with type {param.annotation}" ) annotation: Any = Any @@ -153,7 +153,7 @@ async def solve_dependencies( if errs_: logger.debug( f"{field_info} " - f"type {type(value)} not match depends {_dependent.func} " + f"type {type(value)} not match depends {_dependent.call} " f"annotation {field._type_display()}, ignored" ) raise SkippedException(field, value) @@ -163,9 +163,9 @@ async def solve_dependencies( # solve sub dependencies sub_dependent: Dependent for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies): - sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) + sub_dependent.call = cast(Callable[..., Any], sub_dependent.call) sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key) - func = sub_dependent.func + call = sub_dependent.call # solve sub dependency with current cache solved_result = await solve_dependencies( @@ -179,19 +179,19 @@ async def solve_dependencies( async with cache_lock: if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache: solved = dependency_cache[sub_dependent.cache_key] - elif is_gen_callable(func) or is_async_gen_callable(func): + elif is_gen_callable(call) or is_async_gen_callable(call): assert isinstance( _stack, AsyncExitStack ), "Generator dependency should be called in context" - if is_gen_callable(func): - cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) + if is_gen_callable(call): + cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) else: - cm = asynccontextmanager(func)(**sub_values) + cm = asynccontextmanager(call)(**sub_values) solved = await _stack.enter_async_context(cm) - elif is_coroutine_callable(func): - solved = await func(**sub_values) + elif is_coroutine_callable(call): + solved = await call(**sub_values) else: - solved = await run_sync(func)(**sub_values) + solved = await run_sync(call)(**sub_values) # parameter dependency if sub_dependent.name is not None: diff --git a/nonebot/dependencies/models.py b/nonebot/dependencies/models.py index 3fdb8e81..3875a6d7 100644 --- a/nonebot/dependencies/models.py +++ b/nonebot/dependencies/models.py @@ -36,17 +36,17 @@ class Dependent: def __init__( self, *, - func: Optional[T_Handler] = None, + call: Optional[T_Handler] = None, name: Optional[str] = None, params: Optional[List[ModelField]] = None, allow_types: Optional[List[Type[Param]]] = None, dependencies: Optional[List["Dependent"]] = None, use_cache: bool = True, ) -> None: - self.func = func + self.call = call self.name = name self.params = params or [] self.allow_types = allow_types or [] self.dependencies = dependencies or [] self.use_cache = use_cache - self.cache_key = self.func + self.cache_key = self.call diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index 44a6b4c9..d82976c5 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -7,9 +7,9 @@ from pydantic.typing import ForwardRef, evaluate_forwardref from nonebot.typing import T_Handler -def get_typed_signature(func: T_Handler) -> inspect.Signature: - signature = inspect.signature(func) - globalns = getattr(func, "__globals__", {}) +def get_typed_signature(call: T_Handler) -> inspect.Signature: + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, diff --git a/nonebot/handler.py b/nonebot/handler.py index 78ec21b9..30e701fd 100644 --- a/nonebot/handler.py +++ b/nonebot/handler.py @@ -25,7 +25,7 @@ class Handler: def __init__( self, - func: Callable[..., Any], + call: Callable[..., Any], *, name: Optional[str] = None, dependencies: Optional[List[DependsWrapper]] = None, @@ -38,17 +38,17 @@ class Handler: :参数: - * ``func: Callable[..., Any]``: 事件处理函数。 + * ``call: Callable[..., Any]``: 事件处理函数。 * ``name: Optional[str]``: 事件处理器名称。默认为函数名。 * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。 * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。 """ - self.func = func + self.call = call """ :类型: ``Callable[..., Any]`` :说明: 事件处理函数 """ - self.name = get_name(func) if name is None else name + self.name = get_name(call) if name is None else name """ :类型: ``str`` :说明: 事件处理函数名 @@ -68,7 +68,7 @@ class Handler: if dependencies: for depends in dependencies: self.cache_dependent(depends) - self.dependent = get_dependent(func=func, allow_types=self.allow_types) + self.dependent = get_dependent(call=call, allow_types=self.allow_types) def __repr__(self) -> str: return f"" @@ -94,10 +94,10 @@ class Handler: **params, ) - if asyncio.iscoroutinefunction(self.func): - return await self.func(**values) + if asyncio.iscoroutinefunction(self.call): + return await self.call(**values) else: - return await run_sync(self.func)(**values) + return await run_sync(self.call)(**values) def cache_dependent(self, dependency: DependsWrapper): if not dependency.dependency: diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 19c7e553..47ebebd1 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -442,7 +442,7 @@ class Matcher(metaclass=MatcherMeta): def _decorator(func: T_Handler) -> T_Handler: - if cls.handlers and cls.handlers[-1].func is func: + if cls.handlers and cls.handlers[-1].call is func: func_handler = cls.handlers[-1] for depend in reversed(_dependencies): func_handler.prepend_dependency(depend) @@ -513,7 +513,7 @@ class Matcher(metaclass=MatcherMeta): def _decorator(func: T_Handler) -> T_Handler: - if cls.handlers and cls.handlers[-1].func is func: + if cls.handlers and cls.handlers[-1].call is func: func_handler = cls.handlers[-1] for depend in reversed(_dependencies): func_handler.prepend_dependency(depend) diff --git a/nonebot/permission.py b/nonebot/permission.py index 185022e0..f1f863db 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -11,7 +11,17 @@ r""" import asyncio from contextlib import AsyncExitStack -from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional +from typing import ( + Any, + Dict, + List, + Type, + Tuple, + Union, + Callable, + NoReturn, + Optional, +) from nonebot import params from nonebot.handler import Handler @@ -119,41 +129,59 @@ class Permission: return Permission(*self.checkers, other) -async def _message(event: Event) -> bool: - return event.get_type() == "message" +class Message: + async def __call__(self, event: Event) -> bool: + return event.get_type() == "message" -async def _notice(event: Event) -> bool: - return event.get_type() == "notice" +class Notice: + async def __call__(self, event: Event) -> bool: + return event.get_type() == "notice" -async def _request(event: Event) -> bool: - return event.get_type() == "request" +class Request: + async def __call__(self, event: Event) -> bool: + return event.get_type() == "request" -async def _metaevent(event: Event) -> bool: - return event.get_type() == "meta_event" +class MetaEvent: + async def __call__(self, event: Event) -> bool: + return event.get_type() == "meta_event" -MESSAGE = Permission(_message) +MESSAGE = Permission(Message()) """ - **说明**: 匹配任意 ``message`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 message type 的 Matcher。 """ -NOTICE = Permission(_notice) +NOTICE = Permission(Notice()) """ - **说明**: 匹配任意 ``notice`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 notice type 的 Matcher。 """ -REQUEST = Permission(_request) +REQUEST = Permission(Request()) """ - **说明**: 匹配任意 ``request`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 request type 的 Matcher。 """ -METAEVENT = Permission(_metaevent) +METAEVENT = Permission(MetaEvent()) """ - **说明**: 匹配任意 ``meta_event`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 meta_event type 的 Matcher。 """ -def USER(*user: str, perm: Optional[Permission] = None): +class User: + def __init__( + self, users: Tuple[str, ...], perm: Optional[Permission] = None + ) -> None: + self.users = users + self.perm = perm + + async def __call__(self, bot: Bot, event: Event) -> bool: + return bool( + event.get_session_id() in self.users + and (self.perm is None or await self.perm(bot, event)) + ) + + +def USER(*users: str, perm: Optional[Permission] = None): """ :说明: @@ -165,21 +193,18 @@ def USER(*user: str, perm: Optional[Permission] = None): * ``perm: Optional[Permission]``: 需要同时满足的权限 """ - async def _user(bot: Bot, event: Event) -> bool: - return bool( - event.get_session_id() in user and (perm is None or await perm(bot, event)) + return Permission(User(users, perm)) + + +class SuperUser: + async def __call__(self, bot: Bot, event: Event) -> bool: + return ( + event.get_type() == "message" + and event.get_user_id() in bot.config.superusers ) - return Permission(_user) - -async def _superuser(bot: Bot, event: Event) -> bool: - return ( - event.get_type() == "message" and event.get_user_id() in bot.config.superusers - ) - - -SUPERUSER = Permission(_superuser) +SUPERUSER = Permission(SuperUser()) """ - **说明**: 匹配任意超级用户消息类型事件 """ diff --git a/nonebot/rule.py b/nonebot/rule.py index 9469c517..2408e4eb 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -203,6 +203,24 @@ class TrieRule: return prefix, suffix +class Startswith: + def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): + self.msg = msg + self.ignorecase = ignorecase + + async def __call__(self, event: Event) -> Any: + if event.get_type() != "message": + return False + text = event.get_plaintext() + return bool( + re.match( + f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})", + text, + re.IGNORECASE if self.ignorecase else 0, + ) + ) + + def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: """ :说明: @@ -216,18 +234,25 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru if isinstance(msg, str): msg = (msg,) - pattern = re.compile( - f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})", - re.IGNORECASE if ignorecase else 0, - ) + return Rule(Startswith(msg, ignorecase)) - async def _startswith(bot: Bot, event: Event, state: T_State) -> bool: + +class Endswith: + def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): + self.msg = msg + self.ignorecase = ignorecase + + async def __call__(self, event: Event) -> Any: if event.get_type() != "message": return False text = event.get_plaintext() - return bool(pattern.match(text)) - - return Rule(_startswith) + return bool( + re.search( + f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$", + text, + re.IGNORECASE if self.ignorecase else 0, + ) + ) def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: @@ -243,18 +268,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule if isinstance(msg, str): msg = (msg,) - pattern = re.compile( - f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$", - re.IGNORECASE if ignorecase else 0, - ) + return Rule(Endswith(msg, ignorecase)) - async def _endswith(bot: Bot, event: Event, state: T_State) -> bool: + +class Keywords: + def __init__(self, *keywords: str): + self.keywords = keywords + + async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False text = event.get_plaintext() - return bool(pattern.search(text)) - - return Rule(_endswith) + return bool(text and any(keyword in text for keyword in self.keywords)) def keyword(*keywords: str) -> Rule: @@ -268,13 +293,18 @@ def keyword(*keywords: str) -> Rule: * ``*keywords: str``: 关键词 """ - async def _keyword(event: Event) -> bool: - if event.get_type() != "message": - return False - text = event.get_plaintext() - return bool(text and any(keyword in text for keyword in keywords)) + return Rule(Keywords(*keywords)) - return Rule(_keyword) + +class Command: + def __init__(self, cmds: List[Tuple[str, ...]]): + self.cmds = cmds + + async def __call__(self, state: T_State) -> bool: + return state[PREFIX_KEY][CMD_KEY] in self.cmds + + def __repr__(self): + return f"" def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: @@ -304,10 +334,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: config = get_driver().config command_start = config.command_start command_sep = config.command_sep - commands = list(cmds) - for index, command in enumerate(commands): + commands: List[Tuple[str, ...]] = [] + for command in cmds: if isinstance(command, str): - commands[index] = command = (command,) + command = (command,) + + commands.append(command) if len(command) == 1: for start in command_start: @@ -316,10 +348,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - async def _command(state: T_State) -> bool: - return state[PREFIX_KEY][CMD_KEY] in commands - - return Rule(_command) + return Rule(Command(commands)) class ArgumentParser(ArgParser): @@ -350,6 +379,27 @@ class ArgumentParser(ArgParser): return super().parse_args(args=args, namespace=namespace) # type: ignore +class ShellCommand: + def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]): + self.cmds = cmds + self.parser = parser + + async def __call__(self, event: Event, state: T_State) -> bool: + if state[PREFIX_KEY][CMD_KEY] in self.cmds: + message = str(event.get_message()) + strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() + state[SHELL_ARGV] = shlex.split(strip_message) + if self.parser: + try: + args = self.parser.parse_args(state[SHELL_ARGV]) + state[SHELL_ARGS] = args + except ParserExit as e: + state[SHELL_ARGS] = e + return True + else: + return False + + def shell_command( *cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None ) -> Rule: @@ -392,10 +442,12 @@ def shell_command( config = get_driver().config command_start = config.command_start command_sep = config.command_sep - commands = list(cmds) - for index, command in enumerate(commands): + commands: List[Tuple[str, ...]] = [] + for command in cmds: if isinstance(command, str): - commands[index] = command = (command,) + command = (command,) + + commands.append(command) if len(command) == 1: for start in command_start: @@ -404,23 +456,26 @@ def shell_command( for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - async def _shell_command(event: Event, state: T_State) -> bool: - if state[PREFIX_KEY][CMD_KEY] in commands: - message = str(event.get_message()) - strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() - state[SHELL_ARGV] = shlex.split(strip_message) - if parser: - try: - args = parser.parse_args(state[SHELL_ARGV]) - state[SHELL_ARGS] = args - except ParserExit as e: - state[SHELL_ARGS] = e + return Rule(ShellCommand(commands, parser)) + + +class Regex: + def __init__(self, regex: str, flags: int = 0): + self.regex = regex + self.flags = flags + + async def __call__(self, event: Event, state: T_State) -> bool: + if event.get_type() != "message": + return False + matched = re.search(self.regex, str(event.get_message()), self.flags) + if matched: + state[REGEX_MATCHED] = matched.group() + state[REGEX_GROUP] = matched.groups() + state[REGEX_DICT] = matched.groupdict() return True else: return False - return Rule(_shell_command) - def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: r""" @@ -441,25 +496,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: \:\:\: """ - pattern = re.compile(regex, flags) - - async def _regex(event: Event, state: T_State) -> bool: - if event.get_type() != "message": - return False - matched = pattern.search(str(event.get_message())) - if matched: - state[REGEX_MATCHED] = matched.group() - state[REGEX_GROUP] = matched.groups() - state[REGEX_DICT] = matched.groupdict() - return True - else: - return False - - return Rule(_regex) + return Rule(Regex(regex, flags)) -async def _to_me(event: Event) -> bool: - return event.is_tome() +class ToMe: + async def __call__(self, event: Event) -> bool: + return event.is_tome() def to_me() -> Rule: @@ -473,4 +515,4 @@ def to_me() -> Rule: * 无 """ - return Rule(_to_me) + return Rule(ToMe()) diff --git a/nonebot/utils.py b/nonebot/utils.py index 001ee43e..0c63289c 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -66,30 +66,30 @@ def generic_check_issubclass( raise -def is_coroutine_callable(func: Callable[..., Any]) -> bool: - if inspect.isroutine(func): - return inspect.iscoroutinefunction(func) - if inspect.isclass(func): +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): return False - func_ = getattr(func, "__call__", None) + func_ = getattr(call, "__call__", None) return inspect.iscoroutinefunction(func_) -def is_gen_callable(func: Callable[..., Any]) -> bool: - if inspect.isgeneratorfunction(func): +def is_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isgeneratorfunction(call): return True - func_ = getattr(func, "__call__", None) + func_ = getattr(call, "__call__", None) return inspect.isgeneratorfunction(func_) -def is_async_gen_callable(func: Callable[..., Any]) -> bool: - if inspect.isasyncgenfunction(func): +def is_async_gen_callable(call: Callable[..., Any]) -> bool: + if inspect.isasyncgenfunction(call): return True - func_ = getattr(func, "__call__", None) + func_ = getattr(call, "__call__", None) return inspect.isasyncgenfunction(func_) -def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: +def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]: """ :说明: @@ -97,17 +97,17 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]: :参数: - * ``func: Callable[P, R]``: 被装饰的同步函数 + * ``call: Callable[P, R]``: 被装饰的同步函数 :返回: - ``Callable[P, Awaitable[R]]`` """ - @wraps(func) + @wraps(call) async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: loop = asyncio.get_running_loop() - pfunc = partial(func, *args, **kwargs) + pfunc = partial(call, *args, **kwargs) result = await loop.run_in_executor(None, pfunc) return result