mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 09:05:04 +08:00
317 lines
9.3 KiB
Python
317 lines
9.3 KiB
Python
"""本模块包含了 NoneBot 的一些工具函数
|
|
|
|
FrontMatter:
|
|
mdx:
|
|
format: md
|
|
sidebar_position: 8
|
|
description: nonebot.utils 模块
|
|
"""
|
|
|
|
import re
|
|
import json
|
|
import asyncio
|
|
import inspect
|
|
import importlib
|
|
import contextlib
|
|
import dataclasses
|
|
from pathlib import Path
|
|
from collections import deque
|
|
from contextvars import copy_context
|
|
from functools import wraps, partial
|
|
from contextlib import AbstractContextManager, asynccontextmanager
|
|
from typing_extensions import ParamSpec, get_args, override, get_origin
|
|
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
|
|
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from nonebot.log import logger
|
|
from nonebot.typing import (
|
|
is_none_type,
|
|
type_has_args,
|
|
origin_is_union,
|
|
origin_is_literal,
|
|
all_literal_values,
|
|
)
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
T = TypeVar("T")
|
|
K = TypeVar("K")
|
|
V = TypeVar("V")
|
|
|
|
|
|
def escape_tag(s: str) -> str:
|
|
"""用于记录带颜色日志时转义 `<tag>` 类型特殊标签
|
|
|
|
参考: [loguru color 标签](https://loguru.readthedocs.io/en/stable/api/logger.html#color)
|
|
|
|
参数:
|
|
s: 需要转义的字符串
|
|
"""
|
|
return re.sub(r"</?((?:[fb]g\s)?[^<>\s]*)>", r"\\\g<0>", s)
|
|
|
|
|
|
def deep_update(
|
|
mapping: dict[K, Any], *updating_mappings: dict[K, Any]
|
|
) -> dict[K, Any]:
|
|
"""深度更新合并字典"""
|
|
updated_mapping = mapping.copy()
|
|
for updating_mapping in updating_mappings:
|
|
for k, v in updating_mapping.items():
|
|
if (
|
|
k in updated_mapping
|
|
and isinstance(updated_mapping[k], dict)
|
|
and isinstance(v, dict)
|
|
):
|
|
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
|
else:
|
|
updated_mapping[k] = v
|
|
return updated_mapping
|
|
|
|
|
|
def lenient_issubclass(
|
|
cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...]]
|
|
) -> bool:
|
|
"""检查 cls 是否是 class_or_tuple 中的一个类型子类并忽略类型错误。"""
|
|
try:
|
|
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def generic_check_issubclass(
|
|
cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...]]
|
|
) -> bool:
|
|
"""检查 cls 是否是 class_or_tuple 中的一个类型子类。
|
|
|
|
特别的:
|
|
|
|
- 如果 cls 是 `typing.Union` 或 `types.UnionType` 类型,
|
|
则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。
|
|
- 如果 cls 是 `typing.Literal` 类型,
|
|
则会检查其中的所有值是否是 class_or_tuple 中一个类型的实例。
|
|
- 如果 cls 是 `typing.TypeVar` 类型,
|
|
则会检查其 `__bound__` 或 `__constraints__`
|
|
是否是 class_or_tuple 中一个类型的子类或 None。
|
|
"""
|
|
if not type_has_args(cls):
|
|
with contextlib.suppress(TypeError):
|
|
return issubclass(cls, class_or_tuple)
|
|
|
|
origin = get_origin(cls)
|
|
if origin_is_union(origin):
|
|
return all(
|
|
is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple)
|
|
for type_ in get_args(cls)
|
|
)
|
|
elif origin_is_literal(origin):
|
|
return all(
|
|
is_none_type(value) or isinstance(value, class_or_tuple)
|
|
for value in all_literal_values(cls)
|
|
)
|
|
# ensure generic List, Dict can be checked
|
|
elif origin:
|
|
# avoid class check error (typing.Final, typing.ClassVar, etc...)
|
|
try:
|
|
return issubclass(origin, class_or_tuple)
|
|
except TypeError:
|
|
return False
|
|
elif isinstance(cls, TypeVar):
|
|
if cls.__constraints__:
|
|
return all(
|
|
is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple)
|
|
for type_ in cls.__constraints__
|
|
)
|
|
elif cls.__bound__:
|
|
return generic_check_issubclass(cls.__bound__, class_or_tuple)
|
|
return False
|
|
|
|
|
|
def type_is_complex(type_: type[Any]) -> bool:
|
|
"""检查 type_ 是否是复杂类型"""
|
|
origin = get_origin(type_)
|
|
return _type_is_complex_inner(type_) or _type_is_complex_inner(origin)
|
|
|
|
|
|
def _type_is_complex_inner(type_: Optional[type[Any]]) -> bool:
|
|
if lenient_issubclass(type_, (str, bytes)):
|
|
return False
|
|
|
|
return lenient_issubclass(
|
|
type_, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
|
|
) or dataclasses.is_dataclass(type_)
|
|
|
|
|
|
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
|
"""检查 call 是否是一个 callable 协程函数"""
|
|
if inspect.isroutine(call):
|
|
return inspect.iscoroutinefunction(call)
|
|
if inspect.isclass(call):
|
|
return False
|
|
func_ = getattr(call, "__call__", None)
|
|
return inspect.iscoroutinefunction(func_)
|
|
|
|
|
|
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
|
"""检查 call 是否是一个生成器函数"""
|
|
if inspect.isgeneratorfunction(call):
|
|
return True
|
|
func_ = getattr(call, "__call__", None)
|
|
return inspect.isgeneratorfunction(func_)
|
|
|
|
|
|
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
|
"""检查 call 是否是一个异步生成器函数"""
|
|
if inspect.isasyncgenfunction(call):
|
|
return True
|
|
func_ = getattr(call, "__call__", None)
|
|
return inspect.isasyncgenfunction(func_)
|
|
|
|
|
|
def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
|
|
"""一个用于包装 sync function 为 async function 的装饰器
|
|
|
|
参数:
|
|
call: 被装饰的同步函数
|
|
"""
|
|
|
|
@wraps(call)
|
|
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
loop = asyncio.get_running_loop()
|
|
pfunc = partial(call, *args, **kwargs)
|
|
context = copy_context()
|
|
result = await loop.run_in_executor(None, partial(context.run, pfunc))
|
|
return result
|
|
|
|
return _wrapper
|
|
|
|
|
|
@asynccontextmanager
|
|
async def run_sync_ctx_manager(
|
|
cm: AbstractContextManager[T],
|
|
) -> AsyncGenerator[T, None]:
|
|
"""一个用于包装 sync context manager 为 async context manager 的执行函数"""
|
|
try:
|
|
yield await run_sync(cm.__enter__)()
|
|
except Exception as e:
|
|
ok = await run_sync(cm.__exit__)(type(e), e, None)
|
|
if not ok:
|
|
raise e
|
|
else:
|
|
await run_sync(cm.__exit__)(None, None, None)
|
|
|
|
|
|
@overload
|
|
async def run_coro_with_catch(
|
|
coro: Coroutine[Any, Any, T],
|
|
exc: tuple[type[Exception], ...],
|
|
return_on_err: None = None,
|
|
) -> Union[T, None]: ...
|
|
|
|
|
|
@overload
|
|
async def run_coro_with_catch(
|
|
coro: Coroutine[Any, Any, T],
|
|
exc: tuple[type[Exception], ...],
|
|
return_on_err: R,
|
|
) -> Union[T, R]: ...
|
|
|
|
|
|
async def run_coro_with_catch(
|
|
coro: Coroutine[Any, Any, T],
|
|
exc: tuple[type[Exception], ...],
|
|
return_on_err: Optional[R] = None,
|
|
) -> Optional[Union[T, R]]:
|
|
"""运行协程并当遇到指定异常时返回指定值。
|
|
|
|
参数:
|
|
coro: 要运行的协程
|
|
exc: 要捕获的异常
|
|
return_on_err: 当发生异常时返回的值
|
|
|
|
返回:
|
|
协程的返回值或发生异常时的指定值
|
|
"""
|
|
|
|
try:
|
|
return await coro
|
|
except exc:
|
|
return return_on_err
|
|
|
|
|
|
def get_name(obj: Any) -> str:
|
|
"""获取对象的名称"""
|
|
if inspect.isfunction(obj) or inspect.isclass(obj):
|
|
return obj.__name__
|
|
return obj.__class__.__name__
|
|
|
|
|
|
def path_to_module_name(path: Path) -> str:
|
|
"""转换路径为模块名"""
|
|
rel_path = path.resolve().relative_to(Path.cwd().resolve())
|
|
if rel_path.stem == "__init__":
|
|
return ".".join(rel_path.parts[:-1])
|
|
else:
|
|
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
|
|
|
|
|
|
def resolve_dot_notation(
|
|
obj_str: str, default_attr: str, default_prefix: Optional[str] = None
|
|
) -> Any:
|
|
"""解析并导入点分表示法的对象"""
|
|
modulename, _, cls = obj_str.partition(":")
|
|
if default_prefix is not None and modulename.startswith("~"):
|
|
modulename = default_prefix + modulename[1:]
|
|
module = importlib.import_module(modulename)
|
|
if not cls:
|
|
return getattr(module, default_attr)
|
|
instance = module
|
|
for attr_str in cls.split("."):
|
|
instance = getattr(instance, attr_str)
|
|
return instance
|
|
|
|
|
|
class classproperty(Generic[T]):
|
|
"""类属性装饰器"""
|
|
|
|
def __init__(self, func: Callable[[Any], T]) -> None:
|
|
self.func = func
|
|
|
|
def __get__(self, instance: Any, owner: Optional[type[Any]] = None) -> T:
|
|
return self.func(type(instance) if owner is None else owner)
|
|
|
|
|
|
class DataclassEncoder(json.JSONEncoder):
|
|
"""可以序列化 {ref}`nonebot.adapters.Message`(List[Dataclass]) 的 `JSONEncoder`"""
|
|
|
|
@override
|
|
def default(self, o):
|
|
if dataclasses.is_dataclass(o):
|
|
return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)}
|
|
return super().default(o)
|
|
|
|
|
|
def logger_wrapper(logger_name: str):
|
|
"""用于打印 adapter 的日志。
|
|
|
|
参数:
|
|
logger_name: adapter 的名称
|
|
|
|
返回:
|
|
日志记录函数
|
|
|
|
日志记录函数的参数:
|
|
|
|
- level: 日志等级
|
|
- message: 日志信息
|
|
- exception: 异常信息
|
|
"""
|
|
|
|
def log(level: str, message: str, exception: Optional[Exception] = None):
|
|
logger.opt(colors=True, exception=exception).log(
|
|
level, f"<m>{escape_tag(logger_name)}</m> | {message}"
|
|
)
|
|
|
|
return log
|