mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-18 09:25:46 +08:00
🎨 change permission check from driver into adapter #46
This commit is contained in:
parent
1f1f9cd7e6
commit
b2a2234d5c
@ -11,7 +11,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.typing import Driver, Message, WebSocket
|
from nonebot.typing import Driver, Message, WebSocket
|
||||||
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
|
from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable
|
||||||
|
|
||||||
|
|
||||||
class BaseBot(abc.ABC):
|
class BaseBot(abc.ABC):
|
||||||
@ -55,6 +55,13 @@ class BaseBot(abc.ABC):
|
|||||||
"""Adapter 类型"""
|
"""Adapter 类型"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def check_permission(cls, driver: Driver, connection_type: str,
|
||||||
|
headers: dict,
|
||||||
|
body: Optional[dict]) -> Union[str, NoReturn]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle_message(self, message: dict):
|
async def handle_message(self, message: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -12,6 +12,8 @@ CQHTTP (OneBot) v11 协议适配
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -19,10 +21,10 @@ import httpx
|
|||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
|
||||||
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
|
|
||||||
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||||
|
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||||
|
from nonebot.exception import NetworkError, ActionFailed, RequestDenied, ApiNotAvailable
|
||||||
|
|
||||||
|
|
||||||
def log(level: str, message: str):
|
def log(level: str, message: str):
|
||||||
@ -39,6 +41,16 @@ def log(level: str, message: str):
|
|||||||
return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message)
|
return logger.opt(colors=True).log(level, "<m>CQHTTP</m> | " + message)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_bearer(
|
||||||
|
access_token: Optional[str] = None) -> Union[Optional[str], NoReturn]:
|
||||||
|
if not access_token:
|
||||||
|
return None
|
||||||
|
scheme, _, param = access_token.partition(" ")
|
||||||
|
if scheme.lower() not in ["bearer", "token"]:
|
||||||
|
raise RequestDenied(401, "Not authenticated")
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -264,8 +276,6 @@ class Bot(BaseBot):
|
|||||||
self_id: str,
|
self_id: str,
|
||||||
*,
|
*,
|
||||||
websocket: Optional[WebSocket] = None):
|
websocket: Optional[WebSocket] = None):
|
||||||
if connection_type not in ["http", "websocket"]:
|
|
||||||
raise ValueError("Unsupported connection type")
|
|
||||||
|
|
||||||
super().__init__(driver,
|
super().__init__(driver,
|
||||||
connection_type,
|
connection_type,
|
||||||
@ -281,6 +291,47 @@ class Bot(BaseBot):
|
|||||||
"""
|
"""
|
||||||
return "cqhttp"
|
return "cqhttp"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@overrides(BaseBot)
|
||||||
|
async def check_permission(cls, driver: Driver, connection_type: str,
|
||||||
|
headers: dict,
|
||||||
|
body: Optional[dict]) -> Union[str, NoReturn]:
|
||||||
|
x_self_id = headers.get("x-self-id")
|
||||||
|
x_signature = headers.get("x-signature")
|
||||||
|
access_token = get_auth_bearer(headers.get("authorization"))
|
||||||
|
|
||||||
|
# 检查连接方式
|
||||||
|
if connection_type not in ["http", "websocket"]:
|
||||||
|
log("WARNING", "Unsupported connection type")
|
||||||
|
raise RequestDenied(405, "Unsupported connection type")
|
||||||
|
|
||||||
|
# 检查self_id
|
||||||
|
if not x_self_id:
|
||||||
|
log("WARNING", "Missing X-Self-ID Header")
|
||||||
|
raise RequestDenied(400, "Missing X-Self-ID Header")
|
||||||
|
|
||||||
|
# 检查签名
|
||||||
|
secret = driver.config.secret
|
||||||
|
if secret and connection_type == "http":
|
||||||
|
if not x_signature:
|
||||||
|
log("WARNING", "Missing Signature Header")
|
||||||
|
raise RequestDenied(401, "Missing Signature")
|
||||||
|
sig = hmac.new(secret.encode("utf-8"),
|
||||||
|
json.dumps(body).encode(), "sha1").hexdigest()
|
||||||
|
if x_signature != "sha1=" + sig:
|
||||||
|
log("WARNING", "Signature Header is invalid")
|
||||||
|
raise RequestDenied(403, "Signature is invalid")
|
||||||
|
|
||||||
|
access_token = driver.config.access_token
|
||||||
|
if access_token and access_token != access_token:
|
||||||
|
log(
|
||||||
|
"WARNING", "Authorization Header is invalid"
|
||||||
|
if access_token else "Missing Authorization Header")
|
||||||
|
raise RequestDenied(
|
||||||
|
403, "Authorization Header is invalid"
|
||||||
|
if access_token else "Missing Authorization Header")
|
||||||
|
return str(x_self_id)
|
||||||
|
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def handle_message(self, message: dict):
|
async def handle_message(self, message: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -9,6 +9,11 @@ def log(level: str, message: str):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_bearer(
|
||||||
|
access_token: Optional[str] = ...) -> Union[Optional[str], NoReturn]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def escape(s: str, *, escape_comma: bool = ...) -> str:
|
def escape(s: str, *, escape_comma: bool = ...) -> str:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -69,6 +74,12 @@ class Bot(BaseBot):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_permission(cls, driver: Driver, connection_type: str,
|
||||||
|
headers: dict,
|
||||||
|
body: Optional[dict]) -> Union[str, NoReturn]:
|
||||||
|
...
|
||||||
|
|
||||||
async def handle_message(self, message: dict):
|
async def handle_message(self, message: dict):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -15,12 +15,13 @@ import logging
|
|||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
|
from fastapi import Body, status, Header, Request, FastAPI, Depends, HTTPException
|
||||||
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
|
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.utils import DataclassEncoder
|
from nonebot.utils import DataclassEncoder
|
||||||
|
from nonebot.exception import RequestDenied
|
||||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||||
from nonebot.typing import Optional, Callable, overrides
|
from nonebot.typing import Optional, Callable, overrides
|
||||||
|
|
||||||
@ -127,97 +128,58 @@ class Driver(BaseDriver):
|
|||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
async def _handle_http(self,
|
async def _handle_http(self,
|
||||||
adapter: str,
|
adapter: str,
|
||||||
data: dict = Body(...),
|
request: Request,
|
||||||
x_self_id: Optional[str] = Header(None),
|
data: dict = Body(...)):
|
||||||
x_signature: Optional[str] = Header(None),
|
|
||||||
auth: Optional[str] = Depends(get_auth_bearer)):
|
|
||||||
# 检查self_id
|
|
||||||
if not x_self_id:
|
|
||||||
logger.warning("Missing X-Self-ID Header")
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Missing X-Self-ID Header")
|
|
||||||
|
|
||||||
# 检查签名
|
|
||||||
secret = self.config.secret
|
|
||||||
if secret:
|
|
||||||
if not x_signature:
|
|
||||||
logger.warning("Missing Signature Header")
|
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Missing Signature")
|
|
||||||
sig = hmac.new(secret.encode("utf-8"),
|
|
||||||
json.dumps(data).encode(), "sha1").hexdigest()
|
|
||||||
if x_signature != "sha1=" + sig:
|
|
||||||
logger.warning("Signature Header is invalid")
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Signature is invalid")
|
|
||||||
|
|
||||||
access_token = self.config.access_token
|
|
||||||
if access_token and access_token != auth:
|
|
||||||
logger.warning("Authorization Header is invalid"
|
|
||||||
if auth else "Missing Authorization Header")
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Authorization Header is invalid"
|
|
||||||
if auth else "Missing Authorization Header")
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
logger.warning("Data received is invalid")
|
logger.warning("Data received is invalid")
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
if adapter not in self._adapters:
|
||||||
|
logger.warning("Unknown adapter")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="adapter not found")
|
||||||
|
|
||||||
|
# 创建 Bot 对象
|
||||||
|
BotClass = self._adapters[adapter]
|
||||||
|
headers = dict(request.headers)
|
||||||
|
try:
|
||||||
|
x_self_id = await BotClass.check_permission(self, "http", headers,
|
||||||
|
data)
|
||||||
|
except RequestDenied as e:
|
||||||
|
raise HTTPException(status_code=e.status_code,
|
||||||
|
detail=e.reason) from None
|
||||||
|
|
||||||
if x_self_id in self._clients:
|
if x_self_id in self._clients:
|
||||||
logger.warning("There's already a reverse websocket api connection,"
|
logger.warning("There's already a reverse websocket api connection,"
|
||||||
"so the event may be handled twice.")
|
"so the event may be handled twice.")
|
||||||
|
|
||||||
# 创建 Bot 对象
|
|
||||||
if adapter in self._adapters:
|
|
||||||
BotClass = self._adapters[adapter]
|
|
||||||
bot = BotClass(self, "http", self.config, x_self_id)
|
bot = BotClass(self, "http", self.config, x_self_id)
|
||||||
else:
|
|
||||||
logger.warning("Unknown adapter")
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="adapter not found")
|
|
||||||
|
|
||||||
asyncio.create_task(bot.handle_message(data))
|
asyncio.create_task(bot.handle_message(data))
|
||||||
return Response("", 204)
|
return Response("", 204)
|
||||||
|
|
||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
async def _handle_ws_reverse(
|
async def _handle_ws_reverse(self, adapter: str,
|
||||||
self,
|
websocket: FastAPIWebSocket):
|
||||||
adapter: str,
|
|
||||||
websocket: FastAPIWebSocket,
|
|
||||||
x_self_id: str = Header(None),
|
|
||||||
auth: Optional[str] = Depends(get_auth_bearer)):
|
|
||||||
ws = WebSocket(websocket)
|
ws = WebSocket(websocket)
|
||||||
|
|
||||||
access_token = self.config.access_token
|
if adapter not in self._adapters:
|
||||||
if access_token and access_token != auth:
|
logger.warning("Unknown adapter")
|
||||||
logger.warning("Authorization Header is invalid"
|
|
||||||
if auth else "Missing Authorization Header")
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not x_self_id:
|
|
||||||
logger.warning(f"Missing X-Self-ID Header")
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
|
||||||
return
|
|
||||||
|
|
||||||
if x_self_id in self._clients:
|
|
||||||
logger.warning(f"Connection Conflict: self_id {x_self_id}")
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create Bot Object
|
# Create Bot Object
|
||||||
if adapter in self._adapters:
|
|
||||||
BotClass = self._adapters[adapter]
|
BotClass = self._adapters[adapter]
|
||||||
bot = BotClass(self,
|
headers = dict(websocket.headers)
|
||||||
"websocket",
|
try:
|
||||||
self.config,
|
x_self_id = await BotClass.check_permission(self, "websocket",
|
||||||
x_self_id,
|
headers, None)
|
||||||
websocket=ws)
|
except RequestDenied:
|
||||||
else:
|
|
||||||
logger.warning("Unknown adapter")
|
|
||||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
|
||||||
|
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
self._clients[x_self_id] = bot
|
self._clients[x_self_id] = bot
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
|
@ -105,6 +105,29 @@ class StopPropagation(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RequestDenied(Exception):
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
Bot 连接请求不合法。
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``status_code: int``: HTTP 状态码
|
||||||
|
* ``reason: str``: 拒绝原因
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, reason: str):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.reason = reason
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<RequestDenied, status_code={self.status_code}, reason={self.reason}>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
class ApiNotAvailable(Exception):
|
class ApiNotAvailable(Exception):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -131,7 +154,7 @@ class ActionFailed(Exception):
|
|||||||
|
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``retcode``: 错误代码
|
* ``retcode: Optional[int]``: 错误代码
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, retcode: Optional[int]):
|
def __init__(self, retcode: Optional[int]):
|
||||||
|
Loading…
Reference in New Issue
Block a user