mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 09:05:04 +08:00
♻️ rewrite adapter abc class
This commit is contained in:
parent
180aaadda9
commit
d80c02ae46
@ -22,6 +22,7 @@ except Exception:
|
|||||||
|
|
||||||
from ._bot import Bot as Bot
|
from ._bot import Bot as Bot
|
||||||
from ._event import Event as Event
|
from ._event import Event as Event
|
||||||
|
from ._adapter import Adapter as Adapter
|
||||||
from ._message import Message as Message
|
from ._message import Message as Message
|
||||||
from ._message import MessageSegment as MessageSegment
|
from ._message import MessageSegment as MessageSegment
|
||||||
from ._template import MessageTemplate as MessageTemplate
|
from ._template import MessageTemplate as MessageTemplate
|
||||||
|
59
nonebot/adapters/_adapter.py
Normal file
59
nonebot/adapters/_adapter.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from ._bot import Bot
|
||||||
|
from nonebot.config import Config
|
||||||
|
from nonebot.drivers import (
|
||||||
|
Driver,
|
||||||
|
ForwardDriver,
|
||||||
|
ReverseDriver,
|
||||||
|
HTTPServerSetup,
|
||||||
|
WebSocketServerSetup,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Adapter(abc.ABC):
|
||||||
|
def __init__(self, driver: Driver, **kwargs: Any):
|
||||||
|
self.driver = driver
|
||||||
|
self.bots: Dict[str, Bot] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self) -> Config:
|
||||||
|
return self.driver.config
|
||||||
|
|
||||||
|
def bot_connect(self, bot: Bot):
|
||||||
|
self.driver._bot_connect(bot)
|
||||||
|
self.bots[bot.self_id] = bot
|
||||||
|
|
||||||
|
def bot_disconnect(self, bot: Bot):
|
||||||
|
self.driver._bot_disconnect(bot)
|
||||||
|
self.bots.pop(bot.self_id, None)
|
||||||
|
|
||||||
|
def setup_http_server(self, setup: HTTPServerSetup):
|
||||||
|
if not isinstance(self.driver, ReverseDriver):
|
||||||
|
raise TypeError("Current driver does not support http server")
|
||||||
|
self.driver.setup_http_server(setup)
|
||||||
|
|
||||||
|
def setup_websocket_server(self, setup: WebSocketServerSetup):
|
||||||
|
if not isinstance(self.driver, ReverseDriver):
|
||||||
|
raise TypeError("Current driver does not support websocket server")
|
||||||
|
self.driver.setup_websocket_server(setup)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def _call_api(self, api: str, **data) -> Any:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``api: str``: API 名称
|
||||||
|
* ``**data``: API 数据
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -12,6 +12,7 @@ from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ._event import Event
|
from ._event import Event
|
||||||
|
from ._adapter import Adapter
|
||||||
from ._message import Message, MessageSegment
|
from ._message import Message, MessageSegment
|
||||||
|
|
||||||
|
|
||||||
@ -25,10 +26,6 @@ class Bot(abc.ABC):
|
|||||||
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
driver: Driver
|
|
||||||
"""Driver 对象"""
|
|
||||||
config: Config
|
|
||||||
"""Config 配置对象"""
|
|
||||||
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
||||||
"""
|
"""
|
||||||
:类型: ``Set[T_CallingAPIHook]``
|
:类型: ``Set[T_CallingAPIHook]``
|
||||||
@ -40,36 +37,27 @@ class Bot(abc.ABC):
|
|||||||
:说明: call_api 后执行的函数
|
:说明: call_api 后执行的函数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, self_id: str, request: HTTPConnection):
|
def __init__(self, adapter: "Adapter", self_id: str):
|
||||||
"""
|
"""
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``self_id: str``: 机器人 ID
|
* ``self_id: str``: 机器人 ID
|
||||||
* ``request: HTTPConnection``: request 连接对象
|
* ``request: HTTPConnection``: request 连接对象
|
||||||
"""
|
"""
|
||||||
|
self.adapter = adapter
|
||||||
self.self_id: str = self_id
|
self.self_id: str = self_id
|
||||||
"""机器人 ID"""
|
"""机器人 ID"""
|
||||||
self.request: HTTPConnection = request
|
|
||||||
"""连接信息"""
|
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> _ApiCall:
|
def __getattr__(self, name: str) -> _ApiCall:
|
||||||
return partial(self.call_api, name)
|
return partial(self.call_api, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
"""Adapter 类型"""
|
return self.adapter.get_name()
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@property
|
||||||
def register(cls, driver: Driver, config: Config, **kwargs):
|
def config(self) -> Config:
|
||||||
"""
|
return self.adapter.config
|
||||||
:说明:
|
|
||||||
|
|
||||||
``register`` 方法会在 ``driver.register_adapter`` 时被调用,用于初始化相关配置
|
|
||||||
"""
|
|
||||||
cls.driver = driver
|
|
||||||
cls.config = config
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -106,20 +94,6 @@ class Bot(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def _call_api(self, api: str, **data) -> Any:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``api: str``: API 名称
|
|
||||||
* ``**data``: API 数据
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def call_api(self, api: str, **data: Any) -> Any:
|
async def call_api(self, api: str, **data: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -162,7 +136,7 @@ class Bot(abc.ABC):
|
|||||||
|
|
||||||
if not skip_calling_api:
|
if not skip_calling_api:
|
||||||
try:
|
try:
|
||||||
result = await self._call_api(api, **data)
|
result = await self.adapter._call_api(api, **data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception = e
|
exception = e
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ from nonebot.config import Env, Config
|
|||||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot, Adapter
|
||||||
|
|
||||||
|
|
||||||
class Driver(abc.ABC):
|
class Driver(abc.ABC):
|
||||||
@ -34,9 +34,9 @@ class Driver(abc.ABC):
|
|||||||
Driver 基类。
|
Driver 基类。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_adapters: Dict[str, Type["Bot"]] = {}
|
_adapters: Dict[str, "Adapter"] = {}
|
||||||
"""
|
"""
|
||||||
:类型: ``Dict[str, Type[Bot]]``
|
:类型: ``Dict[str, Adapter]``
|
||||||
:说明: 已注册的适配器列表
|
:说明: 已注册的适配器列表
|
||||||
"""
|
"""
|
||||||
_bot_connection_hook: Set[T_BotConnectionHook] = set()
|
_bot_connection_hook: Set[T_BotConnectionHook] = set()
|
||||||
@ -85,7 +85,7 @@ class Driver(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
return self._clients
|
return self._clients
|
||||||
|
|
||||||
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
|
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
@ -97,13 +97,13 @@ class Driver(abc.ABC):
|
|||||||
* ``adapter: Type[Bot]``: 适配器 Class
|
* ``adapter: Type[Bot]``: 适配器 Class
|
||||||
* ``**kwargs``: 其他传递给适配器的参数
|
* ``**kwargs``: 其他传递给适配器的参数
|
||||||
"""
|
"""
|
||||||
|
name = adapter.get_name()
|
||||||
if name in self._adapters:
|
if name in self._adapters:
|
||||||
logger.opt(colors=True).debug(
|
logger.opt(colors=True).debug(
|
||||||
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
|
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
self._adapters[name] = adapter
|
self._adapters[name] = adapter(self, **kwargs)
|
||||||
adapter.register(self, self.config, **kwargs)
|
|
||||||
logger.opt(colors=True).debug(
|
logger.opt(colors=True).debug(
|
||||||
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
|
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
|
||||||
)
|
)
|
||||||
@ -213,34 +213,11 @@ class ForwardDriver(Driver):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def setup_http_polling(
|
async def request(self, setup: "HTTPRequest") -> Any:
|
||||||
self,
|
|
||||||
setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def setup_websocket(
|
async def websocket(self, setup: "HTTPConnection") -> Any:
|
||||||
self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -261,7 +238,16 @@ class ReverseDriver(Driver):
|
|||||||
"""驱动 ASGI 对象"""
|
"""驱动 ASGI 对象"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: repack dataclass
|
||||||
@dataclass
|
@dataclass
|
||||||
class HTTPConnection(abc.ABC):
|
class HTTPConnection(abc.ABC):
|
||||||
http_version: str
|
http_version: str
|
||||||
@ -401,36 +387,13 @@ class WebSocket(HTTPConnection, abc.ABC):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HTTPPollingSetup:
|
class HTTPServerSetup:
|
||||||
adapter: str
|
path: str
|
||||||
"""协议适配器名称"""
|
|
||||||
self_id: str
|
|
||||||
"""机器人 ID"""
|
|
||||||
url: str
|
|
||||||
"""URL"""
|
|
||||||
method: str
|
method: str
|
||||||
"""HTTP method"""
|
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
|
||||||
body: bytes
|
|
||||||
"""HTTP body"""
|
|
||||||
headers: Dict[str, str]
|
|
||||||
"""HTTP headers"""
|
|
||||||
http_version: str
|
|
||||||
"""HTTP version"""
|
|
||||||
poll_interval: float
|
|
||||||
"""HTTP 轮询间隔"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebSocketSetup:
|
class WebSocketServerSetup:
|
||||||
adapter: str
|
path: str
|
||||||
"""协议适配器名称"""
|
handle_func: Callable[[WebSocket], Awaitable[Any]]
|
||||||
self_id: str
|
|
||||||
"""机器人 ID"""
|
|
||||||
url: str
|
|
||||||
"""URL"""
|
|
||||||
headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
"""HTTP headers"""
|
|
||||||
reconnect: bool = True
|
|
||||||
"""WebSocket 是否重连"""
|
|
||||||
reconnect_interval: float = 3.0
|
|
||||||
"""WebSocket 重连间隔"""
|
|
||||||
|
@ -12,38 +12,35 @@ FastAPI 驱动适配
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
|
from typing import Any, List, Union, Callable, Optional, Awaitable
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from websockets.exceptions import ConnectionClosed
|
from starlette.websockets import WebSocketState
|
||||||
from fastapi import FastAPI, Request, HTTPException, status
|
from fastapi import Depends, FastAPI, Request, status
|
||||||
from starlette.websockets import WebSocket as FastAPIWebSocket
|
from starlette.websockets import WebSocket as FastAPIWebSocket
|
||||||
from starlette.websockets import WebSocketState, WebSocketDisconnect
|
|
||||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||||
|
|
||||||
from nonebot.config import Env
|
from nonebot.config import Env
|
||||||
from nonebot.log import logger
|
|
||||||
from nonebot.adapters import Bot
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
|
from nonebot.drivers import WebSocket
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
HTTPRequest,
|
HTTPRequest,
|
||||||
|
HTTPResponse,
|
||||||
ForwardDriver,
|
ForwardDriver,
|
||||||
ReverseDriver,
|
ReverseDriver,
|
||||||
WebSocketSetup,
|
HTTPConnection,
|
||||||
HTTPPollingSetup,
|
HTTPServerSetup,
|
||||||
|
WebSocketServerSetup,
|
||||||
)
|
)
|
||||||
|
|
||||||
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
|
|
||||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
|
|
||||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseSettings):
|
class Config(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@ -136,16 +133,7 @@ class Config(BaseSettings):
|
|||||||
|
|
||||||
|
|
||||||
class Driver(ReverseDriver):
|
class Driver(ReverseDriver):
|
||||||
"""
|
"""FastAPI 驱动框架。包含反向 Server 功能。"""
|
||||||
FastAPI 驱动框架。包含反向 Server 功能。
|
|
||||||
|
|
||||||
:上报地址:
|
|
||||||
|
|
||||||
* ``/{adapter name}/``: HTTP POST 上报
|
|
||||||
* ``/{adapter name}/http/``: HTTP POST 上报
|
|
||||||
* ``/{adapter name}/ws``: WebSocket 上报
|
|
||||||
* ``/{adapter name}/ws/``: WebSocket 上报
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: Env, config: NoneBotConfig):
|
def __init__(self, env: Env, config: NoneBotConfig):
|
||||||
super(Driver, self).__init__(env, config)
|
super(Driver, self).__init__(env, config)
|
||||||
@ -159,11 +147,6 @@ class Driver(ReverseDriver):
|
|||||||
redoc_url=self.fastapi_config.fastapi_redoc_url,
|
redoc_url=self.fastapi_config.fastapi_redoc_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._server_app.post("/{adapter}/")(self._handle_http)
|
|
||||||
self._server_app.post("/{adapter}/http")(self._handle_http)
|
|
||||||
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
|
|
||||||
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -188,6 +171,30 @@ class Driver(ReverseDriver):
|
|||||||
"""fastapi 使用的 logger"""
|
"""fastapi 使用的 logger"""
|
||||||
return logging.getLogger("fastapi")
|
return logging.getLogger("fastapi")
|
||||||
|
|
||||||
|
@overrides(ReverseDriver)
|
||||||
|
def setup_http_server(self, setup: HTTPServerSetup):
|
||||||
|
def _get_handle_func():
|
||||||
|
return setup.handle_func
|
||||||
|
|
||||||
|
self._server_app.add_api_route(
|
||||||
|
setup.path,
|
||||||
|
partial(self._handle_http, handle_func=Depends(_get_handle_func)),
|
||||||
|
methods=[setup.method],
|
||||||
|
)
|
||||||
|
|
||||||
|
@overrides(ReverseDriver)
|
||||||
|
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
|
||||||
|
def _get_handle_func():
|
||||||
|
return setup.handle_func
|
||||||
|
|
||||||
|
self._server_app.add_api_websocket_route(
|
||||||
|
setup.path,
|
||||||
|
partial(
|
||||||
|
self._handle_ws,
|
||||||
|
handle_func=Depends(_get_handle_func),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
def on_startup(self, func: Callable) -> Callable:
|
def on_startup(self, func: Callable) -> Callable:
|
||||||
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
|
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
|
||||||
@ -241,19 +248,11 @@ class Driver(ReverseDriver):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_http(self, adapter: str, request: Request):
|
async def _handle_http(
|
||||||
data = await request.body()
|
self,
|
||||||
|
request: Request,
|
||||||
if adapter not in self._adapters:
|
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
|
||||||
logger.warning(
|
):
|
||||||
f"Unknown adapter {adapter}. Please register the adapter before use."
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 Bot 对象
|
|
||||||
BotClass = self._adapters[adapter]
|
|
||||||
http_request = HTTPRequest(
|
http_request = HTTPRequest(
|
||||||
request.scope["http_version"],
|
request.scope["http_version"],
|
||||||
request.url.scheme,
|
request.url.scheme,
|
||||||
@ -261,28 +260,17 @@ class Driver(ReverseDriver):
|
|||||||
request.scope["query_string"],
|
request.scope["query_string"],
|
||||||
dict(request.headers),
|
dict(request.headers),
|
||||||
request.method,
|
request.method,
|
||||||
data,
|
await request.body(),
|
||||||
)
|
|
||||||
x_self_id, response = await BotClass.check_permission(self, http_request)
|
|
||||||
|
|
||||||
if not x_self_id:
|
|
||||||
raise HTTPException(
|
|
||||||
response and response.status or 401,
|
|
||||||
response and response.body and response.body.decode("utf-8"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if x_self_id in self._clients:
|
response = await handle_func(http_request)
|
||||||
logger.warning(
|
return Response(response.body, response.status, response.headers)
|
||||||
"There's already a reverse websocket connection,"
|
|
||||||
"so the event may be handled twice."
|
|
||||||
)
|
|
||||||
|
|
||||||
bot = BotClass(x_self_id, http_request)
|
async def _handle_ws(
|
||||||
|
self,
|
||||||
asyncio.create_task(bot.handle_message(data))
|
websocket: FastAPIWebSocket,
|
||||||
return Response(response and response.body, response and response.status or 200)
|
handle_func: Callable[[WebSocket], Awaitable[Any]],
|
||||||
|
):
|
||||||
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
|
|
||||||
ws = WebSocket(
|
ws = WebSocket(
|
||||||
websocket.scope.get("http_version", "1.1"),
|
websocket.scope.get("http_version", "1.1"),
|
||||||
websocket.url.scheme,
|
websocket.url.scheme,
|
||||||
@ -292,55 +280,7 @@ class Driver(ReverseDriver):
|
|||||||
websocket,
|
websocket,
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter not in self._adapters:
|
await handle_func(ws)
|
||||||
logger.warning(
|
|
||||||
f"Unknown adapter {adapter}. Please register the adapter before use."
|
|
||||||
)
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create Bot Object
|
|
||||||
BotClass = self._adapters[adapter]
|
|
||||||
self_id, _ = await BotClass.check_permission(self, ws)
|
|
||||||
|
|
||||||
if not self_id:
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self_id in self._clients:
|
|
||||||
logger.opt(colors=True).warning(
|
|
||||||
"There's already a websocket connection, "
|
|
||||||
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
|
|
||||||
)
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
||||||
return
|
|
||||||
|
|
||||||
bot = BotClass(self_id, ws)
|
|
||||||
|
|
||||||
await ws.accept()
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
|
|
||||||
f"Bot {escape_tag(self_id)}</y> Accepted!"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._bot_connect(bot)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not ws.closed:
|
|
||||||
try:
|
|
||||||
data = await ws.receive()
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
logger.error("WebSocket disconnected by peer.")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(exception=e).error(
|
|
||||||
"Error when receiving data from websocket."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
asyncio.create_task(bot.handle_message(data.encode()))
|
|
||||||
finally:
|
|
||||||
self._bot_disconnect(bot)
|
|
||||||
|
|
||||||
|
|
||||||
class FullDriver(ForwardDriver, Driver):
|
class FullDriver(ForwardDriver, Driver):
|
||||||
@ -354,17 +294,6 @@ class FullDriver(ForwardDriver, Driver):
|
|||||||
DRIVER=nonebot.drivers.fastapi:FullDriver
|
DRIVER=nonebot.drivers.fastapi:FullDriver
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, config: NoneBotConfig):
|
|
||||||
super(FullDriver, self).__init__(env, config)
|
|
||||||
|
|
||||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
|
||||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
|
||||||
self.shutdown: asyncio.Event = asyncio.Event()
|
|
||||||
self.connections: List[asyncio.Task] = []
|
|
||||||
|
|
||||||
self.on_startup(self._run_forward)
|
|
||||||
self.on_shutdown(self._shutdown_forward)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardDriver)
|
@overrides(ForwardDriver)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -372,217 +301,25 @@ class FullDriver(ForwardDriver, Driver):
|
|||||||
return "fastapi_full"
|
return "fastapi_full"
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(ForwardDriver)
|
||||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
async def request(self, setup: "HTTPRequest") -> Any:
|
||||||
"""
|
async with httpx.AsyncClient(
|
||||||
:说明:
|
http2=setup.http_version == "2", follow_redirects=True
|
||||||
|
) as client:
|
||||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
|
||||||
"""
|
|
||||||
self.http_pollings.append(setup)
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
|
||||||
"""
|
|
||||||
self.websockets.append(setup)
|
|
||||||
|
|
||||||
def _run_forward(self):
|
|
||||||
for setup in self.http_pollings:
|
|
||||||
self.connections.append(asyncio.create_task(self._http_loop(setup)))
|
|
||||||
for setup in self.websockets:
|
|
||||||
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
|
|
||||||
|
|
||||||
def _shutdown_forward(self):
|
|
||||||
self.shutdown.set()
|
|
||||||
for task in self.connections:
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
async def _prepare_setup(
|
|
||||||
self, setup: Union[S, Callable[[], Awaitable[S]]]
|
|
||||||
) -> Optional[S]:
|
|
||||||
try:
|
|
||||||
if callable(setup):
|
|
||||||
return await setup()
|
|
||||||
else:
|
|
||||||
return setup
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error while parsing setup "
|
|
||||||
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
|
||||||
url = httpx.URL(setup.url)
|
|
||||||
if not url.netloc:
|
|
||||||
logger.opt(colors=True).error(
|
|
||||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
return HTTPRequest(
|
|
||||||
setup.http_version,
|
|
||||||
url.scheme,
|
|
||||||
url.path,
|
|
||||||
url.query,
|
|
||||||
setup.headers,
|
|
||||||
setup.method,
|
|
||||||
setup.body,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
|
|
||||||
|
|
||||||
http2: bool = False
|
|
||||||
bot: Optional[Bot] = None
|
|
||||||
request: Optional[HTTPRequest] = None
|
|
||||||
client: Optional[httpx.AsyncClient] = None
|
|
||||||
|
|
||||||
# FIXME: seperate const values from setup (self_id, adapter)
|
|
||||||
# logger.opt(colors=True).info(
|
|
||||||
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
|
|
||||||
# f"Bot {escape_tag(_setup.self_id)}</y>"
|
|
||||||
# )
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not self.shutdown.is_set():
|
|
||||||
|
|
||||||
setup = await self._prepare_setup(_setup)
|
|
||||||
if not setup:
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
continue
|
|
||||||
request = self._build_http_request(setup)
|
|
||||||
if not request:
|
|
||||||
await asyncio.sleep(setup.poll_interval)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
|
|
||||||
elif http2 != (setup.http_version == "2"):
|
|
||||||
await client.aclose()
|
|
||||||
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
|
|
||||||
http2 = setup.http_version == "2"
|
|
||||||
|
|
||||||
if not bot:
|
|
||||||
BotClass = self._adapters[setup.adapter]
|
|
||||||
bot = BotClass(setup.self_id, request)
|
|
||||||
self._bot_connect(bot)
|
|
||||||
else:
|
|
||||||
bot.request = request
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await client.request(
|
response = await client.request(
|
||||||
request.method,
|
setup.method,
|
||||||
setup.url,
|
setup.url,
|
||||||
content=request.body,
|
content=setup.body,
|
||||||
headers=request.headers,
|
headers=setup.headers,
|
||||||
timeout=30.0,
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
return HTTPResponse(
|
||||||
data = response.read()
|
response.status_code, response.content, response.headers
|
||||||
asyncio.create_task(bot.handle_message(data))
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(setup.poll_interval)
|
@overrides(ForwardDriver)
|
||||||
|
async def websocket(self, setup: "HTTPConnection") -> Any:
|
||||||
except asyncio.CancelledError:
|
ws = await Connect(setup.url, extra_headers=setup.headers)
|
||||||
pass
|
return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
|
||||||
"while http polling</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if bot:
|
|
||||||
self._bot_disconnect(bot)
|
|
||||||
if client:
|
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
|
|
||||||
bot: Optional[Bot] = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
|
|
||||||
setup = await self._prepare_setup(_setup)
|
|
||||||
if not setup:
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
continue
|
|
||||||
|
|
||||||
url = httpx.URL(setup.url)
|
|
||||||
if not url.netloc:
|
|
||||||
logger.opt(colors=True).error(
|
|
||||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
connection = Connect(setup.url, extra_headers=setup.headers)
|
|
||||||
async with connection as ws:
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
|
|
||||||
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
|
|
||||||
)
|
|
||||||
request = WebSocket(
|
|
||||||
"1.1", url.scheme, url.path, url.query, setup.headers, ws
|
|
||||||
)
|
|
||||||
|
|
||||||
BotClass = self._adapters[setup.adapter]
|
|
||||||
bot = BotClass(setup.self_id, request)
|
|
||||||
self._bot_connect(bot)
|
|
||||||
while not self.shutdown.is_set():
|
|
||||||
# use try except instead of "request.closed" because of queued message
|
|
||||||
try:
|
|
||||||
msg = await request.receive_bytes()
|
|
||||||
asyncio.create_task(bot.handle_message(msg))
|
|
||||||
except ConnectionClosed:
|
|
||||||
logger.opt(colors=True).error(
|
|
||||||
"<r><bg #f8bbd0>WebSocket connection closed. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
f"<r><bg #f8bbd0>Error while connecting to {url}. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if bot:
|
|
||||||
self._bot_disconnect(bot)
|
|
||||||
bot = None
|
|
||||||
|
|
||||||
if not setup.reconnect:
|
|
||||||
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
|
|
||||||
break
|
|
||||||
await asyncio.sleep(setup.reconnect_interval)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
|
||||||
"while websocket loop</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -111,9 +111,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
# FIXME: run handler with try/except skipped exception
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
|
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
|
||||||
try:
|
try:
|
||||||
return await coro
|
return await coro
|
||||||
|
@ -14,13 +14,12 @@ from contextlib import AsyncExitStack
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
|
||||||
Type,
|
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
Callable,
|
Callable,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
|
Coroutine,
|
||||||
)
|
)
|
||||||
|
|
||||||
from nonebot import params
|
from nonebot import params
|
||||||
@ -30,6 +29,13 @@ from nonebot.exception import SkippedException
|
|||||||
from nonebot.typing import T_PermissionChecker
|
from nonebot.typing import T_PermissionChecker
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
||||||
|
try:
|
||||||
|
return await coro
|
||||||
|
except SkippedException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Permission:
|
class Permission:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -100,20 +106,18 @@ class Permission:
|
|||||||
return True
|
return True
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
|
_run_coro_with_catch(
|
||||||
checker(
|
checker(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
_stack=stack,
|
_stack=stack,
|
||||||
_dependency_cache=dependency_cache,
|
_dependency_cache=dependency_cache,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for checker in self.checkers
|
for checker in self.checkers
|
||||||
),
|
),
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
return next(
|
|
||||||
filter(lambda x: bool(x) and not isinstance(x, SkippedException), results),
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
|
return any(results)
|
||||||
|
|
||||||
def __and__(self, other) -> NoReturn:
|
def __and__(self, other) -> NoReturn:
|
||||||
raise RuntimeError("And operation between Permissions is not allowed.")
|
raise RuntimeError("And operation between Permissions is not allowed.")
|
||||||
|
Loading…
Reference in New Issue
Block a user