mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-02-22 02:25:56 +08:00
♻️ use class rule and permission
This commit is contained in:
parent
ca4d7397f8
commit
5b75b72720
@ -69,22 +69,22 @@ def get_sub_dependant(
|
|||||||
allow_types: Optional[List[Type[Param]]] = None,
|
allow_types: Optional[List[Type[Param]]] = None,
|
||||||
) -> Dependent:
|
) -> Dependent:
|
||||||
sub_dependant = get_dependent(
|
sub_dependant = get_dependent(
|
||||||
func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
|
call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
|
||||||
)
|
)
|
||||||
return sub_dependant
|
return sub_dependant
|
||||||
|
|
||||||
|
|
||||||
def get_dependent(
|
def get_dependent(
|
||||||
*,
|
*,
|
||||||
func: T_Handler,
|
call: T_Handler,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
allow_types: Optional[List[Type[Param]]] = None,
|
allow_types: Optional[List[Type[Param]]] = None,
|
||||||
) -> Dependent:
|
) -> Dependent:
|
||||||
signature = get_typed_signature(func)
|
signature = get_typed_signature(call)
|
||||||
params = signature.parameters
|
params = signature.parameters
|
||||||
dependent = Dependent(
|
dependent = Dependent(
|
||||||
func=func, name=name, allow_types=allow_types, use_cache=use_cache
|
call=call, name=name, allow_types=allow_types, use_cache=use_cache
|
||||||
)
|
)
|
||||||
for param_name, param in params.items():
|
for param_name, param in params.items():
|
||||||
if isinstance(param.default, DependsWrapper):
|
if isinstance(param.default, DependsWrapper):
|
||||||
@ -108,7 +108,7 @@ def get_dependent(
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
f"Unknown parameter {param_name} for function {call} with type {param.annotation}"
|
||||||
)
|
)
|
||||||
|
|
||||||
annotation: Any = Any
|
annotation: Any = Any
|
||||||
@ -153,7 +153,7 @@ async def solve_dependencies(
|
|||||||
if errs_:
|
if errs_:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{field_info} "
|
f"{field_info} "
|
||||||
f"type {type(value)} not match depends {_dependent.func} "
|
f"type {type(value)} not match depends {_dependent.call} "
|
||||||
f"annotation {field._type_display()}, ignored"
|
f"annotation {field._type_display()}, ignored"
|
||||||
)
|
)
|
||||||
raise SkippedException(field, value)
|
raise SkippedException(field, value)
|
||||||
@ -163,9 +163,9 @@ async def solve_dependencies(
|
|||||||
# solve sub dependencies
|
# solve sub dependencies
|
||||||
sub_dependent: Dependent
|
sub_dependent: Dependent
|
||||||
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
|
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
|
||||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
||||||
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
|
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
|
||||||
func = sub_dependent.func
|
call = sub_dependent.call
|
||||||
|
|
||||||
# solve sub dependency with current cache
|
# solve sub dependency with current cache
|
||||||
solved_result = await solve_dependencies(
|
solved_result = await solve_dependencies(
|
||||||
@ -179,19 +179,19 @@ async def solve_dependencies(
|
|||||||
async with cache_lock:
|
async with cache_lock:
|
||||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||||
solved = dependency_cache[sub_dependent.cache_key]
|
solved = dependency_cache[sub_dependent.cache_key]
|
||||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
_stack, AsyncExitStack
|
_stack, AsyncExitStack
|
||||||
), "Generator dependency should be called in context"
|
), "Generator dependency should be called in context"
|
||||||
if is_gen_callable(func):
|
if is_gen_callable(call):
|
||||||
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||||
else:
|
else:
|
||||||
cm = asynccontextmanager(func)(**sub_values)
|
cm = asynccontextmanager(call)(**sub_values)
|
||||||
solved = await _stack.enter_async_context(cm)
|
solved = await _stack.enter_async_context(cm)
|
||||||
elif is_coroutine_callable(func):
|
elif is_coroutine_callable(call):
|
||||||
solved = await func(**sub_values)
|
solved = await call(**sub_values)
|
||||||
else:
|
else:
|
||||||
solved = await run_sync(func)(**sub_values)
|
solved = await run_sync(call)(**sub_values)
|
||||||
|
|
||||||
# parameter dependency
|
# parameter dependency
|
||||||
if sub_dependent.name is not None:
|
if sub_dependent.name is not None:
|
||||||
|
@ -36,17 +36,17 @@ class Dependent:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
func: Optional[T_Handler] = None,
|
call: Optional[T_Handler] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
params: Optional[List[ModelField]] = None,
|
params: Optional[List[ModelField]] = None,
|
||||||
allow_types: Optional[List[Type[Param]]] = None,
|
allow_types: Optional[List[Type[Param]]] = None,
|
||||||
dependencies: Optional[List["Dependent"]] = None,
|
dependencies: Optional[List["Dependent"]] = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.func = func
|
self.call = call
|
||||||
self.name = name
|
self.name = name
|
||||||
self.params = params or []
|
self.params = params or []
|
||||||
self.allow_types = allow_types or []
|
self.allow_types = allow_types or []
|
||||||
self.dependencies = dependencies or []
|
self.dependencies = dependencies or []
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.cache_key = self.func
|
self.cache_key = self.call
|
||||||
|
@ -7,9 +7,9 @@ from pydantic.typing import ForwardRef, evaluate_forwardref
|
|||||||
from nonebot.typing import T_Handler
|
from nonebot.typing import T_Handler
|
||||||
|
|
||||||
|
|
||||||
def get_typed_signature(func: T_Handler) -> inspect.Signature:
|
def get_typed_signature(call: T_Handler) -> inspect.Signature:
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(call)
|
||||||
globalns = getattr(func, "__globals__", {})
|
globalns = getattr(call, "__globals__", {})
|
||||||
typed_params = [
|
typed_params = [
|
||||||
inspect.Parameter(
|
inspect.Parameter(
|
||||||
name=param.name,
|
name=param.name,
|
||||||
|
@ -25,7 +25,7 @@ class Handler:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Callable[..., Any],
|
call: Callable[..., Any],
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
dependencies: Optional[List[DependsWrapper]] = None,
|
dependencies: Optional[List[DependsWrapper]] = None,
|
||||||
@ -38,17 +38,17 @@ class Handler:
|
|||||||
|
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``func: Callable[..., Any]``: 事件处理函数。
|
* ``call: Callable[..., Any]``: 事件处理函数。
|
||||||
* ``name: Optional[str]``: 事件处理器名称。默认为函数名。
|
* ``name: Optional[str]``: 事件处理器名称。默认为函数名。
|
||||||
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。
|
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。
|
||||||
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。
|
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。
|
||||||
"""
|
"""
|
||||||
self.func = func
|
self.call = call
|
||||||
"""
|
"""
|
||||||
:类型: ``Callable[..., Any]``
|
:类型: ``Callable[..., Any]``
|
||||||
:说明: 事件处理函数
|
:说明: 事件处理函数
|
||||||
"""
|
"""
|
||||||
self.name = get_name(func) if name is None else name
|
self.name = get_name(call) if name is None else name
|
||||||
"""
|
"""
|
||||||
:类型: ``str``
|
:类型: ``str``
|
||||||
:说明: 事件处理函数名
|
:说明: 事件处理函数名
|
||||||
@ -68,7 +68,7 @@ class Handler:
|
|||||||
if dependencies:
|
if dependencies:
|
||||||
for depends in dependencies:
|
for depends in dependencies:
|
||||||
self.cache_dependent(depends)
|
self.cache_dependent(depends)
|
||||||
self.dependent = get_dependent(func=func, allow_types=self.allow_types)
|
self.dependent = get_dependent(call=call, allow_types=self.allow_types)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
|
return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
|
||||||
@ -94,10 +94,10 @@ class Handler:
|
|||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(self.func):
|
if asyncio.iscoroutinefunction(self.call):
|
||||||
return await self.func(**values)
|
return await self.call(**values)
|
||||||
else:
|
else:
|
||||||
return await run_sync(self.func)(**values)
|
return await run_sync(self.call)(**values)
|
||||||
|
|
||||||
def cache_dependent(self, dependency: DependsWrapper):
|
def cache_dependent(self, dependency: DependsWrapper):
|
||||||
if not dependency.dependency:
|
if not dependency.dependency:
|
||||||
|
@ -442,7 +442,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
def _decorator(func: T_Handler) -> T_Handler:
|
def _decorator(func: T_Handler) -> T_Handler:
|
||||||
|
|
||||||
if cls.handlers and cls.handlers[-1].func is func:
|
if cls.handlers and cls.handlers[-1].call is func:
|
||||||
func_handler = cls.handlers[-1]
|
func_handler = cls.handlers[-1]
|
||||||
for depend in reversed(_dependencies):
|
for depend in reversed(_dependencies):
|
||||||
func_handler.prepend_dependency(depend)
|
func_handler.prepend_dependency(depend)
|
||||||
@ -513,7 +513,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
|
|
||||||
def _decorator(func: T_Handler) -> T_Handler:
|
def _decorator(func: T_Handler) -> T_Handler:
|
||||||
|
|
||||||
if cls.handlers and cls.handlers[-1].func is func:
|
if cls.handlers and cls.handlers[-1].call is func:
|
||||||
func_handler = cls.handlers[-1]
|
func_handler = cls.handlers[-1]
|
||||||
for depend in reversed(_dependencies):
|
for depend in reversed(_dependencies):
|
||||||
func_handler.prepend_dependency(depend)
|
func_handler.prepend_dependency(depend)
|
||||||
|
@ -11,7 +11,17 @@ r"""
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
Callable,
|
||||||
|
NoReturn,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
from nonebot import params
|
from nonebot import params
|
||||||
from nonebot.handler import Handler
|
from nonebot.handler import Handler
|
||||||
@ -119,41 +129,59 @@ class Permission:
|
|||||||
return Permission(*self.checkers, other)
|
return Permission(*self.checkers, other)
|
||||||
|
|
||||||
|
|
||||||
async def _message(event: Event) -> bool:
|
class Message:
|
||||||
|
async def __call__(self, event: Event) -> bool:
|
||||||
return event.get_type() == "message"
|
return event.get_type() == "message"
|
||||||
|
|
||||||
|
|
||||||
async def _notice(event: Event) -> bool:
|
class Notice:
|
||||||
|
async def __call__(self, event: Event) -> bool:
|
||||||
return event.get_type() == "notice"
|
return event.get_type() == "notice"
|
||||||
|
|
||||||
|
|
||||||
async def _request(event: Event) -> bool:
|
class Request:
|
||||||
|
async def __call__(self, event: Event) -> bool:
|
||||||
return event.get_type() == "request"
|
return event.get_type() == "request"
|
||||||
|
|
||||||
|
|
||||||
async def _metaevent(event: Event) -> bool:
|
class MetaEvent:
|
||||||
|
async def __call__(self, event: Event) -> bool:
|
||||||
return event.get_type() == "meta_event"
|
return event.get_type() == "meta_event"
|
||||||
|
|
||||||
|
|
||||||
MESSAGE = Permission(_message)
|
MESSAGE = Permission(Message())
|
||||||
"""
|
"""
|
||||||
- **说明**: 匹配任意 ``message`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 message type 的 Matcher。
|
- **说明**: 匹配任意 ``message`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 message type 的 Matcher。
|
||||||
"""
|
"""
|
||||||
NOTICE = Permission(_notice)
|
NOTICE = Permission(Notice())
|
||||||
"""
|
"""
|
||||||
- **说明**: 匹配任意 ``notice`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 notice type 的 Matcher。
|
- **说明**: 匹配任意 ``notice`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 notice type 的 Matcher。
|
||||||
"""
|
"""
|
||||||
REQUEST = Permission(_request)
|
REQUEST = Permission(Request())
|
||||||
"""
|
"""
|
||||||
- **说明**: 匹配任意 ``request`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 request type 的 Matcher。
|
- **说明**: 匹配任意 ``request`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 request type 的 Matcher。
|
||||||
"""
|
"""
|
||||||
METAEVENT = Permission(_metaevent)
|
METAEVENT = Permission(MetaEvent())
|
||||||
"""
|
"""
|
||||||
- **说明**: 匹配任意 ``meta_event`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 meta_event type 的 Matcher。
|
- **说明**: 匹配任意 ``meta_event`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 meta_event type 的 Matcher。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def USER(*user: str, perm: Optional[Permission] = None):
|
class User:
|
||||||
|
def __init__(
|
||||||
|
self, users: Tuple[str, ...], perm: Optional[Permission] = None
|
||||||
|
) -> None:
|
||||||
|
self.users = users
|
||||||
|
self.perm = perm
|
||||||
|
|
||||||
|
async def __call__(self, bot: Bot, event: Event) -> bool:
|
||||||
|
return bool(
|
||||||
|
event.get_session_id() in self.users
|
||||||
|
and (self.perm is None or await self.perm(bot, event))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def USER(*users: str, perm: Optional[Permission] = None):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
@ -165,21 +193,18 @@ def USER(*user: str, perm: Optional[Permission] = None):
|
|||||||
* ``perm: Optional[Permission]``: 需要同时满足的权限
|
* ``perm: Optional[Permission]``: 需要同时满足的权限
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _user(bot: Bot, event: Event) -> bool:
|
return Permission(User(users, perm))
|
||||||
return bool(
|
|
||||||
event.get_session_id() in user and (perm is None or await perm(bot, event))
|
|
||||||
)
|
|
||||||
|
|
||||||
return Permission(_user)
|
|
||||||
|
|
||||||
|
|
||||||
async def _superuser(bot: Bot, event: Event) -> bool:
|
class SuperUser:
|
||||||
|
async def __call__(self, bot: Bot, event: Event) -> bool:
|
||||||
return (
|
return (
|
||||||
event.get_type() == "message" and event.get_user_id() in bot.config.superusers
|
event.get_type() == "message"
|
||||||
|
and event.get_user_id() in bot.config.superusers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SUPERUSER = Permission(_superuser)
|
SUPERUSER = Permission(SuperUser())
|
||||||
"""
|
"""
|
||||||
- **说明**: 匹配任意超级用户消息类型事件
|
- **说明**: 匹配任意超级用户消息类型事件
|
||||||
"""
|
"""
|
||||||
|
166
nonebot/rule.py
166
nonebot/rule.py
@ -203,6 +203,24 @@ class TrieRule:
|
|||||||
return prefix, suffix
|
return prefix, suffix
|
||||||
|
|
||||||
|
|
||||||
|
class Startswith:
|
||||||
|
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||||
|
self.msg = msg
|
||||||
|
self.ignorecase = ignorecase
|
||||||
|
|
||||||
|
async def __call__(self, event: Event) -> Any:
|
||||||
|
if event.get_type() != "message":
|
||||||
|
return False
|
||||||
|
text = event.get_plaintext()
|
||||||
|
return bool(
|
||||||
|
re.match(
|
||||||
|
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
|
||||||
|
text,
|
||||||
|
re.IGNORECASE if self.ignorecase else 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
|
def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -216,18 +234,25 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
|
|||||||
if isinstance(msg, str):
|
if isinstance(msg, str):
|
||||||
msg = (msg,)
|
msg = (msg,)
|
||||||
|
|
||||||
pattern = re.compile(
|
return Rule(Startswith(msg, ignorecase))
|
||||||
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
|
|
||||||
re.IGNORECASE if ignorecase else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
|
|
||||||
|
class Endswith:
|
||||||
|
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||||
|
self.msg = msg
|
||||||
|
self.ignorecase = ignorecase
|
||||||
|
|
||||||
|
async def __call__(self, event: Event) -> Any:
|
||||||
if event.get_type() != "message":
|
if event.get_type() != "message":
|
||||||
return False
|
return False
|
||||||
text = event.get_plaintext()
|
text = event.get_plaintext()
|
||||||
return bool(pattern.match(text))
|
return bool(
|
||||||
|
re.search(
|
||||||
return Rule(_startswith)
|
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
|
||||||
|
text,
|
||||||
|
re.IGNORECASE if self.ignorecase else 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
|
def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
|
||||||
@ -243,18 +268,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
|
|||||||
if isinstance(msg, str):
|
if isinstance(msg, str):
|
||||||
msg = (msg,)
|
msg = (msg,)
|
||||||
|
|
||||||
pattern = re.compile(
|
return Rule(Endswith(msg, ignorecase))
|
||||||
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
|
|
||||||
re.IGNORECASE if ignorecase else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
|
|
||||||
|
class Keywords:
|
||||||
|
def __init__(self, *keywords: str):
|
||||||
|
self.keywords = keywords
|
||||||
|
|
||||||
|
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
|
||||||
if event.get_type() != "message":
|
if event.get_type() != "message":
|
||||||
return False
|
return False
|
||||||
text = event.get_plaintext()
|
text = event.get_plaintext()
|
||||||
return bool(pattern.search(text))
|
return bool(text and any(keyword in text for keyword in self.keywords))
|
||||||
|
|
||||||
return Rule(_endswith)
|
|
||||||
|
|
||||||
|
|
||||||
def keyword(*keywords: str) -> Rule:
|
def keyword(*keywords: str) -> Rule:
|
||||||
@ -268,13 +293,18 @@ def keyword(*keywords: str) -> Rule:
|
|||||||
* ``*keywords: str``: 关键词
|
* ``*keywords: str``: 关键词
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _keyword(event: Event) -> bool:
|
return Rule(Keywords(*keywords))
|
||||||
if event.get_type() != "message":
|
|
||||||
return False
|
|
||||||
text = event.get_plaintext()
|
|
||||||
return bool(text and any(keyword in text for keyword in keywords))
|
|
||||||
|
|
||||||
return Rule(_keyword)
|
|
||||||
|
class Command:
|
||||||
|
def __init__(self, cmds: List[Tuple[str, ...]]):
|
||||||
|
self.cmds = cmds
|
||||||
|
|
||||||
|
async def __call__(self, state: T_State) -> bool:
|
||||||
|
return state[PREFIX_KEY][CMD_KEY] in self.cmds
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Command {self.cmds}>"
|
||||||
|
|
||||||
|
|
||||||
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||||
@ -304,10 +334,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
|||||||
config = get_driver().config
|
config = get_driver().config
|
||||||
command_start = config.command_start
|
command_start = config.command_start
|
||||||
command_sep = config.command_sep
|
command_sep = config.command_sep
|
||||||
commands = list(cmds)
|
commands: List[Tuple[str, ...]] = []
|
||||||
for index, command in enumerate(commands):
|
for command in cmds:
|
||||||
if isinstance(command, str):
|
if isinstance(command, str):
|
||||||
commands[index] = command = (command,)
|
command = (command,)
|
||||||
|
|
||||||
|
commands.append(command)
|
||||||
|
|
||||||
if len(command) == 1:
|
if len(command) == 1:
|
||||||
for start in command_start:
|
for start in command_start:
|
||||||
@ -316,10 +348,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
|||||||
for start, sep in product(command_start, command_sep):
|
for start, sep in product(command_start, command_sep):
|
||||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||||
|
|
||||||
async def _command(state: T_State) -> bool:
|
return Rule(Command(commands))
|
||||||
return state[PREFIX_KEY][CMD_KEY] in commands
|
|
||||||
|
|
||||||
return Rule(_command)
|
|
||||||
|
|
||||||
|
|
||||||
class ArgumentParser(ArgParser):
|
class ArgumentParser(ArgParser):
|
||||||
@ -350,6 +379,27 @@ class ArgumentParser(ArgParser):
|
|||||||
return super().parse_args(args=args, namespace=namespace) # type: ignore
|
return super().parse_args(args=args, namespace=namespace) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ShellCommand:
|
||||||
|
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
|
||||||
|
self.cmds = cmds
|
||||||
|
self.parser = parser
|
||||||
|
|
||||||
|
async def __call__(self, event: Event, state: T_State) -> bool:
|
||||||
|
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
|
||||||
|
message = str(event.get_message())
|
||||||
|
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
||||||
|
state[SHELL_ARGV] = shlex.split(strip_message)
|
||||||
|
if self.parser:
|
||||||
|
try:
|
||||||
|
args = self.parser.parse_args(state[SHELL_ARGV])
|
||||||
|
state[SHELL_ARGS] = args
|
||||||
|
except ParserExit as e:
|
||||||
|
state[SHELL_ARGS] = e
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def shell_command(
|
def shell_command(
|
||||||
*cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
|
*cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
|
||||||
) -> Rule:
|
) -> Rule:
|
||||||
@ -392,10 +442,12 @@ def shell_command(
|
|||||||
config = get_driver().config
|
config = get_driver().config
|
||||||
command_start = config.command_start
|
command_start = config.command_start
|
||||||
command_sep = config.command_sep
|
command_sep = config.command_sep
|
||||||
commands = list(cmds)
|
commands: List[Tuple[str, ...]] = []
|
||||||
for index, command in enumerate(commands):
|
for command in cmds:
|
||||||
if isinstance(command, str):
|
if isinstance(command, str):
|
||||||
commands[index] = command = (command,)
|
command = (command,)
|
||||||
|
|
||||||
|
commands.append(command)
|
||||||
|
|
||||||
if len(command) == 1:
|
if len(command) == 1:
|
||||||
for start in command_start:
|
for start in command_start:
|
||||||
@ -404,23 +456,26 @@ def shell_command(
|
|||||||
for start, sep in product(command_start, command_sep):
|
for start, sep in product(command_start, command_sep):
|
||||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||||
|
|
||||||
async def _shell_command(event: Event, state: T_State) -> bool:
|
return Rule(ShellCommand(commands, parser))
|
||||||
if state[PREFIX_KEY][CMD_KEY] in commands:
|
|
||||||
message = str(event.get_message())
|
|
||||||
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
class Regex:
|
||||||
state[SHELL_ARGV] = shlex.split(strip_message)
|
def __init__(self, regex: str, flags: int = 0):
|
||||||
if parser:
|
self.regex = regex
|
||||||
try:
|
self.flags = flags
|
||||||
args = parser.parse_args(state[SHELL_ARGV])
|
|
||||||
state[SHELL_ARGS] = args
|
async def __call__(self, event: Event, state: T_State) -> bool:
|
||||||
except ParserExit as e:
|
if event.get_type() != "message":
|
||||||
state[SHELL_ARGS] = e
|
return False
|
||||||
|
matched = re.search(self.regex, str(event.get_message()), self.flags)
|
||||||
|
if matched:
|
||||||
|
state[REGEX_MATCHED] = matched.group()
|
||||||
|
state[REGEX_GROUP] = matched.groups()
|
||||||
|
state[REGEX_DICT] = matched.groupdict()
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return Rule(_shell_command)
|
|
||||||
|
|
||||||
|
|
||||||
def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||||
r"""
|
r"""
|
||||||
@ -441,24 +496,11 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
|||||||
\:\:\:
|
\:\:\:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pattern = re.compile(regex, flags)
|
return Rule(Regex(regex, flags))
|
||||||
|
|
||||||
async def _regex(event: Event, state: T_State) -> bool:
|
|
||||||
if event.get_type() != "message":
|
|
||||||
return False
|
|
||||||
matched = pattern.search(str(event.get_message()))
|
|
||||||
if matched:
|
|
||||||
state[REGEX_MATCHED] = matched.group()
|
|
||||||
state[REGEX_GROUP] = matched.groups()
|
|
||||||
state[REGEX_DICT] = matched.groupdict()
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return Rule(_regex)
|
|
||||||
|
|
||||||
|
|
||||||
async def _to_me(event: Event) -> bool:
|
class ToMe:
|
||||||
|
async def __call__(self, event: Event) -> bool:
|
||||||
return event.is_tome()
|
return event.is_tome()
|
||||||
|
|
||||||
|
|
||||||
@ -473,4 +515,4 @@ def to_me() -> Rule:
|
|||||||
* 无
|
* 无
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return Rule(_to_me)
|
return Rule(ToMe())
|
||||||
|
@ -66,30 +66,30 @@ def generic_check_issubclass(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def is_coroutine_callable(func: Callable[..., Any]) -> bool:
|
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||||
if inspect.isroutine(func):
|
if inspect.isroutine(call):
|
||||||
return inspect.iscoroutinefunction(func)
|
return inspect.iscoroutinefunction(call)
|
||||||
if inspect.isclass(func):
|
if inspect.isclass(call):
|
||||||
return False
|
return False
|
||||||
func_ = getattr(func, "__call__", None)
|
func_ = getattr(call, "__call__", None)
|
||||||
return inspect.iscoroutinefunction(func_)
|
return inspect.iscoroutinefunction(func_)
|
||||||
|
|
||||||
|
|
||||||
def is_gen_callable(func: Callable[..., Any]) -> bool:
|
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||||
if inspect.isgeneratorfunction(func):
|
if inspect.isgeneratorfunction(call):
|
||||||
return True
|
return True
|
||||||
func_ = getattr(func, "__call__", None)
|
func_ = getattr(call, "__call__", None)
|
||||||
return inspect.isgeneratorfunction(func_)
|
return inspect.isgeneratorfunction(func_)
|
||||||
|
|
||||||
|
|
||||||
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
|
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||||
if inspect.isasyncgenfunction(func):
|
if inspect.isasyncgenfunction(call):
|
||||||
return True
|
return True
|
||||||
func_ = getattr(func, "__call__", None)
|
func_ = getattr(call, "__call__", None)
|
||||||
return inspect.isasyncgenfunction(func_)
|
return inspect.isasyncgenfunction(func_)
|
||||||
|
|
||||||
|
|
||||||
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
@ -97,17 +97,17 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
|||||||
|
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``func: Callable[P, R]``: 被装饰的同步函数
|
* ``call: Callable[P, R]``: 被装饰的同步函数
|
||||||
|
|
||||||
:返回:
|
:返回:
|
||||||
|
|
||||||
- ``Callable[P, Awaitable[R]]``
|
- ``Callable[P, Awaitable[R]]``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(call)
|
||||||
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
pfunc = partial(func, *args, **kwargs)
|
pfunc = partial(call, *args, **kwargs)
|
||||||
result = await loop.run_in_executor(None, pfunc)
|
result = await loop.run_in_executor(None, pfunc)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user