🔀 Merge pull request #406

Feature: 支持自定义 Response
This commit is contained in:
Ju4tCode 2021-06-14 14:14:43 +08:00 committed by GitHub
commit e9bc98e74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 465 additions and 511 deletions

View File

@ -57,7 +57,7 @@ Config 配置对象
### _abstract_ `__init__(connection_type, self_id, *, websocket=None)`
### `__init__(self_id, request)`
* **参数**
@ -73,19 +73,14 @@ Config 配置对象
### `connection_type`
连接类型
### `self_id`
机器人 ID
### `websocket`
### `request`
Websocket 连接对象
连接信息
### _abstract property_ `type`
@ -102,7 +97,7 @@ Adapter 类型
### _abstract async classmethod_ `check_permission(driver, connection_type, headers, body)`
### _abstract async classmethod_ `check_permission(driver, request)`
* **说明**
@ -130,7 +125,10 @@ Adapter 类型
* **返回**
* `str`: 连接唯一标识符
* `str`: 连接唯一标识符,`None` 代表连接不合法
* `HTTPResponse`: HTTP 上报响应
@ -153,7 +151,7 @@ Adapter 类型
* **参数**
* `message: dict`: 收到的上报消息
* `message: bytes`: 收到的上报消息

View File

@ -204,7 +204,7 @@ CQHTTP 协议 Bot 适配。继承属性参考 [BaseBot](./#class-basebot) 。
* 返回: `"cqhttp"`
### _async classmethod_ `check_permission(driver, connection_type, headers, body)`
### _async classmethod_ `check_permission(driver, request)`
* **说明**

View File

@ -105,7 +105,7 @@ sidebarDepth: 0
* 返回: `"ding"`
### _async classmethod_ `check_permission(driver, connection_type, headers, body)`
### _async classmethod_ `check_permission(driver, request)`
* **说明**

View File

@ -682,6 +682,11 @@ API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则
# NoneBot.adapters.mirai.bot_ws 模块
## _class_ `WebSocket`
基类:[`nonebot.drivers.WebSocket`](../drivers/README.md#nonebot.drivers.WebSocket)
## _class_ `WebsocketBot`
基类:`nonebot.adapters.mirai.bot.Bot`

View File

@ -268,74 +268,71 @@ Reverse Driver 基类。将后端框架封装,以满足适配器使用。
用于处理 WebSocket 类型请求的函数
## _class_ `HTTPRequest`
## _class_ `HTTPConnection`
基类:`object`
HTTP 请求封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope)。
基类:`abc.ABC`
### _property_ `type`
Always http
### _property_ `scope`
Raw scope from asgi.
The connection scope information, a dictionary that
contains at least a type key specifying the protocol that is incoming.
### _property_ `http_version`
### `http_version`
One of "1.0", "1.1" or "2".
### _property_ `method`
The HTTP method name, uppercased.
### _property_ `schema`
### `scheme`
URL scheme portion (likely "http" or "https").
Optional (but must not be empty); default is "http".
### _property_ `path`
### `path`
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
### _property_ `query_string`
### `query_string`
URL portion after the ?, percent-encoded.
### _property_ `headers`
### `headers`
An iterable of [name, value] two-item iterables,
A dict of name-value pairs,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Duplicates are possible and must be preserved in the message as received.
Header names must be lowercased.
### _property_ `body`
### _abstract property_ `type`
Connection type.
## _class_ `HTTPRequest`
基类:`nonebot.drivers.HTTPConnection`
HTTP 请求封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope)。
### `method`
The HTTP method name, uppercased.
### `body`
Body of the request.
Optional; if missing defaults to b"".
If more_body is set, treat as start of body and concatenate on further chunks.
### _property_ `type`
Always `http`
## _class_ `HTTPResponse`
@ -350,51 +347,40 @@ HTTP 响应封装。参考 [asgi http scope](https://asgi.readthedocs.io/en/late
HTTP status code.
### `body`
HTTP body content.
Optional; if missing defaults to `None`.
### `headers`
An iterable of [name, value] two-item iterables,
where name is the header name,
and value is the header value.
A dict of name-value pairs,
where name is the header name, and value is the header value.
Order must be preserved in the HTTP response.
Header names must be lowercased.
Optional; if missing defaults to an empty list.
### `body`
HTTP body content.
Optional; if missing defaults to None.
Optional; if missing defaults to an empty dict.
### _property_ `type`
Always http
Always `http`
## _class_ `WebSocket`
基类:`object`
基类:`nonebot.drivers.HTTPConnection`, `abc.ABC`
WebSocket 连接封装。参考 [asgi websocket scope](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope)。
### _abstract_ `__init__(websocket)`
### _property_ `type`
* **参数**
* `websocket: Any`: WebSocket 连接对象
### _property_ `websocket`
WebSocket 连接对象
Always `websocket`
### _abstract property_ `closed`
@ -424,9 +410,19 @@ WebSocket 连接对象
### _abstract async_ `receive()`
接收一条 WebSocket 信息
接收一条 WebSocket text 信息
### _abstract async_ `receive_bytes()`
接收一条 WebSocket binary 信息
### _abstract async_ `send(data)`
发送一条 WebSocket 信息
发送一条 WebSocket text 信息
### _abstract async_ `send_bytes(data)`
发送一条 WebSocket text 信息

View File

@ -133,3 +133,8 @@ fastapi 使用的 logger
### `run(host=None, port=None, *, app=None, **kwargs)`
使用 `uvicorn` 启动 FastAPI
## _class_ `WebSocket`
基类:[`nonebot.drivers.WebSocket`](README.md#nonebot.drivers.WebSocket)

View File

@ -10,6 +10,28 @@ sidebarDepth: 0
后端使用方法请参考: [Quart 文档](https://pgjones.gitlab.io/quart/index.html)
## _class_ `Config`
基类:`pydantic.env_settings.BaseSettings`
Quart 驱动框架设置
### `quart_reload_dirs`
* **类型**
`List[str]`
* **说明**
`debug` 模式下重载监控文件夹列表,默认为 uvicorn 默认值
## _class_ `Driver`
基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver)
@ -44,7 +66,7 @@ Quart 驱动框架
### _property_ `logger`
fastapi 使用的 logger
Quart 使用的 logger
### `on_startup(func)`

View File

@ -132,27 +132,6 @@ sidebarDepth: 0
## _exception_ `RequestDenied`
基类:`nonebot.exception.NoneBotException`
* **说明**
Bot 连接请求不合法。
* **参数**
* `status_code: int`: HTTP 状态码
* `reason: str`: 拒绝原因
## _exception_ `AdapterException`
基类:`nonebot.exception.NoneBotException`

View File

@ -11,13 +11,14 @@ from copy import copy
from functools import reduce, partial
from typing_extensions import Protocol
from dataclasses import dataclass, field
from typing import (Any, Set, List, Dict, Union, TypeVar, Mapping, Optional,
Iterable, Awaitable, TYPE_CHECKING)
from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping,
Optional, Iterable, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel
from nonebot.log import logger
from nonebot.utils import DataclassEncoder
from nonebot.drivers import HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
if TYPE_CHECKING:
@ -51,12 +52,7 @@ class Bot(abc.ABC):
:说明: call_api 后执行的函数
"""
@abc.abstractmethod
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional["WebSocket"] = None):
def __init__(self, self_id: str, request: HTTPConnection):
"""
:参数:
@ -64,12 +60,10 @@ class Bot(abc.ABC):
* ``self_id: str``: 机器人 ID
* ``websocket: Optional[WebSocket]``: Websocket 连接对象
"""
self.connection_type = connection_type
"""连接类型"""
self.self_id = self_id
self.self_id: str = self_id
"""机器人 ID"""
self.websocket = websocket
"""Websocket 连接对象"""
self.request: HTTPConnection = request
"""连接信息"""
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@ -92,8 +86,9 @@ class Bot(abc.ABC):
@classmethod
@abc.abstractmethod
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[bytes]) -> str:
async def check_permission(
cls, driver: "Driver", request: HTTPConnection
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
"""
:说明:
@ -108,7 +103,8 @@ class Bot(abc.ABC):
:返回:
- ``str``: 连接唯一标识符
- ``str``: 连接唯一标识符``None`` 代表连接不合法
- ``HTTPResponse``: HTTP 上报响应
:异常:
@ -117,7 +113,7 @@ class Bot(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def handle_message(self, message: dict):
async def handle_message(self, message: bytes):
"""
:说明:
@ -125,7 +121,7 @@ class Bot(abc.ABC):
:参数:
* ``message: dict``: 收到的上报消息
* ``message: bytes``: 收到的上报消息
"""
raise NotImplementedError

View File

@ -7,8 +7,8 @@
import abc
import asyncio
from typing import (Any, Set, List, Dict, Type, Tuple, Optional, Callable,
MutableMapping, TYPE_CHECKING)
from dataclasses import dataclass, field
from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger
from nonebot.config import Env, Config
@ -47,12 +47,12 @@ class Driver(abc.ABC):
* ``env: Env``: 包含环境信息的 Env 对象
* ``config: Config``: 包含配置信息的 Config 对象
"""
self.env = env.environment
self.env: str = env.environment
"""
:类型: ``str``
:说明: 环境名称
"""
self.config = config
self.config: Config = config
"""
:类型: ``Config``
:说明: 配置对象
@ -231,143 +231,101 @@ class ReverseDriver(Driver):
raise NotImplementedError
class HTTPRequest:
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
"""One of `"1.0"`, `"1.1"` or `"2"`."""
scheme: str
"""URL scheme portion (likely `"http"` or `"https"`)."""
path: str
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
query_string: bytes = b""
""" URL portion after the `?`, percent-encoded."""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Header names must be lowercased.
"""
@property
@abc.abstractmethod
def type(self) -> str:
"""Connection type."""
raise NotImplementedError
@dataclass
class HTTPRequest(HTTPConnection):
"""HTTP 请求封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
"""Body of the request.
def __init__(self, scope: MutableMapping[str, Any]):
self._scope = scope
Optional; if missing defaults to b"".
"""
@property
def type(self) -> str:
"""Always `http`"""
"""Always ``http``"""
return "http"
@property
def scope(self) -> MutableMapping[str, Any]:
"""Raw scope from asgi.
The connection scope information, a dictionary that
contains at least a `type` key specifying the protocol that is incoming.
"""
return self._scope
@property
def http_version(self) -> str:
"""One of `"1.0"`, `"1.1"` or `"2"`."""
raise self.scope["http_version"]
@property
def method(self) -> str:
"""The HTTP method name, uppercased."""
raise self.scope["method"]
@property
def schema(self) -> str:
"""
URL scheme portion (likely `"http"` or `"https"`).
Optional (but must not be empty); default is `"http"`.
"""
raise self.scope["schema"]
@property
def path(self) -> str:
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
return self.scope["path"]
@property
def query_string(self) -> bytes:
""" URL portion after the `?`, percent-encoded."""
return self.scope["query_string"]
@property
def headers(self) -> List[Tuple[bytes, bytes]]:
"""An iterable of [name, value] two-item iterables,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Duplicates are possible and must be preserved in the message as received.
Header names must be lowercased.
"""
return list(self.scope["headers"])
@property
def body(self) -> bytes:
"""Body of the request.
Optional; if missing defaults to b"".
If more_body is set, treat as start of body and concatenate on further chunks.
"""
return self.scope["body"]
@dataclass
class HTTPResponse:
"""HTTP 响应封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
status: int
"""HTTP status code."""
body: Optional[bytes] = None
"""HTTP body content.
def __init__(self,
status: int,
headers: List[Tuple[bytes, bytes]] = [],
body: Optional[bytes] = None):
self.status: int = status
"""HTTP status code."""
self.headers: List[Tuple[bytes, bytes]] = headers
"""An iterable of [name, value] two-item iterables,
where name is the header name,
and value is the header value.
Optional; if missing defaults to ``None``.
"""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order must be preserved in the HTTP response.
Order must be preserved in the HTTP response.
Header names must be lowercased.
Header names must be lowercased.
Optional; if missing defaults to an empty list.
"""
self.body: Optional[bytes] = body
"""HTTP body content.
Optional; if missing defaults to `None`.
"""
Optional; if missing defaults to an empty dict.
"""
@property
def type(self) -> str:
"""Always `http`"""
"""Always ``http``"""
return "http"
class WebSocket:
@dataclass
class WebSocket(HTTPConnection, abc.ABC):
"""WebSocket 连接封装。参考 `asgi websocket scope`_。
.. _asgi websocket scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
"""
@abc.abstractmethod
def __init__(self, websocket):
"""
:参数:
* ``websocket: Any``: WebSocket 连接对象
"""
self._websocket = websocket
@property
def websocket(self):
"""WebSocket 连接对象"""
return self._websocket
def type(self) -> str:
"""Always ``websocket``"""
return "websocket"
@property
@abc.abstractmethod
@ -389,11 +347,21 @@ class WebSocket:
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> dict:
"""接收一条 WebSocket 信息"""
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: dict):
"""发送一条 WebSocket 信息"""
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError

View File

@ -8,23 +8,22 @@ FastAPI 驱动适配
https://fastapi.tiangolo.com/
"""
import json
import asyncio
import logging
from dataclasses import dataclass
from typing import List, Optional, Callable
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from fastapi import status, Request, FastAPI, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket
as FastAPIWebSocket)
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
class Config(BaseSettings):
@ -179,11 +178,6 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
data_dict = json.loads(data.decode())
if not isinstance(data_dict, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters:
logger.warning(
@ -194,27 +188,34 @@ class Driver(ReverseDriver):
# 创建 Bot 对象
BotClass = self._adapters[adapter]
headers = dict(request.headers)
try:
x_self_id = await BotClass.check_permission(self, "http", headers,
data)
except RequestDenied as e:
raise HTTPException(status_code=e.status_code,
detail=e.reason) from None
http_request = HTTPRequest(request.scope["http_version"],
request.url.scheme, request.url.path,
request.scope["query_string"],
dict(request.headers), request.method, data)
x_self_id, response = await BotClass.check_permission(
self, http_request)
if not x_self_id:
raise HTTPException(response and response.status or 401,
response.body)
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass("http", x_self_id)
bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data_dict))
return Response("", 204)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body,
response and response.status or 200)
@overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
ws = WebSocket(websocket.scope.get("http_version",
"1.1"), websocket.url.scheme,
websocket.url.path, websocket.scope["query_string"],
dict(websocket.headers), websocket)
if adapter not in self._adapters:
logger.warning(
@ -225,11 +226,9 @@ class Driver(ReverseDriver):
# Create Bot Object
BotClass = self._adapters[adapter]
headers = dict(websocket.headers)
try:
x_self_id = await BotClass.check_permission(self, "websocket",
headers, None)
except RequestDenied:
x_self_id, _ = await BotClass.check_permission(self, ws)
if not x_self_id:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
@ -240,7 +239,7 @@ class Driver(ReverseDriver):
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass("websocket", x_self_id, websocket=ws)
bot = BotClass(x_self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
@ -251,54 +250,51 @@ class Driver(ReverseDriver):
try:
while not ws.closed:
data = await ws.receive()
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
if not data:
continue
asyncio.create_task(bot.handle_message(data))
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket):
def __init__(self, websocket: FastAPIWebSocket):
super().__init__(websocket)
self._closed = False
websocket: FastAPIWebSocket = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
return (self.websocket.client_state == WebSocketState.DISCONNECTED or
self.websocket.application_state == WebSocketState.DISCONNECTED)
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code)
self._closed = True
@overrides(BaseWebSocket)
async def receive(self) -> Optional[dict]:
data = None
try:
data = await self.websocket.receive_json()
if not isinstance(data, dict):
data = None
raise ValueError
except ValueError:
logger.warning("Received an invalid json message.")
except WebSocketDisconnect:
self._closed = True
logger.error("WebSocket disconnected by peer.")
return data
async def receive(self) -> str:
return await self.websocket.receive_text()
@overrides(BaseWebSocket)
async def send(self, data: dict) -> None:
text = json.dumps(data, cls=DataclassEncoder)
await self.websocket.send({"type": "websocket.send", "text": text})
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})

View File

@ -9,24 +9,22 @@ Quart 驱动适配
"""
import asyncio
from json.decoder import JSONDecodeError
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar
from typing import List, TypeVar, Callable, Coroutine, Optional
import uvicorn
from pydantic import BaseSettings
from nonebot.config import Config as NoneBotConfig
from nonebot.config import Env
from nonebot.drivers import ReverseDriver, WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket
try:
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
from quart import exceptions
from quart import request as _request
from quart import websocket as _websocket
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
except ImportError:
raise ValueError(
'Please install Quart by using `pip install nonebot2[quart]`')
@ -34,6 +32,25 @@ except ImportError:
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
class Config(BaseSettings):
"""
Quart 驱动框架设置
"""
quart_reload_dirs: List[str] = []
"""
:类型:
``List[str]``
:说明:
``debug`` 模式下重载监控文件夹列表默认为 uvicorn 默认值
"""
class Config:
extra = "ignore"
class Driver(ReverseDriver):
"""
Quart 驱动框架
@ -48,18 +65,20 @@ class Driver(ReverseDriver):
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule('/<adapter>/http',
methods=['POST'],
self._server_app.add_url_rule("/<adapter>/http",
methods=["POST"],
view_func=self._handle_http)
self._server_app.add_websocket('/<adapter>/ws',
self._server_app.add_websocket("/<adapter>/ws",
view_func=self._handle_ws_reverse)
@property
@overrides(ReverseDriver)
def type(self) -> str:
"""驱动名称: ``quart``"""
return 'quart'
return "quart"
@property
@overrides(ReverseDriver)
@ -76,17 +95,21 @@ class Driver(ReverseDriver):
@property
@overrides(ReverseDriver)
def logger(self):
"""fastapi 使用的 logger"""
"""Quart 使用的 logger"""
return self._server_app.logger
@overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
"""参考文档: `Startup and Shutdown`_
.. _Startup and Shutdown:
https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html
"""
return self.server_app.before_serving(func) # type: ignore
@overrides(ReverseDriver)
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
"""参考文档: `Startup and Shutdown`_"""
return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver)
@ -121,6 +144,7 @@ class Driver(ReverseDriver):
host=host or str(self.config.host),
port=port or self.config.port,
reload=bool(app) and self.config.debug,
reload_dirs=self.quart_config.quart_reload_dirs or None,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
@ -128,11 +152,7 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
async def _handle_http(self, adapter: str):
request: Request = _request
try:
data: Dict[str, Any] = await request.get_json()
except Exception as e:
raise exceptions.BadRequest()
data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. '
@ -140,25 +160,32 @@ class Driver(ReverseDriver):
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
headers = {k: v for k, v in request.headers.items(lower=True)}
http_request = HTTPRequest(request.http_version, request.scheme,
request.path, request.query_string,
dict(request.headers), request.method, data)
try:
self_id = await BotClass.check_permission(self, 'http', headers,
data)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
self_id, response = await BotClass.check_permission(self, http_request)
if not self_id:
raise exceptions.HTTPException(
response and response.status or 401,
description=(response and response.body or b"").decode(),
name="Request Denied")
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass('http', self_id)
bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response('', 204)
return Response(response and response.body or "",
response and response.status or 200)
@overrides(ReverseDriver)
async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket
ws = WebSocket(websocket.http_version, websocket.scheme,
websocket.path, websocket.query_string,
dict(websocket.headers), websocket)
if adapter not in self._adapters:
logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.'
@ -166,19 +193,23 @@ class Driver(ReverseDriver):
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
headers = {k: v for k, v in websocket.headers.items(lower=True)}
try:
self_id = await BotClass.check_permission(self, 'websocket',
headers, None)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
self_id, response = await BotClass.check_permission(self, ws)
if not self_id:
raise exceptions.HTTPException(
response and response.status or 401,
description=(response and response.body or b"").decode(),
name="Request Denied")
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
ws = WebSocket(websocket)
bot = BotClass('websocket', self_id, websocket=ws)
logger.opt(colors=True).warning(
"There's already a reverse websocket connection, "
f"<y>{adapter.upper()} Bot {self_id}</y> ignored.")
raise exceptions.HTTPException(403,
description="Client already exists",
name="Request Denied")
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{adapter.upper()} "
@ -187,52 +218,51 @@ class Driver(ReverseDriver):
try:
while not ws.closed:
data = await ws.receive()
if data is None:
continue
asyncio.create_task(bot.handle_message(data))
try:
data = await ws.receive()
except asyncio.CancelledError:
logger.warning("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket.")
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, websocket: QuartWebSocket):
super().__init__(websocket)
self._closed = False
@property
@overrides(BaseWebSocket)
def websocket(self) -> QuartWebSocket:
return self._websocket
websocket: QuartWebSocket = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
# FIXME
return False
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self):
self._closed = True
# FIXME
pass
@overrides(BaseWebSocket)
async def receive(self) -> Optional[Dict[str, Any]]:
data: Optional[Dict[str, Any]] = None
try:
data = await self.websocket.receive_json()
except JSONDecodeError:
logger.warning('Received an invalid json message.')
except asyncio.CancelledError:
self._closed = True
logger.warning('WebSocket disconnected by peer.')
return data
async def receive(self) -> str:
return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: dict):
await self.websocket.send_json(data)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive() # type: ignore
@overrides(BaseWebSocket)
async def send(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)

View File

@ -115,29 +115,6 @@ class StopPropagation(NoneBotException):
pass
class RequestDenied(NoneBotException):
"""
:说明:
Bot 连接请求不合法
:参数:
* ``status_code: int``: HTTP 状态码
* ``reason: str``: 拒绝原因
"""
def __init__(self, status_code: int, reason: str):
self.status_code = status_code
self.reason = reason
def __repr__(self):
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
def __str__(self):
return self.__repr__()
class AdapterException(NoneBotException):
"""
:说明:

View File

@ -3,14 +3,15 @@ import sys
import hmac
import json
import asyncio
from typing import Any, Dict, Union, Optional, TYPE_CHECKING
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.utils import DataclassEncoder
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape
from .config import Config as CQHTTPConfig
@ -20,7 +21,6 @@ from .exception import NetworkError, ApiNotAvailable, ActionFailed
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
@ -28,7 +28,7 @@ def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
return None
scheme, _, param = access_token.partition(" ")
if scheme.lower() not in ["bearer", "token"]:
raise RequestDenied(401, "Not authenticated")
return None
return param
@ -225,14 +225,6 @@ class Bot(BaseBot):
"""
cqhttp_config: CQHTTPConfig
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional["WebSocket"] = None):
super().__init__(connection_type, self_id, websocket=websocket)
@property
@overrides(BaseBot)
def type(self) -> str:
@ -242,84 +234,84 @@ class Bot(BaseBot):
return "cqhttp"
@classmethod
def register(cls, driver: "Driver", config: "Config"):
def register(cls, driver: Driver, config: "Config"):
super().register(driver, config)
cls.cqhttp_config = CQHTTPConfig(**config.dict())
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[bytes]) -> str:
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
"""
:说明:
CQHTTP (OneBot) 协议鉴权参考 `鉴权 <https://github.com/howmanybots/onebot/blob/master/v11/specs/communication/authorization.md>`_
"""
x_self_id = headers.get("x-self-id")
x_signature = headers.get("x-signature")
token = get_auth_bearer(headers.get("authorization"))
x_self_id = request.headers.get("x-self-id")
x_signature = request.headers.get("x-signature")
token = get_auth_bearer(request.headers.get("authorization"))
cqhttp_config = CQHTTPConfig(**driver.config.dict())
# 检查连接方式
if connection_type not in ["http", "websocket"]:
log("WARNING", "Unsupported connection type")
raise RequestDenied(405, "Unsupported connection type")
# 检查self_id
if not x_self_id:
log("WARNING", "Missing X-Self-ID Header")
raise RequestDenied(400, "Missing X-Self-ID Header")
return None, HTTPResponse(400, b"Missing X-Self-ID Header")
# 检查签名
secret = cqhttp_config.secret
if secret and connection_type == "http":
if secret and isinstance(request, HTTPRequest):
if not x_signature:
log("WARNING", "Missing Signature Header")
raise RequestDenied(401, "Missing Signature")
sig = hmac.new(secret.encode("utf-8"), body, "sha1").hexdigest()
return None, HTTPResponse(401, b"Missing Signature")
sig = hmac.new(secret.encode("utf-8"), request.body,
"sha1").hexdigest()
if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
return None, HTTPResponse(403, b"Signature is invalid")
access_token = cqhttp_config.access_token
if access_token and access_token != token and connection_type == "websocket":
if access_token and access_token != token and isinstance(
request, WebSocket):
log(
"WARNING", "Authorization Header is invalid"
if token else "Missing Authorization Header")
raise RequestDenied(
403, "Authorization Header is invalid"
if token else "Missing Authorization Header")
return str(x_self_id)
return None, HTTPResponse(
403, b"Authorization Header is invalid"
if token else b"Missing Authorization Header")
return str(x_self_id), HTTPResponse(204, b'')
@overrides(BaseBot)
async def handle_message(self, message: dict):
async def handle_message(self, message: bytes):
"""
:说明:
调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_
"""
if not message:
data = json.loads(message)
if not data:
return
if "post_type" not in message:
ResultStore.add_result(message)
if "post_type" not in data:
ResultStore.add_result(data)
return
try:
post_type = message['post_type']
detail_type = message.get(f"{post_type}_type")
post_type = data['post_type']
detail_type = data.get(f"{post_type}_type")
detail_type = f".{detail_type}" if detail_type else ""
sub_type = message.get("sub_type")
sub_type = data.get("sub_type")
sub_type = f".{sub_type}" if sub_type else ""
models = get_event_model(post_type + detail_type + sub_type)
for model in models:
try:
event = model.parse_obj(message)
event = model.parse_obj(data)
break
except Exception as e:
log("DEBUG", "Event Parser Error", e)
else:
event = Event.parse_obj(message)
event = Event.parse_obj(data)
# Check whether user is calling me
await _check_reply(self, event)
@ -329,25 +321,28 @@ class Bot(BaseBot):
await handle_event(self, event)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>"
f"<r><bg #f8bbd0>Failed to handle event. Raw: {data}</bg #f8bbd0></r>"
)
@overrides(BaseBot)
async def _call_api(self, api: str, **data) -> Any:
log("DEBUG", f"Calling API <y>{api}</y>")
if self.connection_type == "websocket":
if isinstance(self.request, WebSocket):
seq = ResultStore.get_seq()
await self.websocket.send({
"action": api,
"params": data,
"echo": {
"seq": seq
}
})
json_data = json.dumps(
{
"action": api,
"params": data,
"echo": {
"seq": seq
}
},
cls=DataclassEncoder)
await self.request.send(json_data)
return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout))
elif self.connection_type == "http":
elif isinstance(self.request, HTTPRequest):
api_root = self.config.api_root.get(self.self_id)
if not api_root:
raise ApiNotAvailable
@ -431,7 +426,7 @@ class Bot(BaseBot):
message, str) else message
msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and getattr(event, "user_id", None)
at_sender = at_sender and bool(getattr(event, "user_id", None))
params = {}
if getattr(event, "user_id", None):
@ -449,8 +444,7 @@ class Bot(BaseBot):
raise ValueError("Cannot guess message type to reply!")
if at_sender and params["message_type"] != "private":
params["message"] = MessageSegment.at(params["user_id"]) + \
MessageSegment.text(" ") + msg
params["message"] = MessageSegment.at(params["user_id"]) + " " + msg
else:
params["message"] = msg
return await self.send_msg(**params)

View File

@ -1,16 +1,17 @@
import json
import urllib.parse
from datetime import datetime
import time
from typing import Any, Union, Optional, TYPE_CHECKING
from datetime import datetime
from typing import Any, Tuple, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse
from .utils import calc_hmac_base64, log
from .config import Config as DingConfig
@ -20,7 +21,6 @@ from .event import MessageEvent, PrivateMessageEvent, GroupMessageEvent, Convers
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver
SEND = "send"
@ -31,10 +31,6 @@ class Bot(BaseBot):
"""
ding_config: DingConfig
def __init__(self, connection_type: str, self_id: str, **kwargs):
super().__init__(connection_type, self_id, **kwargs)
@property
def type(self) -> str:
"""
@ -43,57 +39,61 @@ class Bot(BaseBot):
return "ding"
@classmethod
def register(cls, driver: "Driver", config: "Config"):
def register(cls, driver: Driver, config: "Config"):
super().register(driver, config)
cls.ding_config = DingConfig(**config.dict())
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[bytes]) -> str:
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
"""
:说明:
钉钉协议鉴权参考 `鉴权 <https://ding-doc.dingtalk.com/doc#/serverapi2/elzz1p>`_
"""
timestamp = headers.get("timestamp")
sign = headers.get("sign")
timestamp = request.headers.get("timestamp")
sign = request.headers.get("sign")
# 检查连接方式
if connection_type not in ["http"]:
raise RequestDenied(
405, "Unsupported connection type, available type: `http`")
if not isinstance(request, HTTPRequest):
return None, HTTPResponse(
405, b"Unsupported connection type, available type: `http`")
# 检查 timestamp
if not timestamp:
raise RequestDenied(400, "Missing `timestamp` Header")
return None, HTTPResponse(400, b"Missing `timestamp` Header")
# 检查 sign
secret = cls.ding_config.secret
if secret:
if not sign:
log("WARNING", "Missing Signature Header")
raise RequestDenied(400, "Missing `sign` Header")
return None, HTTPResponse(400, b"Missing `sign` Header")
sign_base64 = calc_hmac_base64(str(timestamp), secret)
if sign != sign_base64.decode('utf-8'):
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
return None, HTTPResponse(403, b"Signature is invalid")
else:
log("WARNING", "Ding signature check ignored!")
return json.loads(body.decode())["chatbotUserId"]
return (json.loads(request.body.decode())["chatbotUserId"],
HTTPResponse(204, b''))
@overrides(BaseBot)
async def handle_message(self, message: dict):
if not message:
async def handle_message(self, message: bytes):
data = json.loads(message)
if not data:
return
# 判断消息类型,生成不同的 Event
try:
conversation_type = message["conversationType"]
conversation_type = data["conversationType"]
if conversation_type == ConversationType.private:
event = PrivateMessageEvent.parse_obj(message)
event = PrivateMessageEvent.parse_obj(data)
elif conversation_type == ConversationType.group:
event = GroupMessageEvent.parse_obj(message)
event = GroupMessageEvent.parse_obj(data)
else:
raise ValueError("Unsupported conversation type")
except Exception as e:
@ -104,7 +104,7 @@ class Bot(BaseBot):
await handle_event(self, event)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>"
f"<r><bg #f8bbd0>Failed to handle event. Raw: {data}</bg #f8bbd0></r>"
)
return

View File

@ -1,8 +1,8 @@
"""
r"""
Mirai-API-HTTP 协议适配
============================
协议详情请看: `mirai-api-http 文档`_
协议详情请看: `mirai-api-http 文档`_
\:\:\: tip
该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试

View File

@ -5,11 +5,11 @@ from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.typing import overrides
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import ApiNotAvailable
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
@ -140,7 +140,7 @@ class SessionManager:
class Bot(BaseBot):
"""
r"""
mirai-api-http 协议 Bot 适配
\:\:\: warning
@ -151,14 +151,6 @@ class Bot(BaseBot):
"""
@overrides(BaseBot)
def __init__(self,
connection_type: str,
self_id: str,
*,
websocket: Optional[WebSocket] = None):
super().__init__(connection_type, self_id, websocket=websocket)
@property
@overrides(BaseBot)
def type(self) -> str:
@ -166,7 +158,8 @@ class Bot(BaseBot):
@property
def alive(self) -> bool:
return not self.websocket.closed
assert isinstance(self.request, WebSocket)
return not self.request.closed
@property
def api(self) -> SessionManager:
@ -177,27 +170,26 @@ class Bot(BaseBot):
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[bytes]) -> str:
if connection_type == 'ws':
raise RequestDenied(
status_code=501,
reason='Websocket connection is not implemented')
self_id: Optional[str] = headers.get('bot')
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
if isinstance(request, WebSocket):
return None, HTTPResponse(
501, b'Websocket connection is not implemented')
self_id: Optional[str] = request.headers.get('bot')
if self_id is None:
raise RequestDenied(status_code=400,
reason='Header `Bot` is required.')
return None, HTTPResponse(400, b'Header `Bot` is required.')
self_id = str(self_id).strip()
await SessionManager.new(
int(self_id),
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, #type: ignore
auth_key=cls.mirai_config.auth_key) # type: ignore
return self_id
return self_id, HTTPResponse(204, b'')
@classmethod
@overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config"):
def register(cls, driver: Driver, config: "Config"):
cls.mirai_config = MiraiConfig(**config.dict())
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
@ -224,7 +216,7 @@ class Bot(BaseBot):
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
"""
r"""
\:\:\: danger
由于Mirai的HTTP API特殊性, 该API暂时无法实现
\:\:\:

View File

@ -1,18 +1,16 @@
import asyncio
import json
import asyncio
from dataclasses import dataclass
from ipaddress import IPv4Address
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
TypeVar)
from typing import Any, Set, Dict, Tuple, TypeVar, Optional, Callable, Coroutine
import httpx
import websockets
from nonebot.config import Config
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse, WebSocket as BaseWebSocket
from .bot import SessionManager, Bot
@ -21,7 +19,9 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
@dataclass
class WebSocket(BaseWebSocket):
websocket: websockets.WebSocketClientProtocol = None # type: ignore
@classmethod
async def new(cls, *, host: IPv4Address, port: int,
@ -37,24 +37,26 @@ class WebSocket(BaseWebSocket):
self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket)
@property
@overrides(BaseWebSocket)
def websocket(self) -> websockets.WebSocketClientProtocol:
return self._websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return self.websocket.closed
@overrides(BaseWebSocket)
async def send(self, data: Dict[str, Any]):
return await self.websocket.send(json.dumps(data))
async def send(self, data: str):
return await self.websocket.send(data)
@overrides(BaseWebSocket)
async def receive(self) -> Dict[str, Any]:
received = await self.websocket.recv()
return json.loads(received)
async def send_bytes(self, data: str):
return await self.websocket.send(data)
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.recv() # type: ignore
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.recv() # type: ignore
async def _dispatcher(self):
while not self.closed:
@ -93,11 +95,6 @@ class WebsocketBot(Bot):
mirai-api-http 正向 Websocket 协议 Bot 适配
"""
@overrides(Bot)
def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket)
@property
@overrides(Bot)
def type(self) -> str:
@ -105,7 +102,8 @@ class WebsocketBot(Bot):
@property
def alive(self) -> bool:
return not self.websocket.closed
assert isinstance(self.request, WebSocket)
return not self.request.closed
@property
def api(self) -> SessionManager:
@ -115,16 +113,14 @@ class WebsocketBot(Bot):
@classmethod
@overrides(Bot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict,
body: Optional[bytes]) -> NoReturn:
raise RequestDenied(
status_code=501,
reason=f'Connection {connection_type} not implented')
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[None, HTTPResponse]:
return None, HTTPResponse(501, b'Connection not implented')
@classmethod
@overrides(Bot)
def register(cls, driver: "Driver", config: "Config", qq: int):
def register(cls, driver: Driver, config: "Config", qq: int):
"""
:说明: