diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py
index 1f34d01e..ca2ef828 100644
--- a/nonebot/adapters/__init__.py
+++ b/nonebot/adapters/__init__.py
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
from nonebot.config import Config
from nonebot.typing import Driver, Message, WebSocket
-from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
+from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable
class BaseBot(abc.ABC):
@@ -55,6 +55,13 @@ class BaseBot(abc.ABC):
"""Adapter 类型"""
raise NotImplementedError
+ @classmethod
+ @abc.abstractmethod
+ async def check_permission(cls, driver: Driver, connection_type: str,
+ headers: dict,
+ body: Optional[dict]) -> Union[str, NoReturn]:
+ raise NotImplementedError
+
@abc.abstractmethod
async def handle_message(self, message: dict):
"""
diff --git a/nonebot/adapters/cqhttp/__init__.py b/nonebot/adapters/cqhttp/__init__.py
index 8a895f17..aa47e9cf 100644
--- a/nonebot/adapters/cqhttp/__init__.py
+++ b/nonebot/adapters/cqhttp/__init__.py
@@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配
import re
import sys
+import hmac
+import json
import asyncio
import httpx
@@ -19,10 +21,10 @@ import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.message import handle_event
-from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
-from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
+from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
+from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable
def log(level: str, message: str):
@@ -39,6 +41,16 @@ def log(level: str, message: str):
return logger.opt(colors=True).log(level, "CQHTTP | " + message)
+def get_auth_bearer(
+ access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]:
+ if not access_token:
+ return None
+ scheme, _, param = access_token.partition(" ")
+ if scheme.lower() not in ["bearer", "token"]:
+ raise RequestDenied(401, "Not authenticated")
+ return param
+
+
def escape(s: str, *, escape_comma: bool = True) -> str:
"""
:说明:
@@ -264,8 +276,6 @@ class Bot(BaseBot):
self_id: str,
*,
websocket: Optional[WebSocket] = None):
- if connection_type not in ["http", "websocket"]:
- raise ValueError("Unsupported connection type")
super().__init__(driver,
connection_type,
@@ -281,6 +291,47 @@ class Bot(BaseBot):
"""
return "cqhttp"
+ @classmethod
+ @overrides(BaseBot)
+ async def check_permission(cls, driver: Driver, connection_type: str,
+ headers: dict,
+ body: Optional[dict]) -> Union[str, NoReturn]:
+ x_self_id = headers.get("x-self-id")
+ x_signature = headers.get("x-signature")
+ access_token = get_auth_bearer(headers.get("authorization"))
+
+ # 检查连接方式
+ if connection_type not in ["http", "websocket"]:
+ log("WARNING", "Unsupported connection type")
+ raise RequestDenied(405, "Unsupported connection type")
+
+ # 检查self_id
+ if not x_self_id:
+ log("WARNING", "Missing X-Self-ID Header")
+ raise RequestDenied(400, "Missing X-Self-ID Header")
+
+ # 检查签名
+ secret = driver.config.secret
+ if secret and connection_type == "http":
+ if not x_signature:
+ log("WARNING", "Missing Signature Header")
+ raise RequestDenied(401, "Missing Signature")
+ sig = hmac.new(secret.encode("utf-8"),
+ json.dumps(body).encode(), "sha1").hexdigest()
+ if x_signature != "sha1=" + sig:
+ log("WARNING", "Signature Header is invalid")
+ raise RequestDenied(403, "Signature is invalid")
+
+ access_token = driver.config.access_token
+ if access_token and access_token != access_token:
+ log(
+ "WARNING", "Authorization Header is invalid"
+ if access_token else "Missing Authorization Header")
+ raise RequestDenied(
+ 403, "Authorization Header is invalid"
+ if access_token else "Missing Authorization Header")
+ return str(x_self_id)
+
@overrides(BaseBot)
async def handle_message(self, message: dict):
"""
diff --git a/nonebot/adapters/cqhttp/__init__.pyi b/nonebot/adapters/cqhttp/__init__.pyi
index 7920bfe3..e5398588 100644
--- a/nonebot/adapters/cqhttp/__init__.pyi
+++ b/nonebot/adapters/cqhttp/__init__.pyi
@@ -9,6 +9,11 @@ def log(level: str, message: str):
...
+def get_auth_bearer(
+ access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]:
+ ...
+
+
def escape(s: str, *, escape_comma: bool = ...) -> str:
...
@@ -69,6 +74,12 @@ class Bot(BaseBot):
def type(self) -> str:
...
+ @classmethod
+ async def check_permission(cls, driver: Driver, connection_type: str,
+ headers: dict,
+ body: Optional[dict]) -> Union[str, NoReturn]:
+ ...
+
async def handle_message(self, message: dict):
...
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index bb0009b4..6a5585dd 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -15,12 +15,13 @@ import logging
import uvicorn
from fastapi.responses import Response
-from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
+from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder
+from nonebot.exception import RequestDenied
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Optional, Callable, overrides
@@ -127,97 +128,58 @@ class Driver(BaseDriver):
@overrides(BaseDriver)
async def _handle_http(self,
adapter: str,
- data: dict = Body(...),
- x_self_id: Optional[str] = Header(None),
- x_signature: Optional[str] = Header(None),
- auth: Optional[str] = Depends(get_auth_bearer)):
- # 检查self_id
- if not x_self_id:
- logger.warning("Missing X-Self-ID Header")
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
- detail="Missing X-Self-ID Header")
-
- # 检查签名
- secret = self.config.secret
- if secret:
- if not x_signature:
- logger.warning("Missing Signature Header")
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Missing Signature")
- sig = hmac.new(secret.encode("utf-8"),
- json.dumps(data).encode(), "sha1").hexdigest()
- if x_signature != "sha1=" + sig:
- logger.warning("Signature Header is invalid")
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
- detail="Signature is invalid")
-
- access_token = self.config.access_token
- if access_token and access_token != auth:
- logger.warning("Authorization Header is invalid"
- if auth else "Missing Authorization Header")
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
- detail="Authorization Header is invalid"
- if auth else "Missing Authorization Header")
-
+ request: Request,
+ data: dict = Body(...)):
if not isinstance(data, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
+ if adapter not in self._adapters:
+ logger.warning("Unknown adapter")
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
+ detail="adapter not found")
+
+ # 创建 Bot 对象
+ BotClass = self._adapters[adapter]
+ headers = dict(request.headers)
+ try:
+ x_self_id = await BotClass.check_permission(self, "http", headers,
+ data)
+ except RequestDenied as e:
+ raise HTTPException(status_code=e.status_code,
+ detail=e.reason) from None
+
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket api connection,"
"so the event may be handled twice.")
- # 创建 Bot 对象
- if adapter in self._adapters:
- BotClass = self._adapters[adapter]
- bot = BotClass(self, "http", self.config, x_self_id)
- else:
- logger.warning("Unknown adapter")
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
- detail="adapter not found")
+ bot = BotClass(self, "http", self.config, x_self_id)
asyncio.create_task(bot.handle_message(data))
return Response("", 204)
@overrides(BaseDriver)
- async def _handle_ws_reverse(
- self,
- adapter: str,
- websocket: FastAPIWebSocket,
- x_self_id: str = Header(None),
- auth: Optional[str] = Depends(get_auth_bearer)):
+ async def _handle_ws_reverse(self, adapter: str,
+ websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
- access_token = self.config.access_token
- if access_token and access_token != auth:
- logger.warning("Authorization Header is invalid"
- if auth else "Missing Authorization Header")
- await ws.close(code=status.WS_1008_POLICY_VIOLATION)
- return
-
- if not x_self_id:
- logger.warning(f"Missing X-Self-ID Header")
- await ws.close(code=status.WS_1008_POLICY_VIOLATION)
- return
-
- if x_self_id in self._clients:
- logger.warning(f"Connection Conflict: self_id {x_self_id}")
+ if adapter not in self._adapters:
+ logger.warning("Unknown adapter")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Create Bot Object
- if adapter in self._adapters:
- BotClass = self._adapters[adapter]
- bot = BotClass(self,
- "websocket",
- self.config,
- x_self_id,
- websocket=ws)
- else:
- logger.warning("Unknown adapter")
+ BotClass = self._adapters[adapter]
+ headers = dict(websocket.headers)
+ try:
+ x_self_id = await BotClass.check_permission(self, "websocket",
+ headers, None)
+ except RequestDenied:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
+ bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
+
await ws.accept()
self._clients[x_self_id] = bot
logger.opt(colors=True).info(
diff --git a/nonebot/exception.py b/nonebot/exception.py
index 45062635..4d862e41 100644
--- a/nonebot/exception.py
+++ b/nonebot/exception.py
@@ -105,6 +105,29 @@ class StopPropagation(Exception):
pass
+class RequestDenied(Exception):
+ """
+ :说明:
+
+ Bot 连接请求不合法。
+
+ :参数:
+
+ * ``status_code: int``: HTTP 状态码
+ * ``reason: str``: 拒绝原因
+ """
+
+ def __init__(self, status_code: int, reason: str):
+ self.status_code = status_code
+ self.reason = reason
+
+ def __repr__(self):
+ return f""
+
+ def __str__(self):
+ return self.__repr__()
+
+
class ApiNotAvailable(Exception):
"""
:说明:
@@ -131,7 +154,7 @@ class ActionFailed(Exception):
:参数:
- * ``retcode``: 错误代码
+ * ``retcode: Optional[int]``: 错误代码
"""
def __init__(self, retcode: Optional[int]):