add websocket class and coolq message segment

This commit is contained in:
yanyongyu 2020-07-15 20:39:59 +08:00
parent d616290626
commit 3dbd927a2a
4 changed files with 165 additions and 21 deletions

View File

@ -21,22 +21,46 @@ class BaseBot(object):
class BaseMessageSegment(dict):
def __init__(self,
d: Optional[Dict[str, Any]] = None,
*,
type_: Optional[str] = None,
data: Optional[Dict[str, str]] = None):
super().__init__()
if isinstance(d, dict) and d.get('type'):
self.update(d)
elif type_:
if type_:
self.type = type_
self.data = data
else:
raise ValueError('the "type" field cannot be None or empty')
raise ValueError('The "type" field cannot be empty')
def __str__(self):
raise NotImplementedError
def __getitem__(self, item):
if item not in ("type", "data"):
raise KeyError(f'Key "{item}" is not allowed')
return super().__getitem__(item)
def __setitem__(self, key, value):
if key not in ("type", "data"):
raise KeyError(f'Key "{key}" is not allowed')
return super().__setitem__(key, value)
# TODO: __eq__ __add__
@property
def type(self) -> str:
return self["type"]
@type.setter
def type(self, value: str):
self["type"] = value
@property
def data(self) -> Dict[str, str]:
return self["data"]
@data.setter
def data(self, data: Optional[Dict[str, str]]):
self["data"] = data or {}
class BaseMessage(list):

View File

@ -5,13 +5,40 @@ import httpx
from nonebot.event import Event
from nonebot.config import Config
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
from nonebot.message import handle_event
from nonebot.drivers import BaseWebSocket
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
def escape(s: str, *, escape_comma: bool = True) -> str:
"""
对字符串进行 CQ 码转义
``escape_comma`` 参数控制是否转义逗号``,``
"""
s = s.replace("&", "&") \
.replace("[", "[") \
.replace("]", "]")
if escape_comma:
s = s.replace(",", ",")
return s
def unescape(s: str) -> str:
"""对字符串进行 CQ 码去转义。"""
return s.replace(",", ",") \
.replace("[", "[") \
.replace("]", "]") \
.replace("&", "&")
class Bot(BaseBot):
def __init__(self, type_: str, config: Config, *, websocket=None):
def __init__(self,
type_: str,
config: Config,
*,
websocket: BaseWebSocket = None):
if type_ not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")
self.type = type_
@ -33,7 +60,32 @@ class Bot(BaseBot):
class MessageSegment(BaseMessageSegment):
pass
def __str__(self):
type_ = self.type
data = self.data.copy()
# process special types
if type_ == "text":
return escape(data.get("text", ""), escape_comma=False)
elif type_ == "at_all":
type_ = "at"
data = {"qq": "all"}
params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()])
return f"[CQ:{type_}{',' if params else ''}{params}]"
@staticmethod
def at(user_id: int) -> "MessageSegment":
return MessageSegment("at", {"qq": str(user_id)})
@staticmethod
def at_all() -> "MessageSegment":
return MessageSegment("at_all")
@staticmethod
def dice() -> "MessageSegment":
return MessageSegment(type_="dice")
class Message(BaseMessage):

View File

@ -39,3 +39,29 @@ class BaseDriver(object):
async def _handle_http_api(self):
raise NotImplementedError
class BaseWebSocket(object):
def __init__(self, websocket):
self._websocket = websocket
@property
def websocket(self):
return self._websocket
@property
def closed(self):
raise NotImplementedError
async def accept(self):
raise NotImplementedError
async def close(self):
raise NotImplementedError
async def receive(self) -> dict:
raise NotImplementedError
async def send(self, data: dict):
raise NotImplementedError

View File

@ -7,12 +7,13 @@ from typing import Optional
from ipaddress import IPv4Address
import uvicorn
from fastapi.security import OAuth2PasswordBearer
from starlette.websockets import WebSocketDisconnect
from fastapi import Body, FastAPI, WebSocket
from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Config
from nonebot.drivers import BaseDriver
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.adapters.coolq import Bot as CoolQBot
@ -86,7 +87,11 @@ class Driver(BaseDriver):
log_config=LOGGING_CONFIG,
**kwargs)
async def _handle_http(self, adapter: str, data: dict = Body(...)):
async def _handle_http(self,
adapter: str,
data: dict = Body(...),
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
# TODO: Check authorization
logger.debug(f"Received message: {data}")
if adapter == "coolq":
@ -94,20 +99,57 @@ class Driver(BaseDriver):
await bot.handle_message(data)
return {"status": 200, "message": "success"}
async def _handle_ws_reverse(self, adapter: str, websocket: WebSocket):
async def _handle_ws_reverse(self,
adapter: str,
websocket: FastAPIWebSocket,
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
websocket = WebSocket(websocket)
# TODO: Check authorization
await websocket.accept()
while True:
try:
data = await websocket.receive_json()
except json.decoder.JSONDecodeError as e:
logger.exception(e)
while not websocket.closed:
data = await websocket.receive()
if not data:
continue
except WebSocketDisconnect:
logger.error("WebSocket Disconnect")
return
logger.debug(f"Received message: {data}")
if adapter == "coolq":
bot = CoolQBot("websocket", self.config, websocket=websocket)
await bot.handle_message(data)
class WebSocket(BaseWebSocket):
def __init__(self, websocket: FastAPIWebSocket):
super().__init__(websocket)
self._closed = None
@property
def closed(self):
return self._closed
async def accept(self):
await self.websocket.accept()
self._closed = False
async def close(self):
await self.websocket.close()
self._closed = True
async def receive(self) -> Optional[dict]:
data = None
try:
data = await self.websocket.receive_json()
except ValueError:
logger.debug("Received an invalid json message.")
except WebSocketDisconnect:
self._closed = True
logger.error("WebSocket disconnected by peer.")
return data
async def send(self, data: dict) -> None:
await self.websocket.send_json(data)