mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
♻️ use class rule and permission
This commit is contained in:
parent
ca4d7397f8
commit
5b75b72720
@ -69,22 +69,22 @@ def get_sub_dependant(
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
) -> Dependent:
|
||||
sub_dependant = get_dependent(
|
||||
func=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
|
||||
call=dependency, name=name, use_cache=depends.use_cache, allow_types=allow_types
|
||||
)
|
||||
return sub_dependant
|
||||
|
||||
|
||||
def get_dependent(
|
||||
*,
|
||||
func: T_Handler,
|
||||
call: T_Handler,
|
||||
name: Optional[str] = None,
|
||||
use_cache: bool = True,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
) -> Dependent:
|
||||
signature = get_typed_signature(func)
|
||||
signature = get_typed_signature(call)
|
||||
params = signature.parameters
|
||||
dependent = Dependent(
|
||||
func=func, name=name, allow_types=allow_types, use_cache=use_cache
|
||||
call=call, name=name, allow_types=allow_types, use_cache=use_cache
|
||||
)
|
||||
for param_name, param in params.items():
|
||||
if isinstance(param.default, DependsWrapper):
|
||||
@ -108,7 +108,7 @@ def get_dependent(
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
|
||||
f"Unknown parameter {param_name} for function {call} with type {param.annotation}"
|
||||
)
|
||||
|
||||
annotation: Any = Any
|
||||
@ -153,7 +153,7 @@ async def solve_dependencies(
|
||||
if errs_:
|
||||
logger.debug(
|
||||
f"{field_info} "
|
||||
f"type {type(value)} not match depends {_dependent.func} "
|
||||
f"type {type(value)} not match depends {_dependent.call} "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise SkippedException(field, value)
|
||||
@ -163,9 +163,9 @@ async def solve_dependencies(
|
||||
# solve sub dependencies
|
||||
sub_dependent: Dependent
|
||||
for sub_dependent in chain(_sub_dependents or tuple(), _dependent.dependencies):
|
||||
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
|
||||
sub_dependent.call = cast(Callable[..., Any], sub_dependent.call)
|
||||
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key)
|
||||
func = sub_dependent.func
|
||||
call = sub_dependent.call
|
||||
|
||||
# solve sub dependency with current cache
|
||||
solved_result = await solve_dependencies(
|
||||
@ -179,19 +179,19 @@ async def solve_dependencies(
|
||||
async with cache_lock:
|
||||
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
|
||||
solved = dependency_cache[sub_dependent.cache_key]
|
||||
elif is_gen_callable(func) or is_async_gen_callable(func):
|
||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
||||
assert isinstance(
|
||||
_stack, AsyncExitStack
|
||||
), "Generator dependency should be called in context"
|
||||
if is_gen_callable(func):
|
||||
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
|
||||
if is_gen_callable(call):
|
||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||
else:
|
||||
cm = asynccontextmanager(func)(**sub_values)
|
||||
cm = asynccontextmanager(call)(**sub_values)
|
||||
solved = await _stack.enter_async_context(cm)
|
||||
elif is_coroutine_callable(func):
|
||||
solved = await func(**sub_values)
|
||||
elif is_coroutine_callable(call):
|
||||
solved = await call(**sub_values)
|
||||
else:
|
||||
solved = await run_sync(func)(**sub_values)
|
||||
solved = await run_sync(call)(**sub_values)
|
||||
|
||||
# parameter dependency
|
||||
if sub_dependent.name is not None:
|
||||
|
@ -36,17 +36,17 @@ class Dependent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
func: Optional[T_Handler] = None,
|
||||
call: Optional[T_Handler] = None,
|
||||
name: Optional[str] = None,
|
||||
params: Optional[List[ModelField]] = None,
|
||||
allow_types: Optional[List[Type[Param]]] = None,
|
||||
dependencies: Optional[List["Dependent"]] = None,
|
||||
use_cache: bool = True,
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.call = call
|
||||
self.name = name
|
||||
self.params = params or []
|
||||
self.allow_types = allow_types or []
|
||||
self.dependencies = dependencies or []
|
||||
self.use_cache = use_cache
|
||||
self.cache_key = self.func
|
||||
self.cache_key = self.call
|
||||
|
@ -7,9 +7,9 @@ from pydantic.typing import ForwardRef, evaluate_forwardref
|
||||
from nonebot.typing import T_Handler
|
||||
|
||||
|
||||
def get_typed_signature(func: T_Handler) -> inspect.Signature:
|
||||
signature = inspect.signature(func)
|
||||
globalns = getattr(func, "__globals__", {})
|
||||
def get_typed_signature(call: T_Handler) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
|
@ -25,7 +25,7 @@ class Handler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
call: Callable[..., Any],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
dependencies: Optional[List[DependsWrapper]] = None,
|
||||
@ -38,17 +38,17 @@ class Handler:
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[..., Any]``: 事件处理函数。
|
||||
* ``call: Callable[..., Any]``: 事件处理函数。
|
||||
* ``name: Optional[str]``: 事件处理器名称。默认为函数名。
|
||||
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入。
|
||||
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型。
|
||||
"""
|
||||
self.func = func
|
||||
self.call = call
|
||||
"""
|
||||
:类型: ``Callable[..., Any]``
|
||||
:说明: 事件处理函数
|
||||
"""
|
||||
self.name = get_name(func) if name is None else name
|
||||
self.name = get_name(call) if name is None else name
|
||||
"""
|
||||
:类型: ``str``
|
||||
:说明: 事件处理函数名
|
||||
@ -68,7 +68,7 @@ class Handler:
|
||||
if dependencies:
|
||||
for depends in dependencies:
|
||||
self.cache_dependent(depends)
|
||||
self.dependent = get_dependent(func=func, allow_types=self.allow_types)
|
||||
self.dependent = get_dependent(call=call, allow_types=self.allow_types)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
|
||||
@ -94,10 +94,10 @@ class Handler:
|
||||
**params,
|
||||
)
|
||||
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(**values)
|
||||
if asyncio.iscoroutinefunction(self.call):
|
||||
return await self.call(**values)
|
||||
else:
|
||||
return await run_sync(self.func)(**values)
|
||||
return await run_sync(self.call)(**values)
|
||||
|
||||
def cache_dependent(self, dependency: DependsWrapper):
|
||||
if not dependency.dependency:
|
||||
|
@ -442,7 +442,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
def _decorator(func: T_Handler) -> T_Handler:
|
||||
|
||||
if cls.handlers and cls.handlers[-1].func is func:
|
||||
if cls.handlers and cls.handlers[-1].call is func:
|
||||
func_handler = cls.handlers[-1]
|
||||
for depend in reversed(_dependencies):
|
||||
func_handler.prepend_dependency(depend)
|
||||
@ -513,7 +513,7 @@ class Matcher(metaclass=MatcherMeta):
|
||||
|
||||
def _decorator(func: T_Handler) -> T_Handler:
|
||||
|
||||
if cls.handlers and cls.handlers[-1].func is func:
|
||||
if cls.handlers and cls.handlers[-1].call is func:
|
||||
func_handler = cls.handlers[-1]
|
||||
for depend in reversed(_dependencies):
|
||||
func_handler.prepend_dependency(depend)
|
||||
|
@ -11,7 +11,17 @@ r"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Callable,
|
||||
NoReturn,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from nonebot import params
|
||||
from nonebot.handler import Handler
|
||||
@ -119,41 +129,59 @@ class Permission:
|
||||
return Permission(*self.checkers, other)
|
||||
|
||||
|
||||
async def _message(event: Event) -> bool:
|
||||
return event.get_type() == "message"
|
||||
class Message:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "message"
|
||||
|
||||
|
||||
async def _notice(event: Event) -> bool:
|
||||
return event.get_type() == "notice"
|
||||
class Notice:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "notice"
|
||||
|
||||
|
||||
async def _request(event: Event) -> bool:
|
||||
return event.get_type() == "request"
|
||||
class Request:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "request"
|
||||
|
||||
|
||||
async def _metaevent(event: Event) -> bool:
|
||||
return event.get_type() == "meta_event"
|
||||
class MetaEvent:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "meta_event"
|
||||
|
||||
|
||||
MESSAGE = Permission(_message)
|
||||
MESSAGE = Permission(Message())
|
||||
"""
|
||||
- **说明**: 匹配任意 ``message`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 message type 的 Matcher。
|
||||
"""
|
||||
NOTICE = Permission(_notice)
|
||||
NOTICE = Permission(Notice())
|
||||
"""
|
||||
- **说明**: 匹配任意 ``notice`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 notice type 的 Matcher。
|
||||
"""
|
||||
REQUEST = Permission(_request)
|
||||
REQUEST = Permission(Request())
|
||||
"""
|
||||
- **说明**: 匹配任意 ``request`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 request type 的 Matcher。
|
||||
"""
|
||||
METAEVENT = Permission(_metaevent)
|
||||
METAEVENT = Permission(MetaEvent())
|
||||
"""
|
||||
- **说明**: 匹配任意 ``meta_event`` 类型事件,仅在需要同时捕获不同类型事件时使用。优先使用 meta_event type 的 Matcher。
|
||||
"""
|
||||
|
||||
|
||||
def USER(*user: str, perm: Optional[Permission] = None):
|
||||
class User:
|
||||
def __init__(
|
||||
self, users: Tuple[str, ...], perm: Optional[Permission] = None
|
||||
) -> None:
|
||||
self.users = users
|
||||
self.perm = perm
|
||||
|
||||
async def __call__(self, bot: Bot, event: Event) -> bool:
|
||||
return bool(
|
||||
event.get_session_id() in self.users
|
||||
and (self.perm is None or await self.perm(bot, event))
|
||||
)
|
||||
|
||||
|
||||
def USER(*users: str, perm: Optional[Permission] = None):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -165,21 +193,18 @@ def USER(*user: str, perm: Optional[Permission] = None):
|
||||
* ``perm: Optional[Permission]``: 需要同时满足的权限
|
||||
"""
|
||||
|
||||
async def _user(bot: Bot, event: Event) -> bool:
|
||||
return bool(
|
||||
event.get_session_id() in user and (perm is None or await perm(bot, event))
|
||||
return Permission(User(users, perm))
|
||||
|
||||
|
||||
class SuperUser:
|
||||
async def __call__(self, bot: Bot, event: Event) -> bool:
|
||||
return (
|
||||
event.get_type() == "message"
|
||||
and event.get_user_id() in bot.config.superusers
|
||||
)
|
||||
|
||||
return Permission(_user)
|
||||
|
||||
|
||||
async def _superuser(bot: Bot, event: Event) -> bool:
|
||||
return (
|
||||
event.get_type() == "message" and event.get_user_id() in bot.config.superusers
|
||||
)
|
||||
|
||||
|
||||
SUPERUSER = Permission(_superuser)
|
||||
SUPERUSER = Permission(SuperUser())
|
||||
"""
|
||||
- **说明**: 匹配任意超级用户消息类型事件
|
||||
"""
|
||||
|
168
nonebot/rule.py
168
nonebot/rule.py
@ -203,6 +203,24 @@ class TrieRule:
|
||||
return prefix, suffix
|
||||
|
||||
|
||||
class Startswith:
|
||||
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||
self.msg = msg
|
||||
self.ignorecase = ignorecase
|
||||
|
||||
async def __call__(self, event: Event) -> Any:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(
|
||||
re.match(
|
||||
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
|
||||
text,
|
||||
re.IGNORECASE if self.ignorecase else 0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
|
||||
"""
|
||||
:说明:
|
||||
@ -216,18 +234,25 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
|
||||
if isinstance(msg, str):
|
||||
msg = (msg,)
|
||||
|
||||
pattern = re.compile(
|
||||
f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})",
|
||||
re.IGNORECASE if ignorecase else 0,
|
||||
)
|
||||
return Rule(Startswith(msg, ignorecase))
|
||||
|
||||
async def _startswith(bot: Bot, event: Event, state: T_State) -> bool:
|
||||
|
||||
class Endswith:
|
||||
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||
self.msg = msg
|
||||
self.ignorecase = ignorecase
|
||||
|
||||
async def __call__(self, event: Event) -> Any:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(pattern.match(text))
|
||||
|
||||
return Rule(_startswith)
|
||||
return bool(
|
||||
re.search(
|
||||
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
|
||||
text,
|
||||
re.IGNORECASE if self.ignorecase else 0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
|
||||
@ -243,18 +268,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
|
||||
if isinstance(msg, str):
|
||||
msg = (msg,)
|
||||
|
||||
pattern = re.compile(
|
||||
f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$",
|
||||
re.IGNORECASE if ignorecase else 0,
|
||||
)
|
||||
return Rule(Endswith(msg, ignorecase))
|
||||
|
||||
async def _endswith(bot: Bot, event: Event, state: T_State) -> bool:
|
||||
|
||||
class Keywords:
|
||||
def __init__(self, *keywords: str):
|
||||
self.keywords = keywords
|
||||
|
||||
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(pattern.search(text))
|
||||
|
||||
return Rule(_endswith)
|
||||
return bool(text and any(keyword in text for keyword in self.keywords))
|
||||
|
||||
|
||||
def keyword(*keywords: str) -> Rule:
|
||||
@ -268,13 +293,18 @@ def keyword(*keywords: str) -> Rule:
|
||||
* ``*keywords: str``: 关键词
|
||||
"""
|
||||
|
||||
async def _keyword(event: Event) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(text and any(keyword in text for keyword in keywords))
|
||||
return Rule(Keywords(*keywords))
|
||||
|
||||
return Rule(_keyword)
|
||||
|
||||
class Command:
|
||||
def __init__(self, cmds: List[Tuple[str, ...]]):
|
||||
self.cmds = cmds
|
||||
|
||||
async def __call__(self, state: T_State) -> bool:
|
||||
return state[PREFIX_KEY][CMD_KEY] in self.cmds
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Command {self.cmds}>"
|
||||
|
||||
|
||||
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
@ -304,10 +334,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
config = get_driver().config
|
||||
command_start = config.command_start
|
||||
command_sep = config.command_sep
|
||||
commands = list(cmds)
|
||||
for index, command in enumerate(commands):
|
||||
commands: List[Tuple[str, ...]] = []
|
||||
for command in cmds:
|
||||
if isinstance(command, str):
|
||||
commands[index] = command = (command,)
|
||||
command = (command,)
|
||||
|
||||
commands.append(command)
|
||||
|
||||
if len(command) == 1:
|
||||
for start in command_start:
|
||||
@ -316,10 +348,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
async def _command(state: T_State) -> bool:
|
||||
return state[PREFIX_KEY][CMD_KEY] in commands
|
||||
|
||||
return Rule(_command)
|
||||
return Rule(Command(commands))
|
||||
|
||||
|
||||
class ArgumentParser(ArgParser):
|
||||
@ -350,6 +379,27 @@ class ArgumentParser(ArgParser):
|
||||
return super().parse_args(args=args, namespace=namespace) # type: ignore
|
||||
|
||||
|
||||
class ShellCommand:
|
||||
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
|
||||
self.cmds = cmds
|
||||
self.parser = parser
|
||||
|
||||
async def __call__(self, event: Event, state: T_State) -> bool:
|
||||
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
|
||||
message = str(event.get_message())
|
||||
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
||||
state[SHELL_ARGV] = shlex.split(strip_message)
|
||||
if self.parser:
|
||||
try:
|
||||
args = self.parser.parse_args(state[SHELL_ARGV])
|
||||
state[SHELL_ARGS] = args
|
||||
except ParserExit as e:
|
||||
state[SHELL_ARGS] = e
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def shell_command(
|
||||
*cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None
|
||||
) -> Rule:
|
||||
@ -392,10 +442,12 @@ def shell_command(
|
||||
config = get_driver().config
|
||||
command_start = config.command_start
|
||||
command_sep = config.command_sep
|
||||
commands = list(cmds)
|
||||
for index, command in enumerate(commands):
|
||||
commands: List[Tuple[str, ...]] = []
|
||||
for command in cmds:
|
||||
if isinstance(command, str):
|
||||
commands[index] = command = (command,)
|
||||
command = (command,)
|
||||
|
||||
commands.append(command)
|
||||
|
||||
if len(command) == 1:
|
||||
for start in command_start:
|
||||
@ -404,23 +456,26 @@ def shell_command(
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
async def _shell_command(event: Event, state: T_State) -> bool:
|
||||
if state[PREFIX_KEY][CMD_KEY] in commands:
|
||||
message = str(event.get_message())
|
||||
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
||||
state[SHELL_ARGV] = shlex.split(strip_message)
|
||||
if parser:
|
||||
try:
|
||||
args = parser.parse_args(state[SHELL_ARGV])
|
||||
state[SHELL_ARGS] = args
|
||||
except ParserExit as e:
|
||||
state[SHELL_ARGS] = e
|
||||
return Rule(ShellCommand(commands, parser))
|
||||
|
||||
|
||||
class Regex:
|
||||
def __init__(self, regex: str, flags: int = 0):
|
||||
self.regex = regex
|
||||
self.flags = flags
|
||||
|
||||
async def __call__(self, event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
matched = re.search(self.regex, str(event.get_message()), self.flags)
|
||||
if matched:
|
||||
state[REGEX_MATCHED] = matched.group()
|
||||
state[REGEX_GROUP] = matched.groups()
|
||||
state[REGEX_DICT] = matched.groupdict()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
return Rule(_shell_command)
|
||||
|
||||
|
||||
def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||
r"""
|
||||
@ -441,25 +496,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||
\:\:\:
|
||||
"""
|
||||
|
||||
pattern = re.compile(regex, flags)
|
||||
|
||||
async def _regex(event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
return False
|
||||
matched = pattern.search(str(event.get_message()))
|
||||
if matched:
|
||||
state[REGEX_MATCHED] = matched.group()
|
||||
state[REGEX_GROUP] = matched.groups()
|
||||
state[REGEX_DICT] = matched.groupdict()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
return Rule(_regex)
|
||||
return Rule(Regex(regex, flags))
|
||||
|
||||
|
||||
async def _to_me(event: Event) -> bool:
|
||||
return event.is_tome()
|
||||
class ToMe:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.is_tome()
|
||||
|
||||
|
||||
def to_me() -> Rule:
|
||||
@ -473,4 +515,4 @@ def to_me() -> Rule:
|
||||
* 无
|
||||
"""
|
||||
|
||||
return Rule(_to_me)
|
||||
return Rule(ToMe())
|
||||
|
@ -66,30 +66,30 @@ def generic_check_issubclass(
|
||||
raise
|
||||
|
||||
|
||||
def is_coroutine_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(func):
|
||||
return inspect.iscoroutinefunction(func)
|
||||
if inspect.isclass(func):
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isroutine(call):
|
||||
return inspect.iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
func_ = getattr(func, "__call__", None)
|
||||
func_ = getattr(call, "__call__", None)
|
||||
return inspect.iscoroutinefunction(func_)
|
||||
|
||||
|
||||
def is_gen_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(func):
|
||||
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isgeneratorfunction(call):
|
||||
return True
|
||||
func_ = getattr(func, "__call__", None)
|
||||
func_ = getattr(call, "__call__", None)
|
||||
return inspect.isgeneratorfunction(func_)
|
||||
|
||||
|
||||
def is_async_gen_callable(func: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(func):
|
||||
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
||||
if inspect.isasyncgenfunction(call):
|
||||
return True
|
||||
func_ = getattr(func, "__call__", None)
|
||||
func_ = getattr(call, "__call__", None)
|
||||
return inspect.isasyncgenfunction(func_)
|
||||
|
||||
|
||||
def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -97,17 +97,17 @@ def run_sync(func: Callable[P, R]) -> Callable[P, Awaitable[R]]:
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[P, R]``: 被装饰的同步函数
|
||||
* ``call: Callable[P, R]``: 被装饰的同步函数
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Callable[P, Awaitable[R]]``
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
@wraps(call)
|
||||
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
loop = asyncio.get_running_loop()
|
||||
pfunc = partial(func, *args, **kwargs)
|
||||
pfunc = partial(call, *args, **kwargs)
|
||||
result = await loop.run_in_executor(None, pfunc)
|
||||
return result
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user