diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py
index efa6bc71..ad0d9817 100644
--- a/nonebot/drivers/__init__.py
+++ b/nonebot/drivers/__init__.py
@@ -11,7 +11,7 @@ from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger
from nonebot.config import Env, Config
-from nonebot.typing import T_WebSocketConnectionHook, T_WebSocketDisconnectionHook
+from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
from nonebot.adapters import Bot
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
class Driver(abc.ABC):
"""
- Driver 基类。将后端框架封装,以满足适配器使用。
+ Driver 基类。
"""
_adapters: Dict[str, Type["Bot"]] = {}
@@ -27,15 +27,15 @@ class Driver(abc.ABC):
:类型: ``Dict[str, Type[Bot]]``
:说明: 已注册的适配器列表
"""
- _ws_connection_hook: Set[T_WebSocketConnectionHook] = set()
+ _bot_connection_hook: Set[T_BotConnectionHook] = set()
"""
- :类型: ``Set[T_WebSocketConnectionHook]``
- :说明: WebSocket 连接建立时执行的函数
+ :类型: ``Set[T_BotConnectionHook]``
+ :说明: Bot 连接建立时执行的函数
"""
- _ws_disconnection_hook: Set[T_WebSocketDisconnectionHook] = set()
+ _bot_disconnection_hook: Set[T_BotDisconnectionHook] = set()
"""
- :类型: ``Set[T_WebSocketDisconnectionHook]``
- :说明: WebSocket 连接断开时执行的函数
+ :类型: ``Set[T_BotDisconnectionHook]``
+ :说明: Bot 连接断开时执行的函数
"""
@abc.abstractmethod
@@ -62,6 +62,18 @@ class Driver(abc.ABC):
:说明: 已连接的 Bot
"""
+ @property
+ def bots(self) -> Dict[str, "Bot"]:
+ """
+ :类型:
+
+ ``Dict[str, Bot]``
+ :说明:
+
+ 获取当前所有已连接的 Bot
+ """
+ return self._clients
+
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
"""
:说明:
@@ -88,108 +100,12 @@ class Driver(abc.ABC):
"""驱动类型名称"""
raise NotImplementedError
- @property
- @abc.abstractmethod
- def server_app(self):
- """驱动 APP 对象"""
- raise NotImplementedError
-
- @property
- @abc.abstractmethod
- def asgi(self):
- """驱动 ASGI 对象"""
- raise NotImplementedError
-
@property
@abc.abstractmethod
def logger(self):
"""驱动专属 logger 日志记录器"""
raise NotImplementedError
- @property
- def bots(self) -> Dict[str, "Bot"]:
- """
- :类型:
-
- ``Dict[str, Bot]``
- :说明:
-
- 获取当前所有已连接的 Bot
- """
- return self._clients
-
- @abc.abstractmethod
- def on_startup(self, func: Callable) -> Callable:
- """注册一个在驱动启动时运行的函数"""
- raise NotImplementedError
-
- @abc.abstractmethod
- def on_shutdown(self, func: Callable) -> Callable:
- """注册一个在驱动停止时运行的函数"""
- raise NotImplementedError
-
- def on_bot_connect(
- self, func: T_WebSocketConnectionHook) -> T_WebSocketConnectionHook:
- """
- :说明:
-
- 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。
-
- :函数参数:
-
- * ``bot: Bot``: 当前连接上的 Bot 对象
- """
- self._ws_connection_hook.add(func)
- return func
-
- def on_bot_disconnect(
- self,
- func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook:
- """
- :说明:
-
- 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。
-
- :函数参数:
-
- * ``bot: Bot``: 当前连接上的 Bot 对象
- """
- self._ws_disconnection_hook.add(func)
- return func
-
- def _bot_connect(self, bot: "Bot") -> None:
- """在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
- self._clients[bot.self_id] = bot
-
- async def _run_hook(bot: "Bot") -> None:
- coros = list(map(lambda x: x(bot), self._ws_connection_hook))
- if coros:
- try:
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running WebSocketConnection hook. "
- "Running cancelled!")
-
- asyncio.create_task(_run_hook(bot))
-
- def _bot_disconnect(self, bot: "Bot") -> None:
- """在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
- if bot.self_id in self._clients:
- del self._clients[bot.self_id]
-
- async def _run_hook(bot: "Bot") -> None:
- coros = list(map(lambda x: x(bot), self._ws_disconnection_hook))
- if coros:
- try:
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running WebSocketDisConnection hook. "
- "Running cancelled!")
-
- asyncio.create_task(_run_hook(bot))
-
@abc.abstractmethod
def run(self,
host: Optional[str] = None,
@@ -211,6 +127,98 @@ class Driver(abc.ABC):
logger.opt(colors=True).debug(
f"Loaded adapters: {', '.join(self._adapters)}")
+ @abc.abstractmethod
+ def on_startup(self, func: Callable) -> Callable:
+ """注册一个在驱动启动时运行的函数"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def on_shutdown(self, func: Callable) -> Callable:
+ """注册一个在驱动停止时运行的函数"""
+ raise NotImplementedError
+
+ def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook:
+ """
+ :说明:
+
+ 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。
+
+ :函数参数:
+
+ * ``bot: Bot``: 当前连接上的 Bot 对象
+ """
+ self._bot_connection_hook.add(func)
+ return func
+
+ def on_bot_disconnect(
+ self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
+ """
+ :说明:
+
+ 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。
+
+ :函数参数:
+
+ * ``bot: Bot``: 当前连接上的 Bot 对象
+ """
+ self._bot_disconnection_hook.add(func)
+ return func
+
+ def _bot_connect(self, bot: "Bot") -> None:
+ """在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
+ self._clients[bot.self_id] = bot
+
+ async def _run_hook(bot: "Bot") -> None:
+ coros = list(map(lambda x: x(bot), self._bot_connection_hook))
+ if coros:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketConnection hook. "
+ "Running cancelled!")
+
+ asyncio.create_task(_run_hook(bot))
+
+ def _bot_disconnect(self, bot: "Bot") -> None:
+ """在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
+ if bot.self_id in self._clients:
+ del self._clients[bot.self_id]
+
+ async def _run_hook(bot: "Bot") -> None:
+ coros = list(map(lambda x: x(bot), self._bot_disconnection_hook))
+ if coros:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketDisConnection hook. "
+ "Running cancelled!")
+
+ asyncio.create_task(_run_hook(bot))
+
+
+class ForwardDriver(Driver):
+ pass
+
+
+class ReverseDriver(Driver):
+ """
+ Reverse Driver 基类。将后端框架封装,以满足适配器使用。
+ """
+
+ @property
+ @abc.abstractmethod
+ def server_app(self):
+ """驱动 APP 对象"""
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def asgi(self):
+ """驱动 ASGI 对象"""
+ raise NotImplementedError
+
@abc.abstractmethod
async def _handle_http(self):
"""用于处理 HTTP 类型请求的函数"""
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index 18a5f1db..28a2b19a 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -24,7 +24,7 @@ from nonebot.typing import overrides
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.config import Env, Config as NoneBotConfig
-from nonebot.drivers import Driver as BaseDriver, WebSocket as BaseWebSocket
+from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
class Config(BaseSettings):
@@ -76,7 +76,7 @@ class Config(BaseSettings):
extra = "ignore"
-class Driver(BaseDriver):
+class Driver(ReverseDriver):
"""
FastAPI 驱动框架
@@ -106,40 +106,40 @@ class Driver(BaseDriver):
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
@property
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def type(self) -> str:
"""驱动名称: ``fastapi``"""
return "fastapi"
@property
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def server_app(self) -> FastAPI:
"""``FastAPI APP`` 对象"""
return self._server_app
@property
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def asgi(self):
"""``FastAPI APP`` 对象"""
return self._server_app
@property
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def logger(self) -> logging.Logger:
"""fastapi 使用的 logger"""
return logging.getLogger("fastapi")
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events `_"""
return self.server_app.on_event("startup")(func)
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def on_shutdown(self, func: Callable) -> Callable:
"""参考文档: `Events `_"""
return self.server_app.on_event("shutdown")(func)
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
def run(self,
host: Optional[str] = None,
port: Optional[int] = None,
@@ -176,7 +176,7 @@ class Driver(BaseDriver):
log_config=LOGGING_CONFIG,
**kwargs)
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
data_dict = json.loads(data.decode())
@@ -211,7 +211,7 @@ class Driver(BaseDriver):
asyncio.create_task(bot.handle_message(data_dict))
return Response("", 204)
- @overrides(BaseDriver)
+ @overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
diff --git a/nonebot/typing.py b/nonebot/typing.py
index d73865bd..9a5bcf80 100644
--- a/nonebot/typing.py
+++ b/nonebot/typing.py
@@ -55,21 +55,21 @@ T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]]
事件处理状态 State 类工厂函数
"""
-T_WebSocketConnectionHook = Callable[["Bot"], Awaitable[None]]
+T_BotConnectionHook = Callable[["Bot"], Awaitable[None]]
"""
:类型: ``Callable[[Bot], Awaitable[None]]``
:说明:
- WebSocket 连接建立时执行的函数
+ Bot 连接建立时执行的函数
"""
-T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
+T_BotDisconnectionHook = Callable[["Bot"], Awaitable[None]]
"""
:类型: ``Callable[[Bot], Awaitable[None]]``
:说明:
- WebSocket 连接断开时执行的函数
+ Bot 连接断开时执行的函数
"""
T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]]
"""