nonebot2/nonebot/drivers/__init__.py

266 lines
7.9 KiB
Python
Raw Normal View History

2022-01-22 15:23:07 +08:00
"""本模块定义了驱动适配器基类。
各驱动请继承以下基类
2020-10-10 23:40:01 +08:00
2022-01-22 15:23:07 +08:00
FrontMatter:
sidebar_position: 0
description: nonebot.drivers 模块
2020-10-10 23:40:01 +08:00
"""
2020-07-04 22:51:10 +08:00
2020-07-18 18:18:43 +08:00
import abc
2020-12-28 13:36:00 +08:00
import asyncio
from dataclasses import dataclass
from contextlib import asynccontextmanager
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Callable,
Awaitable,
AsyncGenerator,
)
2020-07-04 22:51:10 +08:00
2020-08-14 17:41:24 +08:00
from nonebot.log import logger
from nonebot.utils import escape_tag
2020-08-10 13:06:02 +08:00
from nonebot.config import Env, Config
2022-01-15 21:27:43 +08:00
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
from ._model import URL as URL
from ._model import Request as Request
2021-12-20 00:28:02 +08:00
from ._model import Response as Response
from ._model import WebSocket as WebSocket
from ._model import HTTPVersion as HTTPVersion
2020-12-06 02:30:19 +08:00
if TYPE_CHECKING:
2021-12-06 22:19:05 +08:00
from nonebot.adapters import Bot, Adapter
2020-07-11 17:32:03 +08:00
2020-07-04 22:51:10 +08:00
2020-12-07 00:06:09 +08:00
class Driver(abc.ABC):
2022-01-22 15:23:07 +08:00
"""Driver 基类。
参数:
env: 包含环境信息的 Env 对象
config: 包含配置信息的 Config 对象
2020-10-10 23:40:01 +08:00
"""
2021-12-06 22:19:05 +08:00
_adapters: Dict[str, "Adapter"] = {}
2022-01-22 15:23:07 +08:00
"""已注册的适配器列表"""
2021-05-21 17:06:20 +08:00
_bot_connection_hook: Set[T_BotConnectionHook] = set()
2022-01-22 15:23:07 +08:00
"""Bot 连接建立时执行的函数"""
2021-05-21 17:06:20 +08:00
_bot_disconnection_hook: Set[T_BotDisconnectionHook] = set()
2022-01-22 15:23:07 +08:00
"""Bot 连接断开时执行的函数"""
2020-07-04 22:51:10 +08:00
2020-08-10 13:06:02 +08:00
def __init__(self, env: Env, config: Config):
2021-06-10 21:52:20 +08:00
self.env: str = env.environment
2022-01-22 15:23:07 +08:00
"""环境名称"""
2021-06-10 21:52:20 +08:00
self.config: Config = config
2022-01-22 15:23:07 +08:00
"""全局配置对象"""
2020-12-06 02:30:19 +08:00
self._clients: Dict[str, "Bot"] = {}
2020-08-13 15:23:04 +08:00
2021-05-21 17:06:20 +08:00
@property
def bots(self) -> Dict[str, "Bot"]:
2022-01-22 15:23:07 +08:00
"""获取当前所有已连接的 Bot"""
2021-05-21 17:06:20 +08:00
return self._clients
2021-12-22 16:53:55 +08:00
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
2022-01-22 15:23:07 +08:00
"""注册一个协议适配器
2020-11-30 11:08:00 +08:00
2022-01-12 18:31:12 +08:00
参数:
2022-01-22 15:23:07 +08:00
adapter: 适配器类
kwargs: 其他传递给适配器的参数
2020-10-16 01:10:46 +08:00
"""
2021-12-06 22:19:05 +08:00
name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
2021-12-06 22:19:05 +08:00
self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
2020-07-05 20:39:34 +08:00
2020-08-13 15:56:09 +08:00
@property
@abc.abstractmethod
2021-12-22 16:53:55 +08:00
def type(self) -> str:
2020-10-16 01:10:46 +08:00
"""驱动类型名称"""
2020-08-13 15:56:09 +08:00
raise NotImplementedError
2020-07-04 22:51:10 +08:00
@property
2020-07-18 18:18:43 +08:00
@abc.abstractmethod
2020-07-04 22:51:10 +08:00
def logger(self):
2020-10-16 01:10:46 +08:00
"""驱动专属 logger 日志记录器"""
2020-07-04 22:51:10 +08:00
raise NotImplementedError
2021-05-21 17:06:20 +08:00
@abc.abstractmethod
2021-06-21 01:22:33 +08:00
def run(self, *args, **kwargs):
2020-10-16 01:10:46 +08:00
"""
2022-01-12 18:16:05 +08:00
启动驱动框架
2020-10-16 01:10:46 +08:00
"""
2021-05-21 17:06:20 +08:00
logger.opt(colors=True).debug(
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
2020-08-07 17:51:57 +08:00
2020-08-11 10:44:05 +08:00
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
2022-01-22 15:23:07 +08:00
"""注册一个在驱动器启动时执行的函数"""
2020-08-11 10:44:05 +08:00
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
2022-01-22 15:23:07 +08:00
"""注册一个在驱动器停止时执行的函数"""
2020-08-11 10:44:05 +08:00
raise NotImplementedError
2021-05-21 17:06:20 +08:00
def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook:
2022-01-22 15:23:07 +08:00
"""装饰一个函数使他在 bot 连接成功时执行。
2020-12-28 13:36:00 +08:00
2022-01-22 15:23:07 +08:00
插槽函数参数:
- bot: 当前连接上的 Bot 对象
2020-12-28 13:36:00 +08:00
"""
2021-05-21 17:06:20 +08:00
self._bot_connection_hook.add(func)
2020-12-28 13:36:00 +08:00
return func
def on_bot_disconnect(self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
2022-01-22 15:23:07 +08:00
"""装饰一个函数使他在 bot 连接断开时执行。
2020-12-28 13:36:00 +08:00
2022-01-22 15:23:07 +08:00
插槽函数参数:
- bot: 当前连接上的 Bot 对象
2020-12-28 13:36:00 +08:00
"""
2021-05-21 17:06:20 +08:00
self._bot_disconnection_hook.add(func)
2020-12-28 13:36:00 +08:00
return func
2020-12-28 13:59:54 +08:00
def _bot_connect(self, bot: "Bot") -> None:
2022-01-22 15:23:07 +08:00
"""在连接成功后,调用该函数来注册 bot 对象"""
2021-12-21 18:22:14 +08:00
if bot.self_id in self._clients:
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
2020-12-28 13:36:00 +08:00
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
2021-05-21 17:06:20 +08:00
coros = list(map(lambda x: x(bot), self._bot_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
2020-12-28 13:59:54 +08:00
def _bot_disconnect(self, bot: "Bot") -> None:
2022-01-22 15:23:07 +08:00
"""在连接断开后,调用该函数来注销 bot 对象"""
2020-12-28 13:36:00 +08:00
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
2021-05-21 17:06:20 +08:00
coros = list(map(lambda x: x(bot), self._bot_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
asyncio.create_task(_run_hook(bot))
2020-12-28 13:36:00 +08:00
2020-11-30 11:08:00 +08:00
2021-12-22 16:53:55 +08:00
class ForwardMixin(abc.ABC):
2022-01-22 15:23:07 +08:00
"""客户端混入基类。"""
2021-12-22 16:53:55 +08:00
@property
@abc.abstractmethod
def type(self) -> str:
2022-01-22 15:23:07 +08:00
"""客户端驱动类型名称"""
2021-12-22 16:53:55 +08:00
raise NotImplementedError
2021-06-21 01:22:33 +08:00
@abc.abstractmethod
2021-12-20 00:28:02 +08:00
async def request(self, setup: Request) -> Response:
2022-01-22 15:23:07 +08:00
"""发送一个 HTTP 请求"""
2021-07-20 15:35:56 +08:00
raise NotImplementedError
@abc.abstractmethod
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
2022-01-22 15:23:07 +08:00
"""发起一个 WebSocket 连接"""
2021-06-21 01:22:33 +08:00
raise NotImplementedError
yield # used for static type checking's generator detection
2020-11-30 11:08:00 +08:00
2021-12-22 16:53:55 +08:00
class ForwardDriver(Driver, ForwardMixin):
2022-01-22 15:23:07 +08:00
"""客户端基类。将客户端框架封装,以满足适配器使用。"""
2021-12-22 16:53:55 +08:00
2021-05-21 17:06:20 +08:00
class ReverseDriver(Driver):
2022-01-22 15:23:07 +08:00
"""服务端基类。将后端框架封装,以满足适配器使用。"""
2021-05-21 17:06:20 +08:00
@property
@abc.abstractmethod
def server_app(self) -> Any:
2021-05-21 17:06:20 +08:00
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self) -> Any:
2021-05-21 17:06:20 +08:00
"""驱动 ASGI 对象"""
raise NotImplementedError
2020-07-05 20:39:34 +08:00
2021-12-06 22:19:05 +08:00
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
2022-01-22 15:23:07 +08:00
"""设置一个 HTTP 服务器路由配置"""
2021-12-06 22:19:05 +08:00
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
2022-01-22 15:23:07 +08:00
"""设置一个 WebSocket 服务器路由配置"""
2021-12-06 22:19:05 +08:00
raise NotImplementedError
2021-12-22 16:53:55 +08:00
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
2022-01-22 15:23:07 +08:00
"""将一个驱动器和多个混入类合并。"""
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
map(lambda m: issubclass(m, ForwardMixin), mixins)
), "`mixins` must be subclass of ForwardMixin"
2022-01-22 15:23:07 +08:00
if not mixins:
return driver
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
2021-12-22 16:53:55 +08:00
@property
def type(self) -> str:
return (
driver.type.__get__(self)
+ "+"
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
)
return CombinedDriver
2021-07-31 12:24:11 +08:00
@dataclass
2021-12-06 22:19:05 +08:00
class HTTPServerSetup:
2022-01-22 15:23:07 +08:00
"""HTTP 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
2021-07-31 12:24:11 +08:00
method: str
name: str
handle_func: Callable[[Request], Awaitable[Response]]
2021-07-31 12:24:11 +08:00
@dataclass
2021-12-06 22:19:05 +08:00
class WebSocketServerSetup:
2022-01-22 15:23:07 +08:00
"""WebSocket 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
name: str
2021-12-06 22:19:05 +08:00
handle_func: Callable[[WebSocket], Awaitable[Any]]