nonebot2/nonebot/drivers/quart.py
2021-12-22 16:53:55 +08:00

294 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Quart 驱动适配
================
后端使用方法请参考: `Quart 文档`_
.. _Quart 文档:
https://pgjones.gitlab.io/quart/index.html
"""
from functools import partial
from typing import List, TypeVar, Callable, Optional, Coroutine
import uvicorn
from pydantic import BaseSettings
from nonebot.config import Env
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
try:
from quart import request as _request
import werkzeug.exceptions as exceptions
from quart import websocket as _websocket
from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket
except ImportError:
raise ValueError("Please install Quart by using `pip install nonebot2[quart]`")
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
class Config(BaseSettings):
"""
Quart 驱动框架设置
"""
quart_reload: Optional[bool] = None
"""
:类型:
``Optional[bool]``
:说明:
开启/关闭冷重载,默认会在配置了 app 的 debug 模式启用
"""
quart_reload_dirs: Optional[List[str]] = None
"""
:类型:
``Optional[List[str]]``
:说明:
重载监控文件夹列表,默认为 uvicorn 默认值
"""
quart_reload_delay: Optional[float] = None
"""
:类型:
``Optional[float]``
:说明:
重载延迟,默认为 uvicorn 默认值
"""
quart_reload_includes: Optional[List[str]] = None
"""
:类型:
``Optional[List[str]]``
:说明:
要监听的文件列表,支持 glob pattern默认为 uvicorn 默认值
"""
quart_reload_excludes: Optional[List[str]] = None
"""
:类型:
``Optional[List[str]]``
:说明:
不要监听的文件列表,支持 glob pattern默认为 uvicorn 默认值
"""
class Config:
extra = "ignore"
class Driver(ReverseDriver):
"""
Quart 驱动框架
"""
def __init__(self, env: Env, config: NoneBotConfig):
super().__init__(env, config)
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__)
@property
@overrides(ReverseDriver)
def type(self) -> str:
"""驱动名称: ``quart``"""
return "quart"
@property
@overrides(ReverseDriver)
def server_app(self) -> Quart:
"""``Quart`` 对象"""
return self._server_app
@property
@overrides(ReverseDriver)
def asgi(self):
"""``Quart`` 对象"""
return self._server_app
@property
@overrides(ReverseDriver)
def logger(self):
"""Quart 使用的 logger"""
return self._server_app.logger
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
self._server_app.add_url_rule(
setup.path.path,
methods=[setup.method],
view_func=partial(self._handle_http, setup=setup),
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
self._server_app.add_websocket(
setup.path.path,
view_func=partial(self._handle_ws, setup=setup),
)
@overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown`_
.. _Startup and Shutdown:
https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html
"""
return self.server_app.before_serving(func) # type: ignore
@overrides(ReverseDriver)
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown`_"""
return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver)
def run(
self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs,
):
"""使用 ``uvicorn`` 启动 Quart"""
super().run(host, port, app, **kwargs)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"handlers": {
"default": {
"class": "nonebot.log.LoguruHandler",
},
},
"loggers": {
"uvicorn.error": {"handlers": ["default"], "level": "INFO"},
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
},
},
}
uvicorn.run(
app or self.server_app, # type: ignore
host=host or str(self.config.host),
port=port or self.config.port,
reload=self.quart_config.quart_reload
if self.quart_config.quart_reload is not None
else (bool(app) and self.config.debug),
reload_dirs=self.quart_config.quart_reload_dirs,
reload_delay=self.quart_config.quart_reload_delay,
reload_includes=self.quart_config.quart_reload_includes,
reload_excludes=self.quart_config.quart_reload_excludes,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs,
)
async def _handle_http(self, setup: HTTPServerSetup) -> Response:
request: Request = _request
http_request = BaseRequest(
request.method,
request.url,
headers=request.headers.items(),
cookies=list(request.cookies.items()),
content=await request.get_data(
cache=False, as_text=False, parse_form_data=False
),
version=request.http_version,
)
response = await setup.handle_func(http_request)
return Response(
response.content or "",
response.status_code or 200,
headers=dict(response.headers),
)
async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
websocket: QuartWebSocket = _websocket
http_request = BaseRequest(
websocket.method,
websocket.url,
headers=websocket.headers.items(),
cookies=list(websocket.cookies.items()),
version=websocket.http_version,
)
ws = WebSocket(request=http_request, websocket=websocket)
await setup.handle_func(ws)
class WebSocket(BaseWebSocket):
def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self):
# FIXME
return True
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
@overrides(BaseWebSocket)
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
async def receive(self) -> str:
msg = await self.websocket.receive()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
msg = await self.websocket.receive()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
async def send(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)