nonebot2/nonebot/utils.py

228 lines
6.1 KiB
Python
Raw Normal View History

2022-01-19 16:16:56 +08:00
"""本模块包含了 NoneBot 的一些工具函数
2022-01-16 11:30:09 +08:00
FrontMatter:
sidebar_position: 8
description: nonebot.utils 模块
"""
2020-09-30 18:01:31 +08:00
import re
2020-08-10 13:06:02 +08:00
import json
2020-08-14 17:41:24 +08:00
import asyncio
import inspect
import importlib
2020-08-10 13:06:02 +08:00
import dataclasses
from pathlib import Path
from contextvars import copy_context
2020-08-14 17:41:24 +08:00
from functools import wraps, partial
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, get_args, get_origin
from typing import (
Any,
Type,
Tuple,
Union,
TypeVar,
Callable,
Optional,
2022-02-11 11:25:31 +08:00
Coroutine,
AsyncGenerator,
ContextManager,
overload,
)
2020-08-10 13:06:02 +08:00
from pydantic.typing import is_union, is_none_type
2020-12-02 19:52:45 +08:00
from nonebot.log import logger
2020-12-06 02:30:19 +08:00
from nonebot.typing import overrides
2020-08-14 17:41:24 +08:00
2021-11-13 19:38:01 +08:00
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
2021-11-13 19:38:01 +08:00
2020-08-14 17:41:24 +08:00
2020-09-30 18:01:31 +08:00
def escape_tag(s: str) -> str:
2022-01-19 16:16:56 +08:00
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签
参考: [loguru color 标签](https://loguru.readthedocs.io/en/stable/api/logger.html#color)
2020-11-30 11:08:00 +08:00
2022-01-12 18:31:12 +08:00
参数:
2022-01-12 19:10:29 +08:00
s: 需要转义的字符串
2020-10-01 00:39:44 +08:00
"""
2020-09-30 18:01:31 +08:00
return re.sub(r"</?((?:[fb]g\s)?[^<>\s]*)>", r"\\\g<0>", s)
def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]
) -> bool:
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。
特别的如果 cls `typing.Union` `types.UnionType` 类型
则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None
"""
try:
return issubclass(cls, class_or_tuple)
except TypeError:
2021-12-06 11:27:25 +08:00
origin = get_origin(cls)
if is_union(origin):
return all(
is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple)
for type_ in get_args(cls)
)
2021-12-06 11:27:25 +08:00
elif origin:
return issubclass(origin, class_or_tuple)
2021-12-14 01:08:48 +08:00
return False
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
2022-01-19 16:16:56 +08:00
"""检查 call 是否是一个 callable 协程函数"""
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
2021-11-13 19:38:01 +08:00
return False
func_ = getattr(call, "__call__", None)
2021-11-13 19:38:01 +08:00
return inspect.iscoroutinefunction(func_)
def is_gen_callable(call: Callable[..., Any]) -> bool:
2022-01-19 16:16:56 +08:00
"""检查 call 是否是一个生成器函数"""
if inspect.isgeneratorfunction(call):
return True
func_ = getattr(call, "__call__", None)
return inspect.isgeneratorfunction(func_)
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
2022-01-19 16:16:56 +08:00
"""检查 call 是否是一个异步生成器函数"""
if inspect.isasyncgenfunction(call):
return True
func_ = getattr(call, "__call__", None)
return inspect.isasyncgenfunction(func_)
2022-02-11 11:25:31 +08:00
def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
2022-01-19 16:16:56 +08:00
"""一个用于包装 sync function 为 async function 的装饰器
2020-11-30 11:08:00 +08:00
2022-01-12 18:31:12 +08:00
参数:
2022-01-12 19:10:29 +08:00
call: 被装饰的同步函数
2020-09-13 00:43:31 +08:00
"""
2020-08-14 17:41:24 +08:00
@wraps(call)
2021-11-13 19:38:01 +08:00
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
2020-08-14 17:41:24 +08:00
loop = asyncio.get_running_loop()
pfunc = partial(call, *args, **kwargs)
context = copy_context()
result = await loop.run_in_executor(None, partial(context.run, pfunc))
2020-08-14 17:41:24 +08:00
return result
return _wrapper
2020-08-10 13:06:02 +08:00
@asynccontextmanager
async def run_sync_ctx_manager(
cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
2022-01-19 16:16:56 +08:00
"""一个用于包装 sync context manager 为 async context manager 的执行函数"""
try:
yield await run_sync(cm.__enter__)()
except Exception as e:
ok = await run_sync(cm.__exit__)(type(e), e, None)
if not ok:
raise e
else:
await run_sync(cm.__exit__)(None, None, None)
@overload
async def run_coro_with_catch(
coro: Coroutine[Any, Any, T],
exc: Tuple[Type[Exception], ...],
return_on_err: None = None,
) -> Union[T, None]:
...
@overload
async def run_coro_with_catch(
coro: Coroutine[Any, Any, T],
exc: Tuple[Type[Exception], ...],
return_on_err: R,
) -> Union[T, R]:
...
async def run_coro_with_catch(
coro: Coroutine[Any, Any, T],
exc: Tuple[Type[Exception], ...],
return_on_err: Optional[R] = None,
) -> Optional[Union[T, R]]:
try:
return await coro
except exc:
return return_on_err
def get_name(obj: Any) -> str:
2022-01-19 16:16:56 +08:00
"""获取对象的名称"""
if inspect.isfunction(obj) or inspect.isclass(obj):
return obj.__name__
return obj.__class__.__name__
def path_to_module_name(path: Path) -> str:
"""转换路径为模块名"""
rel_path = path.resolve().relative_to(Path.cwd().resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def resolve_dot_notation(
obj_str: str, default_attr: str, default_prefix: Optional[str] = None
) -> Any:
"""解析并导入点分表示法的对象"""
modulename, _, cls = obj_str.partition(":")
if default_prefix is not None and modulename.startswith("~"):
modulename = default_prefix + modulename[1:]
module = importlib.import_module(modulename)
if not cls:
return getattr(module, default_attr)
instance = module
for attr_str in cls.split("."):
instance = getattr(instance, attr_str)
return instance
2020-08-10 13:06:02 +08:00
class DataclassEncoder(json.JSONEncoder):
2023-06-01 14:18:16 +08:00
"""在JSON序列化 {ref}`nonebot.adapters.Message` (List[Dataclass]) 时使用的 `JSONEncoder`"""
2020-08-10 13:06:02 +08:00
@overrides(json.JSONEncoder)
def default(self, o):
if dataclasses.is_dataclass(o):
return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)}
2020-08-10 13:06:02 +08:00
return super().default(o)
2020-12-02 19:52:45 +08:00
def logger_wrapper(logger_name: str):
2022-01-19 16:16:56 +08:00
"""用于打印 adapter 的日志。
2020-12-02 19:52:45 +08:00
参数:
2022-01-19 16:16:56 +08:00
logger_name: adapter 的名称
返回:
日志记录函数
- level: 日志等级
- message: 日志信息
- exception: 异常信息
2020-12-03 17:08:16 +08:00
"""
2020-12-02 19:52:45 +08:00
2020-12-03 17:08:16 +08:00
def log(level: str, message: str, exception: Optional[Exception] = None):
2021-12-21 18:22:14 +08:00
logger.opt(colors=True, exception=exception).log(
level, f"<m>{escape_tag(logger_name)}</m> | {message}"
)
2020-12-02 19:52:45 +08:00
return log