diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 1871eec1..7977acfb 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -212,6 +212,9 @@ class Driver(ForwardDriver): BotClass = self._adapters[setup.adapter] bot = BotClass(setup.self_id, request) self._bot_connect(bot) + logger.opt(colors=True).info( + f"Start http polling for {setup.adapter.upper()} " + f"Bot {setup.self_id}") headers = request.headers timeout = aiohttp.ClientTimeout(30) @@ -289,11 +292,13 @@ class Driver(ForwardDriver): ) try: async with session.ws_connect(url) as ws: + logger.opt(colors=True).info( + f"WebSocket Connection to {setup.adapter.upper()} " + f"Bot {setup.self_id} succeeded!") request = WebSocket( setup.http_version, url.scheme, url.path, - url.raw_query_string.encode("latin-1"), { - **setup.headers, "host": host - }, ws) + url.raw_query_string.encode("latin-1"), headers, + ws) BotClass = self._adapters[setup.adapter] bot = BotClass(setup.self_id, request) diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 9e013a5f..caf9e2e9 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -11,19 +11,46 @@ FastAPI 驱动适配 import asyncio import logging from dataclasses import dataclass -from typing import List, Optional, Callable +from typing import List, Dict, Union, Optional, Callable +import httpx import uvicorn from pydantic import BaseSettings from fastapi.responses import Response +from websockets.exceptions import ConnectionClosed from fastapi import status, Request, FastAPI, HTTPException +from websockets.legacy.client import Connect, WebSocketClientProtocol from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket as FastAPIWebSocket) from nonebot.log import logger +from nonebot.adapters import Bot from nonebot.typing import overrides from nonebot.config import Env, Config as NoneBotConfig -from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket +from nonebot.drivers import ReverseDriver, ForwardDriver +from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket + + +@dataclass +class HTTPPollingSetup: + adapter: str + self_id: str + url: str + method: str + body: bytes + headers: Dict[str, str] + http_version: str + poll_interval: float + + +@dataclass +class WebSocketSetup: + adapter: str + self_id: str + url: str + headers: Dict[str, str] + http_version: str + reconnect_interval: float class Config(BaseSettings): @@ -75,7 +102,7 @@ class Config(BaseSettings): extra = "ignore" -class Driver(ReverseDriver): +class Driver(ReverseDriver, ForwardDriver): """ FastAPI 驱动框架 @@ -90,7 +117,11 @@ class Driver(ReverseDriver): def __init__(self, env: Env, config: NoneBotConfig): super().__init__(env, config) - self.fastapi_config = Config(**config.dict()) + self.fastapi_config: Config = Config(**config.dict()) + self.http_pollings: List[HTTPPollingSetup] = [] + self.websockets: List[WebSocketSetup] = [] + self.shutdown: asyncio.Event = asyncio.Event() + self.connections: List[asyncio.Task] = [] self._server_app = FastAPI( debug=config.debug, @@ -104,6 +135,9 @@ class Driver(ReverseDriver): self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) + self.on_startup(self._run_forward) + self.on_shutdown(self._shutdown_forward) + @property @overrides(ReverseDriver) def type(self) -> str: @@ -138,6 +172,32 @@ class Driver(ReverseDriver): """参考文档: `Events `_""" return self.server_app.on_event("shutdown")(func) + @overrides(ForwardDriver) + def setup_http_polling(self, + adapter: str, + self_id: str, + url: str, + polling_interval: float = 3., + method: str = "GET", + body: bytes = b"", + headers: Dict[str, str] = {}, + http_version: str = "1.1") -> None: + self.http_pollings.append( + HTTPPollingSetup(adapter, self_id, url, method, body, headers, + http_version, polling_interval)) + + @overrides(ForwardDriver) + def setup_websocket(self, + adapter: str, + self_id: str, + url: str, + reconnect_interval: float = 3., + headers: Dict[str, str] = {}, + http_version: str = "1.1") -> None: + self.websockets.append( + WebSocketSetup(adapter, self_id, url, headers, http_version, + reconnect_interval)) + @overrides(ReverseDriver) def run(self, host: Optional[str] = None, @@ -166,14 +226,27 @@ class Driver(ReverseDriver): }, }, } - uvicorn.run(app or self.server_app, - host=host or str(self.config.host), - port=port or self.config.port, - reload=bool(app) and self.config.debug, - reload_dirs=self.fastapi_config.fastapi_reload_dirs or None, - debug=self.config.debug, - log_config=LOGGING_CONFIG, - **kwargs) + uvicorn.run( + app or self.server_app, # type: ignore + host=host or str(self.config.host), + port=port or self.config.port, + reload=bool(app) and self.config.debug, + reload_dirs=self.fastapi_config.fastapi_reload_dirs or None, + debug=self.config.debug, + log_config=LOGGING_CONFIG, + **kwargs) + + def _run_forward(self): + for setup in self.http_pollings: + self.connections.append(asyncio.create_task(self._http_loop(setup))) + for setup in self.websockets: + self.connections.append(asyncio.create_task(self._ws_loop(setup))) + + def _shutdown_forward(self): + self.shutdown.set() + for task in self.connections: + if not task.done(): + task.cancel() async def _handle_http(self, adapter: str, request: Request): data = await request.body() @@ -263,37 +336,166 @@ class Driver(ReverseDriver): finally: self._bot_disconnect(bot) + async def _http_loop(self, setup: HTTPPollingSetup): + url = httpx.URL(setup.url) + if not url.netloc: + logger.opt(colors=True).error( + f"Error parsing url {url}") + return + request = HTTPRequest( + setup.http_version, url.scheme, url.path, url.query, { + **setup.headers, "host": url.netloc.decode("ascii") + }, setup.method, setup.body) + + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + logger.opt(colors=True).info( + f"Start http polling for {setup.adapter.upper()} " + f"Bot {setup.self_id}") + + headers = request.headers + http2: bool = False + if request.http_version == "2": + http2 = True + + try: + async with httpx.AsyncClient(headers=headers, + timeout=30., + http2=http2) as session: + while not self.shutdown.is_set(): + logger.debug( + f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" + ) + try: + response = await session.request(request.method, + url, + content=request.body) + response.raise_for_status() + data = response.read() + asyncio.create_task(bot.handle_message(data)) + except httpx.HTTPError as e: + logger.opt(colors=True, exception=e).error( + f"Error occurred while requesting {url}. " + "Try to reconnect...") + + await asyncio.sleep(setup.poll_interval) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Unexpected exception occurred " + "while http polling") + finally: + self._bot_disconnect(bot) + + async def _ws_loop(self, setup: WebSocketSetup): + url = httpx.URL(setup.url) + if not url.netloc: + logger.opt(colors=True).error( + f"Error parsing url {url}") + return + + headers = {**setup.headers, "host": url.netloc.decode("ascii")} + + bot: Optional[Bot] = None + try: + while True: + logger.debug( + f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" + ) + try: + connection = Connect(setup.url) + async with connection as ws: + logger.opt(colors=True).info( + f"WebSocket Connection to {setup.adapter.upper()} " + f"Bot {setup.self_id} succeeded!") + request = WebSocket(setup.http_version, url.scheme, + url.path, url.query, headers, ws) + + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + while not self.shutdown.is_set(): + # use try except instead of "request.closed" because of queued message + try: + msg = await request.receive_bytes() + asyncio.create_task(bot.handle_message(msg)) + except ConnectionClosed: + logger.opt(colors=True).error( + "WebSocket connection closed by peer. " + "Try to reconnect...") + except Exception as e: + logger.opt(colors=True, exception=e).error( + f"Error while connecting to {url}. " + "Try to reconnect...") + finally: + if bot: + self._bot_disconnect(bot) + bot = None + await asyncio.sleep(setup.reconnect_interval) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Unexpected exception occurred " + "while websocket loop") + @dataclass class WebSocket(BaseWebSocket): - websocket: FastAPIWebSocket = None # type: ignore + websocket: Union[FastAPIWebSocket, + WebSocketClientProtocol] = None # type: ignore @property @overrides(BaseWebSocket) - def closed(self): - return (self.websocket.client_state == WebSocketState.DISCONNECTED or + def closed(self) -> bool: + if isinstance(self.websocket, FastAPIWebSocket): + return ( + self.websocket.client_state == WebSocketState.DISCONNECTED or self.websocket.application_state == WebSocketState.DISCONNECTED) + else: + return self.websocket.closed @overrides(BaseWebSocket) async def accept(self): - await self.websocket.accept() + if isinstance(self.websocket, FastAPIWebSocket): + await self.websocket.accept() + else: + raise NotImplementedError @overrides(BaseWebSocket) async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): - await self.websocket.close(code=code) + await self.websocket.close(code) @overrides(BaseWebSocket) async def receive(self) -> str: - return await self.websocket.receive_text() + if isinstance(self.websocket, FastAPIWebSocket): + return await self.websocket.receive_text() + else: + msg = await self.websocket.recv() + return msg.decode("utf-8") if isinstance(msg, bytes) else msg @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: - return await self.websocket.receive_bytes() + if isinstance(self.websocket, FastAPIWebSocket): + return await self.websocket.receive_bytes() + else: + msg = await self.websocket.recv() + return msg.encode("utf-8") if isinstance(msg, str) else msg @overrides(BaseWebSocket) async def send(self, data: str) -> None: - await self.websocket.send({"type": "websocket.send", "text": data}) + if isinstance(self.websocket, FastAPIWebSocket): + await self.websocket.send({"type": "websocket.send", "text": data}) + else: + await self.websocket.send(data) @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: - await self.websocket.send({"type": "websocket.send", "bytes": data}) + if isinstance(self.websocket, FastAPIWebSocket): + await self.websocket.send({"type": "websocket.send", "bytes": data}) + else: + await self.websocket.send(data) diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 21dc88d3..3c0c643f 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -140,14 +140,15 @@ class Driver(ReverseDriver): }, }, } - uvicorn.run(app or self.server_app, - host=host or str(self.config.host), - port=port or self.config.port, - reload=bool(app) and self.config.debug, - reload_dirs=self.quart_config.quart_reload_dirs or None, - debug=self.config.debug, - log_config=LOGGING_CONFIG, - **kwargs) + uvicorn.run( + app or self.server_app, # type: ignore + host=host or str(self.config.host), + port=port or self.config.port, + reload=bool(app) and self.config.debug, + reload_dirs=self.quart_config.quart_reload_dirs or None, + debug=self.config.debug, + log_config=LOGGING_CONFIG, + **kwargs) async def _handle_http(self, adapter: str): request: Request = _request diff --git a/pages/changelog.md b/pages/changelog.md index 63c58c72..4245df40 100644 --- a/pages/changelog.md +++ b/pages/changelog.md @@ -14,6 +14,9 @@ sidebar: auto - 修复 `type_updater` `permission_updater` 未传递的错误 - 修复 `type_updater` `permission_updater` 参数 `state` 错误 - 修复使用 `state_factory` 后导致无法在 session 内传递 `state` +- 新增正向 Driver(Client) 支持 +- 新增 `aiohttp` 正向 Driver +- `fastapi` Driver 新增正向支持 ## v2.0.0a13.post1 diff --git a/tests/.env.dev b/tests/.env.dev index 7056d21a..b04033ea 100644 --- a/tests/.env.dev +++ b/tests/.env.dev @@ -1,4 +1,4 @@ -DRIVER=nonebot.drivers.aiohttp:Driver +DRIVER=nonebot.drivers.fastapi:Driver HOST=0.0.0.0 PORT=2333 DEBUG=true