2022-02-06 14:52:50 +08:00
import abc
import asyncio
2024-04-05 21:11:05 +08:00
from types import TracebackType
2024-04-16 00:33:48 +08:00
from collections . abc import AsyncGenerator
2024-04-05 21:11:05 +08:00
from typing_extensions import Self , TypeAlias
2023-03-20 22:37:57 +08:00
from contextlib import AsyncExitStack , asynccontextmanager
2024-04-16 00:33:48 +08:00
from typing import TYPE_CHECKING , Any , Union , ClassVar , Optional
2022-02-06 14:52:50 +08:00
from nonebot . log import logger
from nonebot . config import Env , Config
from nonebot . dependencies import Dependent
2022-04-04 10:35:14 +08:00
from nonebot . exception import SkippedException
from nonebot . utils import escape_tag , run_coro_with_catch
2022-02-06 17:08:11 +08:00
from nonebot . internal . params import BotParam , DependParam , DefaultParam
2023-03-20 22:37:57 +08:00
from nonebot . typing import (
T_DependencyCache ,
T_BotConnectionHook ,
T_BotDisconnectionHook ,
)
2022-02-06 14:52:50 +08:00
2023-12-10 18:12:10 +08:00
from . _lifespan import LIFESPAN_FUNC , Lifespan
2024-04-05 21:11:05 +08:00
from . model import (
Request ,
Response ,
WebSocket ,
QueryTypes ,
CookieTypes ,
HeaderTypes ,
HTTPVersion ,
HTTPServerSetup ,
WebSocketServerSetup ,
)
2022-02-06 14:52:50 +08:00
if TYPE_CHECKING :
2022-02-06 17:08:11 +08:00
from nonebot . internal . adapter import Bot , Adapter
2022-02-06 14:52:50 +08:00
BOT_HOOK_PARAMS = [ DependParam , BotParam , DefaultParam ]
class Driver ( abc . ABC ) :
2023-08-26 11:03:24 +08:00
""" 驱动器基类。
驱动器控制框架的启动和停止 , 适配器的注册 , 以及机器人生命周期管理 。
2022-02-06 14:52:50 +08:00
参数 :
env : 包含环境信息的 Env 对象
config : 包含配置信息的 Config 对象
"""
2024-04-16 00:33:48 +08:00
_adapters : ClassVar [ dict [ str , " Adapter " ] ] = { }
2022-02-06 14:52:50 +08:00
""" 已注册的适配器列表 """
2024-04-16 00:33:48 +08:00
_bot_connection_hook : ClassVar [ set [ Dependent [ Any ] ] ] = set ( )
2022-02-06 14:52:50 +08:00
""" Bot 连接建立时执行的函数 """
2024-04-16 00:33:48 +08:00
_bot_disconnection_hook : ClassVar [ set [ Dependent [ Any ] ] ] = set ( )
2022-02-06 14:52:50 +08:00
""" Bot 连接断开时执行的函数 """
def __init__ ( self , env : Env , config : Config ) :
self . env : str = env . environment
""" 环境名称 """
self . config : Config = config
""" 全局配置对象 """
2024-04-16 00:33:48 +08:00
self . _bots : dict [ str , " Bot " ] = { }
self . _bot_tasks : set [ asyncio . Task ] = set ( )
2023-12-10 18:12:10 +08:00
self . _lifespan = Lifespan ( )
2022-09-09 11:52:57 +08:00
def __repr__ ( self ) - > str :
return (
f " Driver(type= { self . type !r} , "
f " adapters= { len ( self . _adapters ) } , bots= { len ( self . _bots ) } ) "
)
2022-02-06 14:52:50 +08:00
@property
2024-04-16 00:33:48 +08:00
def bots ( self ) - > dict [ str , " Bot " ] :
2022-02-06 14:52:50 +08:00
""" 获取当前所有已连接的 Bot """
2022-09-09 11:52:57 +08:00
return self . _bots
2022-02-06 14:52:50 +08:00
2024-04-16 00:33:48 +08:00
def register_adapter ( self , adapter : type [ " Adapter " ] , * * kwargs ) - > None :
2022-02-06 14:52:50 +08:00
""" 注册一个协议适配器
参数 :
adapter : 适配器类
kwargs : 其他传递给适配器的参数
"""
name = adapter . get_name ( )
if name in self . _adapters :
logger . opt ( colors = True ) . debug (
f ' Adapter " <y> { escape_tag ( name ) } </y> " already exists '
)
return
self . _adapters [ name ] = adapter ( self , * * kwargs )
logger . opt ( colors = True ) . debug (
f ' Succeeded to load adapter " <y> { escape_tag ( name ) } </y> " '
)
@property
@abc.abstractmethod
def type ( self ) - > str :
""" 驱动类型名称 """
raise NotImplementedError
@property
@abc.abstractmethod
def logger ( self ) :
""" 驱动专属 logger 日志记录器 """
raise NotImplementedError
@abc.abstractmethod
def run ( self , * args , * * kwargs ) :
2023-06-24 14:47:35 +08:00
""" 启动驱动框架 """
2022-02-06 14:52:50 +08:00
logger . opt ( colors = True ) . debug (
f " <g>Loaded adapters: { escape_tag ( ' , ' . join ( self . _adapters ) ) } </g> "
)
2023-08-26 11:03:24 +08:00
self . on_shutdown ( self . _cleanup )
2023-12-10 18:12:10 +08:00
def on_startup ( self , func : LIFESPAN_FUNC ) - > LIFESPAN_FUNC :
""" 注册一个启动时执行的函数 """
return self . _lifespan . on_startup ( func )
2022-02-06 14:52:50 +08:00
2023-12-10 18:12:10 +08:00
def on_shutdown ( self , func : LIFESPAN_FUNC ) - > LIFESPAN_FUNC :
""" 注册一个停止时执行的函数 """
return self . _lifespan . on_shutdown ( func )
2022-02-06 14:52:50 +08:00
2022-02-06 15:24:41 +08:00
@classmethod
def on_bot_connect ( cls , func : T_BotConnectionHook ) - > T_BotConnectionHook :
2022-02-06 14:52:50 +08:00
""" 装饰一个函数使他在 bot 连接成功时执行。
钩子函数参数 :
- bot : 当前连接上的 Bot 对象
"""
2022-02-06 15:24:41 +08:00
cls . _bot_connection_hook . add (
2022-02-06 14:52:50 +08:00
Dependent [ Any ] . parse ( call = func , allow_types = BOT_HOOK_PARAMS )
)
return func
2022-02-06 15:24:41 +08:00
@classmethod
def on_bot_disconnect ( cls , func : T_BotDisconnectionHook ) - > T_BotDisconnectionHook :
2022-02-06 14:52:50 +08:00
""" 装饰一个函数使他在 bot 连接断开时执行。
钩子函数参数 :
- bot : 当前连接上的 Bot 对象
"""
2022-02-06 15:24:41 +08:00
cls . _bot_disconnection_hook . add (
2022-02-06 14:52:50 +08:00
Dependent [ Any ] . parse ( call = func , allow_types = BOT_HOOK_PARAMS )
)
return func
def _bot_connect ( self , bot : " Bot " ) - > None :
""" 在连接成功后,调用该函数来注册 bot 对象 """
2022-09-09 11:52:57 +08:00
if bot . self_id in self . _bots :
2022-02-06 14:52:50 +08:00
raise RuntimeError ( f " Duplicate bot connection with id { bot . self_id } " )
2022-09-09 11:52:57 +08:00
self . _bots [ bot . self_id ] = bot
2022-02-06 14:52:50 +08:00
async def _run_hook ( bot : " Bot " ) - > None :
2023-03-20 22:37:57 +08:00
dependency_cache : T_DependencyCache = { }
async with AsyncExitStack ( ) as stack :
if coros := [
run_coro_with_catch (
hook ( bot = bot , stack = stack , dependency_cache = dependency_cache ) ,
( SkippedException , ) ,
2022-02-06 14:52:50 +08:00
)
2023-03-20 22:37:57 +08:00
for hook in self . _bot_connection_hook
] :
try :
await asyncio . gather ( * coros )
except Exception as e :
logger . opt ( colors = True , exception = e ) . error (
2023-06-24 14:47:35 +08:00
" <r><bg #f8bbd0> "
" Error when running WebSocketConnection hook. "
" Running cancelled! "
" </bg #f8bbd0></r> "
2023-03-20 22:37:57 +08:00
)
2022-02-06 14:52:50 +08:00
2023-08-26 11:03:24 +08:00
task = asyncio . create_task ( _run_hook ( bot ) )
task . add_done_callback ( self . _bot_tasks . discard )
self . _bot_tasks . add ( task )
2022-02-06 14:52:50 +08:00
def _bot_disconnect ( self , bot : " Bot " ) - > None :
""" 在连接断开后,调用该函数来注销 bot 对象 """
2022-09-09 11:52:57 +08:00
if bot . self_id in self . _bots :
del self . _bots [ bot . self_id ]
2022-02-06 14:52:50 +08:00
async def _run_hook ( bot : " Bot " ) - > None :
2023-03-20 22:37:57 +08:00
dependency_cache : T_DependencyCache = { }
async with AsyncExitStack ( ) as stack :
if coros := [
run_coro_with_catch (
hook ( bot = bot , stack = stack , dependency_cache = dependency_cache ) ,
( SkippedException , ) ,
2022-02-06 14:52:50 +08:00
)
2023-03-20 22:37:57 +08:00
for hook in self . _bot_disconnection_hook
] :
try :
await asyncio . gather ( * coros )
except Exception as e :
logger . opt ( colors = True , exception = e ) . error (
2023-06-24 14:47:35 +08:00
" <r><bg #f8bbd0> "
" Error when running WebSocketDisConnection hook. "
" Running cancelled! "
" </bg #f8bbd0></r> "
2023-03-20 22:37:57 +08:00
)
2022-02-06 14:52:50 +08:00
2023-08-26 11:03:24 +08:00
task = asyncio . create_task ( _run_hook ( bot ) )
task . add_done_callback ( self . _bot_tasks . discard )
self . _bot_tasks . add ( task )
2022-02-06 14:52:50 +08:00
2023-08-26 11:03:24 +08:00
async def _cleanup ( self ) - > None :
""" 清理驱动器资源 """
if self . _bot_tasks :
logger . opt ( colors = True ) . debug (
" <y>Waiting for running bot connection hooks...</y> "
)
await asyncio . gather ( * self . _bot_tasks , return_exceptions = True )
2022-02-06 14:52:50 +08:00
2023-08-26 11:03:24 +08:00
class Mixin ( abc . ABC ) :
""" 可与其他驱动器共用的混入基类。 """
2022-02-06 14:52:50 +08:00
@property
@abc.abstractmethod
def type ( self ) - > str :
2023-08-26 11:03:24 +08:00
""" 混入驱动类型名称 """
2022-02-06 14:52:50 +08:00
raise NotImplementedError
2023-08-26 11:03:24 +08:00
class ForwardMixin ( Mixin ) :
""" 客户端混入基类。 """
class ReverseMixin ( Mixin ) :
""" 服务端混入基类。 """
2024-04-05 21:11:05 +08:00
class HTTPClientSession ( abc . ABC ) :
""" HTTP 客户端会话基类。 """
@abc.abstractmethod
def __init__ (
self ,
params : QueryTypes = None ,
headers : HeaderTypes = None ,
cookies : CookieTypes = None ,
version : Union [ str , HTTPVersion ] = HTTPVersion . H11 ,
timeout : Optional [ float ] = None ,
proxy : Optional [ str ] = None ,
) :
raise NotImplementedError
@abc.abstractmethod
async def request ( self , setup : Request ) - > Response :
""" 发送一个 HTTP 请求 """
raise NotImplementedError
@abc.abstractmethod
async def setup ( self ) - > None :
""" 初始化会话 """
raise NotImplementedError
@abc.abstractmethod
async def close ( self ) - > None :
""" 关闭会话 """
raise NotImplementedError
async def __aenter__ ( self ) - > Self :
await self . setup ( )
return self
async def __aexit__ (
self ,
2024-04-16 00:33:48 +08:00
exc_type : Optional [ type [ BaseException ] ] ,
2024-04-05 21:11:05 +08:00
exc : Optional [ BaseException ] ,
tb : Optional [ TracebackType ] ,
) - > None :
await self . close ( )
2023-08-26 11:03:24 +08:00
class HTTPClientMixin ( ForwardMixin ) :
""" HTTP 客户端混入基类。 """
2022-02-06 14:52:50 +08:00
@abc.abstractmethod
async def request ( self , setup : Request ) - > Response :
""" 发送一个 HTTP 请求 """
raise NotImplementedError
2024-04-05 21:11:05 +08:00
@abc.abstractmethod
def get_session (
self ,
params : QueryTypes = None ,
headers : HeaderTypes = None ,
cookies : CookieTypes = None ,
version : Union [ str , HTTPVersion ] = HTTPVersion . H11 ,
timeout : Optional [ float ] = None ,
proxy : Optional [ str ] = None ,
) - > HTTPClientSession :
""" 获取一个 HTTP 会话 """
raise NotImplementedError
2023-08-26 11:03:24 +08:00
class WebSocketClientMixin ( ForwardMixin ) :
""" WebSocket 客户端混入基类。 """
2022-02-06 14:52:50 +08:00
@abc.abstractmethod
@asynccontextmanager
async def websocket ( self , setup : Request ) - > AsyncGenerator [ WebSocket , None ] :
""" 发起一个 WebSocket 连接 """
raise NotImplementedError
yield # used for static type checking's generator detection
2023-08-26 11:03:24 +08:00
class ASGIMixin ( ReverseMixin ) :
""" ASGI 服务端基类。
2022-02-06 14:52:50 +08:00
2023-08-26 11:03:24 +08:00
将后端框架封装 , 以满足适配器使用 。
"""
2022-02-06 14:52:50 +08:00
@property
@abc.abstractmethod
def server_app ( self ) - > Any :
""" 驱动 APP 对象 """
raise NotImplementedError
@property
@abc.abstractmethod
def asgi ( self ) - > Any :
""" 驱动 ASGI 对象 """
raise NotImplementedError
@abc.abstractmethod
def setup_http_server ( self , setup : " HTTPServerSetup " ) - > None :
""" 设置一个 HTTP 服务器路由配置 """
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server ( self , setup : " WebSocketServerSetup " ) - > None :
""" 设置一个 WebSocket 服务器路由配置 """
raise NotImplementedError
2023-08-26 11:03:24 +08:00
ForwardDriver : TypeAlias = ForwardMixin
""" 支持客户端请求的驱动器。
* * Deprecated * * , 请使用 { ref } ` nonebot . drivers . ForwardMixin ` 或其子类代替 。
"""
ReverseDriver : TypeAlias = ReverseMixin
""" 支持服务端请求的驱动器。
* * Deprecated * * , 请使用 { ref } ` nonebot . drivers . ReverseMixin ` 或其子类代替 。
"""