diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py
index 1871eec1..7977acfb 100644
--- a/nonebot/drivers/aiohttp.py
+++ b/nonebot/drivers/aiohttp.py
@@ -212,6 +212,9 @@ class Driver(ForwardDriver):
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
+ logger.opt(colors=True).info(
+ f"Start http polling for {setup.adapter.upper()} "
+ f"Bot {setup.self_id}")
headers = request.headers
timeout = aiohttp.ClientTimeout(30)
@@ -289,11 +292,13 @@ class Driver(ForwardDriver):
)
try:
async with session.ws_connect(url) as ws:
+ logger.opt(colors=True).info(
+ f"WebSocket Connection to {setup.adapter.upper()} "
+ f"Bot {setup.self_id} succeeded!")
request = WebSocket(
setup.http_version, url.scheme, url.path,
- url.raw_query_string.encode("latin-1"), {
- **setup.headers, "host": host
- }, ws)
+ url.raw_query_string.encode("latin-1"), headers,
+ ws)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index 9e013a5f..caf9e2e9 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -11,19 +11,46 @@ FastAPI 驱动适配
import asyncio
import logging
from dataclasses import dataclass
-from typing import List, Optional, Callable
+from typing import List, Dict, Union, Optional, Callable
+import httpx
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
+from websockets.exceptions import ConnectionClosed
from fastapi import status, Request, FastAPI, HTTPException
+from websockets.legacy.client import Connect, WebSocketClientProtocol
from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket
as FastAPIWebSocket)
from nonebot.log import logger
+from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig
-from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
+from nonebot.drivers import ReverseDriver, ForwardDriver
+from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket
+
+
+@dataclass
+class HTTPPollingSetup:
+ adapter: str
+ self_id: str
+ url: str
+ method: str
+ body: bytes
+ headers: Dict[str, str]
+ http_version: str
+ poll_interval: float
+
+
+@dataclass
+class WebSocketSetup:
+ adapter: str
+ self_id: str
+ url: str
+ headers: Dict[str, str]
+ http_version: str
+ reconnect_interval: float
class Config(BaseSettings):
@@ -75,7 +102,7 @@ class Config(BaseSettings):
extra = "ignore"
-class Driver(ReverseDriver):
+class Driver(ReverseDriver, ForwardDriver):
"""
FastAPI 驱动框架
@@ -90,7 +117,11 @@ class Driver(ReverseDriver):
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
- self.fastapi_config = Config(**config.dict())
+ self.fastapi_config: Config = Config(**config.dict())
+ self.http_pollings: List[HTTPPollingSetup] = []
+ self.websockets: List[WebSocketSetup] = []
+ self.shutdown: asyncio.Event = asyncio.Event()
+ self.connections: List[asyncio.Task] = []
self._server_app = FastAPI(
debug=config.debug,
@@ -104,6 +135,9 @@ class Driver(ReverseDriver):
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
+ self.on_startup(self._run_forward)
+ self.on_shutdown(self._shutdown_forward)
+
@property
@overrides(ReverseDriver)
def type(self) -> str:
@@ -138,6 +172,32 @@ class Driver(ReverseDriver):
"""参考文档: `Events `_"""
return self.server_app.on_event("shutdown")(func)
+ @overrides(ForwardDriver)
+ def setup_http_polling(self,
+ adapter: str,
+ self_id: str,
+ url: str,
+ polling_interval: float = 3.,
+ method: str = "GET",
+ body: bytes = b"",
+ headers: Dict[str, str] = {},
+ http_version: str = "1.1") -> None:
+ self.http_pollings.append(
+ HTTPPollingSetup(adapter, self_id, url, method, body, headers,
+ http_version, polling_interval))
+
+ @overrides(ForwardDriver)
+ def setup_websocket(self,
+ adapter: str,
+ self_id: str,
+ url: str,
+ reconnect_interval: float = 3.,
+ headers: Dict[str, str] = {},
+ http_version: str = "1.1") -> None:
+ self.websockets.append(
+ WebSocketSetup(adapter, self_id, url, headers, http_version,
+ reconnect_interval))
+
@overrides(ReverseDriver)
def run(self,
host: Optional[str] = None,
@@ -166,14 +226,27 @@ class Driver(ReverseDriver):
},
},
}
- uvicorn.run(app or self.server_app,
- host=host or str(self.config.host),
- port=port or self.config.port,
- reload=bool(app) and self.config.debug,
- reload_dirs=self.fastapi_config.fastapi_reload_dirs or None,
- debug=self.config.debug,
- log_config=LOGGING_CONFIG,
- **kwargs)
+ uvicorn.run(
+ app or self.server_app, # type: ignore
+ host=host or str(self.config.host),
+ port=port or self.config.port,
+ reload=bool(app) and self.config.debug,
+ reload_dirs=self.fastapi_config.fastapi_reload_dirs or None,
+ debug=self.config.debug,
+ log_config=LOGGING_CONFIG,
+ **kwargs)
+
+ def _run_forward(self):
+ for setup in self.http_pollings:
+ self.connections.append(asyncio.create_task(self._http_loop(setup)))
+ for setup in self.websockets:
+ self.connections.append(asyncio.create_task(self._ws_loop(setup)))
+
+ def _shutdown_forward(self):
+ self.shutdown.set()
+ for task in self.connections:
+ if not task.done():
+ task.cancel()
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
@@ -263,37 +336,166 @@ class Driver(ReverseDriver):
finally:
self._bot_disconnect(bot)
+ async def _http_loop(self, setup: HTTPPollingSetup):
+ url = httpx.URL(setup.url)
+ if not url.netloc:
+ logger.opt(colors=True).error(
+ f"Error parsing url {url}")
+ return
+ request = HTTPRequest(
+ setup.http_version, url.scheme, url.path, url.query, {
+ **setup.headers, "host": url.netloc.decode("ascii")
+ }, setup.method, setup.body)
+
+ BotClass = self._adapters[setup.adapter]
+ bot = BotClass(setup.self_id, request)
+ self._bot_connect(bot)
+ logger.opt(colors=True).info(
+ f"Start http polling for {setup.adapter.upper()} "
+ f"Bot {setup.self_id}")
+
+ headers = request.headers
+ http2: bool = False
+ if request.http_version == "2":
+ http2 = True
+
+ try:
+ async with httpx.AsyncClient(headers=headers,
+ timeout=30.,
+ http2=http2) as session:
+ while not self.shutdown.is_set():
+ logger.debug(
+ f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
+ )
+ try:
+ response = await session.request(request.method,
+ url,
+ content=request.body)
+ response.raise_for_status()
+ data = response.read()
+ asyncio.create_task(bot.handle_message(data))
+ except httpx.HTTPError as e:
+ logger.opt(colors=True, exception=e).error(
+ f"Error occurred while requesting {url}. "
+ "Try to reconnect...")
+
+ await asyncio.sleep(setup.poll_interval)
+
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Unexpected exception occurred "
+ "while http polling")
+ finally:
+ self._bot_disconnect(bot)
+
+ async def _ws_loop(self, setup: WebSocketSetup):
+ url = httpx.URL(setup.url)
+ if not url.netloc:
+ logger.opt(colors=True).error(
+ f"Error parsing url {url}")
+ return
+
+ headers = {**setup.headers, "host": url.netloc.decode("ascii")}
+
+ bot: Optional[Bot] = None
+ try:
+ while True:
+ logger.debug(
+ f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
+ )
+ try:
+ connection = Connect(setup.url)
+ async with connection as ws:
+ logger.opt(colors=True).info(
+ f"WebSocket Connection to {setup.adapter.upper()} "
+ f"Bot {setup.self_id} succeeded!")
+ request = WebSocket(setup.http_version, url.scheme,
+ url.path, url.query, headers, ws)
+
+ BotClass = self._adapters[setup.adapter]
+ bot = BotClass(setup.self_id, request)
+ self._bot_connect(bot)
+ while not self.shutdown.is_set():
+ # use try except instead of "request.closed" because of queued message
+ try:
+ msg = await request.receive_bytes()
+ asyncio.create_task(bot.handle_message(msg))
+ except ConnectionClosed:
+ logger.opt(colors=True).error(
+ "WebSocket connection closed by peer. "
+ "Try to reconnect...")
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ f"Error while connecting to {url}. "
+ "Try to reconnect...")
+ finally:
+ if bot:
+ self._bot_disconnect(bot)
+ bot = None
+ await asyncio.sleep(setup.reconnect_interval)
+
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Unexpected exception occurred "
+ "while websocket loop")
+
@dataclass
class WebSocket(BaseWebSocket):
- websocket: FastAPIWebSocket = None # type: ignore
+ websocket: Union[FastAPIWebSocket,
+ WebSocketClientProtocol] = None # type: ignore
@property
@overrides(BaseWebSocket)
- def closed(self):
- return (self.websocket.client_state == WebSocketState.DISCONNECTED or
+ def closed(self) -> bool:
+ if isinstance(self.websocket, FastAPIWebSocket):
+ return (
+ self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED)
+ else:
+ return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
- await self.websocket.accept()
+ if isinstance(self.websocket, FastAPIWebSocket):
+ await self.websocket.accept()
+ else:
+ raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
- await self.websocket.close(code=code)
+ await self.websocket.close(code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
- return await self.websocket.receive_text()
+ if isinstance(self.websocket, FastAPIWebSocket):
+ return await self.websocket.receive_text()
+ else:
+ msg = await self.websocket.recv()
+ return msg.decode("utf-8") if isinstance(msg, bytes) else msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
- return await self.websocket.receive_bytes()
+ if isinstance(self.websocket, FastAPIWebSocket):
+ return await self.websocket.receive_bytes()
+ else:
+ msg = await self.websocket.recv()
+ return msg.encode("utf-8") if isinstance(msg, str) else msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
- await self.websocket.send({"type": "websocket.send", "text": data})
+ if isinstance(self.websocket, FastAPIWebSocket):
+ await self.websocket.send({"type": "websocket.send", "text": data})
+ else:
+ await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
- await self.websocket.send({"type": "websocket.send", "bytes": data})
+ if isinstance(self.websocket, FastAPIWebSocket):
+ await self.websocket.send({"type": "websocket.send", "bytes": data})
+ else:
+ await self.websocket.send(data)
diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py
index 21dc88d3..3c0c643f 100644
--- a/nonebot/drivers/quart.py
+++ b/nonebot/drivers/quart.py
@@ -140,14 +140,15 @@ class Driver(ReverseDriver):
},
},
}
- uvicorn.run(app or self.server_app,
- host=host or str(self.config.host),
- port=port or self.config.port,
- reload=bool(app) and self.config.debug,
- reload_dirs=self.quart_config.quart_reload_dirs or None,
- debug=self.config.debug,
- log_config=LOGGING_CONFIG,
- **kwargs)
+ uvicorn.run(
+ app or self.server_app, # type: ignore
+ host=host or str(self.config.host),
+ port=port or self.config.port,
+ reload=bool(app) and self.config.debug,
+ reload_dirs=self.quart_config.quart_reload_dirs or None,
+ debug=self.config.debug,
+ log_config=LOGGING_CONFIG,
+ **kwargs)
async def _handle_http(self, adapter: str):
request: Request = _request
diff --git a/pages/changelog.md b/pages/changelog.md
index 63c58c72..4245df40 100644
--- a/pages/changelog.md
+++ b/pages/changelog.md
@@ -14,6 +14,9 @@ sidebar: auto
- 修复 `type_updater` `permission_updater` 未传递的错误
- 修复 `type_updater` `permission_updater` 参数 `state` 错误
- 修复使用 `state_factory` 后导致无法在 session 内传递 `state`
+- 新增正向 Driver(Client) 支持
+- 新增 `aiohttp` 正向 Driver
+- `fastapi` Driver 新增正向支持
## v2.0.0a13.post1
diff --git a/tests/.env.dev b/tests/.env.dev
index 7056d21a..b04033ea 100644
--- a/tests/.env.dev
+++ b/tests/.env.dev
@@ -1,4 +1,4 @@
-DRIVER=nonebot.drivers.aiohttp:Driver
+DRIVER=nonebot.drivers.fastapi:Driver
HOST=0.0.0.0
PORT=2333
DEBUG=true