"""本模块包含了 NoneBot 的一些工具函数 FrontMatter: mdx: format: md sidebar_position: 8 description: nonebot.utils 模块 """ import re import json import asyncio import inspect import importlib import contextlib import dataclasses from pathlib import Path from collections import deque from contextvars import copy_context from functools import wraps, partial from contextlib import AbstractContextManager, asynccontextmanager from typing_extensions import ParamSpec, get_args, override, get_origin from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload from pydantic import BaseModel from nonebot.log import logger from nonebot.typing import ( is_none_type, type_has_args, origin_is_union, origin_is_literal, all_literal_values, ) P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") K = TypeVar("K") V = TypeVar("V") def escape_tag(s: str) -> str: """用于记录带颜色日志时转义 `` 类型特殊标签 参考: [loguru color 标签](https://loguru.readthedocs.io/en/stable/api/logger.html#color) 参数: s: 需要转义的字符串 """ return re.sub(r"\s]*)>", r"\\\g<0>", s) def deep_update( mapping: dict[K, Any], *updating_mappings: dict[K, Any] ) -> dict[K, Any]: """深度更新合并字典""" updated_mapping = mapping.copy() for updating_mapping in updating_mappings: for k, v in updating_mapping.items(): if ( k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict) ): updated_mapping[k] = deep_update(updated_mapping[k], v) else: updated_mapping[k] = v return updated_mapping def lenient_issubclass( cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...]] ) -> bool: """检查 cls 是否是 class_or_tuple 中的一个类型子类并忽略类型错误。""" try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) except TypeError: return False def generic_check_issubclass( cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...]] ) -> bool: """检查 cls 是否是 class_or_tuple 中的一个类型子类。 特别的: - 如果 cls 是 `typing.Union` 或 `types.UnionType` 类型, 则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。 - 如果 cls 是 `typing.Literal` 类型, 则会检查其中的所有值是否是 class_or_tuple 中一个类型的实例。 - 如果 cls 是 `typing.TypeVar` 类型, 则会检查其 `__bound__` 或 `__constraints__` 是否是 class_or_tuple 中一个类型的子类或 None。 """ if not type_has_args(cls): with contextlib.suppress(TypeError): return issubclass(cls, class_or_tuple) origin = get_origin(cls) if origin_is_union(origin): return all( is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple) for type_ in get_args(cls) ) elif origin_is_literal(origin): return all( is_none_type(value) or isinstance(value, class_or_tuple) for value in all_literal_values(cls) ) # ensure generic List, Dict can be checked elif origin: # avoid class check error (typing.Final, typing.ClassVar, etc...) try: return issubclass(origin, class_or_tuple) except TypeError: return False elif isinstance(cls, TypeVar): if cls.__constraints__: return all( is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple) for type_ in cls.__constraints__ ) elif cls.__bound__: return generic_check_issubclass(cls.__bound__, class_or_tuple) return False def type_is_complex(type_: type[Any]) -> bool: """检查 type_ 是否是复杂类型""" origin = get_origin(type_) return _type_is_complex_inner(type_) or _type_is_complex_inner(origin) def _type_is_complex_inner(type_: Optional[type[Any]]) -> bool: if lenient_issubclass(type_, (str, bytes)): return False return lenient_issubclass( type_, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque) ) or dataclasses.is_dataclass(type_) def is_coroutine_callable(call: Callable[..., Any]) -> bool: """检查 call 是否是一个 callable 协程函数""" if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False func_ = getattr(call, "__call__", None) return inspect.iscoroutinefunction(func_) def is_gen_callable(call: Callable[..., Any]) -> bool: """检查 call 是否是一个生成器函数""" if inspect.isgeneratorfunction(call): return True func_ = getattr(call, "__call__", None) return inspect.isgeneratorfunction(func_) def is_async_gen_callable(call: Callable[..., Any]) -> bool: """检查 call 是否是一个异步生成器函数""" if inspect.isasyncgenfunction(call): return True func_ = getattr(call, "__call__", None) return inspect.isasyncgenfunction(func_) def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]: """一个用于包装 sync function 为 async function 的装饰器 参数: call: 被装饰的同步函数 """ @wraps(call) async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: loop = asyncio.get_running_loop() pfunc = partial(call, *args, **kwargs) context = copy_context() result = await loop.run_in_executor(None, partial(context.run, pfunc)) return result return _wrapper @asynccontextmanager async def run_sync_ctx_manager( cm: AbstractContextManager[T], ) -> AsyncGenerator[T, None]: """一个用于包装 sync context manager 为 async context manager 的执行函数""" try: yield await run_sync(cm.__enter__)() except Exception as e: ok = await run_sync(cm.__exit__)(type(e), e, None) if not ok: raise e else: await run_sync(cm.__exit__)(None, None, None) @overload async def run_coro_with_catch( coro: Coroutine[Any, Any, T], exc: tuple[type[Exception], ...], return_on_err: None = None, ) -> Union[T, None]: ... @overload async def run_coro_with_catch( coro: Coroutine[Any, Any, T], exc: tuple[type[Exception], ...], return_on_err: R, ) -> Union[T, R]: ... async def run_coro_with_catch( coro: Coroutine[Any, Any, T], exc: tuple[type[Exception], ...], return_on_err: Optional[R] = None, ) -> Optional[Union[T, R]]: """运行协程并当遇到指定异常时返回指定值。 参数: coro: 要运行的协程 exc: 要捕获的异常 return_on_err: 当发生异常时返回的值 返回: 协程的返回值或发生异常时的指定值 """ try: return await coro except exc: return return_on_err def get_name(obj: Any) -> str: """获取对象的名称""" if inspect.isfunction(obj) or inspect.isclass(obj): return obj.__name__ return obj.__class__.__name__ def path_to_module_name(path: Path) -> str: """转换路径为模块名""" rel_path = path.resolve().relative_to(Path.cwd().resolve()) if rel_path.stem == "__init__": return ".".join(rel_path.parts[:-1]) else: return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) def resolve_dot_notation( obj_str: str, default_attr: str, default_prefix: Optional[str] = None ) -> Any: """解析并导入点分表示法的对象""" modulename, _, cls = obj_str.partition(":") if default_prefix is not None and modulename.startswith("~"): modulename = default_prefix + modulename[1:] module = importlib.import_module(modulename) if not cls: return getattr(module, default_attr) instance = module for attr_str in cls.split("."): instance = getattr(instance, attr_str) return instance class classproperty(Generic[T]): """类属性装饰器""" def __init__(self, func: Callable[[Any], T]) -> None: self.func = func def __get__(self, instance: Any, owner: Optional[type[Any]] = None) -> T: return self.func(type(instance) if owner is None else owner) class DataclassEncoder(json.JSONEncoder): """可以序列化 {ref}`nonebot.adapters.Message`(List[Dataclass]) 的 `JSONEncoder`""" @override def default(self, o): if dataclasses.is_dataclass(o): return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} return super().default(o) def logger_wrapper(logger_name: str): """用于打印 adapter 的日志。 参数: logger_name: adapter 的名称 返回: 日志记录函数 日志记录函数的参数: - level: 日志等级 - message: 日志信息 - exception: 异常信息 """ def log(level: str, message: str, exception: Optional[Exception] = None): logger.opt(colors=True, exception=exception).log( level, f"{escape_tag(logger_name)} | {message}" ) return log