🔀 Merge pull request #406

Feature: 支持自定义 Response
This commit is contained in:
Ju4tCode 2021-06-14 14:14:43 +08:00 committed by GitHub
commit e9bc98e74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 465 additions and 511 deletions

View File

@ -57,7 +57,7 @@ Config 配置对象
### _abstract_ `__init__(connection_type, self_id, *, websocket=None)` ### `__init__(self_id, request)`
* **参数** * **参数**
@ -73,19 +73,14 @@ Config 配置对象
### `connection_type`
连接类型
### `self_id` ### `self_id`
机器人 ID 机器人 ID
### `websocket` ### `request`
Websocket 连接对象 连接信息
### _abstract property_ `type` ### _abstract property_ `type`
@ -102,7 +97,7 @@ Adapter 类型
### _abstract async classmethod_ `check_permission(driver, connection_type, headers, body)` ### _abstract async classmethod_ `check_permission(driver, request)`
* **说明** * **说明**
@ -130,7 +125,10 @@ Adapter 类型
* **返回** * **返回**
* `str`: 连接唯一标识符 * `str`: 连接唯一标识符,`None` 代表连接不合法
* `HTTPResponse`: HTTP 上报响应
@ -153,7 +151,7 @@ Adapter 类型
* **参数** * **参数**
* `message: dict`: 收到的上报消息 * `message: bytes`: 收到的上报消息

View File

@ -204,7 +204,7 @@ CQHTTP 协议 Bot 适配。继承属性参考 [BaseBot](./#class-basebot) 。
* 返回: `"cqhttp"` * 返回: `"cqhttp"`
### _async classmethod_ `check_permission(driver, connection_type, headers, body)` ### _async classmethod_ `check_permission(driver, request)`
* **说明** * **说明**

View File

@ -105,7 +105,7 @@ sidebarDepth: 0
* 返回: `"ding"` * 返回: `"ding"`
### _async classmethod_ `check_permission(driver, connection_type, headers, body)` ### _async classmethod_ `check_permission(driver, request)`
* **说明** * **说明**

View File

@ -682,6 +682,11 @@ API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则
# NoneBot.adapters.mirai.bot_ws 模块 # NoneBot.adapters.mirai.bot_ws 模块
## _class_ `WebSocket`
基类:[`nonebot.drivers.WebSocket`](../drivers/README.md#nonebot.drivers.WebSocket)
## _class_ `WebsocketBot` ## _class_ `WebsocketBot`
基类:`nonebot.adapters.mirai.bot.Bot` 基类:`nonebot.adapters.mirai.bot.Bot`

View File

@ -268,74 +268,71 @@ Reverse Driver 基类。将后端框架封装,以满足适配器使用。
用于处理 WebSocket 类型请求的函数 用于处理 WebSocket 类型请求的函数
## _class_ `HTTPRequest` ## _class_ `HTTPConnection`
基类:`object` 基类:`abc.ABC`
HTTP 请求封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope)。
### _property_ `type` ### `http_version`
Always http
### _property_ `scope`
Raw scope from asgi.
The connection scope information, a dictionary that
contains at least a type key specifying the protocol that is incoming.
### _property_ `http_version`
One of "1.0", "1.1" or "2". One of "1.0", "1.1" or "2".
### _property_ `method` ### `scheme`
The HTTP method name, uppercased.
### _property_ `schema`
URL scheme portion (likely "http" or "https"). URL scheme portion (likely "http" or "https").
Optional (but must not be empty); default is "http".
### _property_ `path` ### `path`
HTTP request target excluding any query string, HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences with percent-encoded sequences and UTF-8 byte sequences
decoded into characters. decoded into characters.
### _property_ `query_string` ### `query_string`
URL portion after the ?, percent-encoded. URL portion after the ?, percent-encoded.
### _property_ `headers` ### `headers`
An iterable of [name, value] two-item iterables, A dict of name-value pairs,
where name is the header name, and value is the header value. where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request; Order of header values must be preserved from the original HTTP request;
order of header names is not important. order of header names is not important.
Duplicates are possible and must be preserved in the message as received.
Header names must be lowercased. Header names must be lowercased.
### _property_ `body` ### _abstract property_ `type`
Connection type.
## _class_ `HTTPRequest`
基类:`nonebot.drivers.HTTPConnection`
HTTP 请求封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope)。
### `method`
The HTTP method name, uppercased.
### `body`
Body of the request. Body of the request.
Optional; if missing defaults to b"". Optional; if missing defaults to b"".
If more_body is set, treat as start of body and concatenate on further chunks.
### _property_ `type`
Always `http`
## _class_ `HTTPResponse` ## _class_ `HTTPResponse`
@ -350,51 +347,40 @@ HTTP 响应封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/late
HTTP status code. HTTP status code.
### `body`
HTTP body content.
Optional; if missing defaults to `None`.
### `headers` ### `headers`
An iterable of [name, value] two-item iterables, A dict of name-value pairs,
where name is the header name, where name is the header name, and value is the header value.
and value is the header value.
Order must be preserved in the HTTP response. Order must be preserved in the HTTP response.
Header names must be lowercased. Header names must be lowercased.
Optional; if missing defaults to an empty list. Optional; if missing defaults to an empty dict.
### `body`
HTTP body content.
Optional; if missing defaults to None.
### _property_ `type` ### _property_ `type`
Always http Always `http`
## _class_ `WebSocket` ## _class_ `WebSocket`
基类:`object` 基类:`nonebot.drivers.HTTPConnection`, `abc.ABC`
WebSocket 连接封装。参考 [asgi websocket scope](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope)。 WebSocket 连接封装。参考 [asgi websocket scope](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope)。
### _abstract_ `__init__(websocket)` ### _property_ `type`
Always `websocket`
* **参数**
* `websocket: Any`: WebSocket 连接对象
### _property_ `websocket`
WebSocket 连接对象
### _abstract property_ `closed` ### _abstract property_ `closed`
@ -424,9 +410,19 @@ WebSocket 连接对象
### _abstract async_ `receive()` ### _abstract async_ `receive()`
接收一条 WebSocket 信息 接收一条 WebSocket text 信息
### _abstract async_ `receive_bytes()`
接收一条 WebSocket binary 信息
### _abstract async_ `send(data)` ### _abstract async_ `send(data)`
发送一条 WebSocket 信息 发送一条 WebSocket text 信息
### _abstract async_ `send_bytes(data)`
发送一条 WebSocket text 信息

