🎨 change permission check from driver into adapter #46

This commit is contained in:
yanyongyu 2020-11-11 15:14:29 +08:00
parent 1f1f9cd7e6
commit b2a2234d5c
5 changed files with 130 additions and 76 deletions

View File

@ -11,7 +11,7 @@ from dataclasses import dataclass, field
from nonebot.config import Config from nonebot.config import Config
from nonebot.typing import Driver, Message, WebSocket from nonebot.typing import Driver, Message, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable
class BaseBot(abc.ABC): class BaseBot(abc.ABC):
@ -55,6 +55,13 @@ class BaseBot(abc.ABC):
"""Adapter 类型""" """Adapter 类型"""
raise NotImplementedError raise NotImplementedError
@classmethod
@abc.abstractmethod
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
""" """

View File

@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配
import re import re
import sys import sys
import hmac
import json
import asyncio import asyncio
import httpx import httpx
@ -19,10 +21,10 @@ import httpx
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable
def log(level: str, message: str): def log(level: str, message: str):
@ -39,6 +41,16 @@ def log(level: str, message: str):
return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message) return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message)
def get_auth_bearer(
access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]:
if not access_token:
return None
scheme, _, param = access_token.partition(" ")
if scheme.lower() not in ["bearer", "token"]:
raise RequestDenied(401, "Not authenticated")
return param
def escape(s: str, *, escape_comma: bool = True) -> str: def escape(s: str, *, escape_comma: bool = True) -> str:
""" """
:说明: :说明:
@ -264,8 +276,6 @@ class Bot(BaseBot):
self_id: str, self_id: str,
*, *,
websocket: Optional[WebSocket] = None): websocket: Optional[WebSocket] = None):
if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")
super().__init__(driver, super().__init__(driver,
connection_type, connection_type,
@ -281,6 +291,47 @@ class Bot(BaseBot):
""" """
return "cqhttp" return "cqhttp"
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
x_self_id = headers.get("x-self-id")
x_signature = headers.get("x-signature")
access_token = get_auth_bearer(headers.get("authorization"))
# 检查连接方式
if connection_type not in ["http", "websocket"]:
log("WARNING", "Unsupported connection type")
raise RequestDenied(405, "Unsupported connection type")
# 检查self_id
if not x_self_id:
log("WARNING", "Missing X-Self-ID Header")
raise RequestDenied(400, "Missing X-Self-ID Header")
# 检查签名
secret = driver.config.secret
if secret and connection_type == "http":
if not x_signature:
log("WARNING", "Missing Signature Header")
raise RequestDenied(401, "Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(body).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
access_token = driver.config.access_token
if access_token and access_token != access_token:
log(
"WARNING", "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
raise RequestDenied(
403, "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
return str(x_self_id)
@overrides(BaseBot) @overrides(BaseBot)
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
""" """

View File

@ -9,6 +9,11 @@ def log(level: str, message: str):
... ...
def get_auth_bearer(
access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]:
...
def escape(s: str, *, escape_comma: bool = ...) -> str: def escape(s: str, *, escape_comma: bool = ...) -> str:
... ...
@ -69,6 +74,12 @@ class Bot(BaseBot):
def type(self) -> str: def type(self) -> str:
... ...
@classmethod
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
...
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
... ...

View File

@ -15,12 +15,13 @@ import logging
import uvicorn import uvicorn
from fastapi.responses import Response from fastapi.responses import Response
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Optional, Callable, overrides from nonebot.typing import Optional, Callable, overrides
@ -127,97 +128,58 @@ class Driver(BaseDriver):
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_http(self, async def _handle_http(self,
adapter: str, adapter: str,
data: dict = Body(...), request: Request,
x_self_id: Optional[str] = Header(None), data: dict = Body(...)):
x_signature: Optional[str] = Header(None),
auth: Optional[str] = Depends(get_auth_bearer)):
# 检查self_id
if not x_self_id:
logger.warning("Missing X-Self-ID Header")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Self-ID Header")
# 检查签名
secret = self.config.secret
if secret:
if not x_signature:
logger.warning("Missing Signature Header")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(data).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
logger.warning("Signature Header is invalid")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Signature is invalid")
access_token = self.config.access_token
if access_token and access_token != auth:
logger.warning("Authorization Header is invalid"
if auth else "Missing Authorization Header")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Authorization Header is invalid"
if auth else "Missing Authorization Header")
if not isinstance(data, dict): if not isinstance(data, dict):
logger.warning("Data received is invalid") logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
# 创建 Bot 对象
BotClass = self._adapters[adapter]
headers = dict(request.headers)
try:
x_self_id = await BotClass.check_permission(self, "http", headers,
data)
except RequestDenied as e:
raise HTTPException(status_code=e.status_code,
detail=e.reason) from None
if x_self_id in self._clients: if x_self_id in self._clients:
logger.warning("There's already a reverse websocket api connection," logger.warning("There's already a reverse websocket api connection,"
"so the event may be handled twice.") "so the event may be handled twice.")
# 创建 Bot 对象
if adapter in self._adapters:
BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id) bot = BotClass(self, "http", self.config, x_self_id)
else:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
return Response("", 204) return Response("", 204)
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_ws_reverse( async def _handle_ws_reverse(self, adapter: str,
self, websocket: FastAPIWebSocket):
adapter: str,
websocket: FastAPIWebSocket,
x_self_id: str = Header(None),
auth: Optional[str] = Depends(get_auth_bearer)):
ws = WebSocket(websocket) ws = WebSocket(websocket)
access_token = self.config.access_token if adapter not in self._adapters:
if access_token and access_token != auth: logger.warning("Unknown adapter")
logger.warning("Authorization Header is invalid"
if auth else "Missing Authorization Header")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if not x_self_id:
logger.warning(f"Missing X-Self-ID Header")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if x_self_id in self._clients:
logger.warning(f"Connection Conflict: self_id {x_self_id}")
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return return
# Create Bot Object # Create Bot Object
if adapter in self._adapters:
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
bot = BotClass(self, headers = dict(websocket.headers)
"websocket", try:
self.config, x_self_id = await BotClass.check_permission(self, "websocket",
x_self_id, headers, None)
websocket=ws) except RequestDenied:
else:
logger.warning("Unknown adapter")
await ws.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return return
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
await ws.accept() await ws.accept()
self._clients[x_self_id] = bot self._clients[x_self_id] = bot
logger.opt(colors=True).info( logger.opt(colors=True).info(

View File

@ -105,6 +105,29 @@ class StopPropagation(Exception):
pass pass
class RequestDenied(Exception):
"""
:说明:
Bot 连接请求不合法
:参数:
* ``status_code: int``: HTTP 状态码
* ``reason: str``: 拒绝原因
"""
def __init__(self, status_code: int, reason: str):
self.status_code = status_code
self.reason = reason
def __repr__(self):
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
def __str__(self):
return self.__repr__()
class ApiNotAvailable(Exception): class ApiNotAvailable(Exception):
""" """
:说明: :说明:
@ -131,7 +154,7 @@ class ActionFailed(Exception):
:参数: :参数:
* ``retcode``: 错误代码 * ``retcode: Optional[int]``: 错误代码
""" """
def __init__(self, retcode: Optional[int]): def __init__(self, retcode: Optional[int]):