♻️ separate driver

This commit is contained in:
yanyongyu 2021-05-21 17:06:20 +08:00
parent 61512ff01f
commit 8680a954f8
3 changed files with 127 additions and 119 deletions

View File

@ -11,7 +11,7 @@ from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.typing import T_WebSocketConnectionHook, T_WebSocketDisconnectionHook from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -19,7 +19,7 @@ if TYPE_CHECKING:
class Driver(abc.ABC): class Driver(abc.ABC):
""" """
Driver 基类将后端框架封装以满足适配器使用 Driver 基类
""" """
_adapters: Dict[str, Type["Bot"]] = {} _adapters: Dict[str, Type["Bot"]] = {}
@ -27,15 +27,15 @@ class Driver(abc.ABC):
:类型: ``Dict[str, Type[Bot]]`` :类型: ``Dict[str, Type[Bot]]``
:说明: 已注册的适配器列表 :说明: 已注册的适配器列表
""" """
_ws_connection_hook: Set[T_WebSocketConnectionHook] = set() _bot_connection_hook: Set[T_BotConnectionHook] = set()
""" """
:类型: ``Set[T_WebSocketConnectionHook]`` :类型: ``Set[T_BotConnectionHook]``
:说明: WebSocket 连接建立时执行的函数 :说明: Bot 连接建立时执行的函数
""" """
_ws_disconnection_hook: Set[T_WebSocketDisconnectionHook] = set() _bot_disconnection_hook: Set[T_BotDisconnectionHook] = set()
""" """
:类型: ``Set[T_WebSocketDisconnectionHook]`` :类型: ``Set[T_BotDisconnectionHook]``
:说明: WebSocket 连接断开时执行的函数 :说明: Bot 连接断开时执行的函数
""" """
@abc.abstractmethod @abc.abstractmethod
@ -62,6 +62,18 @@ class Driver(abc.ABC):
:说明: 已连接的 Bot :说明: 已连接的 Bot
""" """
@property
def bots(self) -> Dict[str, "Bot"]:
"""
:类型:
``Dict[str, Bot]``
:说明:
获取当前所有已连接的 Bot
"""
return self._clients
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
""" """
:说明: :说明:
@ -88,108 +100,12 @@ class Driver(abc.ABC):
"""驱动类型名称""" """驱动类型名称"""
raise NotImplementedError raise NotImplementedError
@property
@abc.abstractmethod
def server_app(self):
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self):
"""驱动 ASGI 对象"""
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def logger(self): def logger(self):
"""驱动专属 logger 日志记录器""" """驱动专属 logger 日志记录器"""
raise NotImplementedError raise NotImplementedError
@property
def bots(self) -> Dict[str, "Bot"]:
"""
:类型:
``Dict[str, Bot]``
:说明:
获取当前所有已连接的 Bot
"""
return self._clients
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动启动时运行的函数"""
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动停止时运行的函数"""
raise NotImplementedError
def on_bot_connect(
self, func: T_WebSocketConnectionHook) -> T_WebSocketConnectionHook:
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
self._ws_connection_hook.add(func)
return func
def on_bot_disconnect(
self,
func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook:
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
self._ws_disconnection_hook.add(func)
return func
def _bot_connect(self, bot: "Bot") -> None:
"""在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot), self._ws_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
def _bot_disconnect(self, bot: "Bot") -> None:
"""在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot), self._ws_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
@abc.abstractmethod @abc.abstractmethod
def run(self, def run(self,
host: Optional[str] = None, host: Optional[str] = None,
@ -211,6 +127,98 @@ class Driver(abc.ABC):
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<g>Loaded adapters: {', '.join(self._adapters)}</g>") f"<g>Loaded adapters: {', '.join(self._adapters)}</g>")
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动启动时运行的函数"""
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动停止时运行的函数"""
raise NotImplementedError
def on_bot_connect(self, func: T_BotConnectionHook) -> T_BotConnectionHook:
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
self._bot_connection_hook.add(func)
return func
def on_bot_disconnect(
self, func: T_BotDisconnectionHook) -> T_BotDisconnectionHook:
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
self._bot_disconnection_hook.add(func)
return func
def _bot_connect(self, bot: "Bot") -> None:
"""在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot), self._bot_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
def _bot_disconnect(self, bot: "Bot") -> None:
"""在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot), self._bot_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
class ForwardDriver(Driver):
pass
class ReverseDriver(Driver):
"""
Reverse Driver 基类将后端框架封装以满足适配器使用
"""
@property
@abc.abstractmethod
def server_app(self):
"""驱动 APP 对象"""
raise NotImplementedError
@property
@abc.abstractmethod
def asgi(self):
"""驱动 ASGI 对象"""
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def _handle_http(self): async def _handle_http(self):
"""用于处理 HTTP 类型请求的函数""" """用于处理 HTTP 类型请求的函数"""

View File

@ -24,7 +24,7 @@ from nonebot.typing import overrides
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied from nonebot.exception import RequestDenied
from nonebot.config import Env, Config as NoneBotConfig from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import Driver as BaseDriver, WebSocket as BaseWebSocket from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
class Config(BaseSettings): class Config(BaseSettings):
@ -76,7 +76,7 @@ class Config(BaseSettings):
extra = "ignore" extra = "ignore"
class Driver(BaseDriver): class Driver(ReverseDriver):
""" """
FastAPI 驱动框架 FastAPI 驱动框架
@ -106,40 +106,40 @@ class Driver(BaseDriver):
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
@property @property
@overrides(BaseDriver) @overrides(ReverseDriver)
def type(self) -> str: def type(self) -> str:
"""驱动名称: ``fastapi``""" """驱动名称: ``fastapi``"""
return "fastapi" return "fastapi"
@property @property
@overrides(BaseDriver) @overrides(ReverseDriver)
def server_app(self) -> FastAPI: def server_app(self) -> FastAPI:
"""``FastAPI APP`` 对象""" """``FastAPI APP`` 对象"""
return self._server_app return self._server_app
@property @property
@overrides(BaseDriver) @overrides(ReverseDriver)
def asgi(self): def asgi(self):
"""``FastAPI APP`` 对象""" """``FastAPI APP`` 对象"""
return self._server_app return self._server_app
@property @property
@overrides(BaseDriver) @overrides(ReverseDriver)
def logger(self) -> logging.Logger: def logger(self) -> logging.Logger:
"""fastapi 使用的 logger""" """fastapi 使用的 logger"""
return logging.getLogger("fastapi") return logging.getLogger("fastapi")
@overrides(BaseDriver) @overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable: def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
return self.server_app.on_event("startup")(func) return self.server_app.on_event("startup")(func)
@overrides(BaseDriver) @overrides(ReverseDriver)
def on_shutdown(self, func: Callable) -> Callable: def on_shutdown(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
return self.server_app.on_event("shutdown")(func) return self.server_app.on_event("shutdown")(func)
@overrides(BaseDriver) @overrides(ReverseDriver)
def run(self, def run(self,
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
@ -176,7 +176,7 @@ class Driver(BaseDriver):
log_config=LOGGING_CONFIG, log_config=LOGGING_CONFIG,
**kwargs) **kwargs)
@overrides(BaseDriver) @overrides(ReverseDriver)
async def _handle_http(self, adapter: str, request: Request): async def _handle_http(self, adapter: str, request: Request):
data = await request.body() data = await request.body()
data_dict = json.loads(data.decode()) data_dict = json.loads(data.decode())
@ -211,7 +211,7 @@ class Driver(BaseDriver):
asyncio.create_task(bot.handle_message(data_dict)) asyncio.create_task(bot.handle_message(data_dict))
return Response("", 204) return Response("", 204)
@overrides(BaseDriver) @overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str, async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket): websocket: FastAPIWebSocket):
ws = WebSocket(websocket) ws = WebSocket(websocket)

View File

@ -55,21 +55,21 @@ T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]]
事件处理状态 State 类工厂函数 事件处理状态 State 类工厂函数
""" """
T_WebSocketConnectionHook = Callable[["Bot"], Awaitable[None]] T_BotConnectionHook = Callable[["Bot"], Awaitable[None]]
""" """
:类型: ``Callable[[Bot], Awaitable[None]]`` :类型: ``Callable[[Bot], Awaitable[None]]``
:说明: :说明:
WebSocket 连接建立时执行的函数 Bot 连接建立时执行的函数
""" """
T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]] T_BotDisconnectionHook = Callable[["Bot"], Awaitable[None]]
""" """
:类型: ``Callable[[Bot], Awaitable[None]]`` :类型: ``Callable[[Bot], Awaitable[None]]``
:说明: :说明:
WebSocket 连接断开时执行的函数 Bot 连接断开时执行的函数
""" """
T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]] T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]]
""" """