View File

@ -133,3 +133,8 @@ fastapi 使用的 logger
### `run(host=None, port=None, *, app=None, **kwargs)` ### `run(host=None, port=None, *, app=None, **kwargs)`
使用 `uvicorn` 启动 FastAPI 使用 `uvicorn` 启动 FastAPI
## _class_ `WebSocket`
基类:[`nonebot.drivers.WebSocket`](README.md#nonebot.drivers.WebSocket)

View File

@ -10,6 +10,28 @@ sidebarDepth: 0
后端使用方法请参考: [Quart 文档](https://pgjones.gitlab.io/quart/index.html) 后端使用方法请参考: [Quart 文档](https://pgjones.gitlab.io/quart/index.html)
## _class_ `Config`
基类:`pydantic.env_settings.BaseSettings`
Quart 驱动框架设置
### `quart_reload_dirs`
* **类型**
`List[str]`
* **说明**
`debug` 模式下重载监控文件夹列表,默认为 uvicorn 默认值
## _class_ `Driver` ## _class_ `Driver`
基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver) 基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver)
@ -44,7 +66,7 @@ Quart 驱动框架
### _property_ `logger` ### _property_ `logger`
fastapi 使用的 logger Quart 使用的 logger
### `on_startup(func)` ### `on_startup(func)`

View File

@ -132,27 +132,6 @@ sidebarDepth: 0
## _exception_ `RequestDenied`
基类:`nonebot.exception.NoneBotException`
* **说明**
Bot 连接请求不合法。
* **参数**
* `status_code: int`: HTTP 状态码
* `reason: str`: 拒绝原因
## _exception_ `AdapterException` ## _exception_ `AdapterException`
基类:`nonebot.exception.NoneBotException` 基类:`nonebot.exception.NoneBotException`

View File

