diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index efa6bc71..ad0d9817 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -11,7 +11,7 @@ from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING from nonebot.log import logger from nonebot.config import Env, Config -from nonebot.typing import T_WebSocketConnectionHook, T_WebSocketDisconnectionHook +from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook if TYPE_CHECKING: from nonebot.adapters import Bot @@ -19,7 +19,7 @@ if TYPE_CHECKING: class Driver(abc.ABC): """ - Driver 基类。将后端框架封装,以满足适配器使用。 + Driver 基类。 """ _adapters: Dict[str, Type["Bot"]] = {} @@ -27,15 +27,15 @@ class Driver(abc.ABC): :类型: ``Dict[str, Type[Bot]]`` :说明: 已注册的适配器列表 """ - _ws_connection_hook: Set[T_WebSocketConnectionHook] = set() + _bot_connection_hook: Set[T_BotConnectionHook] = set() """ - :类型: ``Set[T_WebSocketConnectionHook]`` - :说明: WebSocket 连接建立时执行的函数 + :类型: ``Set[T_BotConnectionHook]`` + :说明: Bot 连接建立时执行的函数 """ - _ws_disconnection_hook: Set[T_WebSocketDisconnectionHook] = set() + _bot_disconnection_hook: Set[T_BotDisconnectionHook] = set() """ - :类型: ``Set[T_WebSocketDisconnectionHook]`` - :说明: WebSocket 连接断开时执行的函数 + :类型: ``Set[T_BotDisconnectionHook]`` + :说明: Bot 连接断开时执行的函数 """ @abc.abstractmethod @@ -62,6 +62,18 @@ class Driver(abc.ABC): :说明: 已连接的 Bot """ + @property + def bots(self) -> Dict[str, "Bot"]: + """ + :类型: + + ``Dict[str, Bot]`` + :说明: + + 获取当前所有已连接的 Bot + """ + return self._clients + def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): """ :说明: @@ -88,108 +100,12 @@ class Driver(abc.ABC): """驱动类型名称""" raise NotImplementedError - @property - @abc.abstractmethod - def server_app(self): - """驱动 APP 对象""" - raise NotImplementedError - - @property - @abc.abstractmethod - def asgi(self): - """驱动 ASGI 对象""" - raise NotImplementedError - @property @abc.abstractmethod def logger(self): """驱动专属 logger 日志记录器""" raise NotImplementedError - @property - def bots(self) -> Dict[str, "Bot"]: - """ - :类型: - - ``Dict[str, Bot]`` - :说明: - - 获取当前所有已连接的 Bot - """ - return self._clients - - @abc.abstractmethod - def on_startup(self, func: Callable) -> Callable: - """注册一个在驱动启动时运行的函数""" - raise NotImplementedError - - @abc.abstractmethod - def on_shutdown(self, func: Callable) -> Callable: - """注册一个在驱动停止时运行的函数""" - raise NotImplementedError - - def on_bot_connect( - self, func: T_WebSocketConnectionHook) -> T_WebSocketConnectionHook: - """ - :说明: - - 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。 - - :函数参数: - - * ``bot: Bot``: 当前连接上的 Bot 对象 - """ - self._ws_connection_hook.add(func) - return func - - def on_bot_disconnect( - self, - func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook: - """ - :说明: - - 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。 - - :函数参数: - - * ``bot: Bot``: 当前连接上的 Bot 对象 - """ - self._ws_disconnection_hook.add(func) - return func - - def _bot_connect(self, bot: "Bot") -> None: - """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" - self._clients[bot.self_id] = bot - - async def _run_hook(bot: "Bot") -> None: - coros = list(map(lambda x: x(bot), self._ws_connection_hook)) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketConnection hook. " - "Running cancelled!") - - asyncio.create_task(_run_hook(bot)) - - def _bot_disconnect(self, bot: "Bot") -> None: - """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" - if bot.self_id in self._clients: - del self._clients[bot.self_id] - - async def _run_hook(bot: "Bot") -> None: - coros = list(map(lambda x: x(bot), self._ws_disconnection_hook)) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketDisConnection hook. " - "Running cancelled!") - - asyncio.create_task(_run_hook(bot)) - @abc.abstractmethod def run(self, host: Optional[str] = None, @@ -211,6 +127,98 @@ class Driver(abc.ABC): logger.opt(colors=True).debug( f"Loaded adapters: {', '.join(self._adapters)}") + @abc.abstractmethod + def on_startup(self, func: Callable) -> Callable: + """注册一个在驱动启动时运行的函数""" + raise NotImplementedError + + @abc.abstractmethod + def on_shutdown(self, func: Callable) -> Callable: + """注册一个在驱动停止时运行的函数""" + raise NotImplementedError + + def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook: + """ + :说明: + + 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。 + + :函数参数: + + * ``bot: Bot``: 当前连接上的 Bot 对象 + """ + self._bot_connection_hook.add(func) + return func + + def on_bot_disconnect( + self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook: + """ + :说明: + + 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。 + + :函数参数: + + * ``bot: Bot``: 当前连接上的 Bot 对象 + """ + self._bot_disconnection_hook.add(func) + return func + + def _bot_connect(self, bot: "Bot") -> None: + """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" + self._clients[bot.self_id] = bot + + async def _run_hook(bot: "Bot") -> None: + coros = list(map(lambda x: x(bot), self._bot_connection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketConnection hook. " + "Running cancelled!") + + asyncio.create_task(_run_hook(bot)) + + def _bot_disconnect(self, bot: "Bot") -> None: + """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" + if bot.self_id in self._clients: + del self._clients[bot.self_id] + + async def _run_hook(bot: "Bot") -> None: + coros = list(map(lambda x: x(bot), self._bot_disconnection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketDisConnection hook. " + "Running cancelled!") + + asyncio.create_task(_run_hook(bot)) + + +class ForwardDriver(Driver): + pass + + +class ReverseDriver(Driver): + """ + Reverse Driver 基类。将后端框架封装,以满足适配器使用。 + """ + + @property + @abc.abstractmethod + def server_app(self): + """驱动 APP 对象""" + raise NotImplementedError + + @property + @abc.abstractmethod + def asgi(self): + """驱动 ASGI 对象""" + raise NotImplementedError + @abc.abstractmethod async def _handle_http(self): """用于处理 HTTP 类型请求的函数""" diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 18a5f1db..28a2b19a 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -24,7 +24,7 @@ from nonebot.typing import overrides from nonebot.utils import DataclassEncoder from nonebot.exception import RequestDenied from nonebot.config import Env, Config as NoneBotConfig -from nonebot.drivers import Driver as BaseDriver, WebSocket as BaseWebSocket +from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket class Config(BaseSettings): @@ -76,7 +76,7 @@ class Config(BaseSettings): extra = "ignore" -class Driver(BaseDriver): +class Driver(ReverseDriver): """ FastAPI 驱动框架 @@ -106,40 +106,40 @@ class Driver(BaseDriver): self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) @property - @overrides(BaseDriver) + @overrides(ReverseDriver) def type(self) -> str: """驱动名称: ``fastapi``""" return "fastapi" @property - @overrides(BaseDriver) + @overrides(ReverseDriver) def server_app(self) -> FastAPI: """``FastAPI APP`` 对象""" return self._server_app @property - @overrides(BaseDriver) + @overrides(ReverseDriver) def asgi(self): """``FastAPI APP`` 对象""" return self._server_app @property - @overrides(BaseDriver) + @overrides(ReverseDriver) def logger(self) -> logging.Logger: """fastapi 使用的 logger""" return logging.getLogger("fastapi") - @overrides(BaseDriver) + @overrides(ReverseDriver) def on_startup(self, func: Callable) -> Callable: """参考文档: `Events `_""" return self.server_app.on_event("startup")(func) - @overrides(BaseDriver) + @overrides(ReverseDriver) def on_shutdown(self, func: Callable) -> Callable: """参考文档: `Events `_""" return self.server_app.on_event("shutdown")(func) - @overrides(BaseDriver) + @overrides(ReverseDriver) def run(self, host: Optional[str] = None, port: Optional[int] = None, @@ -176,7 +176,7 @@ class Driver(BaseDriver): log_config=LOGGING_CONFIG, **kwargs) - @overrides(BaseDriver) + @overrides(ReverseDriver) async def _handle_http(self, adapter: str, request: Request): data = await request.body() data_dict = json.loads(data.decode()) @@ -211,7 +211,7 @@ class Driver(BaseDriver): asyncio.create_task(bot.handle_message(data_dict)) return Response("", 204) - @overrides(BaseDriver) + @overrides(ReverseDriver) async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket): ws = WebSocket(websocket) diff --git a/nonebot/typing.py b/nonebot/typing.py index d73865bd..9a5bcf80 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -55,21 +55,21 @@ T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]] 事件处理状态 State 类工厂函数 """ -T_WebSocketConnectionHook = Callable[["Bot"], Awaitable[None]] +T_BotConnectionHook = Callable[["Bot"], Awaitable[None]] """ :类型: ``Callable[[Bot], Awaitable[None]]`` :说明: - WebSocket 连接建立时执行的函数 + Bot 连接建立时执行的函数 """ -T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]] +T_BotDisconnectionHook = Callable[["Bot"], Awaitable[None]] """ :类型: ``Callable[[Bot], Awaitable[None]]`` :说明: - WebSocket 连接断开时执行的函数 + Bot 连接断开时执行的函数 """ T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]] """