add check for self_id

This commit is contained in:
yanyongyu 2020-08-21 16:59:41 +08:00
parent b99c9688e2
commit e86362572b

View File

@ -5,8 +5,7 @@ import json
import logging import logging
import uvicorn import uvicorn
from fastapi import FastAPI, status, HTTPException from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
from fastapi import Body, Header, Response, Depends
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
@ -120,7 +119,6 @@ class Driver(BaseDriver):
async def _handle_http( async def _handle_http(
self, self,
adapter: str, adapter: str,
response: Response,
data: dict = Body(...), data: dict = Body(...),
x_self_id: str = Header(None), x_self_id: str = Header(None),
access_token: Optional[str] = Depends(get_auth_bearer)): access_token: Optional[str] = Depends(get_auth_bearer)):
@ -135,8 +133,8 @@ class Driver(BaseDriver):
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id) bot = BotClass(self, "http", self.config, x_self_id)
else: else:
response.status_code = status.HTTP_404_NOT_FOUND raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
return {"status": 404, "message": "adapter not found"} detail="adapter not found")
await bot.handle_message(data) await bot.handle_message(data)
return {"status": 200, "message": "success"} return {"status": 200, "message": "success"}
@ -156,6 +154,14 @@ class Driver(BaseDriver):
websocket = WebSocket(websocket) websocket = WebSocket(websocket)
if not x_self_id:
logger.error(f"Error Connection Unkown: self_id {x_self_id}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if x_self_id in self._clients:
logger.error(f"Error Connection Conflict: self_id {x_self_id}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# Create Bot Object # Create Bot Object
if adapter in self._adapters: if adapter in self._adapters:
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
@ -165,8 +171,8 @@ class Driver(BaseDriver):
x_self_id, x_self_id,
websocket=websocket) websocket=websocket)
else: else:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
return detail="adapter not found")
await websocket.accept() await websocket.accept()
self._clients[x_self_id] = bot self._clients[x_self_id] = bot