mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-02-22 18:45:26 +08:00
🚧 finish forward websocket receive
This commit is contained in:
parent
0bb0d16d93
commit
02af1c1227
@ -1,22 +1,89 @@
|
|||||||
from pprint import pprint
|
import asyncio
|
||||||
from typing import Optional
|
import json
|
||||||
|
from ipaddress import IPv4Address
|
||||||
|
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
|
||||||
|
TypeVar)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import websockets
|
||||||
|
|
||||||
from nonebot.adapters import Bot as BaseBot
|
from nonebot.adapters import Bot as BaseBot
|
||||||
from nonebot.adapters import Event as BaseEvent
|
from nonebot.adapters import Event as BaseEvent
|
||||||
from nonebot.drivers import Driver, WebSocket
|
from nonebot.drivers import Driver
|
||||||
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
|
from nonebot.exception import RequestDenied
|
||||||
|
from nonebot.log import logger
|
||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
|
from .config import Config
|
||||||
from .event import Event
|
from .event import Event
|
||||||
|
|
||||||
|
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
||||||
|
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
||||||
|
bound=WebsocketHandlerFunction)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocket(BaseWebSocket):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def new(cls, *, host: IPv4Address, port: int,
|
||||||
|
session_key: str) -> "WebSocket":
|
||||||
|
listen_address = httpx.URL(f'ws://{host}:{port}/all',
|
||||||
|
params={'sessionKey': session_key})
|
||||||
|
websocket = await websockets.connect(uri=str(listen_address))
|
||||||
|
return cls(websocket)
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
def __init__(self, websocket: websockets.WebSocketClientProtocol):
|
||||||
|
self.event_handlers: Set[WebsocketHandlerFunction] = set()
|
||||||
|
super().__init__(websocket)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
def websocket(self) -> websockets.WebSocketClientProtocol:
|
||||||
|
return self._websocket
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def send(self, data: Dict[str, Any]):
|
||||||
|
return await self.websocket.send(json.dumps(data))
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def receive(self) -> Dict[str, Any]:
|
||||||
|
received = await self.websocket.recv()
|
||||||
|
return json.loads(received)
|
||||||
|
|
||||||
|
async def _dispatcher(self):
|
||||||
|
while not self.websocket.closed:
|
||||||
|
try:
|
||||||
|
data = await self.receive()
|
||||||
|
except websockets.ConnectionClosedOK:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'Websocket client listened {self.websocket} '
|
||||||
|
f'failed to receive data: {e}')
|
||||||
|
continue
|
||||||
|
asyncio.ensure_future(
|
||||||
|
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
|
||||||
|
return_exceptions=True))
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def accept(self):
|
||||||
|
asyncio.ensure_future(self._dispatcher())
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def close(self):
|
||||||
|
await self.websocket.close()
|
||||||
|
|
||||||
|
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
|
||||||
|
self.event_handlers.add(callable)
|
||||||
|
return callable
|
||||||
|
|
||||||
|
|
||||||
class MiraiBot(BaseBot):
|
class MiraiBot(BaseBot):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, connection_type: str, self_id: str, *,
|
||||||
connection_type: str,
|
websocket: WebSocket):
|
||||||
self_id: str,
|
|
||||||
*,
|
|
||||||
websocket: Optional["WebSocket"] = None):
|
|
||||||
super().__init__(connection_type, self_id, websocket=websocket)
|
super().__init__(connection_type, self_id, websocket=websocket)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -27,8 +94,55 @@ class MiraiBot(BaseBot):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def check_permission(cls, driver: "Driver", connection_type: str,
|
async def check_permission(cls, driver: "Driver", connection_type: str,
|
||||||
headers: dict, body: Optional[dict]) -> str:
|
headers: dict, body: Optional[dict]) -> NoReturn:
|
||||||
return ''
|
raise RequestDenied(
|
||||||
|
status_code=501,
|
||||||
|
reason=f'Connection {connection_type} not implented')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@overrides(BaseBot)
|
||||||
|
def register(cls, driver: "Driver", config: "Config", qq: int):
|
||||||
|
config = Config.parse_obj(config.dict())
|
||||||
|
assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid'
|
||||||
|
|
||||||
|
super().register(driver, config) # type: ignore
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def _startup():
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
base_url=f'http://{config.host}:{config.port}') as client:
|
||||||
|
response = await client.get('/about')
|
||||||
|
info = response.json()
|
||||||
|
logger.debug(f'Mirai API returned info: {info}')
|
||||||
|
response = await client.post('/auth',
|
||||||
|
json={'authKey': config.auth_key})
|
||||||
|
status = response.json()
|
||||||
|
assert status['code'] == 0
|
||||||
|
session_key = status['session']
|
||||||
|
response = await client.post('/verify',
|
||||||
|
json={
|
||||||
|
'sessionKey': session_key,
|
||||||
|
'qq': qq
|
||||||
|
})
|
||||||
|
assert response.json()['code'] == 0
|
||||||
|
|
||||||
|
websocket = await WebSocket.new(
|
||||||
|
host=config.host, # type: ignore
|
||||||
|
port=config.port, # type: ignore
|
||||||
|
session_key=session_key)
|
||||||
|
bot = cls(connection_type='forward_ws',
|
||||||
|
self_id=str(qq),
|
||||||
|
websocket=websocket)
|
||||||
|
websocket.handle(bot.handle_message)
|
||||||
|
driver._clients[str(qq)] = bot
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
|
async def _shutdown():
|
||||||
|
bot = driver._clients.pop(str(qq), None)
|
||||||
|
if bot is None:
|
||||||
|
return
|
||||||
|
await bot.websocket.close() #type:ignore
|
||||||
|
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def handle_message(self, message: dict):
|
async def handle_message(self, message: dict):
|
||||||
|
13
nonebot/adapters/mirai/config.py
Normal file
13
nonebot/adapters/mirai/config.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from ipaddress import IPv4Address
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
|
||||||
|
host: Optional[IPv4Address] = Field(None, alias='mirai_host')
|
||||||
|
port: Optional[int] = Field(None, alias='mirai_port')
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.ignore
|
@ -62,7 +62,7 @@ class Driver(abc.ABC):
|
|||||||
:说明: 已连接的 Bot
|
:说明: 已连接的 Bot
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register_adapter(self, name: str, adapter: Type["Bot"]):
|
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ class Driver(abc.ABC):
|
|||||||
* ``adapter: Type[Bot]``: 适配器 Class
|
* ``adapter: Type[Bot]``: 适配器 Class
|
||||||
"""
|
"""
|
||||||
self._adapters[name] = adapter
|
self._adapters[name] = adapter
|
||||||
adapter.register(self, self.config)
|
adapter.register(self, self.config, **kwargs)
|
||||||
logger.opt(
|
logger.opt(
|
||||||
colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"')
|
colors=True).debug(f'Succeeded to load adapter "<y>{name}</y>"')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user