nonebot2/nonebot/drivers/quart.py

240 lines
7.9 KiB
Python
Raw Normal View History

2021-02-06 10:46:17 +08:00
"""
Quart 驱动适配
================
后端使用方法请参考: `Quart 文档`_
.. _Quart 文档:
https://pgjones.gitlab.io/quart/index.html
"""
import asyncio
from json.decoder import JSONDecodeError
2021-02-06 11:42:40 +08:00
from typing import Any, Callable, Coroutine, Dict, Optional, Type, TypeVar
2021-02-06 10:34:52 +08:00
import uvicorn
from nonebot.config import Config as NoneBotConfig
from nonebot.config import Env
from nonebot.drivers import Driver as BaseDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
2021-02-06 10:34:52 +08:00
from nonebot.log import logger
from nonebot.typing import overrides
try:
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
from quart import exceptions
from quart import request as _request
from quart import websocket as _websocket
except ImportError:
raise ValueError(
'Please install Quart by using `pip install nonebot2[quart]`')
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
class Driver(BaseDriver):
2021-02-06 10:46:17 +08:00
"""
Quart 驱动框架
:上报地址:
* ``/{adapter name}/http``: HTTP POST 上报
2021-02-06 10:46:17 +08:00
* ``/{adapter name}/ws``: WebSocket 上报
"""
@overrides(BaseDriver)
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule('/<adapter>/http',
methods=['POST'],
view_func=self._handle_http)
self._server_app.add_websocket('/<adapter>/ws',
view_func=self._handle_ws_reverse)
@property
@overrides(BaseDriver)
def type(self) -> str:
2021-02-06 10:46:17 +08:00
"""驱动名称: ``quart``"""
return 'quart'
@property
@overrides(BaseDriver)
def server_app(self) -> Quart:
2021-02-06 10:46:17 +08:00
"""``Quart`` 对象"""
return self._server_app
@property
@overrides(BaseDriver)
def asgi(self):
2021-02-06 10:46:17 +08:00
"""``Quart`` 对象"""
return self._server_app
@property
@overrides(BaseDriver)
2021-02-06 10:34:52 +08:00
def logger(self):
2021-02-06 10:46:17 +08:00
"""fastapi 使用的 logger"""
return self._server_app.logger
@overrides(BaseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
2021-02-06 10:46:17 +08:00
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
return self.server_app.before_serving(func) # type: ignore
@overrides(BaseDriver)
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
2021-02-06 10:46:17 +08:00
"""参考文档: `Startup and Shutdown <https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html>`_"""
return self.server_app.after_serving(func) # type: ignore
@overrides(BaseDriver)
def run(self,
host: Optional[str] = None,
port: Optional[int] = None,
2021-02-06 10:34:52 +08:00
*,
app: Optional[str] = None,
**kwargs):
2021-02-06 10:34:52 +08:00
"""使用 ``uvicorn`` 启动 Quart"""
super().run(host, port, app, **kwargs)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"handlers": {
"default": {
"class": "nonebot.log.LoguruHandler",
},
},
"loggers": {
"uvicorn.error": {
"handlers": ["default"],
"level": "INFO"
},
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
},
},
}
uvicorn.run(app or self.server_app,
host=host or str(self.config.host),
port=port or self.config.port,
reload=bool(app) and self.config.debug,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
@overrides(BaseDriver)
async def _handle_http(self, adapter: str):
request: Request = _request
2021-02-06 10:34:52 +08:00
try:
data: Dict[str, Any] = await request.get_json()
except Exception as e:
raise exceptions.BadRequest()
2021-02-06 10:34:52 +08:00
if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. '
'Please register the adapter before use.')
raise exceptions.NotFound()
2021-02-06 10:34:52 +08:00
BotClass = self._adapters[adapter]
2021-02-06 10:34:52 +08:00
headers = {k: v for k, v in request.headers.items(lower=True)}
try:
self_id = await BotClass.check_permission(self, 'http', headers,
data)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
bot = BotClass('http', self_id)
asyncio.create_task(bot.handle_message(data))
return Response('', 204)
@overrides(BaseDriver)
async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket
if adapter not in self._adapters:
logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.'
)
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
2021-02-06 10:34:52 +08:00
headers = {k: v for k, v in websocket.headers.items(lower=True)}
try:
2021-02-06 10:34:52 +08:00
self_id = await BotClass.check_permission(self, 'websocket',
headers, None)
except RequestDenied as e:
raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason,
name='Request Denied')
if self_id in self._clients:
logger.warning("There's already a reverse websocket connection,"
"so the event may be handled twice.")
ws = WebSocket(websocket)
bot = BotClass('websocket', self_id, websocket=ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{adapter.upper()} "
f"Bot {self_id}</y> Accepted!")
self._bot_connect(bot)
try:
while not ws.closed:
data = await ws.receive()
if data is None:
continue
asyncio.create_task(bot.handle_message(data))
finally:
self._bot_disconnect(bot)
class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, websocket: QuartWebSocket):
super().__init__(websocket)
self._closed = False
@property
@overrides(BaseWebSocket)
def websocket(self) -> QuartWebSocket:
return self._websocket
@property
@overrides(BaseWebSocket)
def closed(self):
return self._closed
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
self._closed = False
@overrides(BaseWebSocket)
async def close(self):
self._closed = True
@overrides(BaseWebSocket)
async def receive(self) -> Optional[Dict[str, Any]]:
data: Optional[Dict[str, Any]] = None
try:
data = await self.websocket.receive_json()
except JSONDecodeError:
logger.warning('Received an invalid json message.')
except asyncio.CancelledError:
self._closed = True
logger.warning('WebSocket disconnected by peer.')
return data
@overrides(BaseWebSocket)
async def send(self, data: dict):
await self.websocket.send_json(data)