nonebot2/nonebot/drivers/__init__.py

368 lines
9.8 KiB
Python
Raw Normal View History

2020-10-10 23:40:01 +08:00
"""
后端驱动适配基类
2020-11-13 01:46:26 +08:00
=================
2020-10-10 23:40:01 +08:00
各驱动请继承以下基类
"""
2020-07-04 22:51:10 +08:00
2020-07-18 18:18:43 +08:00
import abc
2020-12-28 13:36:00 +08:00
import asyncio
2021-06-10 21:52:20 +08:00
from dataclasses import dataclass, field
from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
2020-07-04 22:51:10 +08:00
2020-08-14 17:41:24 +08:00
from nonebot.log import logger
2020-08-10 13:06:02 +08:00
from nonebot.config import Env, Config
2021-05-21 17:06:20 +08:00
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
2020-12-06 02:30:19 +08:00
if TYPE_CHECKING:
2020-12-07 00:06:09 +08:00
from nonebot.adapters import Bot
2020-07-11 17:32:03 +08:00
2020-07-04 22:51:10 +08:00
2020-12-07 00:06:09 +08:00
class Driver(abc.ABC):
2020-10-10 23:40:01 +08:00
"""
2021-05-21 17:06:20 +08:00
Driver 基类
2020-10-10 23:40:01 +08:00
"""
2020-12-06 02:30:19 +08:00
_adapters: Dict[str, Type["Bot"]] = {}
2020-10-10 23:40:01 +08:00
"""
:类型: ``Dict[str, Type[Bot]]``
:说明: 已注册的适配器列表
"""
2021-05-21 17:06:20 +08:00
_bot_connection_hook: Set[T_BotConnectionHook] = set()
2020-12-28 13:36:00 +08:00
"""
2021-05-21 17:06:20 +08:00
:类型: ``Set[T_BotConnectionHook]``
:说明: Bot 连接建立时执行的函数
2020-12-28 13:36:00 +08:00
"""
2021-05-21 17:06:20 +08:00
_bot_disconnection_hook: Set[T_BotDisconnectionHook] = set()
2020-12-28 13:36:00 +08:00
"""
2021-05-21 17:06:20 +08:00
:类型: ``Set[T_BotDisconnectionHook]``
:说明: Bot 连接断开时执行的函数
2020-12-28 13:36:00 +08:00
"""
2020-07-04 22:51:10 +08:00
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
2020-08-10 13:06:02 +08:00
def __init__(self, env: Env, config: Config):
2020-10-16 01:10:46 +08:00
"""
:参数:
2020-11-13 01:46:26 +08:00
2020-10-16 01:10:46 +08:00
* ``env: Env``: 包含环境信息的 Env 对象
* ``config: Config``: 包含配置信息的 Config 对象
"""
2021-06-10 21:52:20 +08:00
self.env: str = env.environment
2020-10-16 01:10:46 +08:00
"""
:类型: ``str``
:说明: 环境名称
"""
2021-06-10 21:52:20 +08:00
self.config: Config = config
2020-10-16 01:10:46 +08:00
"""
:类型: ``Config``
:说明: 配置对象
"""
2020-12-06 02:30:19 +08:00
self._clients: Dict[str, "Bot"] = {}
2020-10-16 01:10:46 +08:00
"""
:类型: ``Dict[str, Bot]``
:说明: 已连接的 Bot
"""
2020-08-13 15:23:04 +08:00
2021-05-21 17:06:20 +08:00
@property
def bots(self) -> Dict[str, "Bot"]:
"""
:类型:
``Dict[str, Bot]``
:说明:
获取当前所有已连接的 Bot
"""
return self._clients
2021-02-05 13:31:33 +08:00
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
2020-10-16 01:10:46 +08:00
"""
:说明:
2020-11-30 11:08:00 +08:00
2020-10-16 01:10:46 +08:00
注册一个协议适配器
2020-11-30 11:08:00 +08:00
2020-10-16 01:10:46 +08:00
:参数:
2020-11-30 11:08:00 +08:00
2020-10-16 01:10:46 +08:00
* ``name: str``: 适配器名称用于在连接时进行识别
* ``adapter: Type[Bot]``: 适配器 Class
"""
if name in self._adapters:
logger.opt(
colors=True).debug(f'Adapter "<y>{name}</y>" already exists')
return
2021-01-17 13:46:29 +08:00
self._adapters[name] = adapter
2021-02-05 13:31:33 +08:00
adapter.register(self, self.config, **kwargs)
2020-08-27 13:27:42 +08:00
logger.opt(
colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"')
2020-07-05 20:39:34 +08:00
2020-08-13 15:56:09 +08:00
@property
@abc.abstractmethod
def type(self):
2020-10-16 01:10:46 +08:00
"""驱动类型名称"""
2020-08-13 15:56:09 +08:00
raise NotImplementedError
2020-07-04 22:51:10 +08:00
@property
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
2020-07-04 22:51:10 +08:00
def logger(self):
2020-10-16 01:10:46 +08:00
"""驱动专属 logger 日志记录器"""
2020-07-04 22:51:10 +08:00
raise NotImplementedError
2021-05-21 17:06:20 +08:00
@abc.abstractmethod
def run(self,
host: Optional[str] = None,
port: Optional[int] = None,
*args,
**kwargs):
2020-10-16 01:10:46 +08:00
"""
2020-12-28 13:36:00 +08:00
:说明:
2021-05-21 17:06:20 +08:00
启动驱动框架
:参数:
* ``host: Optional[str]``: 驱动绑定 IP
* ``post: Optional[int]``: 驱动绑定端口
* ``*args``
* ``**kwargs``
2020-10-16 01:10:46 +08:00
"""
2021-05-21 17:06:20 +08:00
logger.opt(colors=True).debug(
f"<g>Loaded adapters: {', '.join(self._adapters)}</g>")
2020-08-07 17:51:57 +08:00
2020-08-11 10:44:05 +08:00
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
2020-10-16 01:10:46 +08:00
"""注册一个在驱动启动时运行的函数"""
2020-08-11 10:44:05 +08:00
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
2020-10-16 01:10:46 +08:00
"""注册一个在驱动停止时运行的函数"""
2020-08-11 10:44:05 +08:00
raise NotImplementedError
2021-05-21 17:06:20 +08:00
def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook:
2020-12-28 13:36:00 +08:00
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
2021-05-21 17:06:20 +08:00
self._bot_connection_hook.add(func)
2020-12-28 13:36:00 +08:00
return func
def on_bot_disconnect(
2021-05-21 17:06:20 +08:00
self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
2020-12-28 13:36:00 +08:00
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
2021-05-21 17:06:20 +08:00
self._bot_disconnection_hook.add(func)
2020-12-28 13:36:00 +08:00
return func
2020-12-28 13:59:54 +08:00
def _bot_connect(self, bot: "Bot") -> None:
2020-12-28 13:36:00 +08:00
"""在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
2021-05-21 17:06:20 +08:00
coros = list(map(lambda x: x(bot), self._bot_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
2020-12-28 13:59:54 +08:00
def _bot_disconnect(self, bot: "Bot") -> None:
2020-12-28 13:36:00 +08:00
"""在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
2021-05-21 17:06:20 +08:00
coros = list(map(lambda x: x(bot), self._bot_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
2020-12-28 13:36:00 +08:00
2020-11-30 11:08:00 +08:00
2021-05-21 17:06:20 +08:00
class ForwardDriver(Driver):
pass
2020-11-30 11:08:00 +08:00
2021-05-21 17:06:20 +08:00
class ReverseDriver(Driver):
"""
Reverse Driver 基类将后端框架封装以满足适配器使用
"""
@property
@abc.abstractmethod
def server_app(self):
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self):
"""驱动 ASGI 对象"""
raise NotImplementedError
2020-07-05 20:39:34 +08:00
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
async def _handle_http(self, *args, **kwargs):
2020-10-16 01:10:46 +08:00
"""用于处理 HTTP 类型请求的函数"""
2020-07-05 20:39:34 +08:00
raise NotImplementedError
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
async def _handle_ws_reverse(self, *args, **kwargs):
2020-10-16 01:10:46 +08:00
"""用于处理 WebSocket 类型请求的函数"""
2020-07-05 20:39:34 +08:00
raise NotImplementedError
2020-07-11 17:32:03 +08:00
2021-06-10 21:52:20 +08:00
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
"""One of `"1.0"`, `"1.1"` or `"2"`."""
scheme: str
"""URL scheme portion (likely `"http"` or `"https"`)."""
path: str
"""
2021-06-10 21:52:20 +08:00
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
query_string: bytes = b""
""" URL portion after the `?`, percent-encoded."""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
2021-06-10 21:52:20 +08:00
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
2021-06-10 21:52:20 +08:00
Header names must be lowercased.
"""
@property
2021-06-10 21:52:20 +08:00
@abc.abstractmethod
def type(self) -> str:
"""Connection type."""
raise NotImplementedError
2021-05-31 00:27:31 +08:00
2021-06-10 21:52:20 +08:00
@dataclass
class HTTPRequest(HTTPConnection):
"""HTTP 请求封装。参考 `asgi http scope`_。
2021-05-31 00:27:31 +08:00
2021-06-10 21:52:20 +08:00
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
"""Body of the request.
2021-05-31 00:27:31 +08:00
2021-06-10 21:52:20 +08:00
Optional; if missing defaults to b"".
"""
@property
2021-06-10 21:52:20 +08:00
def type(self) -> str:
"""Always ``http``"""
return "http"
2021-06-10 21:52:20 +08:00
@dataclass
class HTTPResponse:
"""HTTP 响应封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
2021-06-10 21:52:20 +08:00
status: int
"""HTTP status code."""
body: Optional[bytes] = None
"""HTTP body content.
2021-06-10 21:52:20 +08:00
Optional; if missing defaults to ``None``.
"""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
2021-05-31 00:27:31 +08:00
2021-06-10 21:52:20 +08:00
Order must be preserved in the HTTP response.
2021-05-31 00:27:31 +08:00
2021-06-10 21:52:20 +08:00
Header names must be lowercased.
2021-05-31 00:27:31 +08:00
2021-06-10 21:52:20 +08:00
Optional; if missing defaults to an empty dict.
"""
@property
def type(self) -> str:
2021-06-10 21:52:20 +08:00
"""Always ``http``"""
return "http"
2021-06-10 21:52:20 +08:00
@dataclass
class WebSocket(HTTPConnection, abc.ABC):
2021-05-31 00:27:31 +08:00
"""WebSocket 连接封装。参考 `asgi websocket scope`_。
.. _asgi websocket scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
"""
@property
2021-06-10 21:52:20 +08:00
def type(self) -> str:
"""Always ``websocket``"""
return "websocket"
@property
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
def closed(self):
2020-10-16 01:10:46 +08:00
"""
:类型: ``bool``
:说明: 连接是否已经关闭
"""
raise NotImplementedError
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
async def accept(self):
2020-10-16 01:10:46 +08:00
"""接受 WebSocket 连接请求"""
raise NotImplementedError
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
2020-08-01 22:03:40 +08:00
async def close(self, code: int):
2020-10-16 01:10:46 +08:00
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
2021-06-10 21:52:20 +08:00
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
2021-06-10 21:52:20 +08:00
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError