import re import json import asyncio import inspect import collections import dataclasses from functools import wraps, partial from contextlib import asynccontextmanager from typing_extensions import ParamSpec, get_args, get_origin from typing import ( Any, Type, Deque, Tuple, Union, TypeVar, Callable, Optional, Awaitable, AsyncGenerator, ContextManager, ) from nonebot.log import logger from nonebot.typing import overrides P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") def escape_tag(s: str) -> str: """ :说明: 用于记录带颜色日志时转义 ```` 类型特殊标签 :参数: * ``s: str``: 需要转义的字符串 :返回: - ``str`` """ return re.sub(r"\s]*)>", r"\\\g<0>", s) def generic_check_issubclass( cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]] ) -> bool: try: return issubclass(cls, class_or_tuple) except TypeError: origin = get_origin(cls) if origin is Union: for type_ in get_args(cls): if type_ is not type(None) and not generic_check_issubclass( type_, class_or_tuple ): return False return True elif origin: return issubclass(origin, class_or_tuple) raise def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False func_ = getattr(call, "__call__", None) return inspect.iscoroutinefunction(func_) def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True func_ = getattr(call, "__call__", None) return inspect.isgeneratorfunction(func_) def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True func_ = getattr(call, "__call__", None) return inspect.isasyncgenfunction(func_) def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]: """ :说明: 一个用于包装 sync function 为 async function 的装饰器 :参数: * ``call: Callable[P, R]``: 被装饰的同步函数 :返回: - ``Callable[P, Awaitable[R]]`` """ @wraps(call) async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: loop = asyncio.get_running_loop() pfunc = partial(call, *args, **kwargs) result = await loop.run_in_executor(None, pfunc) return result return _wrapper @asynccontextmanager async def run_sync_ctx_manager( cm: ContextManager[T], ) -> AsyncGenerator[T, None]: try: yield await run_sync(cm.__enter__)() except Exception as e: ok = await run_sync(cm.__exit__)(type(e), e, None) if not ok: raise e else: await run_sync(cm.__exit__)(None, None, None) def get_name(obj: Any) -> str: if inspect.isfunction(obj) or inspect.isclass(obj): return obj.__name__ return obj.__class__.__name__ class CacheLock: def __init__(self): self._waiters: Optional[Deque[asyncio.Future]] = None self._locked = False def __repr__(self): extra = "locked" if self._locked else "unlocked" if self._waiters: extra = f"{extra}, waiters: {len(self._waiters)}" return f"<{self.__class__.__name__} [{extra}]>" async def __aenter__(self): await self.acquire() return None async def __aexit__(self, exc_type, exc, tb): self.release() def locked(self): return self._locked async def acquire(self): if not self._locked and ( self._waiters is None or all(w.cancelled() for w in self._waiters) ): self._locked = True return True if self._waiters is None: self._waiters = collections.deque() loop = asyncio.get_running_loop() future = loop.create_future() self._waiters.append(future) # Finally block should be called before the CancelledError # handling as we don't want CancelledError to call # _wake_up_first() and attempt to wake up itself. try: try: await future finally: self._waiters.remove(future) except asyncio.CancelledError: if not self._locked: self._wake_up_first() raise self._locked = True return True def release(self): if self._locked: self._locked = False self._wake_up_first() else: raise RuntimeError("Lock is not acquired.") def _wake_up_first(self): if not self._waiters: return try: future = next(iter(self._waiters)) except StopIteration: return # .done() necessarily means that a waiter will wake up later on and # either take the lock, or, if it was cancelled and lock wasn't # taken already, will hit this again and wake up a new waiter. if not future.done(): future.set_result(True) class DataclassEncoder(json.JSONEncoder): """ :说明: 在JSON序列化 ``Message`` (List[Dataclass]) 时使用的 ``JSONEncoder`` """ @overrides(json.JSONEncoder) def default(self, o): if dataclasses.is_dataclass(o): return dataclasses.asdict(o) return super().default(o) def logger_wrapper(logger_name: str): """ :说明: 用于打印 adapter 的日志。 :log 参数: * ``level: Literal['WARNING', 'DEBUG', 'INFO']``: 日志等级 * ``message: str``: 日志信息 * ``exception: Optional[Exception]``: 异常信息 """ def log(level: str, message: str, exception: Optional[Exception] = None): return logger.opt(colors=True, exception=exception).log( level, f"{escape_tag(logger_name)} | " + message ) return log