🎨 change permission check from driver into adapter #46

This commit is contained in:
yanyongyu 2020-11-11 15:14:29 +08:00
parent 1f1f9cd7e6
commit b2a2234d5c
5 changed files with 130 additions and 76 deletions

View File

@ -11,7 +11,7 @@ from dataclasses import dataclass, field
from nonebot.config import Config
from nonebot.typing import Driver, Message, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable
class BaseBot(abc.ABC):
@ -55,6 +55,13 @@ class BaseBot(abc.ABC):
"""Adapter 类型"""
raise NotImplementedError
@classmethod
@abc.abstractmethod
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
raise NotImplementedError
@abc.abstractmethod
async def handle_message(self, message: dict):
"""

View File

@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配
import re
import sys
import hmac
import json
import asyncio
import httpx
@ -19,10 +21,10 @@ import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable
def log(level: str, message: str):
@ -39,6 +41,16 @@ def log(level: str, message: str):
return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message)
def get_auth_bearer(
access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]:
if not access_token:
return None
scheme, _, param = access_token.partition(" ")
if scheme.lower() not in ["bearer", "token"]:
raise RequestDenied(401, "Not authenticated")
return param
def escape(s: str, *, escape_comma: bool = True) -> str:
"""
:说明:
@ -264,8 +276,6 @@ class Bot(BaseBot):
self_id: str,
*,
websocket: Optional[WebSocket] = None):
if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")
super().__init__(driver,
connection_type,
@ -281,6 +291,47 @@ class Bot(BaseBot):
"""
return "cqhttp"
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
x_self_id = headers.get("x-self-id")
x_signature = headers.get("x-signature")
access_token = get_auth_bearer(headers.get("authorization"))
# 检查连接方式
if connection_type not in ["http", "websocket"]:
log("WARNING", "Unsupported connection type")
raise RequestDenied(405, "Unsupported connection type")
# 检查self_id
if not x_self_id:
log("WARNING", "Missing X-Self-ID Header")
raise RequestDenied(400, "Missing X-Self-ID Header")
# 检查签名
secret = driver.config.secret
if secret and connection_type == "http":
if not x_signature:
log("WARNING", "Missing Signature Header")
raise RequestDenied(401, "Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(body).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
access_token = driver.config.access_token
if access_token and access_token != access_token:
log(
"WARNING", "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
raise RequestDenied(
403, "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
return str(x_self_id)
@overrides(BaseBot)
async def handle_message(self, message: dict):
"""

View File

@ -9,6 +9,11 @@ def log(level: str, message: str):
...
def get_auth_bearer(
access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]:
...
def escape(s: str, *, escape_comma: bool = ...) -> str:
...
@ -69,6 +74,12 @@ class Bot(BaseBot):
def type(self) -> str:
...
@classmethod
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
...
async def handle_message(self, message: dict):
...

View File

@ -15,12 +15,13 @@ import logging
import uvicorn
from fastapi.responses import Response
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder
from nonebot.exception import RequestDenied
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Optional, Callable, overrides
@ -127,97 +128,58 @@ class Driver(BaseDriver):
@overrides(BaseDriver)
async def _handle_http(self,
adapter: str,
data: dict = Body(...),
x_self_id: Optional[str] = Header(None),
x_signature: Optional[str] = Header(None),
auth: Optional[str] = Depends(get_auth_bearer)):
# 检查self_id
if not x_self_id:
logger.warning("Missing X-Self-ID Header")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing X-Self-ID Header")
# 检查签名
secret = self.config.secret
if secret:
if not x_signature:
logger.warning("Missing Signature Header")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(data).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
logger.warning("Signature Header is invalid")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Signature is invalid")
access_token = self.config.access_token
if access_token and access_token != auth:
logger.warning("Authorization Header is invalid"
if auth else "Missing Authorization Header")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Authorization Header is invalid"
if auth else "Missing Authorization Header")
request: Request,
data: dict = Body(...)):
if not isinstance(data, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if adapter not in self._adapters:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
# 创建 Bot 对象
BotClass = self._adapters[adapter]
headers = dict(request.headers)
try:
x_self_id = await BotClass.check_permission(self, "http", headers,
data)
except RequestDenied as e:
raise HTTPException(status_code=e.status_code,
detail=e.reason) from None
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket api connection,"
"so the event may be handled twice.")
# 创建 Bot 对象
if adapter in self._adapters:
BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id)
else:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found")
bot = BotClass(self, "http", self.config, x_self_id)
asyncio.create_task(bot.handle_message(data))
return Response("", 204)
@overrides(BaseDriver)
async def _handle_ws_reverse(
self,
adapter: str,
websocket: FastAPIWebSocket,
x_self_id: str = Header(None),
auth: Optional[str] = Depends(get_auth_bearer)):
async def _handle_ws_reverse(self, adapter: str,
websocket: FastAPIWebSocket):
ws = WebSocket(websocket)
access_token = self.config.access_token
if access_token and access_token != auth:
logger.warning("Authorization Header is invalid"
if auth else "Missing Authorization Header")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if not x_self_id:
logger.warning(f"Missing X-Self-ID Header")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if x_self_id in self._clients:
logger.warning(f"Connection Conflict: self_id {x_self_id}")
if adapter not in self._adapters:
logger.warning("Unknown adapter")
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Create Bot Object
if adapter in self._adapters:
BotClass = self._adapters[adapter]
bot = BotClass(self,
"websocket",
self.config,
x_self_id,
websocket=ws)
else:
logger.warning("Unknown adapter")
BotClass = self._adapters[adapter]
headers = dict(websocket.headers)
try:
x_self_id = await BotClass.check_permission(self, "websocket",
headers, None)
except RequestDenied:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
await ws.accept()
self._clients[x_self_id] = bot
logger.opt(colors=True).info(

View File

@ -105,6 +105,29 @@ class StopPropagation(Exception):
pass
class RequestDenied(Exception):
"""
:说明:
Bot 连接请求不合法
:参数:
* ``status_code: int``: HTTP 状态码
* ``reason: str``: 拒绝原因
"""
def __init__(self, status_code: int, reason: str):
self.status_code = status_code
self.reason = reason
def __repr__(self):
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
def __str__(self):
return self.__repr__()
class ApiNotAvailable(Exception):
"""
:说明:
@ -131,7 +154,7 @@ class ActionFailed(Exception):
:参数:
* ``retcode``: 错误代码
* ``retcode: Optional[int]``: 错误代码
"""
def __init__(self, retcode: Optional[int]):