From b2a2234d5cd0c445628b294be4fc6a234f411628 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Wed, 11 Nov 2020 15:14:29 +0800 Subject: [PATCH] :art: change permission check from driver into adapter #46 --- nonebot/adapters/__init__.py | 9 ++- nonebot/adapters/cqhttp/__init__.py | 59 ++++++++++++++-- nonebot/adapters/cqhttp/__init__.pyi | 11 +++ nonebot/drivers/fastapi.py | 102 +++++++++------------------ nonebot/exception.py | 25 ++++++- 5 files changed, 130 insertions(+), 76 deletions(-) diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 1f34d01e..ca2ef828 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, field from nonebot.config import Config from nonebot.typing import Driver, Message, WebSocket -from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable +from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable class BaseBot(abc.ABC): @@ -55,6 +55,13 @@ class BaseBot(abc.ABC): """Adapter 类型""" raise NotImplementedError + @classmethod + @abc.abstractmethod + async def check_permission(cls, driver: Driver, connection_type: str, + headers: dict, + body: Optional[dict]) -> Union[str, NoReturn]: + raise NotImplementedError + @abc.abstractmethod async def handle_message(self, message: dict): """ diff --git a/nonebot/adapters/cqhttp/__init__.py b/nonebot/adapters/cqhttp/__init__.py index 8a895f17..aa47e9cf 100644 --- a/nonebot/adapters/cqhttp/__init__.py +++ b/nonebot/adapters/cqhttp/__init__.py @@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配 import re import sys +import hmac +import json import asyncio import httpx @@ -19,10 +21,10 @@ import httpx from nonebot.log import logger from nonebot.config import Config from nonebot.message import handle_event -from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional -from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable from nonebot.typing import overrides, Driver, WebSocket, NoReturn +from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment +from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable def log(level: str, message: str): @@ -39,6 +41,16 @@ def log(level: str, message: str): return logger.opt(colors=True).log(level, "CQHTTP | " + message) +def get_auth_bearer( + access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]: + if not access_token: + return None + scheme, _, param = access_token.partition(" ") + if scheme.lower() not in ["bearer", "token"]: + raise RequestDenied(401, "Not authenticated") + return param + + def escape(s: str, *, escape_comma: bool = True) -> str: """ :说明: @@ -264,8 +276,6 @@ class Bot(BaseBot): self_id: str, *, websocket: Optional[WebSocket] = None): - if connection_type not in ["http", "websocket"]: - raise ValueError("Unsupported connection type") super().__init__(driver, connection_type, @@ -281,6 +291,47 @@ class Bot(BaseBot): """ return "cqhttp" + @classmethod + @overrides(BaseBot) + async def check_permission(cls, driver: Driver, connection_type: str, + headers: dict, + body: Optional[dict]) -> Union[str, NoReturn]: + x_self_id = headers.get("x-self-id") + x_signature = headers.get("x-signature") + access_token = get_auth_bearer(headers.get("authorization")) + + # 检查连接方式 + if connection_type not in ["http", "websocket"]: + log("WARNING", "Unsupported connection type") + raise RequestDenied(405, "Unsupported connection type") + + # 检查self_id + if not x_self_id: + log("WARNING", "Missing X-Self-ID Header") + raise RequestDenied(400, "Missing X-Self-ID Header") + + # 检查签名 + secret = driver.config.secret + if secret and connection_type == "http": + if not x_signature: + log("WARNING", "Missing Signature Header") + raise RequestDenied(401, "Missing Signature") + sig = hmac.new(secret.encode("utf-8"), + json.dumps(body).encode(), "sha1").hexdigest() + if x_signature != "sha1=" + sig: + log("WARNING", "Signature Header is invalid") + raise RequestDenied(403, "Signature is invalid") + + access_token = driver.config.access_token + if access_token and access_token != access_token: + log( + "WARNING", "Authorization Header is invalid" + if access_token else "Missing Authorization Header") + raise RequestDenied( + 403, "Authorization Header is invalid" + if access_token else "Missing Authorization Header") + return str(x_self_id) + @overrides(BaseBot) async def handle_message(self, message: dict): """ diff --git a/nonebot/adapters/cqhttp/__init__.pyi b/nonebot/adapters/cqhttp/__init__.pyi index 7920bfe3..e5398588 100644 --- a/nonebot/adapters/cqhttp/__init__.pyi +++ b/nonebot/adapters/cqhttp/__init__.pyi @@ -9,6 +9,11 @@ def log(level: str, message: str): ... +def get_auth_bearer( + access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]: + ... + + def escape(s: str, *, escape_comma: bool = ...) -> str: ... @@ -69,6 +74,12 @@ class Bot(BaseBot): def type(self) -> str: ... + @classmethod + async def check_permission(cls, driver: Driver, connection_type: str, + headers: dict, + body: Optional[dict]) -> Union[str, NoReturn]: + ... + async def handle_message(self, message: dict): ... diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index bb0009b4..6a5585dd 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -15,12 +15,13 @@ import logging import uvicorn from fastapi.responses import Response -from fastapi import Body, status, Header, FastAPI, Depends, HTTPException +from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket from nonebot.log import logger from nonebot.config import Env, Config from nonebot.utils import DataclassEncoder +from nonebot.exception import RequestDenied from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.typing import Optional, Callable, overrides @@ -127,97 +128,58 @@ class Driver(BaseDriver): @overrides(BaseDriver) async def _handle_http(self, adapter: str, - data: dict = Body(...), - x_self_id: Optional[str] = Header(None), - x_signature: Optional[str] = Header(None), - auth: Optional[str] = Depends(get_auth_bearer)): - # 检查self_id - if not x_self_id: - logger.warning("Missing X-Self-ID Header") - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - detail="Missing X-Self-ID Header") - - # 检查签名 - secret = self.config.secret - if secret: - if not x_signature: - logger.warning("Missing Signature Header") - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, - detail="Missing Signature") - sig = hmac.new(secret.encode("utf-8"), - json.dumps(data).encode(), "sha1").hexdigest() - if x_signature != "sha1=" + sig: - logger.warning("Signature Header is invalid") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, - detail="Signature is invalid") - - access_token = self.config.access_token - if access_token and access_token != auth: - logger.warning("Authorization Header is invalid" - if auth else "Missing Authorization Header") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, - detail="Authorization Header is invalid" - if auth else "Missing Authorization Header") - + request: Request, + data: dict = Body(...)): if not isinstance(data, dict): logger.warning("Data received is invalid") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + if adapter not in self._adapters: + logger.warning("Unknown adapter") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, + detail="adapter not found") + + # 创建 Bot 对象 + BotClass = self._adapters[adapter] + headers = dict(request.headers) + try: + x_self_id = await BotClass.check_permission(self, "http", headers, + data) + except RequestDenied as e: + raise HTTPException(status_code=e.status_code, + detail=e.reason) from None + if x_self_id in self._clients: logger.warning("There's already a reverse websocket api connection," "so the event may be handled twice.") - # 创建 Bot 对象 - if adapter in self._adapters: - BotClass = self._adapters[adapter] - bot = BotClass(self, "http", self.config, x_self_id) - else: - logger.warning("Unknown adapter") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail="adapter not found") + bot = BotClass(self, "http", self.config, x_self_id) asyncio.create_task(bot.handle_message(data)) return Response("", 204) @overrides(BaseDriver) - async def _handle_ws_reverse( - self, - adapter: str, - websocket: FastAPIWebSocket, - x_self_id: str = Header(None), - auth: Optional[str] = Depends(get_auth_bearer)): + async def _handle_ws_reverse(self, adapter: str, + websocket: FastAPIWebSocket): ws = WebSocket(websocket) - access_token = self.config.access_token - if access_token and access_token != auth: - logger.warning("Authorization Header is invalid" - if auth else "Missing Authorization Header") - await ws.close(code=status.WS_1008_POLICY_VIOLATION) - return - - if not x_self_id: - logger.warning(f"Missing X-Self-ID Header") - await ws.close(code=status.WS_1008_POLICY_VIOLATION) - return - - if x_self_id in self._clients: - logger.warning(f"Connection Conflict: self_id {x_self_id}") + if adapter not in self._adapters: + logger.warning("Unknown adapter") await ws.close(code=status.WS_1008_POLICY_VIOLATION) return # Create Bot Object - if adapter in self._adapters: - BotClass = self._adapters[adapter] - bot = BotClass(self, - "websocket", - self.config, - x_self_id, - websocket=ws) - else: - logger.warning("Unknown adapter") + BotClass = self._adapters[adapter] + headers = dict(websocket.headers) + try: + x_self_id = await BotClass.check_permission(self, "websocket", + headers, None) + except RequestDenied: await ws.close(code=status.WS_1008_POLICY_VIOLATION) return + bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) + await ws.accept() self._clients[x_self_id] = bot logger.opt(colors=True).info( diff --git a/nonebot/exception.py b/nonebot/exception.py index 45062635..4d862e41 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -105,6 +105,29 @@ class StopPropagation(Exception): pass +class RequestDenied(Exception): + """ + :说明: + + Bot 连接请求不合法。 + + :参数: + + * ``status_code: int``: HTTP 状态码 + * ``reason: str``: 拒绝原因 + """ + + def __init__(self, status_code: int, reason: str): + self.status_code = status_code + self.reason = reason + + def __repr__(self): + return f"" + + def __str__(self): + return self.__repr__() + + class ApiNotAvailable(Exception): """ :说明: @@ -131,7 +154,7 @@ class ActionFailed(Exception): :参数: - * ``retcode``: 错误代码 + * ``retcode: Optional[int]``: 错误代码 """ def __init__(self, retcode: Optional[int]):