""" 后端驱动适配基类 ================= 各驱动请继承以下基类 """ import abc import asyncio from typing import (Any, Set, List, Dict, Type, Tuple, Optional, Callable, MutableMapping, TYPE_CHECKING) from nonebot.log import logger from nonebot.config import Env, Config from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook if TYPE_CHECKING: from nonebot.adapters import Bot class Driver(abc.ABC): """ Driver 基类。 """ _adapters: Dict[str, Type["Bot"]] = {} """ :类型: ``Dict[str, Type[Bot]]`` :说明: 已注册的适配器列表 """ _bot_connection_hook: Set[T_BotConnectionHook] = set() """ :类型: ``Set[T_BotConnectionHook]`` :说明: Bot 连接建立时执行的函数 """ _bot_disconnection_hook: Set[T_BotDisconnectionHook] = set() """ :类型: ``Set[T_BotDisconnectionHook]`` :说明: Bot 连接断开时执行的函数 """ @abc.abstractmethod def __init__(self, env: Env, config: Config): """ :参数: * ``env: Env``: 包含环境信息的 Env 对象 * ``config: Config``: 包含配置信息的 Config 对象 """ self.env = env.environment """ :类型: ``str`` :说明: 环境名称 """ self.config = config """ :类型: ``Config`` :说明: 配置对象 """ self._clients: Dict[str, "Bot"] = {} """ :类型: ``Dict[str, Bot]`` :说明: 已连接的 Bot """ @property def bots(self) -> Dict[str, "Bot"]: """ :类型: ``Dict[str, Bot]`` :说明: 获取当前所有已连接的 Bot """ return self._clients def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): """ :说明: 注册一个协议适配器 :参数: * ``name: str``: 适配器名称,用于在连接时进行识别 * ``adapter: Type[Bot]``: 适配器 Class """ if name in self._adapters: logger.opt( colors=True).debug(f'Adapter "{name}" already exists') return self._adapters[name] = adapter adapter.register(self, self.config, **kwargs) logger.opt( colors=True).debug(f'Succeeded to load adapter "{name}"') @property @abc.abstractmethod def type(self): """驱动类型名称""" raise NotImplementedError @property @abc.abstractmethod def logger(self): """驱动专属 logger 日志记录器""" raise NotImplementedError @abc.abstractmethod def run(self, host: Optional[str] = None, port: Optional[int] = None, *args, **kwargs): """ :说明: 启动驱动框架 :参数: * ``host: Optional[str]``: 驱动绑定 IP * ``post: Optional[int]``: 驱动绑定端口 * ``*args`` * ``**kwargs`` """ logger.opt(colors=True).debug( f"Loaded adapters: {', '.join(self._adapters)}") @abc.abstractmethod def on_startup(self, func: Callable) -> Callable: """注册一个在驱动启动时运行的函数""" raise NotImplementedError @abc.abstractmethod def on_shutdown(self, func: Callable) -> Callable: """注册一个在驱动停止时运行的函数""" raise NotImplementedError def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook: """ :说明: 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。 :函数参数: * ``bot: Bot``: 当前连接上的 Bot 对象 """ self._bot_connection_hook.add(func) return func def on_bot_disconnect( self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook: """ :说明: 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。 :函数参数: * ``bot: Bot``: 当前连接上的 Bot 对象 """ self._bot_disconnection_hook.add(func) return func def _bot_connect(self, bot: "Bot") -> None: """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" self._clients[bot.self_id] = bot async def _run_hook(bot: "Bot") -> None: coros = list(map(lambda x: x(bot), self._bot_connection_hook)) if coros: try: await asyncio.gather(*coros) except Exception as e: logger.opt(colors=True, exception=e).error( "Error when running WebSocketConnection hook. " "Running cancelled!") asyncio.create_task(_run_hook(bot)) def _bot_disconnect(self, bot: "Bot") -> None: """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" if bot.self_id in self._clients: del self._clients[bot.self_id] async def _run_hook(bot: "Bot") -> None: coros = list(map(lambda x: x(bot), self._bot_disconnection_hook)) if coros: try: await asyncio.gather(*coros) except Exception as e: logger.opt(colors=True, exception=e).error( "Error when running WebSocketDisConnection hook. " "Running cancelled!") asyncio.create_task(_run_hook(bot)) class ForwardDriver(Driver): pass class ReverseDriver(Driver): """ Reverse Driver 基类。将后端框架封装,以满足适配器使用。 """ @property @abc.abstractmethod def server_app(self): """驱动 APP 对象""" raise NotImplementedError @property @abc.abstractmethod def asgi(self): """驱动 ASGI 对象""" raise NotImplementedError @abc.abstractmethod async def _handle_http(self, *args, **kwargs): """用于处理 HTTP 类型请求的函数""" raise NotImplementedError @abc.abstractmethod async def _handle_ws_reverse(self, *args, **kwargs): """用于处理 WebSocket 类型请求的函数""" raise NotImplementedError class HTTPRequest: """HTTP 请求封装。参考 `asgi http scope`_。 .. _asgi http scope: https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope """ def __init__(self, scope: MutableMapping[str, Any]): self._scope = scope @property def type(self) -> str: """Always `http`""" return "http" @property def scope(self) -> MutableMapping[str, Any]: """Raw scope from asgi. The connection scope information, a dictionary that contains at least a `type` key specifying the protocol that is incoming. """ return self._scope @property def http_version(self) -> str: """One of `"1.0"`, `"1.1"` or `"2"`.""" raise self.scope["http_version"] @property def method(self) -> str: """The HTTP method name, uppercased.""" raise self.scope["method"] @property def schema(self) -> str: """ URL scheme portion (likely `"http"` or `"https"`). Optional (but must not be empty); default is `"http"`. """ raise self.scope["schema"] @property def path(self) -> str: """ HTTP request target excluding any query string, with percent-encoded sequences and UTF-8 byte sequences decoded into characters. """ return self.scope["path"] @property def query_string(self) -> bytes: """ URL portion after the `?`, percent-encoded.""" return self.scope["query_string"] @property def headers(self) -> List[Tuple[bytes, bytes]]: """An iterable of [name, value] two-item iterables, where name is the header name, and value is the header value. Order of header values must be preserved from the original HTTP request; order of header names is not important. Duplicates are possible and must be preserved in the message as received. Header names must be lowercased. """ return list(self.scope["headers"]) @property def body(self) -> bytes: """Body of the request. Optional; if missing defaults to b"". If more_body is set, treat as start of body and concatenate on further chunks. """ return self.scope["body"] class HTTPResponse: """HTTP 响应封装。参考 `asgi http scope`_。 .. _asgi http scope: https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope """ def __init__(self, status: int, headers: List[Tuple[bytes, bytes]] = [], body: Optional[bytes] = None): self.status: int = status """HTTP status code.""" self.headers: List[Tuple[bytes, bytes]] = headers """An iterable of [name, value] two-item iterables, where name is the header name, and value is the header value. Order must be preserved in the HTTP response. Header names must be lowercased. Optional; if missing defaults to an empty list. """ self.body: Optional[bytes] = body """HTTP body content. Optional; if missing defaults to `None`. """ @property def type(self) -> str: """Always `http`""" return "http" class WebSocket: """WebSocket 连接封装。参考 `asgi websocket scope`_。 .. _asgi websocket scope: https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope """ @abc.abstractmethod def __init__(self, websocket): """ :参数: * ``websocket: Any``: WebSocket 连接对象 """ self._websocket = websocket @property def websocket(self): """WebSocket 连接对象""" return self._websocket @property @abc.abstractmethod def closed(self): """ :类型: ``bool`` :说明: 连接是否已经关闭 """ raise NotImplementedError @abc.abstractmethod async def accept(self): """接受 WebSocket 连接请求""" raise NotImplementedError @abc.abstractmethod async def close(self, code: int): """关闭 WebSocket 连接请求""" raise NotImplementedError @abc.abstractmethod async def receive(self) -> dict: """接收一条 WebSocket 信息""" raise NotImplementedError @abc.abstractmethod async def send(self, data: dict): """发送一条 WebSocket 信息""" raise NotImplementedError