@ -11,13 +11,14 @@ from copy import copy
from functools import reduce, partial from functools import reduce, partial
from typing_extensions import Protocol from typing_extensions import Protocol
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Set, List, Dict, Union, TypeVar, Mapping, Optional, from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping,
Iterable, Awaitable, TYPE_CHECKING) Optional, Iterable, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel from pydantic import BaseModel
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.drivers import HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,12 +52,7 @@ class Bot(abc.ABC):
:说明: call_api 后执行的函数 :说明: call_api 后执行的函数
""" """
@abc.abstractmethod def __init__(self, self_id: str, request: HTTPConnection):
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional["WebSocket"] = None):
""" """
:参数: :参数:
@ -64,12 +60,10 @@ class Bot(abc.ABC):
* ``self_id: str``: 机器人 ID * ``self_id: str``: 机器人 ID
* ``websocket: Optional[WebSocket]``: Websocket 连接对象 * ``websocket: Optional[WebSocket]``: Websocket 连接对象
""" """
self.connection_type = connection_type self.self_id: str = self_id
"""连接类型"""
self.self_id = self_id
"""机器人 ID""" """机器人 ID"""
self.websocket = websocket self.request: HTTPConnection = request
"""Websocket 连接对象""" """连接信息"""
def __getattr__(self, name: str) -> _ApiCall: def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name) return partial(self.call_api, name)
@ -92,8 +86,9 @@ class Bot(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(
headers: dict, body: Optional[bytes]) -> str: cls, driver: "Driver", request: HTTPConnection
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
""" """
:说明: :说明:
@ -108,7 +103,8 @@ class Bot(abc.ABC):
:返回: :返回:
- ``str``: 连接唯一标识符 - ``str``: 连接唯一标识符``None`` 代表连接不合法
- ``HTTPResponse``: HTTP 上报响应
:异常: :异常:
@ -117,7 +113,7 @@ class Bot(abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def handle_message(self, message: dict): async def handle_message(self, message: bytes):
""" """
:说明: :说明:
@ -125,7 +121,7 @@ class Bot(abc.ABC):
:参数: :参数:
* ``message: dict``: 收到的上报消息 * ``message: bytes``: 收到的上报消息
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -7,8 +7,8 @@
import abc import abc
import asyncio import asyncio
from typing import (Any, Set, List, Dict, Type, Tuple, Optional, Callable, from dataclasses import dataclass, field
MutableMapping, TYPE_CHECKING) from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
@ -47,12 +47,12 @@ class Driver(abc.ABC):
* ``env: Env``: 包含环境信息的 Env 对象 * ``env: Env``: 包含环境信息的 Env 对象
* ``config: Config``: 包含配置信息的 Config 对象 * ``config: Config``: 包含配置信息的 Config 对象
""" """
self.env = env.environment self.env: str = env.environment
""" """
:类型: ``str`` :类型: ``str``
:说明: 环境名称 :说明: 环境名称
""" """
self.config = config self.config: Config = config
""" """
:类型: ``Config`` :类型: ``Config``
:说明: 配置对象 :说明: 配置对象
@ -231,143 +231,101 @@ class ReverseDriver(Driver):
raise NotImplementedError raise NotImplementedError
class HTTPRequest: @dataclass
class HTTPConnection(abc.ABC):
http_version: str
"""One of `"1.0"`, `"1.1"` or `"2"`."""
scheme: str
"""URL scheme portion (likely `"http"` or `"https"`)."""
path: str
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
query_string: bytes = b""
""" URL portion after the `?`, percent-encoded."""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Header names must be lowercased.
"""
@property
@abc.abstractmethod
def type(self) -> str:
"""Connection type."""
raise NotImplementedError
@dataclass
class HTTPRequest(HTTPConnection):
"""HTTP 请求封装。参考 `asgi http scope`_。 """HTTP 请求封装。参考 `asgi http scope`_。
.. _asgi http scope: .. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
""" """
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
"""Body of the request.
def __init__(self, scope: MutableMapping[str, Any]): Optional; if missing defaults to b"".
self._scope = scope """
@property @property
def type(self) -> str: def type(self) -> str:
"""Always `http`""" """Always ``http``"""
return "http" return "http"
@property
def scope(self) -> MutableMapping[str, Any]:
"""Raw scope from asgi.
The connection scope information, a dictionary that
contains at least a `type` key specifying the protocol that is incoming.
"""
return self._scope
@property
def http_version(self) -> str:
"""One of `"1.0"`, `"1.1"` or `"2"`."""
raise self.scope["http_version"]
@property
def method(self) -> str:
"""The HTTP method name, uppercased."""
raise self.scope["method"]
@property
def schema(self) -> str:
"""
URL scheme portion (likely `"http"` or `"https"`).
Optional (but must not be empty); default is `"http"`.
"""
raise self.scope["schema"]
@property
def path(self) -> str:
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
return self.scope["path"]
@property
def query_string(self) -> bytes:
""" URL portion after the `?`, percent-encoded."""
return self.scope["query_string"]
@property
def headers(self) -> List[Tuple[bytes, bytes]]:
"""An iterable of [name, value] two-item iterables,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Duplicates are possible and must be preserved in the message as received.
Header names must be lowercased.
"""
return list(self.scope["headers"])
@property
def body(self) -> bytes:
"""Body of the request.
Optional; if missing defaults to b"".
If more_body is set, treat as start of body and concatenate on further chunks.
"""
return self.scope["body"]
@dataclass
class HTTPResponse: class HTTPResponse:
"""HTTP 响应封装。参考 `asgi http scope`_。 """HTTP 响应封装。参考 `asgi http scope`_。
.. _asgi http scope: .. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
""" """
status: int
"""HTTP status code."""
body: Optional[bytes] = None
"""HTTP body content.
def __init__(self, Optional; if missing defaults to ``None``.
status: int, """
headers: List[Tuple[bytes, bytes]] = [], headers: Dict[str, str] = field(default_factory=dict)
body: Optional[bytes] = None): """A dict of name-value pairs,
self.status: int = status where name is the header name, and value is the header value.
"""HTTP status code."""
self.headers: List[Tuple[bytes, bytes]] = headers
"""An iterable of [name, value] two-item iterables,
where name is the header name,
and value is the header value.
Order must be preserved in the HTTP response. Order must be preserved in the HTTP response.
Header names must be lowercased. Header names must be lowercased.
Optional; if missing defaults to an empty list. Optional; if missing defaults to an empty dict.
""" """
self.body: Optional[bytes] = body
"""HTTP body content.
Optional; if missing defaults to `None`.
"""
@property @property
def type(self) -> str: def type(self) -> str:
"""Always `http`""" """Always ``http``"""
return "http" return "http"
class WebSocket: @dataclass
class WebSocket(HTTPConnection, abc.ABC):
"""WebSocket 连接封装。参考 `asgi websocket scope`_。 """WebSocket 连接封装。参考 `asgi websocket scope`_。
.. _asgi websocket scope: .. _asgi websocket scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
""" """
@abc.abstractmethod
def __init__(self, websocket):
"""
:参数:
* ``websocket: Any``: WebSocket 连接对象
"""
self._websocket = websocket
@property @property
def websocket(self): def type(self) -> str:
"""WebSocket 连接对象""" """Always ``websocket``"""
return self._websocket return "websocket"
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -389,11 +347,21 @@ class WebSocket:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def receive(self) -> dict: async def receive(self) -> str:
"""接收一条 WebSocket 信息""" """接收一条 WebSocket text 信息"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def send(self, data: dict): async def receive_bytes(self) -> bytes:
"""发送一条 WebSocket 信息""" """接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError raise NotImplementedError

View File

@ -8,23 +8,22 @@ FastAPI 驱动适配
https://fastapi.tiangolo.com/ https://fastapi.tiangolo.com/
""" """
import json
import asyncio import asyncio
import logging import logging
from dataclasses import dataclass
from typing import List, Optional, Callable from typing import List, Optional, Callable
import uvicorn import uvicorn
from pydantic import BaseSettings from pydantic import BaseSettings
from fastapi.responses import Response from fastapi.responses import Response
from fastapi import status, Request, FastAPI, HTTPException from fastapi import status, Request, FastAPI, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket
as FastAPIWebSocket)
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.config import Env, Config as NoneBotConfig from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
class Config(BaseSettings): class Config(BaseSettings):
@ -179,11 +178,6 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver) @overrides(ReverseDriver)
async def _handle_http(self, adapter: str, request: Request): async def _handle_http(self, adapter: str, request: Request):
data = await request.body() data = await request.body()
data_dict = json.loads(data.decode())
if not isinstance(data_dict, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning( logger.warning(
@ -194,27 +188,34 @@ class Driver(ReverseDriver):
# 创建 Bot 对象 # 创建 Bot 对象
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
headers = dict(request.headers) http_request = HTTPRequest(request.scope["http_version"],
try: request.url.scheme, request.url.path,
x_self_id = await BotClass.check_permission(self, "http", headers, request.scope["query_string"],
data) dict(request.headers), request.method, data)
except RequestDenied as e: x_self_id, response = await BotClass.check_permission(
raise HTTPException(status_code=e.status_code, self, http_request)
detail=e.reason) from None
if not x_self_id:
raise HTTPException(response and response.status or 401,
response.body)
if x_self_id in self._clients: if x_self_id in self._clients:
logger.warning("There's already a reverse websocket connection," logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.") "so the event may be handled twice.")
bot = BotClass("http", x_self_id) bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data_dict)) asyncio.create_task(bot.handle_message(data))
return Response("", 204) return Response(response and response.body,
response and response.status or 200)
@overrides(ReverseDriver) @overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str, async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket): websocket: FastAPIWebSocket):
ws = WebSocket(websocket) ws = WebSocket(websocket.scope.get("http_version",
"1.1"), websocket.url.scheme,
websocket.url.path, websocket.scope["query_string"],
dict(websocket.headers), websocket)
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning( logger.warning(
@ -225,11 +226,9 @@ class Driver(ReverseDriver):
# Create Bot Object # Create Bot Object
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
headers = dict(websocket.headers) x_self_id, _ = await BotClass.check_permission(self, ws)
try:
x_self_id = await BotClass.check_permission(self, "websocket", if not x_self_id:
headers, None)
except RequestDenied:
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return return
@ -240,7 +239,7 @@ class Driver(ReverseDriver):
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return return
bot = BotClass("websocket", x_self_id, websocket=ws) bot = BotClass(x_self_id, ws)
await ws.accept() await ws.accept()
logger.opt(colors=True).info( logger.opt(colors=True).info(
@ -251,54 +250,51 @@ class Driver(ReverseDriver):
try: try:
while not ws.closed: while not ws.closed:
data = await ws.receive() try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
if not data: asyncio.create_task(bot.handle_message(data.encode()))
continue
asyncio.create_task(bot.handle_message(data))
finally: finally:
self._bot_disconnect(bot) self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
websocket: FastAPIWebSocket = None # type: ignore
def __init__(self, websocket: FastAPIWebSocket):
super().__init__(websocket)
self._closed = False
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self): def closed(self):
return self._closed return (self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
await self.websocket.accept() await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code) await self.websocket.close(code=code)
self._closed = True
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> Optional[dict]: async def receive(self) -> str:
data = None return await self.websocket.receive_text()
try:
data = await self.websocket.receive_json()
if not isinstance(data, dict):
data = None
raise ValueError
except ValueError:
logger.warning("Received an invalid json message.")
except WebSocketDisconnect:
self._closed = True
logger.error("WebSocket disconnected by peer.")
return data
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: dict) -> None: async def receive_bytes(self) -> bytes:
text = json.dumps(data, cls=DataclassEncoder) return await self.websocket.receive_bytes()
await self.websocket.send({"type": "websocket.send", "text": text})
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})

View File

@ -9,24 +9,22 @@ Quart 驱动适配
""" """
import asyncio import asyncio
from json.decoder import JSONDecodeError from typing import List, TypeVar, Callable, Coroutine, Optional
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar
import uvicorn import uvicorn
from pydantic import BaseSettings
from nonebot.config import Config as NoneBotConfig
from nonebot.config import Env
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
try: try:
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
from quart import exceptions from quart import exceptions
from quart import request as _request from quart import request as _request
from quart import websocket as _websocket from quart import websocket as _websocket
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
except ImportError: except ImportError:
raise ValueError( raise ValueError(
'Please install Quart by using `pip install nonebot2[quart]`') 'Please install Quart by using `pip install nonebot2[quart]`')
@ -34,6 +32,25 @@ except ImportError:
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) _AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
class Config(BaseSettings):
"""
Quart 驱动框架设置
"""
quart_reload_dirs: List[str] = []
"""
:类型:
``List[str]``
:说明:
``debug`` 模式下重载监控文件夹列表默认为 uvicorn 默认值
"""
class Config:
extra = "ignore"
class Driver(ReverseDriver): class Driver(ReverseDriver):
""" """
Quart 驱动框架 Quart 驱动框架
@ -48,18 +65,20 @@ class Driver(ReverseDriver):
def __init__(self, env: Env, config: NoneBotConfig): def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config) super().__init__(env, config)
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__) self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule('/<adapter>/http', self._server_app.add_url_rule("/<adapter>/http",
methods=['POST'], methods=["POST"],
view_func=self._handle_http) view_func=self._handle_http)
self._server_app.add_websocket('/<adapter>/ws', self._server_app.add_websocket("/<adapter>/ws",
view_func=self._handle_ws_reverse) view_func=self._handle_ws_reverse)
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
def type(self) -> str: def type(self) -> str:
"""驱动名称: ``quart``""" """驱动名称: ``quart``"""
return 'quart' return "quart"
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
@ -76,17 +95,21 @@ class Driver(ReverseDriver):
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
def logger(self): def logger(self):
"""fastapi 使用的 logger""" """Quart 使用的 logger"""
return self._server_app.logger return self._server_app.logger
@overrides(ReverseDriver) @overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable: def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_""" """参考文档: `Startup and Shutdown`_
.. _Startup and Shutdown:
https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html
"""
return self.server_app.before_serving(func) # type: ignore return self.server_app.before_serving(func) # type: ignore
@overrides(ReverseDriver) @overrides(ReverseDriver)
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_""" """参考文档: `Startup and Shutdown`_"""
return self.server_app.after_serving(func) # type: ignore return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver) @overrides(ReverseDriver)
@ -121,6 +144,7 @@ class Driver(ReverseDriver):
host=host or str(self.config.host), host=host or str(self.config.host),
port=port or self.config.port, port=port or self.config.port,
reload=bool(app) and self.config.debug, reload=bool(app) and self.config.debug,
reload_dirs=self.quart_config.quart_reload_dirs or None,
debug=self.config.debug, debug=self.config.debug,
log_config=LOGGING_CONFIG, log_config=LOGGING_CONFIG,
**kwargs) **kwargs)
@ -128,11 +152,7 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver) @overrides(ReverseDriver)
async def _handle_http(self, adapter: str): async def _handle_http(self, adapter: str):
request: Request = _request request: Request = _request
data: bytes = await request.get_data() # type: ignore
try:
data: Dict[str, Any] = await request.get_json()
except Exception as e:
raise exceptions.BadRequest()
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. ' logger.warning(f'Unknown adapter {adapter}. '
@ -140,25 +160,32 @@ class Driver(ReverseDriver):
raise exceptions.NotFound() raise exceptions.NotFound()
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
headers = {k: v for k, v in request.headers.items(lower=True)} http_request = HTTPRequest(request.http_version, request.scheme,
request.path, request.query_string,
dict(request.headers), request.method, data)
try: self_id, response = await BotClass.check_permission(self, http_request)
self_id = await BotClass.check_permission(self, 'http', headers,
data) if not self_id:
except RequestDenied as e: raise exceptions.HTTPException(
raise exceptions.HTTPException(status_code=e.status_code, response and response.status or 401,
description=e.reason, description=(response and response.body or b"").decode(),
name='Request Denied') name="Request Denied")
if self_id in self._clients: if self_id in self._clients:
logger.warning("There's already a reverse websocket connection," logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.") "so the event may be handled twice.")
bot = BotClass('http', self_id) bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
return Response('', 204) return Response(response and response.body or "",
response and response.status or 200)
@overrides(ReverseDriver) @overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str): async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket websocket: QuartWebSocket = _websocket
ws = WebSocket(websocket.http_version, websocket.scheme,
websocket.path, websocket.query_string,
dict(websocket.headers), websocket)
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning( logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.' f'Unknown adapter {adapter}. Please register the adapter before use.'
@ -166,19 +193,23 @@ class Driver(ReverseDriver):
raise exceptions.NotFound() raise exceptions.NotFound()
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
headers = {k: v for k, v in websocket.headers.items(lower=True)} self_id, response = await BotClass.check_permission(self, ws)
try:
self_id = await BotClass.check_permission(self, 'websocket', if not self_id:
headers, None) raise exceptions.HTTPException(
except RequestDenied as e: response and response.status or 401,
raise exceptions.HTTPException(status_code=e.status_code, description=(response and response.body or b"").decode(),
description=e.reason, name="Request Denied")
name='Request Denied')
if self_id in self._clients: if self_id in self._clients:
logger.warning("There's already a reverse websocket connection," logger.opt(colors=True).warning(
"so the event may be handled twice.") "There's already a reverse websocket connection, "
ws = WebSocket(websocket) f"<y>{adapter.upper()} Bot {self_id}</y> ignored.")
bot = BotClass('websocket', self_id, websocket=ws) raise exceptions.HTTPException(403,
description="Client already exists",
name="Request Denied")
bot = BotClass(self_id, ws)
await ws.accept() await ws.accept()
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection from <y>{adapter.upper()} " f"WebSocket Connection from <y>{adapter.upper()} "
@ -187,52 +218,51 @@ class Driver(ReverseDriver):
try: try:
while not ws.closed: while not ws.closed:
data = await ws.receive() try:
if data is None: data = await ws.receive()
continue except asyncio.CancelledError:
asyncio.create_task(bot.handle_message(data)) logger.warning("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
asyncio.create_task(bot.handle_message(data.encode()))
finally: finally:
self._bot_disconnect(bot) self._bot_disconnect(bot)
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
websocket: QuartWebSocket = None # type: ignore
@overrides(BaseWebSocket)
def __init__(self, websocket: QuartWebSocket):
super().__init__(websocket)
self._closed = False
@property
@overrides(BaseWebSocket)
def websocket(self) -> QuartWebSocket:
return self._websocket
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self): def closed(self):
return self._closed # FIXME
return False
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
await self.websocket.accept() await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self): async def close(self):
self._closed = True # FIXME
pass
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> Optional[Dict[str, Any]]: async def receive(self) -> str:
data: Optional[Dict[str, Any]] = None return await self.websocket.receive() # type: ignore
try:
data = await self.websocket.receive_json()
except JSONDecodeError:
logger.warning('Received an invalid json message.')
except asyncio.CancelledError:
self._closed = True
logger.warning('WebSocket disconnected by peer.')
return data
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: dict): async def receive_bytes(self) -> bytes:
await self.websocket.send_json(data) return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)

View File

@ -115,29 +115,6 @@ class StopPropagation(NoneBotException):
pass pass
class RequestDenied(NoneBotException):
"""
:说明:
Bot 连接请求不合法
:参数:
* ``status_code: int``: HTTP 状态码
* ``reason: str``: 拒绝原因
"""
def __init__(self, status_code: int, reason: str):
self.status_code = status_code
self.reason = reason
def __repr__(self):
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
def __str__(self):
return self.__repr__()
class AdapterException(NoneBotException): class AdapterException(NoneBotException):
""" """
:说明: :说明:

View File

@ -3,14 +3,15 @@ import sys
import hmac import hmac
import json import json
import asyncio import asyncio
from typing import Any, Dict, Union, Optional, TYPE_CHECKING from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx import httpx
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.utils import DataclassEncoder
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape from .utils import log, escape
from .config import Config as CQHTTPConfig from .config import Config as CQHTTPConfig
@ -20,7 +21,6 @@ from .exception import NetworkError, ApiNotAvailable, ActionFailed
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]: def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
@ -28,7 +28,7 @@ def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
return None return None
scheme, _, param = access_token.partition(" ") scheme, _, param = access_token.partition(" ")
if scheme.lower() not in ["bearer", "token"]: if scheme.lower() not in ["bearer", "token"]:
raise RequestDenied(401, "Not authenticated") return None
return param return param
@ -225,14 +225,6 @@ class Bot(BaseBot):
""" """
cqhttp_config: CQHTTPConfig cqhttp_config: CQHTTPConfig
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional["WebSocket"] = None):
super().__init__(connection_type, self_id, websocket=websocket)
@property @property
@overrides(BaseBot) @overrides(BaseBot)
def type(self) -> str: def type(self) -> str:
@ -242,84 +234,84 @@ class Bot(BaseBot):
return "cqhttp" return "cqhttp"
@classmethod @classmethod
def register(cls, driver: "Driver", config: "Config"): def register(cls, driver: Driver, config: "Config"):
super().register(driver, config) super().register(driver, config)
cls.cqhttp_config = CQHTTPConfig(**config.dict()) cls.cqhttp_config = CQHTTPConfig(**config.dict())
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(
headers: dict, body: Optional[bytes]) -> str: cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
""" """
:说明: :说明:
CQHTTP (OneBot) 协议鉴权参考 `鉴权 <https://github.com/howmanybots/onebot/blob/master/v11/specs/communication/authorization.md>`_ CQHTTP (OneBot) 协议鉴权参考 `鉴权 <https://github.com/howmanybots/onebot/blob/master/v11/specs/communication/authorization.md>`_
""" """
x_self_id = headers.get("x-self-id") x_self_id = request.headers.get("x-self-id")
x_signature = headers.get("x-signature") x_signature = request.headers.get("x-signature")
token = get_auth_bearer(headers.get("authorization")) token = get_auth_bearer(request.headers.get("authorization"))
cqhttp_config = CQHTTPConfig(**driver.config.dict()) cqhttp_config = CQHTTPConfig(**driver.config.dict())
# 检查连接方式
if connection_type not in ["http", "websocket"]:
log("WARNING", "Unsupported connection type")
raise RequestDenied(405, "Unsupported connection type")
# 检查self_id # 检查self_id
if not x_self_id: if not x_self_id:
log("WARNING", "Missing X-Self-ID Header") log("WARNING", "Missing X-Self-ID Header")
raise RequestDenied(400, "Missing X-Self-ID Header") return None, HTTPResponse(400, b"Missing X-Self-ID Header")
# 检查签名 # 检查签名
secret = cqhttp_config.secret secret = cqhttp_config.secret
if secret and connection_type == "http": if secret and isinstance(request, HTTPRequest):
if not x_signature: if not x_signature:
log("WARNING", "Missing Signature Header") log("WARNING", "Missing Signature Header")
raise RequestDenied(401, "Missing Signature") return None, HTTPResponse(401, b"Missing Signature")
sig = hmac.new(secret.encode("utf-8"), body, "sha1").hexdigest() sig = hmac.new(secret.encode("utf-8"), request.body,
"sha1").hexdigest()
if x_signature != "sha1=" + sig: if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid") log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid") return None, HTTPResponse(403, b"Signature is invalid")
access_token = cqhttp_config.access_token access_token = cqhttp_config.access_token
if access_token and access_token != token and connection_type == "websocket": if access_token and access_token != token and isinstance(
request, WebSocket):
log( log(
"WARNING", "Authorization Header is invalid" "WARNING", "Authorization Header is invalid"
if token else "Missing Authorization Header") if token else "Missing Authorization Header")
raise RequestDenied( return None, HTTPResponse(
403, "Authorization Header is invalid" 403, b"Authorization Header is invalid"
if token else "Missing Authorization Header") if token else b"Missing Authorization Header")
return str(x_self_id) return str(x_self_id), HTTPResponse(204, b'')
@overrides(BaseBot) @overrides(BaseBot)
async def handle_message(self, message: dict): async def handle_message(self, message: bytes):
""" """
:说明: :说明:
调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_ 调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_
""" """
if not message: data = json.loads(message)
if not data:
return return
if "post_type" not in message: if "post_type" not in data:
ResultStore.add_result(message) ResultStore.add_result(data)
return return
try: try:
post_type = message['post_type'] post_type = data['post_type']
detail_type = message.get(f"{post_type}_type") detail_type = data.get(f"{post_type}_type")
detail_type = f".{detail_type}" if detail_type else "" detail_type = f".{detail_type}" if detail_type else ""
sub_type = message.get("sub_type") sub_type = data.get("sub_type")
sub_type = f".{sub_type}" if sub_type else "" sub_type = f".{sub_type}" if sub_type else ""
models = get_event_model(post_type + detail_type + sub_type) models = get_event_model(post_type + detail_type + sub_type)
for model in models: for model in models:
try: try:
event = model.parse_obj(message) event = model.parse_obj(data)
break break
except Exception as e: except Exception as e:
log("DEBUG", "Event Parser Error", e) log("DEBUG", "Event Parser Error", e)
else: else:
event = Event.parse_obj(message) event = Event.parse_obj(data)
# Check whether user is calling me # Check whether user is calling me
await _check_reply(self, event) await _check_reply(self, event)
@ -329,25 +321,28 @@ class Bot(BaseBot):
await handle_event(self, event) await handle_event(self, event)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>" f"<r><bg #f8bbd0>Failed to handle event. Raw: {data}</bg #f8bbd0></r>"
) )
@overrides(BaseBot) @overrides(BaseBot)
async def _call_api(self, api: str, **data) -> Any: async def _call_api(self, api: str, **data) -> Any:
log("DEBUG", f"Calling API <y>{api}</y>") log("DEBUG", f"Calling API <y>{api}</y>")
if self.connection_type == "websocket": if isinstance(self.request, WebSocket):
seq = ResultStore.get_seq() seq = ResultStore.get_seq()
await self.websocket.send({ json_data = json.dumps(
"action": api, {
"params": data, "action": api,
"echo": { "params": data,
"seq": seq "echo": {
} "seq": seq
}) }
},
cls=DataclassEncoder)
await self.request.send(json_data)
return _handle_api_result(await ResultStore.fetch( return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout)) seq, self.config.api_timeout))
elif self.connection_type == "http": elif isinstance(self.request, HTTPRequest):
api_root = self.config.api_root.get(self.self_id) api_root = self.config.api_root.get(self.self_id)
if not api_root: if not api_root:
raise ApiNotAvailable raise ApiNotAvailable
@ -431,7 +426,7 @@ class Bot(BaseBot):
message, str) else message message, str) else message
msg = message if isinstance(message, Message) else Message(message) msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and getattr(event, "user_id", None) at_sender = at_sender and bool(getattr(event, "user_id", None))
params = {} params = {}
if getattr(event, "user_id", None): if getattr(event, "user_id", None):
@ -449,8 +444,7 @@ class Bot(BaseBot):
raise ValueError("Cannot guess message type to reply!") raise ValueError("Cannot guess message type to reply!")
if at_sender and params["message_type"] != "private": if at_sender and params["message_type"] != "private":
params["message"] = MessageSegment.at(params["user_id"]) + \ params["message"] = MessageSegment.at(params["user_id"]) + " " + msg
MessageSegment.text(" ") + msg
else: else:
params["message"] = msg params["message"] = msg
return await self.send_msg(**params) return await self.send_msg(**params)

View File

@ -1,16 +1,17 @@
import json import json
import urllib.parse import urllib.parse
from datetime import datetime
import time import time
from typing import Any, Union, Optional, TYPE_CHECKING from datetime import datetime
from typing import Any, Tuple, Union, Optional, TYPE_CHECKING
import httpx import httpx
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse
from .utils import calc_hmac_base64, log from .utils import calc_hmac_base64, log
from .config import Config as DingConfig from .config import Config as DingConfig
@ -20,7 +21,6 @@ from .event import MessageEvent, PrivateMessageEvent, GroupMessageEvent, Convers
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver
SEND = "send" SEND = "send"
@ -31,10 +31,6 @@ class Bot(BaseBot):
""" """
ding_config: DingConfig ding_config: DingConfig
def __init__(self, connection_type: str, self_id: str, **kwargs):
super().__init__(connection_type, self_id, **kwargs)
@property @property
def type(self) -> str: def type(self) -> str:
""" """
@ -43,57 +39,61 @@ class Bot(BaseBot):
return "ding" return "ding"
@classmethod @classmethod
def register(cls, driver: "Driver", config: "Config"): def register(cls, driver: Driver, config: "Config"):
super().register(driver, config) super().register(driver, config)
cls.ding_config = DingConfig(**config.dict()) cls.ding_config = DingConfig(**config.dict())
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(
headers: dict, body: Optional[bytes]) -> str: cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
""" """
:说明: :说明:
钉钉协议鉴权参考 `鉴权 <https://ding-doc.dingtalk.com/doc#/serverapi2/elzz1p>`_ 钉钉协议鉴权参考 `鉴权 <https://ding-doc.dingtalk.com/doc#/serverapi2/elzz1p>`_
""" """
timestamp = headers.get("timestamp") timestamp = request.headers.get("timestamp")
sign = headers.get("sign") sign = request.headers.get("sign")
# 检查连接方式 # 检查连接方式
if connection_type not in ["http"]: if not isinstance(request, HTTPRequest):
raise RequestDenied( return None, HTTPResponse(
405, "Unsupported connection type, available type: `http`") 405, b"Unsupported connection type, available type: `http`")
# 检查 timestamp # 检查 timestamp
if not timestamp: if not timestamp:
raise RequestDenied(400, "Missing `timestamp` Header") return None, HTTPResponse(400, b"Missing `timestamp` Header")
# 检查 sign # 检查 sign
secret = cls.ding_config.secret secret = cls.ding_config.secret
if secret: if secret:
if not sign: if not sign:
log("WARNING", "Missing Signature Header") log("WARNING", "Missing Signature Header")
raise RequestDenied(400, "Missing `sign` Header") return None, HTTPResponse(400, b"Missing `sign` Header")
sign_base64 = calc_hmac_base64(str(timestamp), secret) sign_base64 = calc_hmac_base64(str(timestamp), secret)
if sign != sign_base64.decode('utf-8'): if sign != sign_base64.decode('utf-8'):
log("WARNING", "Signature Header is invalid") log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid") return None, HTTPResponse(403, b"Signature is invalid")
else: else:
log("WARNING", "Ding signature check ignored!") log("WARNING", "Ding signature check ignored!")
return json.loads(body.decode())["chatbotUserId"] return (json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b''))
@overrides(BaseBot) @overrides(BaseBot)
async def handle_message(self, message: dict): async def handle_message(self, message: bytes):
if not message: data = json.loads(message)
if not data:
return return
# 判断消息类型,生成不同的 Event # 判断消息类型,生成不同的 Event
try: try:
conversation_type = message["conversationType"] conversation_type = data["conversationType"]
if conversation_type == ConversationType.private: if conversation_type == ConversationType.private:
event = PrivateMessageEvent.parse_obj(message) event = PrivateMessageEvent.parse_obj(data)
elif conversation_type == ConversationType.group: elif conversation_type == ConversationType.group:
event = GroupMessageEvent.parse_obj(message) event = GroupMessageEvent.parse_obj(data)
else: else:
raise ValueError("Unsupported conversation type") raise ValueError("Unsupported conversation type")
except Exception as e: except Exception as e:
@ -104,7 +104,7 @@ class Bot(BaseBot):
await handle_event(self, event) await handle_event(self, event)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>" f"<r><bg #f8bbd0>Failed to handle event. Raw: {data}</bg #f8bbd0></r>"
) )
return return

View File

@ -1,8 +1,8 @@
""" r"""
Mirai-API-HTTP 协议适配 Mirai-API-HTTP 协议适配
============================ ============================
协议详情请看: `mirai-api-http 文档`_ 协议详情请看: `mirai-api-http 文档`_
\:\:\: tip \:\:\: tip
该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试 该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试

View File

@ -5,11 +5,11 @@ from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx import httpx
from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import ApiNotAvailable
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage from .event import Event, FriendMessage, GroupMessage, TempMessage
@ -140,7 +140,7 @@ class SessionManager:
class Bot(BaseBot): class Bot(BaseBot):
""" r"""
mirai-api-http 协议 Bot 适配 mirai-api-http 协议 Bot 适配
\:\:\: warning \:\:\: warning
@ -151,14 +151,6 @@ class Bot(BaseBot):
""" """
@overrides(BaseBot)
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional[WebSocket] = None):
super().__init__(connection_type, self_id, websocket=websocket)
@property @property
@overrides(BaseBot) @overrides(BaseBot)
def type(self) -> str: def type(self) -> str:
@ -166,7 +158,8 @@ class Bot(BaseBot):
@property @property
def alive(self) -> bool: def alive(self) -> bool:
return not self.websocket.closed assert isinstance(self.request, WebSocket)
return not self.request.closed
@property @property
def api(self) -> SessionManager: def api(self) -> SessionManager:
@ -177,27 +170,26 @@ class Bot(BaseBot):
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(
headers: dict, body: Optional[bytes]) -> str: cls, driver: Driver,
if connection_type == 'ws': request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
raise RequestDenied( if isinstance(request, WebSocket):
status_code=501, return None, HTTPResponse(
reason='Websocket connection is not implemented') 501, b'Websocket connection is not implemented')
self_id: Optional[str] = headers.get('bot') self_id: Optional[str] = request.headers.get('bot')
if self_id is None: if self_id is None:
raise RequestDenied(status_code=400, return None, HTTPResponse(400, b'Header `Bot` is required.')
reason='Header `Bot` is required.')
self_id = str(self_id).strip() self_id = str(self_id).strip()
await SessionManager.new( await SessionManager.new(
int(self_id), int(self_id),
host=cls.mirai_config.host, # type: ignore host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, #type: ignore port=cls.mirai_config.port, #type: ignore
auth_key=cls.mirai_config.auth_key) # type: ignore auth_key=cls.mirai_config.auth_key) # type: ignore
return self_id return self_id, HTTPResponse(204, b'')
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config"): def register(cls, driver: Driver, config: "Config"):
cls.mirai_config = MiraiConfig(**config.dict()) cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None: cls.mirai_config.port) is None:
@ -224,7 +216,7 @@ class Bot(BaseBot):
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn: async def call_api(self, api: str, **data) -> NoReturn:
""" r"""
\:\:\: danger \:\:\: danger
由于Mirai的HTTP API特殊性, 该API暂时无法实现 由于Mirai的HTTP API特殊性, 该API暂时无法实现
\:\:\: \:\:\:

View File

@ -1,18 +1,16 @@
import asyncio
import json import json
import asyncio
from dataclasses import dataclass
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set, from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine
TypeVar)
import httpx import httpx
import websockets import websockets
from nonebot.config import Config
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket
from .bot import SessionManager, Bot from .bot import SessionManager, Bot
@ -21,7 +19,9 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction) bound=WebsocketHandlerFunction)
@dataclass
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
websocket: websockets.WebSocketClientProtocol = None # type: ignore
@classmethod @classmethod
async def new(cls, *, host: IPv4Address, port: int, async def new(cls, *, host: IPv4Address, port: int,
@ -37,24 +37,26 @@ class WebSocket(BaseWebSocket):
self.event_handlers: Set[WebsocketHandlerFunction] = set() self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket) super().__init__(websocket)
@property
@overrides(BaseWebSocket)
def websocket(self) -> websockets.WebSocketClientProtocol:
return self._websocket
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self) -> bool: def closed(self) -> bool:
return self.websocket.closed return self.websocket.closed
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: Dict[str, Any]): async def send(self, data: str):
return await self.websocket.send(json.dumps(data)) return await self.websocket.send(data)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> Dict[str, Any]: async def send_bytes(self, data: str):
received = await self.websocket.recv() return await self.websocket.send(data)
return json.loads(received)
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.recv() # type: ignore
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.recv() # type: ignore
async def _dispatcher(self): async def _dispatcher(self):
while not self.closed: while not self.closed:
@ -93,11 +95,6 @@ class WebsocketBot(Bot):
mirai-api-http 正向 Websocket 协议 Bot 适配 mirai-api-http 正向 Websocket 协议 Bot 适配
""" """
@overrides(Bot)
def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket)
@property @property
@overrides(Bot) @overrides(Bot)
def type(self) -> str: def type(self) -> str:
@ -105,7 +102,8 @@ class WebsocketBot(Bot):
@property @property
def alive(self) -> bool: def alive(self) -> bool:
return not self.websocket.closed assert isinstance(self.request, WebSocket)
return not self.request.closed
@property @property
def api(self) -> SessionManager: def api(self) -> SessionManager:
@ -115,16 +113,14 @@ class WebsocketBot(Bot):
@classmethod @classmethod
@overrides(Bot) @overrides(Bot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(
headers: dict, cls, driver: Driver,
body: Optional[bytes]) -> NoReturn: request: HTTPConnection) -> Tuple[None, HTTPResponse]:
raise RequestDenied( return None, HTTPResponse(501, b'Connection not implented')
status_code=501,
reason=f'Connection {connection_type} not implented')
@classmethod @classmethod
@overrides(Bot) @overrides(Bot)
def register(cls, driver: "Driver", config: "Config", qq: int): def register(cls, driver: Driver, config: "Config", qq: int):
""" """
:说明: :说明: