mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🎨 change permission check from driver into adapter #46
This commit is contained in:
parent
1f1f9cd7e6
commit
b2a2234d5c
@ -11,7 +11,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.typing import Driver, Message, WebSocket
|
||||
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
|
||||
from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable
|
||||
|
||||
|
||||
class BaseBot(abc.ABC):
|
||||
@ -55,6 +55,13 @@ class BaseBot(abc.ABC):
|
||||
"""Adapter 类型"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def check_permission(cls, driver: Driver, connection_type: str,
|
||||
headers: dict,
|
||||
body: Optional[dict]) -> Union[str, NoReturn]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_message(self, message: dict):
|
||||
"""
|
||||
|
@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配
|
||||
|
||||
import re
|
||||
import sys
|
||||
import hmac
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
@ -19,10 +21,10 @@ import httpx
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
|
||||
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||
from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable
|
||||
|
||||
|
||||
def log(level: str, message: str):
|
||||
@ -39,6 +41,16 @@ def log(level: str, message: str):
|
||||
return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message)
|
||||
|
||||
|
||||
def get_auth_bearer(
|
||||
access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]:
|
||||
if not access_token:
|
||||
return None
|
||||
scheme, _, param = access_token.partition(" ")
|
||||
if scheme.lower() not in ["bearer", "token"]:
|
||||
raise RequestDenied(401, "Not authenticated")
|
||||
return param
|
||||
|
||||
|
||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||
"""
|
||||
:说明:
|
||||
@ -264,8 +276,6 @@ class Bot(BaseBot):
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: Optional[WebSocket] = None):
|
||||
if connection_type not in ["http", "websocket"]:
|
||||
raise ValueError("Unsupported connection type")
|
||||
|
||||
super().__init__(driver,
|
||||
connection_type,
|
||||
@ -281,6 +291,47 @@ class Bot(BaseBot):
|
||||
"""
|
||||
return "cqhttp"
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(cls, driver: Driver, connection_type: str,
|
||||
headers: dict,
|
||||
body: Optional[dict]) -> Union[str, NoReturn]:
|
||||
x_self_id = headers.get("x-self-id")
|
||||
x_signature = headers.get("x-signature")
|
||||
access_token = get_auth_bearer(headers.get("authorization"))
|
||||
|
||||
# 检查连接方式
|
||||
if connection_type not in ["http", "websocket"]:
|
||||
log("WARNING", "Unsupported connection type")
|
||||
raise RequestDenied(405, "Unsupported connection type")
|
||||
|
||||
# 检查self_id
|
||||
if not x_self_id:
|
||||
log("WARNING", "Missing X-Self-ID Header")
|
||||
raise RequestDenied(400, "Missing X-Self-ID Header")
|
||||
|
||||
# 检查签名
|
||||
secret = driver.config.secret
|
||||
if secret and connection_type == "http":
|
||||
if not x_signature:
|
||||
log("WARNING", "Missing Signature Header")
|
||||
raise RequestDenied(401, "Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"),
|
||||
json.dumps(body).encode(), "sha1").hexdigest()
|
||||
if x_signature != "sha1=" + sig:
|
||||
log("WARNING", "Signature Header is invalid")
|
||||
raise RequestDenied(403, "Signature is invalid")
|
||||
|
||||
access_token = driver.config.access_token
|
||||
if access_token and access_token != access_token:
|
||||
log(
|
||||
"WARNING", "Authorization Header is invalid"
|
||||
if access_token else "Missing Authorization Header")
|
||||
raise RequestDenied(
|
||||
403, "Authorization Header is invalid"
|
||||
if access_token else "Missing Authorization Header")
|
||||
return str(x_self_id)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def handle_message(self, message: dict):
|
||||
"""
|
||||
|
@ -9,6 +9,11 @@ def log(level: str, message: str):
|
||||
...
|
||||
|
||||
|
||||
def get_auth_bearer(
|
||||
access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]:
|
||||
...
|
||||
|
||||
|
||||
def escape(s: str, *, escape_comma: bool = ...) -> str:
|
||||
...
|
||||
|
||||
@ -69,6 +74,12 @@ class Bot(BaseBot):
|
||||
def type(self) -> str:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def check_permission(cls, driver: Driver, connection_type: str,
|
||||
headers: dict,
|
||||
body: Optional[dict]) -> Union[str, NoReturn]:
|
||||
...
|
||||
|
||||
async def handle_message(self, message: dict):
|
||||
...
|
||||
|
||||
|
@ -15,12 +15,13 @@ import logging
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import Response
|
||||
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
|
||||
from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
|
||||
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||
from nonebot.typing import Optional, Callable, overrides
|
||||
|
||||
@ -127,97 +128,58 @@ class Driver(BaseDriver):
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_http(self,
|
||||
adapter: str,
|
||||
data: dict = Body(...),
|
||||
x_self_id: Optional[str] = Header(None),
|
||||
x_signature: Optional[str] = Header(None),
|
||||
auth: Optional[str] = Depends(get_auth_bearer)):
|
||||
# 检查self_id
|
||||
if not x_self_id:
|
||||
logger.warning("Missing X-Self-ID Header")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing X-Self-ID Header")
|
||||
|
||||
# 检查签名
|
||||
secret = self.config.secret
|
||||
if secret:
|
||||
if not x_signature:
|
||||
logger.warning("Missing Signature Header")
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"),
|
||||
json.dumps(data).encode(), "sha1").hexdigest()
|
||||
if x_signature != "sha1=" + sig:
|
||||
logger.warning("Signature Header is invalid")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Signature is invalid")
|
||||
|
||||
access_token = self.config.access_token
|
||||
if access_token and access_token != auth:
|
||||
logger.warning("Authorization Header is invalid"
|
||||
if auth else "Missing Authorization Header")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Authorization Header is invalid"
|
||||
if auth else "Missing Authorization Header")
|
||||
|
||||
request: Request,
|
||||
data: dict = Body(...)):
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("Data received is invalid")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if adapter not in self._adapters:
|
||||
logger.warning("Unknown adapter")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="adapter not found")
|
||||
|
||||
# 创建 Bot 对象
|
||||
BotClass = self._adapters[adapter]
|
||||
headers = dict(request.headers)
|
||||
try:
|
||||
x_self_id = await BotClass.check_permission(self, "http", headers,
|
||||
data)
|
||||
except RequestDenied as e:
|
||||
raise HTTPException(status_code=e.status_code,
|
||||
detail=e.reason) from None
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.warning("There's already a reverse websocket api connection,"
|
||||
"so the event may be handled twice.")
|
||||
|
||||
# 创建 Bot 对象
|
||||
if adapter in self._adapters:
|
||||
BotClass = self._adapters[adapter]
|
||||
bot = BotClass(self, "http", self.config, x_self_id)
|
||||
else:
|
||||
logger.warning("Unknown adapter")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="adapter not found")
|
||||
bot = BotClass(self, "http", self.config, x_self_id)
|
||||
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
return Response("", 204)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_ws_reverse(
|
||||
self,
|
||||
adapter: str,
|
||||
websocket: FastAPIWebSocket,
|
||||
x_self_id: str = Header(None),
|
||||
auth: Optional[str] = Depends(get_auth_bearer)):
|
||||
async def _handle_ws_reverse(self, adapter: str,
|
||||
websocket: FastAPIWebSocket):
|
||||
ws = WebSocket(websocket)
|
||||
|
||||
access_token = self.config.access_token
|
||||
if access_token and access_token != auth:
|
||||
logger.warning("Authorization Header is invalid"
|
||||
if auth else "Missing Authorization Header")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if not x_self_id:
|
||||
logger.warning(f"Missing X-Self-ID Header")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.warning(f"Connection Conflict: self_id {x_self_id}")
|
||||
if adapter not in self._adapters:
|
||||
logger.warning("Unknown adapter")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
# Create Bot Object
|
||||
if adapter in self._adapters:
|
||||
BotClass = self._adapters[adapter]
|
||||
bot = BotClass(self,
|
||||
"websocket",
|
||||
self.config,
|
||||
x_self_id,
|
||||
websocket=ws)
|
||||
else:
|
||||
logger.warning("Unknown adapter")
|
||||
BotClass = self._adapters[adapter]
|
||||
headers = dict(websocket.headers)
|
||||
try:
|
||||
x_self_id = await BotClass.check_permission(self, "websocket",
|
||||
headers, None)
|
||||
except RequestDenied:
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
|
||||
|
||||
await ws.accept()
|
||||
self._clients[x_self_id] = bot
|
||||
logger.opt(colors=True).info(
|
||||
|
@ -105,6 +105,29 @@ class StopPropagation(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestDenied(Exception):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
Bot 连接请求不合法。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``status_code: int``: HTTP 状态码
|
||||
* ``reason: str``: 拒绝原因
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, reason: str):
|
||||
self.status_code = status_code
|
||||
self.reason = reason
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
class ApiNotAvailable(Exception):
|
||||
"""
|
||||
:说明:
|
||||
@ -131,7 +154,7 @@ class ActionFailed(Exception):
|
||||
|
||||
:参数:
|
||||
|
||||
* ``retcode``: 错误代码
|
||||
* ``retcode: Optional[int]``: 错误代码
|
||||
"""
|
||||
|
||||
def __init__(self, retcode: Optional[int]):
|
||||
|
Loading…
Reference in New Issue
Block a user