nonebot2/nonebot/internal/driver/abstract.py

354 lines
11 KiB
Python

import abc
from types import TracebackType
from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
T_BotDisconnectionHook,
)
from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import (
Request,
Response,
WebSocket,
QueryTypes,
CookieTypes,
HeaderTypes,
HTTPVersion,
HTTPServerSetup,
WebSocketServerSetup,
)
if TYPE_CHECKING:
from nonebot.internal.adapter import Bot, Adapter
BOT_HOOK_PARAMS = [DependParam, BotParam, DefaultParam]
class Driver(abc.ABC):
"""驱动器基类。
驱动器控制框架的启动和停止,适配器的注册,以及机器人生命周期管理。
参数:
env: 包含环境信息的 Env 对象
config: 包含配置信息的 Config 对象
"""
_adapters: ClassVar[dict[str, "Adapter"]] = {}
"""已注册的适配器列表"""
_bot_connection_hook: ClassVar[set[Dependent[Any]]] = set()
"""Bot 连接建立时执行的函数"""
_bot_disconnection_hook: ClassVar[set[Dependent[Any]]] = set()
"""Bot 连接断开时执行的函数"""
def __init__(self, env: Env, config: Config):
self.env: str = env.environment
"""环境名称"""
self.config: Config = config
"""全局配置对象"""
self._bots: dict[str, "Bot"] = {}
self._lifespan = Lifespan()
def __repr__(self) -> str:
return (
f"Driver(type={self.type!r}, "
f"adapters={len(self._adapters)}, bots={len(self._bots)})"
)
@property
def bots(self) -> dict[str, "Bot"]:
"""获取当前所有已连接的 Bot"""
return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
参数:
adapter: 适配器类
kwargs: 其他传递给适配器的参数
"""
name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@property
@abc.abstractmethod
def type(self) -> str:
"""驱动类型名称"""
raise NotImplementedError
@property
@abc.abstractmethod
def logger(self):
"""驱动专属 logger 日志记录器"""
raise NotImplementedError
@abc.abstractmethod
def run(self, *args, **kwargs):
"""启动驱动框架"""
logger.opt(colors=True).debug(
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
"""装饰一个函数使他在 bot 连接成功时执行。
钩子函数参数:
- bot: 当前连接上的 Bot 对象
"""
cls._bot_connection_hook.add(
Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
)
return func
@classmethod
def on_bot_disconnect(cls, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
"""装饰一个函数使他在 bot 连接断开时执行。
钩子函数参数:
- bot: 当前连接上的 Bot 对象
"""
cls._bot_disconnection_hook.add(
Dependent[Any].parse(call=func, allow_types=BOT_HOOK_PARAMS)
)
return func
def _bot_connect(self, bot: "Bot") -> None:
"""在连接成功后,调用该函数来注册 bot 对象"""
if bot.self_id in self._bots:
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot
if not self._bot_connection_hook:
return
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with AsyncExitStack() as stack, create_task_group() as tg:
for hook in self._bot_connection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
self.task_group.start_soon(_run_hook, bot)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots:
del self._bots[bot.self_id]
if not self._bot_disconnection_hook:
return
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
# shield cancellation to ensure bot disconnect hooks are always run
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with create_task_group() as tg, AsyncExitStack() as stack:
for hook in self._bot_disconnection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
self.task_group.start_soon(_run_hook, bot)
class Mixin(abc.ABC):
"""可与其他驱动器共用的混入基类。"""
@property
@abc.abstractmethod
def type(self) -> str:
"""混入驱动类型名称"""
raise NotImplementedError
class ForwardMixin(Mixin):
"""客户端混入基类。"""
class ReverseMixin(Mixin):
"""服务端混入基类。"""
class HTTPClientSession(abc.ABC):
"""HTTP 客户端会话基类。"""
@abc.abstractmethod
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
raise NotImplementedError
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
@abc.abstractmethod
async def setup(self) -> None:
"""初始化会话"""
raise NotImplementedError
@abc.abstractmethod
async def close(self) -> None:
"""关闭会话"""
raise NotImplementedError
async def __aenter__(self) -> Self:
await self.setup()
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self.close()
class HTTPClientMixin(ForwardMixin):
"""HTTP 客户端混入基类。"""
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
@abc.abstractmethod
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> HTTPClientSession:
"""获取一个 HTTP 会话"""
raise NotImplementedError
class WebSocketClientMixin(ForwardMixin):
"""WebSocket 客户端混入基类。"""
@abc.abstractmethod
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
"""发起一个 WebSocket 连接"""
raise NotImplementedError
yield # used for static type checking's generator detection
class ASGIMixin(ReverseMixin):
"""ASGI 服务端基类。
将后端框架封装,以满足适配器使用。
"""
@property
@abc.abstractmethod
def server_app(self) -> Any:
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self) -> Any:
"""驱动 ASGI 对象"""
raise NotImplementedError
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
"""设置一个 HTTP 服务器路由配置"""
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
"""设置一个 WebSocket 服务器路由配置"""
raise NotImplementedError
ForwardDriver: TypeAlias = ForwardMixin
"""支持客户端请求的驱动器。
**Deprecated**,请使用 {ref}`nonebot.drivers.ForwardMixin` 或其子类代替。
"""
ReverseDriver: TypeAlias = ReverseMixin
"""支持服务端请求的驱动器。
**Deprecated**,请使用 {ref}`nonebot.drivers.ReverseMixin` 或其子类代替。
"""