mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-02-17 16:20:05 +08:00
⏪ revert call method to http post, add api handle
This commit is contained in:
parent
5b3ef53301
commit
c82ceefc8b
@ -1 +1,3 @@
|
||||
from .bot import MiraiBot
|
||||
from .bot import MiraiBot
|
||||
from .event import *
|
||||
from .message import MessageChain, MessageSegment
|
||||
|
@ -1,128 +1,74 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from ipaddress import IPv4Address
|
||||
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
|
||||
TypeVar)
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides
|
||||
|
||||
from .config import Config as MiraiConfig
|
||||
from .event import Event
|
||||
|
||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
||||
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
||||
bound=WebsocketHandlerFunction)
|
||||
from .event import Event, FriendMessage, TempMessage, GroupMessage
|
||||
|
||||
|
||||
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
|
||||
qq: int) -> str:
|
||||
class SessionManager:
|
||||
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
|
||||
session_expiry: timedelta = timedelta(minutes=15)
|
||||
|
||||
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
||||
self.session_key, self.client = session_key, client
|
||||
|
||||
async def post(self, path: str, *, params: Optional[Dict[str, Any]] = None):
|
||||
params = {**(params or {}), 'sessionKey': self.session_key}
|
||||
response = await self.client.post(path, json=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
about = await request('GET', path='/about')
|
||||
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
|
||||
f'<g><b>{about["data"]["version"]}</b></g>')
|
||||
|
||||
status = await request('POST', path='/auth', json={'authKey': auth_key})
|
||||
assert status['code'] == 0
|
||||
session_key = status['session']
|
||||
|
||||
verify = await request('POST',
|
||||
path='/verify',
|
||||
json={
|
||||
'sessionKey': session_key,
|
||||
'qq': qq
|
||||
})
|
||||
assert verify['code'] == 0, verify['msg']
|
||||
return session_key
|
||||
|
||||
|
||||
class WebSocket(BaseWebSocket):
|
||||
@classmethod
|
||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||
auth_key: str):
|
||||
if self_id in cls.sessions:
|
||||
manager = cls.get(self_id)
|
||||
if manager is not None:
|
||||
return manager
|
||||
client = httpx.AsyncClient(base_url=f'http://{host}:{port}')
|
||||
response = await client.post('/auth', json={'authKey': auth_key})
|
||||
response.raise_for_status()
|
||||
auth = response.json()
|
||||
assert auth['code'] == 0
|
||||
session_key = auth['session']
|
||||
response = await client.post('/verify',
|
||||
json={
|
||||
'sessionKey': session_key,
|
||||
'qq': self_id
|
||||
})
|
||||
assert response.json()['code'] == 0
|
||||
cls.sessions[self_id] = session_key, datetime.now(), client
|
||||
return cls(session_key, client)
|
||||
|
||||
@classmethod
|
||||
async def new(cls, *, host: IPv4Address, port: int,
|
||||
session_key: str) -> "WebSocket":
|
||||
listen_address = httpx.URL(f'ws://{host}:{port}/all',
|
||||
params={'sessionKey': session_key})
|
||||
websocket = await websockets.connect(uri=str(listen_address))
|
||||
return cls(websocket)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, websocket: websockets.WebSocketClientProtocol):
|
||||
self.event_handlers: Set[WebsocketHandlerFunction] = set()
|
||||
super().__init__(websocket)
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def websocket(self) -> websockets.WebSocketClientProtocol:
|
||||
return self._websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def closed(self) -> bool:
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: Dict[str, Any]):
|
||||
return await self.websocket.send(json.dumps(data))
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> Dict[str, Any]:
|
||||
received = await self.websocket.recv()
|
||||
return json.loads(received)
|
||||
|
||||
async def _dispatcher(self):
|
||||
while not self.closed:
|
||||
try:
|
||||
data = await self.receive()
|
||||
except websockets.ConnectionClosedOK:
|
||||
logger.debug(f'Websocket connection {self.websocket} closed')
|
||||
break
|
||||
except websockets.ConnectionClosedError:
|
||||
logger.exception(f'Websocket connection {self.websocket} '
|
||||
'connection closed abnormally:')
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.exception(f'Websocket client listened {self.websocket} '
|
||||
f'failed to decode data: {e}')
|
||||
continue
|
||||
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
|
||||
return_exceptions=True)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def accept(self):
|
||||
asyncio.create_task(self._dispatcher())
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def close(self):
|
||||
await self.websocket.close()
|
||||
|
||||
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
|
||||
self.event_handlers.add(callable)
|
||||
return callable
|
||||
def get(cls, self_id: int):
|
||||
key, time, client = cls.sessions[self_id]
|
||||
if datetime.now() - time > cls.session_expiry:
|
||||
return None
|
||||
return cls(key, client)
|
||||
|
||||
|
||||
class MiraiBot(BaseBot):
|
||||
|
||||
def __init__(self, connection_type: str, self_id: str, *,
|
||||
websocket: WebSocket):
|
||||
def __init__(self,
|
||||
connection_type: str,
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: Optional[WebSocket] = None):
|
||||
super().__init__(connection_type, self_id, websocket=websocket)
|
||||
websocket.handle(self.handle_message)
|
||||
self.driver._bot_connect(self)
|
||||
self.api = SessionManager.get(int(self_id))
|
||||
|
||||
@property
|
||||
@overrides(BaseBot)
|
||||
@ -136,85 +82,44 @@ class MiraiBot(BaseBot):
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(cls, driver: "Driver", connection_type: str,
|
||||
headers: dict, body: Optional[dict]) -> NoReturn:
|
||||
raise RequestDenied(
|
||||
status_code=501,
|
||||
reason=f'Connection {connection_type} not implented')
|
||||
headers: dict, body: Optional[dict]) -> str:
|
||||
if connection_type == 'ws':
|
||||
raise RequestDenied(
|
||||
status_code=501,
|
||||
reason='Websocket connection is not implemented')
|
||||
self_id: Optional[str] = headers.get('bot')
|
||||
if self_id is None:
|
||||
raise RequestDenied(status_code=400,
|
||||
reason='Header `Bot` is required.')
|
||||
self_id = str(self_id).strip()
|
||||
await SessionManager.new(
|
||||
int(self_id),
|
||||
host=cls.mirai_config.host, # type: ignore
|
||||
port=cls.mirai_config.port, #type: ignore
|
||||
auth_key=cls.mirai_config.auth_key) # type: ignore
|
||||
return self_id
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
def register(cls, driver: "Driver", config: "Config", qq: int):
|
||||
def register(cls, driver: "Driver", config: "Config"):
|
||||
cls.mirai_config = MiraiConfig(**config.dict())
|
||||
cls.active = True
|
||||
assert cls.mirai_config.auth_key is not None
|
||||
assert cls.mirai_config.host is not None
|
||||
assert cls.mirai_config.port is not None
|
||||
super().register(driver, config)
|
||||
|
||||
async def _bot_connection():
|
||||
async with httpx.AsyncClient(
|
||||
base_url=
|
||||
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
|
||||
) as client:
|
||||
session_key = await _ws_authorization(
|
||||
client,
|
||||
auth_key=cls.mirai_config.auth_key, # type: ignore
|
||||
qq=qq) # type: ignore
|
||||
|
||||
websocket = await WebSocket.new(
|
||||
host=cls.mirai_config.host, # type: ignore
|
||||
port=cls.mirai_config.port, # type: ignore
|
||||
session_key=session_key)
|
||||
bot = cls(connection_type='forward_ws',
|
||||
self_id=str(qq),
|
||||
websocket=websocket)
|
||||
websocket.handle(bot.handle_message)
|
||||
driver._clients[str(qq)] = bot
|
||||
await websocket.accept()
|
||||
|
||||
async def _connection_ensure():
|
||||
if str(qq) not in driver._clients:
|
||||
await _bot_connection()
|
||||
elif not driver._clients[str(qq)].alive:
|
||||
driver._clients.pop(str(qq), None)
|
||||
await _bot_connection()
|
||||
|
||||
@driver.on_startup
|
||||
async def _startup():
|
||||
|
||||
async def _checker():
|
||||
while cls.active:
|
||||
try:
|
||||
await _connection_ensure()
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(
|
||||
'Failed to create mirai connection to '
|
||||
f'<y>{qq}</y>, reason: <r>{e}</r>. '
|
||||
'Will retry after 3 seconds')
|
||||
await asyncio.sleep(3)
|
||||
|
||||
asyncio.create_task(_checker())
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _shutdown():
|
||||
cls.active = False
|
||||
bot = driver._clients.pop(str(qq), None)
|
||||
if bot is None:
|
||||
return
|
||||
await bot.websocket.close() #type:ignore
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def handle_message(self, message: dict):
|
||||
event = Event.new(message)
|
||||
await handle_event(self, event)
|
||||
await handle_event(bot=self,
|
||||
event=Event.new({
|
||||
**message,
|
||||
'self_id': self.self_id,
|
||||
}))
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data):
|
||||
return super().call_api(api, **data)
|
||||
return await self.api.post('/' + api, params=data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self, event: "BaseEvent", message: str, **kwargs):
|
||||
return super().send(event, message, **kwargs)
|
||||
|
||||
def __del__(self):
|
||||
self.driver._bot_disconnect(self)
|
||||
pass
|
||||
|
220
nonebot/adapters/mirai/bot_ws.py
Normal file
220
nonebot/adapters/mirai/bot_ws.py
Normal file
@ -0,0 +1,220 @@
|
||||
import asyncio
|
||||
import json
|
||||
from ipaddress import IPv4Address
|
||||
from typing import (Any, Callable, Coroutine, Dict, NoReturn, Optional, Set,
|
||||
TypeVar)
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides
|
||||
|
||||
from .config import Config as MiraiConfig
|
||||
from .event import Event
|
||||
|
||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
||||
WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
||||
bound=WebsocketHandlerFunction)
|
||||
|
||||
|
||||
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
|
||||
qq: int) -> str:
|
||||
|
||||
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
about = await request('GET', path='/about')
|
||||
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
|
||||
f'<g><b>{about["data"]["version"]}</b></g>')
|
||||
|
||||
status = await request('POST', path='/auth', json={'authKey': auth_key})
|
||||
assert status['code'] == 0
|
||||
session_key = status['session']
|
||||
|
||||
verify = await request('POST',
|
||||
path='/verify',
|
||||
json={
|
||||
'sessionKey': session_key,
|
||||
'qq': qq
|
||||
})
|
||||
assert verify['code'] == 0, verify['msg']
|
||||
return session_key
|
||||
|
||||
|
||||
class WebSocket(BaseWebSocket):
|
||||
|
||||
@classmethod
|
||||
async def new(cls, *, host: IPv4Address, port: int,
|
||||
session_key: str) -> "WebSocket":
|
||||
listen_address = httpx.URL(f'ws://{host}:{port}/all',
|
||||
params={'sessionKey': session_key})
|
||||
websocket = await websockets.connect(uri=str(listen_address))
|
||||
return cls(websocket)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, websocket: websockets.WebSocketClientProtocol):
|
||||
self.event_handlers: Set[WebsocketHandlerFunction] = set()
|
||||
super().__init__(websocket)
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def websocket(self) -> websockets.WebSocketClientProtocol:
|
||||
return self._websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def closed(self) -> bool:
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: Dict[str, Any]):
|
||||
return await self.websocket.send(json.dumps(data))
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> Dict[str, Any]:
|
||||
received = await self.websocket.recv()
|
||||
return json.loads(received)
|
||||
|
||||
async def _dispatcher(self):
|
||||
while not self.closed:
|
||||
try:
|
||||
data = await self.receive()
|
||||
except websockets.ConnectionClosedOK:
|
||||
logger.debug(f'Websocket connection {self.websocket} closed')
|
||||
break
|
||||
except websockets.ConnectionClosedError:
|
||||
logger.exception(f'Websocket connection {self.websocket} '
|
||||
'connection closed abnormally:')
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.exception(f'Websocket client listened {self.websocket} '
|
||||
f'failed to decode data: {e}')
|
||||
continue
|
||||
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
|
||||
return_exceptions=True)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def accept(self):
|
||||
asyncio.create_task(self._dispatcher())
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def close(self):
|
||||
await self.websocket.close()
|
||||
|
||||
def handle(self, callable: WebsocketHandler_T) -> WebsocketHandler_T:
|
||||
self.event_handlers.add(callable)
|
||||
return callable
|
||||
|
||||
|
||||
class MiraiWebsocketBot(BaseBot):
|
||||
|
||||
def __init__(self, connection_type: str, self_id: str, *,
|
||||
websocket: WebSocket):
|
||||
super().__init__(connection_type, self_id, websocket=websocket)
|
||||
websocket.handle(self.handle_message)
|
||||
self.driver._bot_connect(self)
|
||||
|
||||
@property
|
||||
@overrides(BaseBot)
|
||||
def type(self) -> str:
|
||||
return "mirai"
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return not self.websocket.closed
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
async def check_permission(cls, driver: "Driver", connection_type: str,
|
||||
headers: dict, body: Optional[dict]) -> NoReturn:
|
||||
raise RequestDenied(
|
||||
status_code=501,
|
||||
reason=f'Connection {connection_type} not implented')
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseBot)
|
||||
def register(cls, driver: "Driver", config: "Config", qq: int):
|
||||
cls.mirai_config = MiraiConfig(**config.dict())
|
||||
cls.active = True
|
||||
assert cls.mirai_config.auth_key is not None
|
||||
assert cls.mirai_config.host is not None
|
||||
assert cls.mirai_config.port is not None
|
||||
super().register(driver, config)
|
||||
|
||||
async def _bot_connection():
|
||||
async with httpx.AsyncClient(
|
||||
base_url=
|
||||
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
|
||||
) as client:
|
||||
session_key = await _ws_authorization(
|
||||
client,
|
||||
auth_key=cls.mirai_config.auth_key, # type: ignore
|
||||
qq=qq) # type: ignore
|
||||
|
||||
websocket = await WebSocket.new(
|
||||
host=cls.mirai_config.host, # type: ignore
|
||||
port=cls.mirai_config.port, # type: ignore
|
||||
session_key=session_key)
|
||||
bot = cls(connection_type='forward_ws',
|
||||
self_id=str(qq),
|
||||
websocket=websocket)
|
||||
websocket.handle(bot.handle_message)
|
||||
driver._clients[str(qq)] = bot
|
||||
await websocket.accept()
|
||||
|
||||
async def _connection_ensure():
|
||||
if str(qq) not in driver._clients:
|
||||
await _bot_connection()
|
||||
elif not driver._clients[str(qq)].alive:
|
||||
driver._clients.pop(str(qq), None)
|
||||
await _bot_connection()
|
||||
|
||||
@driver.on_startup
|
||||
async def _startup():
|
||||
|
||||
async def _checker():
|
||||
while cls.active:
|
||||
try:
|
||||
await _connection_ensure()
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(
|
||||
'Failed to create mirai connection to '
|
||||
f'<y>{qq}</y>, reason: <r>{e}</r>. '
|
||||
'Will retry after 3 seconds')
|
||||
await asyncio.sleep(3)
|
||||
|
||||
asyncio.create_task(_checker())
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _shutdown():
|
||||
cls.active = False
|
||||
bot = driver._clients.pop(str(qq), None)
|
||||
if bot is None:
|
||||
return
|
||||
await bot.websocket.close() #type:ignore
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def handle_message(self, message: dict):
|
||||
event = Event.new(message)
|
||||
await handle_event(self, event)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data):
|
||||
return super().call_api(api, **data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self, event: "BaseEvent", message: str, **kwargs):
|
||||
return super().send(event, message, **kwargs)
|
||||
|
||||
def __del__(self):
|
||||
self.driver._bot_disconnect(self)
|
@ -37,6 +37,7 @@ class PrivateSenderInfo(BaseModel):
|
||||
|
||||
|
||||
class Event(BaseEvent):
|
||||
self_id: int
|
||||
type: str
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user