diff --git a/nonebot/adapters/mirai/__init__.py b/nonebot/adapters/mirai/__init__.py
index 991f30fd..1107af38 100644
--- a/nonebot/adapters/mirai/__init__.py
+++ b/nonebot/adapters/mirai/__init__.py
@@ -1,3 +1,4 @@
from .bot import MiraiBot
+from .bot_ws import MiraiWebsocketBot
from .event import *
from .message import MessageChain, MessageSegment
diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py
index 2414dca8..ebb9b768 100644
--- a/nonebot/adapters/mirai/bot.py
+++ b/nonebot/adapters/mirai/bot.py
@@ -1,19 +1,19 @@
from datetime import datetime, timedelta
from io import BytesIO
from ipaddress import IPv4Address
-from typing import Any, Dict, List, NoReturn, Optional, Tuple
+from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
import httpx
from nonebot.adapters import Bot as BaseBot
-from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
-from nonebot.exception import RequestDenied
from nonebot.exception import ActionFailed as BaseActionFailed
+from nonebot.exception import RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
+from nonebot.utils import escape_tag
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
@@ -41,7 +41,8 @@ class SessionManager:
@staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0)
- logger.debug(f'Mirai API returned data: {data}')
+ logger.opt(colors=True).debug('Mirai API returned data: '
+ f'{escape_tag(str(data))}')
if code != 0:
raise ActionFailed(code, message=data['msg'])
return data
@@ -85,10 +86,10 @@ class SessionManager:
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str):
- if self_id in cls.sessions:
- manager = cls.get(self_id)
- if manager is not None:
- return manager
+ session = cls.get(self_id)
+ if session is not None:
+ return session
+
client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
response = await client.post('/auth', json={'authKey': auth_key})
response.raise_for_status()
@@ -102,10 +103,13 @@ class SessionManager:
})
assert response.json()['code'] == 0
cls.sessions[self_id] = session_key, datetime.now(), client
+
return cls(session_key, client)
@classmethod
def get(cls, self_id: int):
+ if self_id not in cls.sessions:
+ return None
key, time, client = cls.sessions[self_id]
if datetime.now() - time > cls.session_expiry:
return None
@@ -114,6 +118,7 @@ class SessionManager:
class MiraiBot(BaseBot):
+ @overrides(BaseBot)
def __init__(self,
connection_type: str,
self_id: str,
@@ -179,17 +184,20 @@ class MiraiBot(BaseBot):
@overrides(BaseBot)
async def send(self,
event: Event,
- message: MessageChain,
- at_sender: bool = False,
- **kwargs):
+ message: Union[MessageChain, MessageSegment, str],
+ at_sender: bool = False):
+ if isinstance(message, MessageSegment):
+ message = MessageChain(message)
+ elif isinstance(message, str):
+ message = MessageChain(MessageSegment.plain(message))
if isinstance(event, FriendMessage):
return await self.send_friend_message(target=event.sender.id,
message_chain=message)
elif isinstance(event, GroupMessage):
- return await self.send_group_message(
- group=event.sender.group.id,
- message_chain=message if not at_sender else
- (MessageSegment.at(target=event.sender.id) + message))
+ if at_sender:
+ message = MessageSegment.at(event.sender.id) + message
+ return await self.send_group_message(group=event.sender.group.id,
+ message_chain=message)
elif isinstance(event, TempMessage):
return await self.send_temp_message(qq=event.sender.id,
group=event.sender.group.id,
diff --git a/nonebot/adapters/mirai/bot_ws.py b/nonebot/adapters/mirai/bot_ws.py
index d9803c47..d20d81dd 100644
--- a/nonebot/adapters/mirai/bot_ws.py
+++ b/nonebot/adapters/mirai/bot_ws.py
@@ -7,50 +7,21 @@ from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
import httpx
import websockets
-from nonebot.adapters import Bot as BaseBot
-from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
from nonebot.log import logger
-from nonebot.message import handle_event
from nonebot.typing import overrides
+from .bot import MiraiBot, SessionManager
from .config import Config as MiraiConfig
-from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
-async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
- qq: int) -> str:
-
- async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
- response = await client.request(method, path, **kwargs)
- response.raise_for_status()
- return response.json()
-
- about = await request('GET', path='/about')
- logger.opt(colors=True).debug('Mirai API HTTP backend version: '
- f'{about["data"]["version"]}')
-
- status = await request('POST', path='/auth', json={'authKey': auth_key})
- assert status['code'] == 0
- session_key = status['session']
-
- verify = await request('POST',
- path='/verify',
- json={
- 'sessionKey': session_key,
- 'qq': qq
- })
- assert verify['code'] == 0, verify['msg']
- return session_key
-
-
class WebSocket(BaseWebSocket):
@classmethod
@@ -59,6 +30,7 @@ class WebSocket(BaseWebSocket):
listen_address = httpx.URL(f'ws://{host}:{port}/all',
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
+ await (await websocket.ping())
return cls(websocket)
@overrides(BaseWebSocket)
@@ -116,25 +88,24 @@ class WebSocket(BaseWebSocket):
return callable
-class MiraiWebsocketBot(BaseBot):
+class MiraiWebsocketBot(MiraiBot):
+ @overrides(MiraiBot)
def __init__(self, connection_type: str, self_id: str, *,
websocket: WebSocket):
super().__init__(connection_type, self_id, websocket=websocket)
- websocket.handle(self.handle_message)
- self.driver._bot_connect(self)
@property
- @overrides(BaseBot)
+ @overrides(MiraiBot)
def type(self) -> str:
- return "mirai"
+ return "mirai-ws"
@property
def alive(self) -> bool:
return not self.websocket.closed
@classmethod
- @overrides(BaseBot)
+ @overrides(MiraiBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> NoReturn:
raise RequestDenied(
@@ -142,7 +113,7 @@ class MiraiWebsocketBot(BaseBot):
reason=f'Connection {connection_type} not implented')
@classmethod
- @overrides(BaseBot)
+ @overrides(MiraiBot)
def register(cls, driver: "Driver", config: "Config", qq: int):
cls.mirai_config = MiraiConfig(**config.dict())
cls.active = True
@@ -152,32 +123,33 @@ class MiraiWebsocketBot(BaseBot):
super().register(driver, config)
async def _bot_connection():
- async with httpx.AsyncClient(
- base_url=
- f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
- ) as client:
- session_key = await _ws_authorization(
- client,
- auth_key=cls.mirai_config.auth_key, # type: ignore
- qq=qq) # type: ignore
-
+ session: SessionManager = await SessionManager.new(
+ qq,
+ host=cls.mirai_config.host, # type: ignore
+ port=cls.mirai_config.port, # type: ignore
+ auth_key=cls.mirai_config.auth_key # type: ignore
+ )
websocket = await WebSocket.new(
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
- session_key=session_key)
+ session_key=session.session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
websocket.handle(bot.handle_message)
- driver._clients[str(qq)] = bot
await websocket.accept()
+ return bot
async def _connection_ensure():
- if str(qq) not in driver._clients:
- await _bot_connection()
- elif not driver._clients[str(qq)].alive:
- driver._clients.pop(str(qq), None)
- await _bot_connection()
+ self_id = str(qq)
+ if self_id not in driver._clients:
+ bot = await _bot_connection()
+ driver._bot_connect(bot)
+ else:
+ bot = driver._clients[self_id]
+ if not bot.alive:
+ driver._bot_disconnect(bot)
+ return
@driver.on_startup
async def _startup():
@@ -202,19 +174,3 @@ class MiraiWebsocketBot(BaseBot):
if bot is None:
return
await bot.websocket.close() #type:ignore
-
- @overrides(BaseBot)
- async def handle_message(self, message: dict):
- event = Event.new(message)
- await handle_event(self, event)
-
- @overrides(BaseBot)
- async def call_api(self, api: str, **data):
- return super().call_api(api, **data)
-
- @overrides(BaseBot)
- async def send(self, event: "BaseEvent", message: str, **kwargs):
- return super().send(event, message, **kwargs)
-
- def __del__(self):
- self.driver._bot_disconnect(self)
diff --git a/nonebot/adapters/mirai/event/base.py b/nonebot/adapters/mirai/event/base.py
index 6fbb30ff..3b6916f5 100644
--- a/nonebot/adapters/mirai/event/base.py
+++ b/nonebot/adapters/mirai/event/base.py
@@ -86,7 +86,7 @@ class Event(BaseEvent):
@overrides(BaseEvent)
def get_event_description(self) -> str:
- return str(self.dict())
+ return str(self.normalize_dict())
@overrides(BaseEvent)
def get_message(self) -> BaseMessage:
diff --git a/nonebot/adapters/mirai/message.py b/nonebot/adapters/mirai/message.py
index ef3949a6..a577a807 100644
--- a/nonebot/adapters/mirai/message.py
+++ b/nonebot/adapters/mirai/message.py
@@ -135,10 +135,11 @@ class MessageSegment(BaseMessageSegment):
return cls(type=MessageType.POKE, name=name)
-class MessageChain(BaseMessage):
+class MessageChain(BaseMessage): #type:List[MessageSegment]
@overrides(BaseMessage)
- def __init__(self, message: Union[List[Dict[str, Any]], MessageSegment],
+ def __init__(self, message: Union[List[Dict[str, Any]],
+ Iterable[MessageSegment], MessageSegment],
**kwargs):
super().__init__(**kwargs)
if isinstance(message, MessageSegment):
@@ -152,15 +153,16 @@ class MessageChain(BaseMessage):
@overrides(BaseMessage)
def _construct(
- self, message: Iterable[Union[Dict[str, Any], MessageSegment]]
+ self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment]]
) -> List[MessageSegment]:
if isinstance(message, str):
raise ValueError(
"String operation is not supported in mirai adapter")
return [
*map(
- lambda segment: segment if isinstance(segment, MessageSegment)
- else MessageSegment(**segment), message)
+ lambda x: x
+ if isinstance(x, MessageSegment) else MessageSegment(**x),
+ message)
]
def export(self) -> List[Dict[str, Any]]: