diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 9f7c69b7..ef6f4cc5 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -11,7 +11,12 @@ from nonebot.config import Config class BaseBot(abc.ABC): @abc.abstractmethod - def __init__(self, type: str, config: Config, *, websocket=None): + def __init__(self, + type: str, + config: Config, + self_id: int, + *, + websocket=None): raise NotImplementedError @abc.abstractmethod diff --git a/nonebot/adapters/coolq.py b/nonebot/adapters/coolq.py index 3de28006..2bc46b75 100644 --- a/nonebot/adapters/coolq.py +++ b/nonebot/adapters/coolq.py @@ -10,6 +10,7 @@ from nonebot.event import Event from nonebot.config import Config from nonebot.message import handle_event from nonebot.drivers import BaseWebSocket +from nonebot.exception import ApiNotAvailable from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment @@ -44,6 +45,7 @@ class Bot(BaseBot): def __init__(self, connection_type: str, config: Config, + self_id: int, *, websocket: BaseWebSocket = None): if connection_type not in ["http", "websocket"]: @@ -51,6 +53,7 @@ class Bot(BaseBot): self.type = "coolq" self.connection_type = connection_type self.config = config + self.self_id = self_id self.websocket = websocket async def handle_message(self, message: dict): @@ -70,7 +73,24 @@ class Bot(BaseBot): if self.type == "websocket": pass elif self.type == "http": - pass + api_root = self.config.api_root.get(self.self_id) + if not api_root: + raise ApiNotAvailable + elif not api_root.endswith("/"): + api_root += "/" + + headers = {} + if self.config.access_token: + headers["Authorization"] = "Bearer " + self.config.access_token + + async with httpx.AsyncClient() as client: + response = await client.post(api_root + api) + + if 200 <= response.status_code < 300: + # TODO: handle http api response + return ... + raise httpx.HTTPError( + "", response) class MessageSegment(BaseMessageSegment): diff --git a/nonebot/config.py b/nonebot/config.py index b06c0625..8e4d8efb 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Set, Union from ipaddress import IPv4Address +from typing import Set, Dict, Union, Optional from pydantic import BaseSettings @@ -15,14 +15,22 @@ class Env(BaseSettings): class Config(BaseSettings): + # nonebot configs driver: str = "nonebot.drivers.fastapi" host: IPv4Address = IPv4Address("127.0.0.1") port: int = 8080 + secret: Optional[str] = None debug: bool = False + # bot connection configs + api_root: Dict[int, str] = {} + access_token: Optional[str] = None + + # bot runtime configs superusers: Set[int] = set() nickname: Union[str, Set[str]] = "" + # custom configs custom_config: dict = {} class Config: diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index dcafe0df..0028a827 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -67,7 +67,7 @@ class BaseWebSocket(object): raise NotImplementedError @abc.abstractmethod - async def close(self): + async def close(self, code: int): raise NotImplementedError @abc.abstractmethod diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 32ab8693..e9acf773 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -3,18 +3,19 @@ import json import logging -from typing import Optional +from typing import Dict, Optional from ipaddress import IPv4Address import uvicorn from fastapi.security import OAuth2PasswordBearer from starlette.websockets import WebSocketDisconnect -from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket +from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket from nonebot.log import logger from nonebot.config import Config -from nonebot.drivers import BaseDriver, BaseWebSocket +from nonebot.adapters import BaseBot from nonebot.adapters.coolq import Bot as CoolQBot +from nonebot.drivers import BaseDriver, BaseWebSocket class Driver(BaseDriver): @@ -28,6 +29,7 @@ class Driver(BaseDriver): ) self.config = config + self._clients: Dict[int, BaseBot] = {} self._server_app.post("/{adapter}/")(self._handle_http) self._server_app.post("/{adapter}/http")(self._handle_http) @@ -43,9 +45,13 @@ class Driver(BaseDriver): return self._server_app @property - def logger(self): + def logger(self) -> logging.Logger: return logging.getLogger("fastapi") + @property + def bots(self) -> Dict[int, BaseBot]: + return self._clients + def run(self, host: Optional[IPv4Address] = None, port: Optional[int] = None, @@ -102,12 +108,25 @@ class Driver(BaseDriver): async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket, + self_id: int = Header(None), access_token: str = OAuth2PasswordBearer( "/", auto_error=False)): websocket = WebSocket(websocket) # TODO: Check authorization + + # Create Bot Object + if adapter == "coolq": + bot = CoolQBot("websocket", + self.config, + self_id, + websocket=websocket) + else: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + return + await websocket.accept() + self._clients[self_id] = bot while not websocket.closed: data = await websocket.receive() @@ -115,10 +134,9 @@ class Driver(BaseDriver): if not data: continue - logger.debug(f"Received message: {data}") - if adapter == "coolq": - bot = CoolQBot("websocket", self.config, websocket=websocket) - await bot.handle_message(data) + await bot.handle_message(data) + + del self._clients[self_id] class WebSocket(BaseWebSocket): @@ -135,8 +153,8 @@ class WebSocket(BaseWebSocket): await self.websocket.accept() self._closed = False - async def close(self): - await self.websocket.close() + async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): + await self.websocket.close(code=code) self._closed = True async def receive(self) -> Optional[dict]: diff --git a/nonebot/exception.py b/nonebot/exception.py index 83de121e..de001dad 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -28,3 +28,8 @@ class RejectedException(Exception): class FinishedException(Exception): """Finish handling a message""" pass + + +class ApiNotAvailable(Exception): + """Api is not available""" + pass