🚧 finish forward websocket receive

This commit is contained in:
Mix 2021-01-30 05:58:30 +08:00
parent 0bb0d16d93
commit 02af1c1227
3 changed files with 139 additions and 12 deletions

View File

@ -1,22 +1,89 @@
from pprint import pprint import asyncio
from typing import Optional import json
from ipaddress import IPv4Address
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
TypeVar)
import httpx
import websockets
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent from nonebot.adapters import Event as BaseEvent
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from .config import Config
from .event import Event from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
class WebSocket(BaseWebSocket):
@classmethod
async def new(cls, *, host: IPv4Address, port: int,
session_key: str) -> "WebSocket":
listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
return cls(websocket)
@overrides(BaseWebSocket)
def __init__(self, websocket: websockets.WebSocketClientProtocol):
self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket)
@property
@overrides(BaseWebSocket)
def websocket(self) -> websockets.WebSocketClientProtocol:
return self._websocket
@overrides(BaseWebSocket)
async def send(self, data: Dict[str, Any]):
return await self.websocket.send(json.dumps(data))
@overrides(BaseWebSocket)
async def receive(self) -> Dict[str, Any]:
received = await self.websocket.recv()
return json.loads(received)
async def _dispatcher(self):
while not self.websocket.closed:
try:
data = await self.receive()
except websockets.ConnectionClosedOK:
break
except Exception as e:
logger.exception(f'Websocket client listened {self.websocket} '
f'failed to receive data: {e}')
continue
asyncio.ensure_future(
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
return_exceptions=True))
@overrides(BaseWebSocket)
async def accept(self):
asyncio.ensure_future(self._dispatcher())
@overrides(BaseWebSocket)
async def close(self):
await self.websocket.close()
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
self.event_handlers.add(callable)
return callable
class MiraiBot(BaseBot): class MiraiBot(BaseBot):
def __init__(self, def __init__(self, connection_type: str, self_id: str, *,
connection_type: str, websocket: WebSocket):
self_id: str,
*,
websocket: Optional["WebSocket"] = None):
super().__init__(connection_type, self_id, websocket=websocket) super().__init__(connection_type, self_id, websocket=websocket)
@property @property
@ -27,8 +94,55 @@ class MiraiBot(BaseBot):
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> str: headers: dict, body: Optional[dict]) -> NoReturn:
return '' raise RequestDenied(
status_code=501,
reason=f'Connection {connection_type} not implented')
@classmethod
@overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config", qq: int):
config = Config.parse_obj(config.dict())
assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid'
super().register(driver, config) # type: ignore
@driver.on_startup
async def _startup():
async with httpx.AsyncClient(
base_url=f'http://{config.host}:{config.port}') as client:
response = await client.get('/about')
info = response.json()
logger.debug(f'Mirai API returned info: {info}')
response = await client.post('/auth',
json={'authKey': config.auth_key})
status = response.json()
assert status['code'] == 0
session_key = status['session']
response = await client.post('/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert response.json()['code'] == 0
websocket = await WebSocket.new(
host=config.host, # type: ignore
port=config.port, # type: ignore
session_key=session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
websocket.handle(bot.handle_message)
driver._clients[str(qq)] = bot
await websocket.accept()
@driver.on_shutdown
async def _shutdown():
bot = driver._clients.pop(str(qq), None)
if bot is None:
return
await bot.websocket.close() #type:ignore
@overrides(BaseBot) @overrides(BaseBot)
async def handle_message(self, message: dict): async def handle_message(self, message: dict):

View File

@ -0,0 +1,13 @@
from ipaddress import IPv4Address
from typing import Optional
from pydantic import BaseModel, Extra, Field
class Config(BaseModel):
auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
host: Optional[IPv4Address] = Field(None, alias='mirai_host')
port: Optional[int] = Field(None, alias='mirai_port')
class Config:
extra = Extra.ignore

View File

@ -62,7 +62,7 @@ class Driver(abc.ABC):
:说明: 已连接的 Bot :说明: 已连接的 Bot
""" """
def register_adapter(self, name: str, adapter: Type["Bot"]): def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
""" """
:说明: :说明:
@ -74,7 +74,7 @@ class Driver(abc.ABC):
* ``adapter: Type[Bot]``: 适配器 Class * ``adapter: Type[Bot]``: 适配器 Class
""" """
self._adapters[name] = adapter self._adapters[name] = adapter
adapter.register(self, self.config) adapter.register(self, self.config, **kwargs)
logger.opt( logger.opt(
colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"') colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"')