♻️ use class rule and permission

This commit is contained in:
yanyongyu 2021-12-06 10:10:51 +08:00
parent ca4d7397f8
commit 5b75b72720
8 changed files with 202 additions and 135 deletions

View File

@ -69,22 +69,22 @@ def get_sub_dependant(
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
sub_dependant = get_dependent(
func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
)
return sub_dependant
def get_dependent(
*,
func: T_Handler,
call: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None,
) -> Dependent:
signature = get_typed_signature(func)
signature = get_typed_signature(call)
params = signature.parameters
dependent = Dependent(
func=func, name=name, allow_types=allow_types, use_cache=use_cache
call=call, name=name, allow_types=allow_types, use_cache=use_cache
)
for param_name, param in params.items():
if isinstance(param.default, DependsWrapper):
@ -108,7 +108,7 @@ def get_dependent(
break
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
f"Unknown parameter {param_name} for function {call} with type {param.annotation}"
)
annotation: Any = Any
@ -153,7 +153,7 @@ async def solve_dependencies(
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {_dependent.func} "
f"type {type(value)} not match depends {_dependent.call} "
f"annotation {field._type_display()}, ignored"
)
raise SkippedException(field, value)
@ -163,9 +163,9 @@ async def solve_dependencies(
# solve sub dependencies
sub_dependent: Dependent
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
func = sub_dependent.func
call = sub_dependent.call
# solve sub dependency with current cache
solved_result = await solve_dependencies(
@ -179,19 +179,19 @@ async def solve_dependencies(
async with cache_lock:
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
elif is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
if is_gen_callable(call):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
cm = asynccontextmanager(call)(**sub_values)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
else:
solved = await run_sync(func)(**sub_values)
solved = await run_sync(call)(**sub_values)
# parameter dependency
if sub_dependent.name is not None:

View File

@ -36,17 +36,17 @@ class Dependent:
def __init__(
self,
*,
func: Optional[T_Handler] = None,
call: Optional[T_Handler] = None,
name: Optional[str] = None,
params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None,
dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True,
) -> None:
self.func = func
self.call = call
self.name = name
self.params = params or []
self.allow_types = allow_types or []
self.dependencies = dependencies or []
self.use_cache = use_cache
self.cache_key = self.func
self.cache_key = self.call

View File

@ -7,9 +7,9 @@ from pydantic.typing import ForwardRef, evaluate_forwardref
from nonebot.typing import T_Handler
def get_typed_signature(func: T_Handler) -> inspect.Signature:
signature = inspect.signature(func)
globalns = getattr(func, "__globals__", {})
def get_typed_signature(call: T_Handler) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,

View File

@ -25,7 +25,7 @@ class Handler:
def __init__(
self,
func: Callable[..., Any],
call: Callable[..., Any],
*,
name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None,
@ -38,17 +38,17 @@ class Handler:
:参数:
* ``func: Callable[..., Any]``: 事件处理函数
* ``call: Callable[..., Any]``: 事件处理函数
* ``name: Optional[str]``: 事件处理器名称默认为函数名
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型
"""
self.func = func
self.call = call
"""
:类型: ``Callable[..., Any]``
:说明: 事件处理函数
"""
self.name = get_name(func) if name is None else name
self.name = get_name(call) if name is None else name
"""
:类型: ``str``
:说明: 事件处理函数名
@ -68,7 +68,7 @@ class Handler:
if dependencies:
for depends in dependencies:
self.cache_dependent(depends)
self.dependent = get_dependent(func=func, allow_types=self.allow_types)
self.dependent = get_dependent(call=call, allow_types=self.allow_types)
def __repr__(self) -> str:
return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
@ -94,10 +94,10 @@ class Handler:
**params,
)
if asyncio.iscoroutinefunction(self.func):
return await self.func(**values)
if asyncio.iscoroutinefunction(self.call):
return await self.call(**values)
else:
return await run_sync(self.func)(**values)
return await run_sync(self.call)(**values)
def cache_dependent(self, dependency: DependsWrapper):
if not dependency.dependency:

View File

@ -442,7 +442,7 @@ class Matcher(metaclass=MatcherMeta):
def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].func is func:
if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1]
for depend in reversed(_dependencies):
func_handler.prepend_dependency(depend)
@ -513,7 +513,7 @@ class Matcher(metaclass=MatcherMeta):
def _decorator(func: T_Handler) -> T_Handler:
if cls.handlers and cls.handlers[-1].func is func:
if cls.handlers and cls.handlers[-1].call is func:
func_handler = cls.handlers[-1]
for depend in reversed(_dependencies):
func_handler.prepend_dependency(depend)

View File

@ -11,7 +11,17 @@ r"""
import asyncio
from contextlib import AsyncExitStack
from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Callable,
NoReturn,
Optional,
)
from nonebot import params
from nonebot.handler import Handler
@ -119,41 +129,59 @@ class Permission:
return Permission(*self.checkers, other)
async def _message(event: Event) -> bool:
return event.get_type() == "message"
class Message:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "message"
async def _notice(event: Event) -> bool:
return event.get_type() == "notice"
class Notice:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "notice"
async def _request(event: Event) -> bool:
return event.get_type() == "request"
class Request:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "request"
async def _metaevent(event: Event) -> bool:
return event.get_type() == "meta_event"
class MetaEvent:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "meta_event"
MESSAGE = Permission(_message)
MESSAGE = Permission(Message())
"""
- **说明**: 匹配任意 ``message`` 类型事件仅在需要同时捕获不同类型事件时使用优先使用 message type Matcher
"""
NOTICE = Permission(_notice)
NOTICE = Permission(Notice())
"""
- **说明**: 匹配任意 ``notice`` 类型事件仅在需要同时捕获不同类型事件时使用优先使用 notice type Matcher
"""
REQUEST = Permission(_request)
REQUEST = Permission(Request())
"""
- **说明**: 匹配任意 ``request`` 类型事件仅在需要同时捕获不同类型事件时使用优先使用 request type Matcher
"""
METAEVENT = Permission(_metaevent)
METAEVENT = Permission(MetaEvent())
"""
- **说明**: 匹配任意 ``meta_event`` 类型事件仅在需要同时捕获不同类型事件时使用优先使用 meta_event type Matcher
"""
def USER(*user: str, perm: Optional[Permission] = None):
class User:
def __init__(
self, users: Tuple[str, ...], perm: Optional[Permission] = None
) -> None:
self.users = users
self.perm = perm
async def __call__(self, bot: Bot, event: Event) -> bool:
return bool(
event.get_session_id() in self.users
and (self.perm is None or await self.perm(bot, event))
)
def USER(*users: str, perm: Optional[Permission] = None):
"""
:说明:
@ -165,21 +193,18 @@ def USER(*user: str, perm: Optional[Permission] = None):
* ``perm: Optional[Permission]``: 需要同时满足的权限
"""
async def _user(bot: Bot, event: Event) -> bool:
return bool(
event.get_session_id() in user and (perm is None or await perm(bot, event))
return Permission(User(users, perm))
class SuperUser:
async def __call__(self, bot: Bot, event: Event) -> bool:
return (
event.get_type() == "message"
and event.get_user_id() in bot.config.superusers
)
return Permission(_user)
async def _superuser(bot: Bot, event: Event) -> bool:
return (
event.get_type() == "message" and event.get_user_id() in bot.config.superusers
)
SUPERUSER = Permission(_superuser)
SUPERUSER = Permission(SuperUser())
"""
- **说明**: 匹配任意超级用户消息类型事件
"""

View File

@ -203,6 +203,24 @@ class TrieRule:
return prefix, suffix
class Startswith:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg
self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(
re.match(
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
text,
re.IGNORECASE if self.ignorecase else 0,
)
)
def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
"""
:说明:
@ -216,18 +234,25 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
if isinstance(msg, str):
msg = (msg,)
pattern = re.compile(
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
re.IGNORECASE if ignorecase else 0,
)
return Rule(Startswith(msg, ignorecase))
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
class Endswith:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg
self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(pattern.match(text))
return Rule(_startswith)
return bool(
re.search(
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
text,
re.IGNORECASE if self.ignorecase else 0,
)
)
def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
@ -243,18 +268,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
if isinstance(msg, str):
msg = (msg,)
pattern = re.compile(
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
re.IGNORECASE if ignorecase else 0,
)
return Rule(Endswith(msg, ignorecase))
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
class Keywords:
def __init__(self, *keywords: str):
self.keywords = keywords
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(pattern.search(text))
return Rule(_endswith)
return bool(text and any(keyword in text for keyword in self.keywords))
def keyword(*keywords: str) -> Rule:
@ -268,13 +293,18 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
async def _keyword(event: Event) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
return bool(text and any(keyword in text for keyword in keywords))
return Rule(Keywords(*keywords))
return Rule(_keyword)
class Command:
def __init__(self, cmds: List[Tuple[str, ...]]):
self.cmds = cmds
async def __call__(self, state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in self.cmds
def __repr__(self):
return f"<Command {self.cmds}>"
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
@ -304,10 +334,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
config = get_driver().config
command_start = config.command_start
command_sep = config.command_sep
commands = list(cmds)
for index, command in enumerate(commands):
commands: List[Tuple[str, ...]] = []
for command in cmds:
if isinstance(command, str):
commands[index] = command = (command,)
command = (command,)
commands.append(command)
if len(command) == 1:
for start in command_start:
@ -316,10 +348,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command)
return Rule(Command(commands))
class ArgumentParser(ArgParser):
@ -350,6 +379,27 @@ class ArgumentParser(ArgParser):
return super().parse_args(args=args, namespace=namespace) # type: ignore
class ShellCommand:
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
self.cmds = cmds
self.parser = parser
async def __call__(self, event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message)
if self.parser:
try:
args = self.parser.parse_args(state[SHELL_ARGV])
state[SHELL_ARGS] = args
except ParserExit as e:
state[SHELL_ARGS] = e
return True
else:
return False
def shell_command(
*cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
) -> Rule:
@ -392,10 +442,12 @@ def shell_command(
config = get_driver().config
command_start = config.command_start
command_sep = config.command_sep
commands = list(cmds)
for index, command in enumerate(commands):
commands: List[Tuple[str, ...]] = []
for command in cmds:
if isinstance(command, str):
commands[index] = command = (command,)
command = (command,)
commands.append(command)
if len(command) == 1:
for start in command_start:
@ -404,23 +456,26 @@ def shell_command(
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message)
if parser:
try:
args = parser.parse_args(state[SHELL_ARGV])
state[SHELL_ARGS] = args
except ParserExit as e:
state[SHELL_ARGS] = e
return Rule(ShellCommand(commands, parser))
class Regex:
def __init__(self, regex: str, flags: int = 0):
self.regex = regex
self.flags = flags
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = re.search(self.regex, str(event.get_message()), self.flags)
if matched:
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict()
return True
else:
return False
return Rule(_shell_command)
def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
r"""
@ -441,25 +496,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
\:\:\:
"""
pattern = re.compile(regex, flags)
async def _regex(event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = pattern.search(str(event.get_message()))
if matched:
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict()
return True
else:
return False
return Rule(_regex)
return Rule(Regex(regex, flags))
async def _to_me(event: Event) -> bool:
return event.is_tome()
class ToMe:
async def __call__(self, event: Event) -> bool:
return event.is_tome()
def to_me() -> Rule:
@ -473,4 +515,4 @@ def to_me() -> Rule:
*
"""
return Rule(_to_me)
return Rule(ToMe())

View File

@ -66,30 +66,30 @@ def generic_check_issubclass(
raise
def is_coroutine_callable(func: Callable[..., Any]) -> bool:
if inspect.isroutine(func):
return inspect.iscoroutinefunction(func)
if inspect.isclass(func):
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
func_ = getattr(func, "__call__", None)
func_ = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(func_)
def is_gen_callable(func: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(func):
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
func_ = getattr(func, "__call__", None)
func_ = getattr(call, "__call__", None)
return inspect.isgeneratorfunction(func_)
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(func):
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
func_ = getattr(func, "__call__", None)
func_ = getattr(call, "__call__", None)
return inspect.isasyncgenfunction(func_)
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]:
"""
:说明:
@ -97,17 +97,17 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
:参数:
* ``func: Callable[P, R]``: 被装饰的同步函数
* ``call: Callable[P, R]``: 被装饰的同步函数
:返回:
- ``Callable[P, Awaitable[R]]``
"""
@wraps(func)
@wraps(call)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_running_loop()
pfunc = partial(func, *args, **kwargs)
pfunc = partial(call, *args, **kwargs)
result = await loop.run_in_executor(None, pfunc)
return result