🚧 add support of reverse post and forward ws for mirai adapter

This commit is contained in:
Mix 2021-01-31 16:02:59 +08:00
parent 73be9151b0
commit 3f56da9245
5 changed files with 57 additions and 90 deletions

View File

@ -1,3 +1,4 @@
from .bot import MiraiBot
from .bot_ws import MiraiWebsocketBot
from .event import *
from .message import MessageChain, MessageSegment

View File

@ -1,19 +1,19 @@
from datetime import datetime, timedelta
from io import BytesIO
from ipaddress import IPv4Address
from typing import Any, Dict, List, NoReturn, Optional, Tuple
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import RequestDenied
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
@ -41,7 +41,8 @@ class SessionManager:
@staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0)
logger.debug(f'Mirai API returned data: {data}')
logger.opt(colors=True).debug('Mirai API returned data: '
f'<y>{escape_tag(str(data))}</y>')
if code != 0:
raise ActionFailed(code, message=data['msg'])
return data
@ -85,10 +86,10 @@ class SessionManager:
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str):
if self_id in cls.sessions:
manager = cls.get(self_id)
if manager is not None:
return manager
session = cls.get(self_id)
if session is not None:
return session
client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
response = await client.post('/auth', json={'authKey': auth_key})
response.raise_for_status()
@ -102,10 +103,13 @@ class SessionManager:
})
assert response.json()['code'] == 0
cls.sessions[self_id] = session_key, datetime.now(), client
return cls(session_key, client)
@classmethod
def get(cls, self_id: int):
if self_id not in cls.sessions:
return None
key, time, client = cls.sessions[self_id]
if datetime.now() - time > cls.session_expiry:
return None
@ -114,6 +118,7 @@ class SessionManager:
class MiraiBot(BaseBot):
@overrides(BaseBot)
def __init__(self,
connection_type: str,
self_id: str,
@ -179,17 +184,20 @@ class MiraiBot(BaseBot):
@overrides(BaseBot)
async def send(self,
event: Event,
message: MessageChain,
at_sender: bool = False,
**kwargs):
message: Union[MessageChain, MessageSegment, str],
at_sender: bool = False):
if isinstance(message, MessageSegment):
message = MessageChain(message)
elif isinstance(message, str):
message = MessageChain(MessageSegment.plain(message))
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
elif isinstance(event, GroupMessage):
return await self.send_group_message(
group=event.sender.group.id,
message_chain=message if not at_sender else
(MessageSegment.at(target=event.sender.id) + message))
if at_sender:
message = MessageSegment.at(event.sender.id) + message
return await self.send_group_message(group=event.sender.group.id,
message_chain=message)
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,

View File

@ -7,50 +7,21 @@ from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
import httpx
import websockets
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
from .bot import MiraiBot, SessionManager
from .config import Config as MiraiConfig
from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
qq: int) -> str:
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
response = await client.request(method, path, **kwargs)
response.raise_for_status()
return response.json()
about = await request('GET', path='/about')
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
f'<g><b>{about["data"]["version"]}</b></g>')
status = await request('POST', path='/auth', json={'authKey': auth_key})
assert status['code'] == 0
session_key = status['session']
verify = await request('POST',
path='/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert verify['code'] == 0, verify['msg']
return session_key
class WebSocket(BaseWebSocket):
@classmethod
@ -59,6 +30,7 @@ class WebSocket(BaseWebSocket):
listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping())
return cls(websocket)
@overrides(BaseWebSocket)
@ -116,25 +88,24 @@ class WebSocket(BaseWebSocket):
return callable
class MiraiWebsocketBot(BaseBot):
class MiraiWebsocketBot(MiraiBot):
@overrides(MiraiBot)
def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket)
websocket.handle(self.handle_message)
self.driver._bot_connect(self)
@property
@overrides(BaseBot)
@overrides(MiraiBot)
def type(self) -> str:
return "mirai"
return "mirai-ws"
@property
def alive(self) -> bool:
return not self.websocket.closed
@classmethod
@overrides(BaseBot)
@overrides(MiraiBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> NoReturn:
raise RequestDenied(
@ -142,7 +113,7 @@ class MiraiWebsocketBot(BaseBot):
reason=f'Connection {connection_type} not implented')
@classmethod
@overrides(BaseBot)
@overrides(MiraiBot)
def register(cls, driver: "Driver", config: "Config", qq: int):
cls.mirai_config = MiraiConfig(**config.dict())
cls.active = True
@ -152,32 +123,33 @@ class MiraiWebsocketBot(BaseBot):
super().register(driver, config)
async def _bot_connection():
async with httpx.AsyncClient(
base_url=
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
) as client:
session_key = await _ws_authorization(
client,
auth_key=cls.mirai_config.auth_key, # type: ignore
qq=qq) # type: ignore
session: SessionManager = await SessionManager.new(
qq,
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
auth_key=cls.mirai_config.auth_key # type: ignore
)
websocket = await WebSocket.new(
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
session_key=session_key)
session_key=session.session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
websocket.handle(bot.handle_message)
driver._clients[str(qq)] = bot
await websocket.accept()
return bot
async def _connection_ensure():
if str(qq) not in driver._clients:
await _bot_connection()
elif not driver._clients[str(qq)].alive:
driver._clients.pop(str(qq), None)
await _bot_connection()
self_id = str(qq)
if self_id not in driver._clients:
bot = await _bot_connection()
driver._bot_connect(bot)
else:
bot = driver._clients[self_id]
if not bot.alive:
driver._bot_disconnect(bot)
return
@driver.on_startup
async def _startup():
@ -202,19 +174,3 @@ class MiraiWebsocketBot(BaseBot):
if bot is None:
return
await bot.websocket.close() #type:ignore
@overrides(BaseBot)
async def handle_message(self, message: dict):
event = Event.new(message)
await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, **data):
return super().call_api(api, **data)
@overrides(BaseBot)
async def send(self, event: "BaseEvent", message: str, **kwargs):
return super().send(event, message, **kwargs)
def __del__(self):
self.driver._bot_disconnect(self)

View File

@ -86,7 +86,7 @@ class Event(BaseEvent):
@overrides(BaseEvent)
def get_event_description(self) -> str:
return str(self.dict())
return str(self.normalize_dict())
@overrides(BaseEvent)
def get_message(self) -> BaseMessage:

View File

@ -135,10 +135,11 @@ class MessageSegment(BaseMessageSegment):
return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage):
class MessageChain(BaseMessage): #type:List[MessageSegment]
@overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str, Any]], MessageSegment],
def __init__(self, message: Union[List[Dict[str, Any]],
Iterable[MessageSegment], MessageSegment],
**kwargs):
super().__init__(**kwargs)
if isinstance(message, MessageSegment):
@ -152,15 +153,16 @@ class MessageChain(BaseMessage):
@overrides(BaseMessage)
def _construct(
self, message: Iterable[Union[Dict[str, Any], MessageSegment]]
self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment]]
) -> List[MessageSegment]:
if isinstance(message, str):
raise ValueError(
"String operation is not supported in mirai adapter")
return [
*map(
lambda segment: segment if isinstance(segment, MessageSegment)
else MessageSegment(**segment), message)
lambda x: x
if isinstance(x, MessageSegment) else MessageSegment(**x),
message)
]
def export(self) -> List[Dict[str, Any]]: