nonebot2/nonebot/utils.py

242 lines
6.0 KiB
Python
Raw Normal View History

2020-09-30 18:01:31 +08:00
import re
2020-08-10 13:06:02 +08:00
import json
2020-08-14 17:41:24 +08:00
import asyncio
import inspect
2021-11-21 15:46:48 +08:00
import collections
2020-08-10 13:06:02 +08:00
import dataclasses
2020-08-14 17:41:24 +08:00
from functools import wraps, partial
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, get_args, get_origin
from typing import (
Any,
Type,
Deque,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Awaitable,
AsyncGenerator,
ContextManager,
)
2020-08-10 13:06:02 +08:00
2020-12-02 19:52:45 +08:00
from nonebot.log import logger
2020-12-06 02:30:19 +08:00
from nonebot.typing import overrides
2020-08-14 17:41:24 +08:00
2021-11-13 19:38:01 +08:00
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
2021-11-13 19:38:01 +08:00
2020-08-14 17:41:24 +08:00
2020-09-30 18:01:31 +08:00
def escape_tag(s: str) -> str:
2020-10-01 00:39:44 +08:00
"""
:说明:
2020-11-30 11:08:00 +08:00
2020-10-01 00:55:03 +08:00
用于记录带颜色日志时转义 ``<tag>`` 类型特殊标签
2020-11-30 11:08:00 +08:00
2020-10-01 00:39:44 +08:00
:参数:
2020-11-30 11:08:00 +08:00
2020-10-01 00:39:44 +08:00
* ``s: str``: 需要转义的字符串
2020-11-30 11:08:00 +08:00
2020-10-01 00:39:44 +08:00
:返回:
2020-11-30 11:08:00 +08:00
2020-10-01 00:39:44 +08:00
- ``str``
"""
2020-09-30 18:01:31 +08:00
return re.sub(r"</?((?:[fb]g\s)?[^<>\s]*)>", r"\\\g<0>", s)
def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]
) -> bool:
try:
return issubclass(cls, class_or_tuple)
except TypeError:
2021-12-06 11:27:25 +08:00
origin = get_origin(cls)
if origin is Union:
for type_ in get_args(cls):
if type_ is not type(None) and not generic_check_issubclass(
type_, class_or_tuple
):
return False
return True
2021-12-06 11:27:25 +08:00
elif origin:
return issubclass(origin, class_or_tuple)
raise
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
2021-11-13 19:38:01 +08:00
return False
func_ = getattr(call, "__call__", None)
2021-11-13 19:38:01 +08:00
return inspect.iscoroutinefunction(func_)
def is_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isgeneratorfunction(call):
return True
func_ = getattr(call, "__call__", None)
return inspect.isgeneratorfunction(func_)
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
if inspect.isasyncgenfunction(call):
return True
func_ = getattr(call, "__call__", None)
return inspect.isasyncgenfunction(func_)
def run_sync(call: Callable[P, R]) -> Callable[P, Awaitable[R]]:
2020-09-13 00:43:31 +08:00
"""
:说明:
2020-11-30 11:08:00 +08:00
2020-09-13 00:43:31 +08:00
一个用于包装 sync function async function 的装饰器
2020-11-30 11:08:00 +08:00
2020-09-13 00:43:31 +08:00
:参数:
2020-11-30 11:08:00 +08:00
* ``call: Callable[P, R]``: 被装饰的同步函数
2020-11-30 11:08:00 +08:00
2020-09-13 00:43:31 +08:00
:返回:
2020-11-30 11:08:00 +08:00
2021-11-13 19:38:01 +08:00
- ``Callable[P, Awaitable[R]]``
2020-09-13 00:43:31 +08:00
"""
2020-08-14 17:41:24 +08:00
@wraps(call)
2021-11-13 19:38:01 +08:00
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
2020-08-14 17:41:24 +08:00
loop = asyncio.get_running_loop()
pfunc = partial(call, *args, **kwargs)
2020-08-14 17:41:24 +08:00
result = await loop.run_in_executor(None, pfunc)
return result
return _wrapper
2020-08-10 13:06:02 +08:00
@asynccontextmanager
async def run_sync_ctx_manager(
cm: ContextManager[T],
) -> AsyncGenerator[T, None]:
try:
yield await run_sync(cm.__enter__)()
except Exception as e:
ok = await run_sync(cm.__exit__)(type(e), e, None)
if not ok:
raise e
else:
await run_sync(cm.__exit__)(None, None, None)
def get_name(obj: Any) -> str:
if inspect.isfunction(obj) or inspect.isclass(obj):
return obj.__name__
return obj.__class__.__name__
2021-11-21 15:46:48 +08:00
class CacheLock:
def __init__(self):
self._waiters: Optional[Deque[asyncio.Future]] = None
self._locked = False
def __repr__(self):
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters: {len(self._waiters)}"
return f"<{self.__class__.__name__} [{extra}]>"
async def __aenter__(self):
await self.acquire()
return None
async def __aexit__(self, exc_type, exc, tb):
self.release()
def locked(self):
return self._locked
async def acquire(self):
if not self._locked and (
self._waiters is None or all(w.cancelled() for w in self._waiters)
):
2021-11-21 15:46:48 +08:00
self._locked = True
return True
if self._waiters is None:
self._waiters = collections.deque()
loop = asyncio.get_running_loop()
future = loop.create_future()
self._waiters.append(future)
# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await future
finally:
self._waiters.remove(future)
except asyncio.CancelledError:
if not self._locked:
self._wake_up_first()
raise
self._locked = True
return True
def release(self):
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
def _wake_up_first(self):
if not self._waiters:
return
try:
future = next(iter(self._waiters))
except StopIteration:
return
# .done() necessarily means that a waiter will wake up later on and
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
if not future.done():
future.set_result(True)
2020-08-10 13:06:02 +08:00
class DataclassEncoder(json.JSONEncoder):
2020-09-13 00:43:31 +08:00
"""
:说明:
2020-11-30 11:08:00 +08:00
2020-09-13 13:01:23 +08:00
在JSON序列化 ``Message`` (List[Dataclass]) 时使用的 ``JSONEncoder``
2020-09-13 00:43:31 +08:00
"""
2020-08-10 13:06:02 +08:00
@overrides(json.JSONEncoder)
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
2020-12-02 19:52:45 +08:00
def logger_wrapper(logger_name: str):
2020-12-03 17:08:16 +08:00
"""
:说明:
2020-12-02 19:52:45 +08:00
2020-12-03 17:08:16 +08:00
用于打印 adapter 的日志
2020-12-02 19:52:45 +08:00
2020-12-03 17:08:16 +08:00
:log 参数:
2020-12-02 19:52:45 +08:00
2020-12-03 17:08:16 +08:00
* ``level: Literal['WARNING', 'DEBUG', 'INFO']``: 日志等级
* ``message: str``: 日志信息
* ``exception: Optional[Exception]``: 异常信息
"""
2020-12-02 19:52:45 +08:00
2020-12-03 17:08:16 +08:00
def log(level: str, message: str, exception: Optional[Exception] = None):
return logger.opt(colors=True, exception=exception).log(
level, f"<m>{escape_tag(logger_name)}</m> | " + message
)
2020-12-02 19:52:45 +08:00
return log