diff --git a/docs/api/adapters/README.md b/docs/api/adapters/README.md index 0036f014..9337ec17 100644 --- a/docs/api/adapters/README.md +++ b/docs/api/adapters/README.md @@ -57,7 +57,7 @@ Config 配置对象 -### _abstract_ `__init__(connection_type, self_id, *, websocket=None)` +### `__init__(self_id, request)` * **参数** @@ -73,19 +73,14 @@ Config 配置对象 -### `connection_type` - -连接类型 - - ### `self_id` 机器人 ID -### `websocket` +### `request` -Websocket 连接对象 +连接信息 ### _abstract property_ `type` @@ -102,7 +97,7 @@ Adapter 类型 -### _abstract async classmethod_ `check_permission(driver, connection_type, headers, body)` +### _abstract async classmethod_ `check_permission(driver, request)` * **说明** @@ -130,7 +125,10 @@ Adapter 类型 * **返回** - * `str`: 连接唯一标识符 + * `str`: 连接唯一标识符,`None` 代表连接不合法 + + + * `HTTPResponse`: HTTP 上报响应 @@ -153,7 +151,7 @@ Adapter 类型 * **参数** - * `message: dict`: 收到的上报消息 + * `message: bytes`: 收到的上报消息 diff --git a/docs/api/adapters/cqhttp.md b/docs/api/adapters/cqhttp.md index ea6e886e..760377e9 100644 --- a/docs/api/adapters/cqhttp.md +++ b/docs/api/adapters/cqhttp.md @@ -204,7 +204,7 @@ CQHTTP 协议 Bot 适配。继承属性参考 [BaseBot](./#class-basebot) 。 * 返回: `"cqhttp"` -### _async classmethod_ `check_permission(driver, connection_type, headers, body)` +### _async classmethod_ `check_permission(driver, request)` * **说明** diff --git a/docs/api/adapters/ding.md b/docs/api/adapters/ding.md index 12057886..57410533 100644 --- a/docs/api/adapters/ding.md +++ b/docs/api/adapters/ding.md @@ -105,7 +105,7 @@ sidebarDepth: 0 * 返回: `"ding"` -### _async classmethod_ `check_permission(driver, connection_type, headers, body)` +### _async classmethod_ `check_permission(driver, request)` * **说明** diff --git a/docs/api/adapters/mirai.md b/docs/api/adapters/mirai.md index c7f72f26..7804c06e 100644 --- a/docs/api/adapters/mirai.md +++ b/docs/api/adapters/mirai.md @@ -682,6 +682,11 @@ API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则 # NoneBot.adapters.mirai.bot_ws 模块 +## _class_ `WebSocket` + +基类:[`nonebot.drivers.WebSocket`](../drivers/README.md#nonebot.drivers.WebSocket) + + ## _class_ `WebsocketBot` 基类:`nonebot.adapters.mirai.bot.Bot` diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md index 7e302f35..903c5d13 100644 --- a/docs/api/drivers/README.md +++ b/docs/api/drivers/README.md @@ -268,74 +268,71 @@ Reverse Driver 基类。将后端框架封装,以满足适配器使用。 用于处理 WebSocket 类型请求的函数 -## _class_ `HTTPRequest` +## _class_ `HTTPConnection` -基类:`object` - -HTTP 请求封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope)。 +基类:`abc.ABC` -### _property_ `type` - -Always http - - -### _property_ `scope` - -Raw scope from asgi. - -The connection scope information, a dictionary that -contains at least a type key specifying the protocol that is incoming. - - -### _property_ `http_version` +### `http_version` One of "1.0", "1.1" or "2". -### _property_ `method` - -The HTTP method name, uppercased. - - -### _property_ `schema` +### `scheme` URL scheme portion (likely "http" or "https"). -Optional (but must not be empty); default is "http". -### _property_ `path` +### `path` HTTP request target excluding any query string, with percent-encoded sequences and UTF-8 byte sequences decoded into characters. -### _property_ `query_string` +### `query_string` URL portion after the ?, percent-encoded. -### _property_ `headers` +### `headers` -An iterable of [name, value] two-item iterables, +A dict of name-value pairs, where name is the header name, and value is the header value. Order of header values must be preserved from the original HTTP request; order of header names is not important. -Duplicates are possible and must be preserved in the message as received. - Header names must be lowercased. -### _property_ `body` +### _abstract property_ `type` + +Connection type. + + +## _class_ `HTTPRequest` + +基类:`nonebot.drivers.HTTPConnection` + +HTTP 请求封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope)。 + + +### `method` + +The HTTP method name, uppercased. + + +### `body` Body of the request. Optional; if missing defaults to b"". -If more_body is set, treat as start of body and concatenate on further chunks. + +### _property_ `type` + +Always `http` ## _class_ `HTTPResponse` @@ -350,51 +347,40 @@ HTTP 响应封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/late HTTP status code. +### `body` + +HTTP body content. + +Optional; if missing defaults to `None`. + + ### `headers` -An iterable of [name, value] two-item iterables, -where name is the header name, -and value is the header value. +A dict of name-value pairs, +where name is the header name, and value is the header value. Order must be preserved in the HTTP response. Header names must be lowercased. -Optional; if missing defaults to an empty list. - - -### `body` - -HTTP body content. - -Optional; if missing defaults to None. +Optional; if missing defaults to an empty dict. ### _property_ `type` -Always http +Always `http` ## _class_ `WebSocket` -基类:`object` +基类:`nonebot.drivers.HTTPConnection`, `abc.ABC` WebSocket 连接封装。参考 [asgi websocket scope](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope)。 -### _abstract_ `__init__(websocket)` +### _property_ `type` - -* **参数** - - - * `websocket: Any`: WebSocket 连接对象 - - - -### _property_ `websocket` - -WebSocket 连接对象 +Always `websocket` ### _abstract property_ `closed` @@ -424,9 +410,19 @@ WebSocket 连接对象 ### _abstract async_ `receive()` -接收一条 WebSocket 信息 +接收一条 WebSocket text 信息 + + +### _abstract async_ `receive_bytes()` + +接收一条 WebSocket binary 信息 ### _abstract async_ `send(data)` -发送一条 WebSocket 信息 +发送一条 WebSocket text 信息 + + +### _abstract async_ `send_bytes(data)` + +发送一条 WebSocket text 信息 diff --git a/docs/api/drivers/fastapi.md b/docs/api/drivers/fastapi.md index 9f2a0be8..005f37b2 100644 --- a/docs/api/drivers/fastapi.md +++ b/docs/api/drivers/fastapi.md @@ -133,3 +133,8 @@ fastapi 使用的 logger ### `run(host=None, port=None, *, app=None, **kwargs)` 使用 `uvicorn` 启动 FastAPI + + +## _class_ `WebSocket` + +基类:[`nonebot.drivers.WebSocket`](README.md#nonebot.drivers.WebSocket) diff --git a/docs/api/drivers/quart.md b/docs/api/drivers/quart.md index 644c2390..35fbc26f 100644 --- a/docs/api/drivers/quart.md +++ b/docs/api/drivers/quart.md @@ -10,6 +10,28 @@ sidebarDepth: 0 后端使用方法请参考: [Quart 文档](https://pgjones.gitlab.io/quart/index.html) +## _class_ `Config` + +基类:`pydantic.env_settings.BaseSettings` + +Quart 驱动框架设置 + + +### `quart_reload_dirs` + + +* **类型** + + `List[str]` + + + +* **说明** + + `debug` 模式下重载监控文件夹列表,默认为 uvicorn 默认值 + + + ## _class_ `Driver` 基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver) @@ -44,7 +66,7 @@ Quart 驱动框架 ### _property_ `logger` -fastapi 使用的 logger +Quart 使用的 logger ### `on_startup(func)` diff --git a/docs/api/exception.md b/docs/api/exception.md index 817c02a9..f48a493b 100644 --- a/docs/api/exception.md +++ b/docs/api/exception.md @@ -132,27 +132,6 @@ sidebarDepth: 0 -## _exception_ `RequestDenied` - -基类:`nonebot.exception.NoneBotException` - - -* **说明** - - Bot 连接请求不合法。 - - - -* **参数** - - - * `status_code: int`: HTTP 状态码 - - - * `reason: str`: 拒绝原因 - - - ## _exception_ `AdapterException` 基类:`nonebot.exception.NoneBotException` diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 92cb5f41..39ff913d 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -11,13 +11,14 @@ from copy import copy from functools import reduce, partial from typing_extensions import Protocol from dataclasses import dataclass, field -from typing import (Any, Set, List, Dict, Union, TypeVar, Mapping, Optional, - Iterable, Awaitable, TYPE_CHECKING) +from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping, + Optional, Iterable, Awaitable, TYPE_CHECKING) from pydantic import BaseModel from nonebot.log import logger from nonebot.utils import DataclassEncoder +from nonebot.drivers import HTTPConnection, HTTPResponse from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook if TYPE_CHECKING: @@ -51,12 +52,7 @@ class Bot(abc.ABC): :说明: call_api 后执行的函数 """ - @abc.abstractmethod - def __init__(self, - connection_type: str, - self_id: str, - *, - websocket: Optional["WebSocket"] = None): + def __init__(self, self_id: str, request: HTTPConnection): """ :参数: @@ -64,12 +60,10 @@ class Bot(abc.ABC): * ``self_id: str``: 机器人 ID * ``websocket: Optional[WebSocket]``: Websocket 连接对象 """ - self.connection_type = connection_type - """连接类型""" - self.self_id = self_id + self.self_id: str = self_id """机器人 ID""" - self.websocket = websocket - """Websocket 连接对象""" + self.request: HTTPConnection = request + """连接信息""" def __getattr__(self, name: str) -> _ApiCall: return partial(self.call_api, name) @@ -92,8 +86,9 @@ class Bot(abc.ABC): @classmethod @abc.abstractmethod - async def check_permission(cls, driver: "Driver", connection_type: str, - headers: dict, body: Optional[bytes]) -> str: + async def check_permission( + cls, driver: "Driver", request: HTTPConnection + ) -> Tuple[Optional[str], Optional[HTTPResponse]]: """ :说明: @@ -108,7 +103,8 @@ class Bot(abc.ABC): :返回: - - ``str``: 连接唯一标识符 + - ``str``: 连接唯一标识符,``None`` 代表连接不合法 + - ``HTTPResponse``: HTTP 上报响应 :异常: @@ -117,7 +113,7 @@ class Bot(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def handle_message(self, message: dict): + async def handle_message(self, message: bytes): """ :说明: @@ -125,7 +121,7 @@ class Bot(abc.ABC): :参数: - * ``message: dict``: 收到的上报消息 + * ``message: bytes``: 收到的上报消息 """ raise NotImplementedError diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 7f6a4675..9cd87cdc 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -7,8 +7,8 @@ import abc import asyncio -from typing import (Any, Set, List, Dict, Type, Tuple, Optional, Callable, - MutableMapping, TYPE_CHECKING) +from dataclasses import dataclass, field +from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING from nonebot.log import logger from nonebot.config import Env, Config @@ -47,12 +47,12 @@ class Driver(abc.ABC): * ``env: Env``: 包含环境信息的 Env 对象 * ``config: Config``: 包含配置信息的 Config 对象 """ - self.env = env.environment + self.env: str = env.environment """ :类型: ``str`` :说明: 环境名称 """ - self.config = config + self.config: Config = config """ :类型: ``Config`` :说明: 配置对象 @@ -231,143 +231,101 @@ class ReverseDriver(Driver): raise NotImplementedError -class HTTPRequest: +@dataclass +class HTTPConnection(abc.ABC): + http_version: str + """One of `"1.0"`, `"1.1"` or `"2"`.""" + scheme: str + """URL scheme portion (likely `"http"` or `"https"`).""" + path: str + """ + HTTP request target excluding any query string, + with percent-encoded sequences and UTF-8 byte sequences + decoded into characters. + """ + query_string: bytes = b"" + """ URL portion after the `?`, percent-encoded.""" + headers: Dict[str, str] = field(default_factory=dict) + """A dict of name-value pairs, + where name is the header name, and value is the header value. + + Order of header values must be preserved from the original HTTP request; + order of header names is not important. + + Header names must be lowercased. + """ + + @property + @abc.abstractmethod + def type(self) -> str: + """Connection type.""" + raise NotImplementedError + + +@dataclass +class HTTPRequest(HTTPConnection): """HTTP 请求封装。参考 `asgi http scope`_。 .. _asgi http scope: https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope """ + method: str = "GET" + """The HTTP method name, uppercased.""" + body: bytes = b"" + """Body of the request. - def __init__(self, scope: MutableMapping[str, Any]): - self._scope = scope + Optional; if missing defaults to b"". + """ @property def type(self) -> str: - """Always `http`""" + """Always ``http``""" return "http" - @property - def scope(self) -> MutableMapping[str, Any]: - """Raw scope from asgi. - - The connection scope information, a dictionary that - contains at least a `type` key specifying the protocol that is incoming. - """ - return self._scope - - @property - def http_version(self) -> str: - """One of `"1.0"`, `"1.1"` or `"2"`.""" - raise self.scope["http_version"] - - @property - def method(self) -> str: - """The HTTP method name, uppercased.""" - raise self.scope["method"] - - @property - def schema(self) -> str: - """ - URL scheme portion (likely `"http"` or `"https"`). - Optional (but must not be empty); default is `"http"`. - """ - raise self.scope["schema"] - - @property - def path(self) -> str: - """ - HTTP request target excluding any query string, - with percent-encoded sequences and UTF-8 byte sequences - decoded into characters. - """ - return self.scope["path"] - - @property - def query_string(self) -> bytes: - """ URL portion after the `?`, percent-encoded.""" - return self.scope["query_string"] - - @property - def headers(self) -> List[Tuple[bytes, bytes]]: - """An iterable of [name, value] two-item iterables, - where name is the header name, and value is the header value. - - Order of header values must be preserved from the original HTTP request; - order of header names is not important. - - Duplicates are possible and must be preserved in the message as received. - - Header names must be lowercased. - """ - return list(self.scope["headers"]) - - @property - def body(self) -> bytes: - """Body of the request. - - Optional; if missing defaults to b"". - - If more_body is set, treat as start of body and concatenate on further chunks. - """ - return self.scope["body"] - +@dataclass class HTTPResponse: """HTTP 响应封装。参考 `asgi http scope`_。 .. _asgi http scope: https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope """ + status: int + """HTTP status code.""" + body: Optional[bytes] = None + """HTTP body content. - def __init__(self, - status: int, - headers: List[Tuple[bytes, bytes]] = [], - body: Optional[bytes] = None): - self.status: int = status - """HTTP status code.""" - self.headers: List[Tuple[bytes, bytes]] = headers - """An iterable of [name, value] two-item iterables, - where name is the header name, - and value is the header value. + Optional; if missing defaults to ``None``. + """ + headers: Dict[str, str] = field(default_factory=dict) + """A dict of name-value pairs, + where name is the header name, and value is the header value. - Order must be preserved in the HTTP response. + Order must be preserved in the HTTP response. - Header names must be lowercased. + Header names must be lowercased. - Optional; if missing defaults to an empty list. - """ - self.body: Optional[bytes] = body - """HTTP body content. - - Optional; if missing defaults to `None`. - """ + Optional; if missing defaults to an empty dict. + """ @property def type(self) -> str: - """Always `http`""" + """Always ``http``""" return "http" -class WebSocket: +@dataclass +class WebSocket(HTTPConnection, abc.ABC): """WebSocket 连接封装。参考 `asgi websocket scope`_。 .. _asgi websocket scope: https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope """ - @abc.abstractmethod - def __init__(self, websocket): - """ - :参数: - - * ``websocket: Any``: WebSocket 连接对象 - """ - self._websocket = websocket - @property - def websocket(self): - """WebSocket 连接对象""" - return self._websocket + def type(self) -> str: + """Always ``websocket``""" + return "websocket" @property @abc.abstractmethod @@ -389,11 +347,21 @@ class WebSocket: raise NotImplementedError @abc.abstractmethod - async def receive(self) -> dict: - """接收一条 WebSocket 信息""" + async def receive(self) -> str: + """接收一条 WebSocket text 信息""" raise NotImplementedError @abc.abstractmethod - async def send(self, data: dict): - """发送一条 WebSocket 信息""" + async def receive_bytes(self) -> bytes: + """接收一条 WebSocket binary 信息""" + raise NotImplementedError + + @abc.abstractmethod + async def send(self, data: str): + """发送一条 WebSocket text 信息""" + raise NotImplementedError + + @abc.abstractmethod + async def send_bytes(self, data: bytes): + """发送一条 WebSocket text 信息""" raise NotImplementedError diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 28a2b19a..a37286e6 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -8,23 +8,22 @@ FastAPI 驱动适配 https://fastapi.tiangolo.com/ """ -import json import asyncio import logging +from dataclasses import dataclass from typing import List, Optional, Callable import uvicorn from pydantic import BaseSettings from fastapi.responses import Response from fastapi import status, Request, FastAPI, HTTPException -from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket +from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket + as FastAPIWebSocket) from nonebot.log import logger from nonebot.typing import overrides -from nonebot.utils import DataclassEncoder -from nonebot.exception import RequestDenied from nonebot.config import Env, Config as NoneBotConfig -from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket +from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket class Config(BaseSettings): @@ -179,11 +178,6 @@ class Driver(ReverseDriver): @overrides(ReverseDriver) async def _handle_http(self, adapter: str, request: Request): data = await request.body() - data_dict = json.loads(data.decode()) - - if not isinstance(data_dict, dict): - logger.warning("Data received is invalid") - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) if adapter not in self._adapters: logger.warning( @@ -194,27 +188,34 @@ class Driver(ReverseDriver): # 创建 Bot 对象 BotClass = self._adapters[adapter] - headers = dict(request.headers) - try: - x_self_id = await BotClass.check_permission(self, "http", headers, - data) - except RequestDenied as e: - raise HTTPException(status_code=e.status_code, - detail=e.reason) from None + http_request = HTTPRequest(request.scope["http_version"], + request.url.scheme, request.url.path, + request.scope["query_string"], + dict(request.headers), request.method, data) + x_self_id, response = await BotClass.check_permission( + self, http_request) + + if not x_self_id: + raise HTTPException(response and response.status or 401, + response.body) if x_self_id in self._clients: logger.warning("There's already a reverse websocket connection," "so the event may be handled twice.") - bot = BotClass("http", x_self_id) + bot = BotClass(x_self_id, http_request) - asyncio.create_task(bot.handle_message(data_dict)) - return Response("", 204) + asyncio.create_task(bot.handle_message(data)) + return Response(response and response.body, + response and response.status or 200) @overrides(ReverseDriver) async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket): - ws = WebSocket(websocket) + ws = WebSocket(websocket.scope.get("http_version", + "1.1"), websocket.url.scheme, + websocket.url.path, websocket.scope["query_string"], + dict(websocket.headers), websocket) if adapter not in self._adapters: logger.warning( @@ -225,11 +226,9 @@ class Driver(ReverseDriver): # Create Bot Object BotClass = self._adapters[adapter] - headers = dict(websocket.headers) - try: - x_self_id = await BotClass.check_permission(self, "websocket", - headers, None) - except RequestDenied: + x_self_id, _ = await BotClass.check_permission(self, ws) + + if not x_self_id: await ws.close(code=status.WS_1008_POLICY_VIOLATION) return @@ -240,7 +239,7 @@ class Driver(ReverseDriver): await ws.close(code=status.WS_1008_POLICY_VIOLATION) return - bot = BotClass("websocket", x_self_id, websocket=ws) + bot = BotClass(x_self_id, ws) await ws.accept() logger.opt(colors=True).info( @@ -251,54 +250,51 @@ class Driver(ReverseDriver): try: while not ws.closed: - data = await ws.receive() + try: + data = await ws.receive() + except WebSocketDisconnect: + logger.error("WebSocket disconnected by peer.") + break + except Exception as e: + logger.opt(exception=e).error( + "Error when receiving data from websocket.") + break - if not data: - continue - - asyncio.create_task(bot.handle_message(data)) + asyncio.create_task(bot.handle_message(data.encode())) finally: self._bot_disconnect(bot) +@dataclass class WebSocket(BaseWebSocket): - - def __init__(self, websocket: FastAPIWebSocket): - super().__init__(websocket) - self._closed = False + websocket: FastAPIWebSocket = None # type: ignore @property @overrides(BaseWebSocket) def closed(self): - return self._closed + return (self.websocket.client_state == WebSocketState.DISCONNECTED or + self.websocket.application_state == WebSocketState.DISCONNECTED) @overrides(BaseWebSocket) async def accept(self): await self.websocket.accept() - self._closed = False @overrides(BaseWebSocket) async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): await self.websocket.close(code=code) - self._closed = True @overrides(BaseWebSocket) - async def receive(self) -> Optional[dict]: - data = None - try: - data = await self.websocket.receive_json() - if not isinstance(data, dict): - data = None - raise ValueError - except ValueError: - logger.warning("Received an invalid json message.") - except WebSocketDisconnect: - self._closed = True - logger.error("WebSocket disconnected by peer.") - - return data + async def receive(self) -> str: + return await self.websocket.receive_text() @overrides(BaseWebSocket) - async def send(self, data: dict) -> None: - text = json.dumps(data, cls=DataclassEncoder) - await self.websocket.send({"type": "websocket.send", "text": text}) + async def receive_bytes(self) -> bytes: + return await self.websocket.receive_bytes() + + @overrides(BaseWebSocket) + async def send(self, data: str) -> None: + await self.websocket.send({"type": "websocket.send", "text": data}) + + @overrides(BaseWebSocket) + async def send_bytes(self, data: bytes) -> None: + await self.websocket.send({"type": "websocket.send", "bytes": data}) diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index cd1ae575..d36de19d 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -9,24 +9,22 @@ Quart 驱动适配 """ import asyncio -from json.decoder import JSONDecodeError -from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar +from typing import List, TypeVar, Callable, Coroutine, Optional import uvicorn +from pydantic import BaseSettings -from nonebot.config import Config as NoneBotConfig -from nonebot.config import Env -from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket -from nonebot.exception import RequestDenied from nonebot.log import logger from nonebot.typing import overrides +from nonebot.config import Env, Config as NoneBotConfig +from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket try: - from quart import Quart, Request, Response - from quart import Websocket as QuartWebSocket from quart import exceptions from quart import request as _request from quart import websocket as _websocket + from quart import Quart, Request, Response + from quart import Websocket as QuartWebSocket except ImportError: raise ValueError( 'Please install Quart by using `pip install nonebot2[quart]`') @@ -34,6 +32,25 @@ except ImportError: _AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) +class Config(BaseSettings): + """ + Quart 驱动框架设置 + """ + quart_reload_dirs: List[str] = [] + """ + :类型: + + ``List[str]`` + + :说明: + + ``debug`` 模式下重载监控文件夹列表,默认为 uvicorn 默认值 + """ + + class Config: + extra = "ignore" + + class Driver(ReverseDriver): """ Quart 驱动框架 @@ -48,18 +65,20 @@ class Driver(ReverseDriver): def __init__(self, env: Env, config: NoneBotConfig): super().__init__(env, config) + self.quart_config = Config(**config.dict()) + self._server_app = Quart(self.__class__.__qualname__) - self._server_app.add_url_rule('//http', - methods=['POST'], + self._server_app.add_url_rule("//http", + methods=["POST"], view_func=self._handle_http) - self._server_app.add_websocket('//ws', + self._server_app.add_websocket("//ws", view_func=self._handle_ws_reverse) @property @overrides(ReverseDriver) def type(self) -> str: """驱动名称: ``quart``""" - return 'quart' + return "quart" @property @overrides(ReverseDriver) @@ -76,17 +95,21 @@ class Driver(ReverseDriver): @property @overrides(ReverseDriver) def logger(self): - """fastapi 使用的 logger""" + """Quart 使用的 logger""" return self._server_app.logger @overrides(ReverseDriver) def on_startup(self, func: _AsyncCallable) -> _AsyncCallable: - """参考文档: `Startup and Shutdown `_""" + """参考文档: `Startup and Shutdown`_ + + .. _Startup and Shutdown: + https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html + """ return self.server_app.before_serving(func) # type: ignore @overrides(ReverseDriver) def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: - """参考文档: `Startup and Shutdown `_""" + """参考文档: `Startup and Shutdown`_""" return self.server_app.after_serving(func) # type: ignore @overrides(ReverseDriver) @@ -121,6 +144,7 @@ class Driver(ReverseDriver): host=host or str(self.config.host), port=port or self.config.port, reload=bool(app) and self.config.debug, + reload_dirs=self.quart_config.quart_reload_dirs or None, debug=self.config.debug, log_config=LOGGING_CONFIG, **kwargs) @@ -128,11 +152,7 @@ class Driver(ReverseDriver): @overrides(ReverseDriver) async def _handle_http(self, adapter: str): request: Request = _request - - try: - data: Dict[str, Any] = await request.get_json() - except Exception as e: - raise exceptions.BadRequest() + data: bytes = await request.get_data() # type: ignore if adapter not in self._adapters: logger.warning(f'Unknown adapter {adapter}. ' @@ -140,25 +160,32 @@ class Driver(ReverseDriver): raise exceptions.NotFound() BotClass = self._adapters[adapter] - headers = {k: v for k, v in request.headers.items(lower=True)} + http_request = HTTPRequest(request.http_version, request.scheme, + request.path, request.query_string, + dict(request.headers), request.method, data) - try: - self_id = await BotClass.check_permission(self, 'http', headers, - data) - except RequestDenied as e: - raise exceptions.HTTPException(status_code=e.status_code, - description=e.reason, - name='Request Denied') + self_id, response = await BotClass.check_permission(self, http_request) + + if not self_id: + raise exceptions.HTTPException( + response and response.status or 401, + description=(response and response.body or b"").decode(), + name="Request Denied") if self_id in self._clients: logger.warning("There's already a reverse websocket connection," "so the event may be handled twice.") - bot = BotClass('http', self_id) + bot = BotClass(self_id, http_request) asyncio.create_task(bot.handle_message(data)) - return Response('', 204) + return Response(response and response.body or "", + response and response.status or 200) @overrides(ReverseDriver) async def _handle_ws_reverse(self, adapter: str): websocket: QuartWebSocket = _websocket + ws = WebSocket(websocket.http_version, websocket.scheme, + websocket.path, websocket.query_string, + dict(websocket.headers), websocket) + if adapter not in self._adapters: logger.warning( f'Unknown adapter {adapter}. Please register the adapter before use.' @@ -166,19 +193,23 @@ class Driver(ReverseDriver): raise exceptions.NotFound() BotClass = self._adapters[adapter] - headers = {k: v for k, v in websocket.headers.items(lower=True)} - try: - self_id = await BotClass.check_permission(self, 'websocket', - headers, None) - except RequestDenied as e: - raise exceptions.HTTPException(status_code=e.status_code, - description=e.reason, - name='Request Denied') + self_id, response = await BotClass.check_permission(self, ws) + + if not self_id: + raise exceptions.HTTPException( + response and response.status or 401, + description=(response and response.body or b"").decode(), + name="Request Denied") + if self_id in self._clients: - logger.warning("There's already a reverse websocket connection," - "so the event may be handled twice.") - ws = WebSocket(websocket) - bot = BotClass('websocket', self_id, websocket=ws) + logger.opt(colors=True).warning( + "There's already a reverse websocket connection, " + f"{adapter.upper()} Bot {self_id} ignored.") + raise exceptions.HTTPException(403, + description="Client already exists", + name="Request Denied") + + bot = BotClass(self_id, ws) await ws.accept() logger.opt(colors=True).info( f"WebSocket Connection from {adapter.upper()} " @@ -187,52 +218,51 @@ class Driver(ReverseDriver): try: while not ws.closed: - data = await ws.receive() - if data is None: - continue - asyncio.create_task(bot.handle_message(data)) + try: + data = await ws.receive() + except asyncio.CancelledError: + logger.warning("WebSocket disconnected by peer.") + break + except Exception as e: + logger.opt(exception=e).error( + "Error when receiving data from websocket.") + break + + asyncio.create_task(bot.handle_message(data.encode())) finally: self._bot_disconnect(bot) class WebSocket(BaseWebSocket): - - @overrides(BaseWebSocket) - def __init__(self, websocket: QuartWebSocket): - super().__init__(websocket) - self._closed = False - - @property - @overrides(BaseWebSocket) - def websocket(self) -> QuartWebSocket: - return self._websocket + websocket: QuartWebSocket = None # type: ignore @property @overrides(BaseWebSocket) def closed(self): - return self._closed + # FIXME + return False @overrides(BaseWebSocket) async def accept(self): await self.websocket.accept() - self._closed = False @overrides(BaseWebSocket) async def close(self): - self._closed = True + # FIXME + pass @overrides(BaseWebSocket) - async def receive(self) -> Optional[Dict[str, Any]]: - data: Optional[Dict[str, Any]] = None - try: - data = await self.websocket.receive_json() - except JSONDecodeError: - logger.warning('Received an invalid json message.') - except asyncio.CancelledError: - self._closed = True - logger.warning('WebSocket disconnected by peer.') - return data + async def receive(self) -> str: + return await self.websocket.receive() # type: ignore @overrides(BaseWebSocket) - async def send(self, data: dict): - await self.websocket.send_json(data) + async def receive_bytes(self) -> bytes: + return await self.websocket.receive() # type: ignore + + @overrides(BaseWebSocket) + async def send(self, data: str): + await self.websocket.send(data) + + @overrides(BaseWebSocket) + async def send_bytes(self, data: bytes): + await self.websocket.send(data) diff --git a/nonebot/exception.py b/nonebot/exception.py index 7eaab701..3cad317a 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -115,29 +115,6 @@ class StopPropagation(NoneBotException): pass -class RequestDenied(NoneBotException): - """ - :说明: - - Bot 连接请求不合法。 - - :参数: - - * ``status_code: int``: HTTP 状态码 - * ``reason: str``: 拒绝原因 - """ - - def __init__(self, status_code: int, reason: str): - self.status_code = status_code - self.reason = reason - - def __repr__(self): - return f"" - - def __str__(self): - return self.__repr__() - - class AdapterException(NoneBotException): """ :说明: diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py index fb7299ec..f1fc4b64 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py @@ -3,14 +3,15 @@ import sys import hmac import json import asyncio -from typing import Any, Dict, Union, Optional, TYPE_CHECKING +from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING import httpx from nonebot.log import logger from nonebot.typing import overrides from nonebot.message import handle_event +from nonebot.utils import DataclassEncoder from nonebot.adapters import Bot as BaseBot -from nonebot.exception import RequestDenied +from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket from .utils import log, escape from .config import Config as CQHTTPConfig @@ -20,7 +21,6 @@ from .exception import NetworkError, ApiNotAvailable, ActionFailed if TYPE_CHECKING: from nonebot.config import Config - from nonebot.drivers import Driver, WebSocket def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]: @@ -28,7 +28,7 @@ def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]: return None scheme, _, param = access_token.partition(" ") if scheme.lower() not in ["bearer", "token"]: - raise RequestDenied(401, "Not authenticated") + return None return param @@ -225,14 +225,6 @@ class Bot(BaseBot): """ cqhttp_config: CQHTTPConfig - def __init__(self, - connection_type: str, - self_id: str, - *, - websocket: Optional["WebSocket"] = None): - - super().__init__(connection_type, self_id, websocket=websocket) - @property @overrides(BaseBot) def type(self) -> str: @@ -242,84 +234,84 @@ class Bot(BaseBot): return "cqhttp" @classmethod - def register(cls, driver: "Driver", config: "Config"): + def register(cls, driver: Driver, config: "Config"): super().register(driver, config) cls.cqhttp_config = CQHTTPConfig(**config.dict()) @classmethod @overrides(BaseBot) - async def check_permission(cls, driver: "Driver", connection_type: str, - headers: dict, body: Optional[bytes]) -> str: + async def check_permission( + cls, driver: Driver, + request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]: """ :说明: CQHTTP (OneBot) 协议鉴权。参考 `鉴权 `_ """ - x_self_id = headers.get("x-self-id") - x_signature = headers.get("x-signature") - token = get_auth_bearer(headers.get("authorization")) + x_self_id = request.headers.get("x-self-id") + x_signature = request.headers.get("x-signature") + token = get_auth_bearer(request.headers.get("authorization")) cqhttp_config = CQHTTPConfig(**driver.config.dict()) - # 检查连接方式 - if connection_type not in ["http", "websocket"]: - log("WARNING", "Unsupported connection type") - raise RequestDenied(405, "Unsupported connection type") - # 检查self_id if not x_self_id: log("WARNING", "Missing X-Self-ID Header") - raise RequestDenied(400, "Missing X-Self-ID Header") + return None, HTTPResponse(400, b"Missing X-Self-ID Header") # 检查签名 secret = cqhttp_config.secret - if secret and connection_type == "http": + if secret and isinstance(request, HTTPRequest): if not x_signature: log("WARNING", "Missing Signature Header") - raise RequestDenied(401, "Missing Signature") - sig = hmac.new(secret.encode("utf-8"), body, "sha1").hexdigest() + return None, HTTPResponse(401, b"Missing Signature") + sig = hmac.new(secret.encode("utf-8"), request.body, + "sha1").hexdigest() if x_signature != "sha1=" + sig: log("WARNING", "Signature Header is invalid") - raise RequestDenied(403, "Signature is invalid") + return None, HTTPResponse(403, b"Signature is invalid") access_token = cqhttp_config.access_token - if access_token and access_token != token and connection_type == "websocket": + if access_token and access_token != token and isinstance( + request, WebSocket): log( "WARNING", "Authorization Header is invalid" if token else "Missing Authorization Header") - raise RequestDenied( - 403, "Authorization Header is invalid" - if token else "Missing Authorization Header") - return str(x_self_id) + return None, HTTPResponse( + 403, b"Authorization Header is invalid" + if token else b"Missing Authorization Header") + return str(x_self_id), HTTPResponse(204, b'') @overrides(BaseBot) - async def handle_message(self, message: dict): + async def handle_message(self, message: bytes): """ :说明: 调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_ """ - if not message: + data = json.loads(message) + + if not data: return - if "post_type" not in message: - ResultStore.add_result(message) + if "post_type" not in data: + ResultStore.add_result(data) return try: - post_type = message['post_type'] - detail_type = message.get(f"{post_type}_type") + post_type = data['post_type'] + detail_type = data.get(f"{post_type}_type") detail_type = f".{detail_type}" if detail_type else "" - sub_type = message.get("sub_type") + sub_type = data.get("sub_type") sub_type = f".{sub_type}" if sub_type else "" models = get_event_model(post_type + detail_type + sub_type) for model in models: try: - event = model.parse_obj(message) + event = model.parse_obj(data) break except Exception as e: log("DEBUG", "Event Parser Error", e) else: - event = Event.parse_obj(message) + event = Event.parse_obj(data) # Check whether user is calling me await _check_reply(self, event) @@ -329,25 +321,28 @@ class Bot(BaseBot): await handle_event(self, event) except Exception as e: logger.opt(colors=True, exception=e).error( - f"Failed to handle event. Raw: {message}" + f"Failed to handle event. Raw: {data}" ) @overrides(BaseBot) async def _call_api(self, api: str, **data) -> Any: log("DEBUG", f"Calling API {api}") - if self.connection_type == "websocket": + if isinstance(self.request, WebSocket): seq = ResultStore.get_seq() - await self.websocket.send({ - "action": api, - "params": data, - "echo": { - "seq": seq - } - }) + json_data = json.dumps( + { + "action": api, + "params": data, + "echo": { + "seq": seq + } + }, + cls=DataclassEncoder) + await self.request.send(json_data) return _handle_api_result(await ResultStore.fetch( seq, self.config.api_timeout)) - elif self.connection_type == "http": + elif isinstance(self.request, HTTPRequest): api_root = self.config.api_root.get(self.self_id) if not api_root: raise ApiNotAvailable @@ -431,7 +426,7 @@ class Bot(BaseBot): message, str) else message msg = message if isinstance(message, Message) else Message(message) - at_sender = at_sender and getattr(event, "user_id", None) + at_sender = at_sender and bool(getattr(event, "user_id", None)) params = {} if getattr(event, "user_id", None): @@ -449,8 +444,7 @@ class Bot(BaseBot): raise ValueError("Cannot guess message type to reply!") if at_sender and params["message_type"] != "private": - params["message"] = MessageSegment.at(params["user_id"]) + \ - MessageSegment.text(" ") + msg + params["message"] = MessageSegment.at(params["user_id"]) + " " + msg else: params["message"] = msg return await self.send_msg(**params) diff --git a/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py b/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py index 8f2880a1..4e13c970 100644 --- a/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py +++ b/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py @@ -1,16 +1,17 @@ import json import urllib.parse -from datetime import datetime import time -from typing import Any, Union, Optional, TYPE_CHECKING +from datetime import datetime +from typing import Any, Tuple, Union, Optional, TYPE_CHECKING import httpx + from nonebot.log import logger from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot -from nonebot.exception import RequestDenied +from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse from .utils import calc_hmac_base64, log from .config import Config as DingConfig @@ -20,7 +21,6 @@ from .event import MessageEvent, PrivateMessageEvent, GroupMessageEvent, Convers if TYPE_CHECKING: from nonebot.config import Config - from nonebot.drivers import Driver SEND = "send" @@ -31,10 +31,6 @@ class Bot(BaseBot): """ ding_config: DingConfig - def __init__(self, connection_type: str, self_id: str, **kwargs): - - super().__init__(connection_type, self_id, **kwargs) - @property def type(self) -> str: """ @@ -43,57 +39,61 @@ class Bot(BaseBot): return "ding" @classmethod - def register(cls, driver: "Driver", config: "Config"): + def register(cls, driver: Driver, config: "Config"): super().register(driver, config) cls.ding_config = DingConfig(**config.dict()) @classmethod @overrides(BaseBot) - async def check_permission(cls, driver: "Driver", connection_type: str, - headers: dict, body: Optional[bytes]) -> str: + async def check_permission( + cls, driver: Driver, + request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]: """ :说明: 钉钉协议鉴权。参考 `鉴权 `_ """ - timestamp = headers.get("timestamp") - sign = headers.get("sign") + timestamp = request.headers.get("timestamp") + sign = request.headers.get("sign") # 检查连接方式 - if connection_type not in ["http"]: - raise RequestDenied( - 405, "Unsupported connection type, available type: `http`") + if not isinstance(request, HTTPRequest): + return None, HTTPResponse( + 405, b"Unsupported connection type, available type: `http`") # 检查 timestamp if not timestamp: - raise RequestDenied(400, "Missing `timestamp` Header") + return None, HTTPResponse(400, b"Missing `timestamp` Header") # 检查 sign secret = cls.ding_config.secret if secret: if not sign: log("WARNING", "Missing Signature Header") - raise RequestDenied(400, "Missing `sign` Header") + return None, HTTPResponse(400, b"Missing `sign` Header") sign_base64 = calc_hmac_base64(str(timestamp), secret) if sign != sign_base64.decode('utf-8'): log("WARNING", "Signature Header is invalid") - raise RequestDenied(403, "Signature is invalid") + return None, HTTPResponse(403, b"Signature is invalid") else: log("WARNING", "Ding signature check ignored!") - return json.loads(body.decode())["chatbotUserId"] + return (json.loads(request.body.decode())["chatbotUserId"], + HTTPResponse(204, b'')) @overrides(BaseBot) - async def handle_message(self, message: dict): - if not message: + async def handle_message(self, message: bytes): + data = json.loads(message) + + if not data: return # 判断消息类型,生成不同的 Event try: - conversation_type = message["conversationType"] + conversation_type = data["conversationType"] if conversation_type == ConversationType.private: - event = PrivateMessageEvent.parse_obj(message) + event = PrivateMessageEvent.parse_obj(data) elif conversation_type == ConversationType.group: - event = GroupMessageEvent.parse_obj(message) + event = GroupMessageEvent.parse_obj(data) else: raise ValueError("Unsupported conversation type") except Exception as e: @@ -104,7 +104,7 @@ class Bot(BaseBot): await handle_event(self, event) except Exception as e: logger.opt(colors=True, exception=e).error( - f"Failed to handle event. Raw: {message}" + f"Failed to handle event. Raw: {data}" ) return diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py index 2b09e365..5adc7a16 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/__init__.py @@ -1,8 +1,8 @@ -""" +r""" Mirai-API-HTTP 协议适配 ============================ -协议详情请看: `mirai-api-http 文档`_ +协议详情请看: `mirai-api-http 文档`_ \:\:\: tip 该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试 diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py index ebce2d74..fe7dc4bc 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py @@ -5,11 +5,11 @@ from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union import httpx -from nonebot.adapters import Bot as BaseBot from nonebot.config import Config -from nonebot.drivers import Driver, WebSocket -from nonebot.exception import ApiNotAvailable, RequestDenied from nonebot.typing import overrides +from nonebot.adapters import Bot as BaseBot +from nonebot.exception import ApiNotAvailable +from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket from .config import Config as MiraiConfig from .event import Event, FriendMessage, GroupMessage, TempMessage @@ -140,7 +140,7 @@ class SessionManager: class Bot(BaseBot): - """ + r""" mirai-api-http 协议 Bot 适配。 \:\:\: warning @@ -151,14 +151,6 @@ class Bot(BaseBot): """ - @overrides(BaseBot) - def __init__(self, - connection_type: str, - self_id: str, - *, - websocket: Optional[WebSocket] = None): - super().__init__(connection_type, self_id, websocket=websocket) - @property @overrides(BaseBot) def type(self) -> str: @@ -166,7 +158,8 @@ class Bot(BaseBot): @property def alive(self) -> bool: - return not self.websocket.closed + assert isinstance(self.request, WebSocket) + return not self.request.closed @property def api(self) -> SessionManager: @@ -177,27 +170,26 @@ class Bot(BaseBot): @classmethod @overrides(BaseBot) - async def check_permission(cls, driver: "Driver", connection_type: str, - headers: dict, body: Optional[bytes]) -> str: - if connection_type == 'ws': - raise RequestDenied( - status_code=501, - reason='Websocket connection is not implemented') - self_id: Optional[str] = headers.get('bot') + async def check_permission( + cls, driver: Driver, + request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]: + if isinstance(request, WebSocket): + return None, HTTPResponse( + 501, b'Websocket connection is not implemented') + self_id: Optional[str] = request.headers.get('bot') if self_id is None: - raise RequestDenied(status_code=400, - reason='Header `Bot` is required.') + return None, HTTPResponse(400, b'Header `Bot` is required.') self_id = str(self_id).strip() await SessionManager.new( int(self_id), host=cls.mirai_config.host, # type: ignore port=cls.mirai_config.port, #type: ignore auth_key=cls.mirai_config.auth_key) # type: ignore - return self_id + return self_id, HTTPResponse(204, b'') @classmethod @overrides(BaseBot) - def register(cls, driver: "Driver", config: "Config"): + def register(cls, driver: Driver, config: "Config"): cls.mirai_config = MiraiConfig(**config.dict()) if (cls.mirai_config.auth_key and cls.mirai_config.host and cls.mirai_config.port) is None: @@ -224,7 +216,7 @@ class Bot(BaseBot): @overrides(BaseBot) async def call_api(self, api: str, **data) -> NoReturn: - """ + r""" \:\:\: danger 由于Mirai的HTTP API特殊性, 该API暂时无法实现 \:\:\: diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py index 7f990183..c7139772 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py @@ -1,18 +1,16 @@ -import asyncio import json +import asyncio +from dataclasses import dataclass from ipaddress import IPv4Address -from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set, - TypeVar) +from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine import httpx import websockets -from nonebot.config import Config -from nonebot.drivers import Driver -from nonebot.drivers import WebSocket as BaseWebSocket -from nonebot.exception import RequestDenied from nonebot.log import logger +from nonebot.config import Config from nonebot.typing import overrides +from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket from .bot import SessionManager, Bot @@ -21,7 +19,9 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T', bound=WebsocketHandlerFunction) +@dataclass class WebSocket(BaseWebSocket): + websocket: websockets.WebSocketClientProtocol = None # type: ignore @classmethod async def new(cls, *, host: IPv4Address, port: int, @@ -37,24 +37,26 @@ class WebSocket(BaseWebSocket): self.event_handlers: Set[WebsocketHandlerFunction] = set() super().__init__(websocket) - @property - @overrides(BaseWebSocket) - def websocket(self) -> websockets.WebSocketClientProtocol: - return self._websocket - @property @overrides(BaseWebSocket) def closed(self) -> bool: return self.websocket.closed @overrides(BaseWebSocket) - async def send(self, data: Dict[str, Any]): - return await self.websocket.send(json.dumps(data)) + async def send(self, data: str): + return await self.websocket.send(data) @overrides(BaseWebSocket) - async def receive(self) -> Dict[str, Any]: - received = await self.websocket.recv() - return json.loads(received) + async def send_bytes(self, data: str): + return await self.websocket.send(data) + + @overrides(BaseWebSocket) + async def receive(self) -> str: + return await self.websocket.recv() # type: ignore + + @overrides(BaseWebSocket) + async def receive_bytes(self) -> bytes: + return await self.websocket.recv() # type: ignore async def _dispatcher(self): while not self.closed: @@ -93,11 +95,6 @@ class WebsocketBot(Bot): mirai-api-http 正向 Websocket 协议 Bot 适配。 """ - @overrides(Bot) - def __init__(self, connection_type: str, self_id: str, *, - websocket: WebSocket): - super().__init__(connection_type, self_id, websocket=websocket) - @property @overrides(Bot) def type(self) -> str: @@ -105,7 +102,8 @@ class WebsocketBot(Bot): @property def alive(self) -> bool: - return not self.websocket.closed + assert isinstance(self.request, WebSocket) + return not self.request.closed @property def api(self) -> SessionManager: @@ -115,16 +113,14 @@ class WebsocketBot(Bot): @classmethod @overrides(Bot) - async def check_permission(cls, driver: "Driver", connection_type: str, - headers: dict, - body: Optional[bytes]) -> NoReturn: - raise RequestDenied( - status_code=501, - reason=f'Connection {connection_type} not implented') + async def check_permission( + cls, driver: Driver, + request: HTTPConnection) -> Tuple[None, HTTPResponse]: + return None, HTTPResponse(501, b'Connection not implented') @classmethod @overrides(Bot) - def register(cls, driver: "Driver", config: "Config", qq: int): + def register(cls, driver: Driver, config: "Config", qq: int): """ :说明: