nonebot2/nonebot/adapters/mirai/bot.py

159 lines
5.6 KiB
Python
Raw Normal View History

import asyncio
import json
from ipaddress import IPv4Address
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
TypeVar)
import httpx
import websockets
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
from .config import Config
from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
class WebSocket(BaseWebSocket):
@classmethod
async def new(cls, *, host: IPv4Address, port: int,
session_key: str) -> "WebSocket":
listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
return cls(websocket)
@overrides(BaseWebSocket)
def __init__(self, websocket: websockets.WebSocketClientProtocol):
self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket)
@property
@overrides(BaseWebSocket)
def websocket(self) -> websockets.WebSocketClientProtocol:
return self._websocket
@overrides(BaseWebSocket)
async def send(self, data: Dict[str, Any]):
return await self.websocket.send(json.dumps(data))
@overrides(BaseWebSocket)
async def receive(self) -> Dict[str, Any]:
received = await self.websocket.recv()
return json.loads(received)
async def _dispatcher(self):
while not self.websocket.closed:
try:
data = await self.receive()
except websockets.ConnectionClosedOK:
break
except Exception as e:
logger.exception(f'Websocket client listened {self.websocket} '
f'failed to receive data: {e}')
continue
asyncio.ensure_future(
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
return_exceptions=True))
@overrides(BaseWebSocket)
async def accept(self):
asyncio.ensure_future(self._dispatcher())
@overrides(BaseWebSocket)
async def close(self):
await self.websocket.close()
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
self.event_handlers.add(callable)
return callable
class MiraiBot(BaseBot):
def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket)
@property
@overrides(BaseBot)
def type(self) -> str:
return "mirai"
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> NoReturn:
raise RequestDenied(
status_code=501,
reason=f'Connection {connection_type} not implented')
@classmethod
@overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config", qq: int):
config = Config.parse_obj(config.dict())
assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid'
super().register(driver, config) # type: ignore
@driver.on_startup
async def _startup():
async with httpx.AsyncClient(
base_url=f'http://{config.host}:{config.port}') as client:
response = await client.get('/about')
info = response.json()
logger.debug(f'Mirai API returned info: {info}')
response = await client.post('/auth',
json={'authKey': config.auth_key})
status = response.json()
assert status['code'] == 0
session_key = status['session']
response = await client.post('/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert response.json()['code'] == 0
websocket = await WebSocket.new(
host=config.host, # type: ignore
port=config.port, # type: ignore
session_key=session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
websocket.handle(bot.handle_message)
driver._clients[str(qq)] = bot
await websocket.accept()
@driver.on_shutdown
async def _shutdown():
bot = driver._clients.pop(str(qq), None)
if bot is None:
return
await bot.websocket.close() #type:ignore
@overrides(BaseBot)
async def handle_message(self, message: dict):
event = Event.new(message)
await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, **data):
return super().call_api(api, **data)
@overrides(BaseBot)
async def send(self, event: "BaseEvent", message: str, **kwargs):
return super().send(event, message, **kwargs)