diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index b8073739..2ee9eb1a 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -1,19 +1,21 @@ import asyncio from json.decoder import JSONDecodeError -from logging import getLogger, warn -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar +from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional, + Type, TypeVar) + +import uvicorn from nonebot.config import Config as NoneBotConfig from nonebot.config import Env from nonebot.drivers import Driver as BaseDriver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.exception import RequestDenied -from nonebot.log import LoguruHandler, logger +from nonebot.log import logger from nonebot.typing import overrides +if TYPE_CHECKING: + from nonebot.adapters import Bot try: - from hypercorn.asyncio import serve - from hypercorn.config import Config as HypercornConfig from quart import Quart, Request, Response from quart import Websocket as QuartWebSocket from quart import exceptions @@ -32,11 +34,21 @@ class Driver(BaseDriver): super().__init__(env, config) self._server_app = Quart(self.__class__.__qualname__) - self._server_app.logger.handlers.clear() - self._server_app.logger.addHandler(LoguruHandler()) - self._server_app.route('//http', - methods=['POST'])(self._handle_http) - self._server_app.websocket('//ws')(self._handle_ws_reverse) + + @overrides(BaseDriver) + def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): + if name in self._adapters: + return + + super().register_adapter(name, adapter, **kwargs) + + @self.server_app.route(f'/{name}/http', endpoint=name + '_http') + async def _http_handler(): + await self._handle_http(name) + + @self.server_app.websocket(f'/{name}/ws', endpoint=name + '_ws') + async def _ws_handler(): + await self._handle_ws_reverse(name) @property @overrides(BaseDriver) @@ -55,7 +67,7 @@ class Driver(BaseDriver): @property @overrides(BaseDriver) - def loggers(self): + def logger(self): return self._server_app.logger @overrides(BaseDriver) @@ -66,43 +78,60 @@ class Driver(BaseDriver): def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: return self.server_app.after_serving(func) # type: ignore + @overrides(BaseDriver) @overrides(BaseDriver) def run(self, host: Optional[str] = None, port: Optional[int] = None, + *, + app: Optional[str] = None, **kwargs): - super().run(host, port, **kwargs) - config = HypercornConfig() - for k, v in kwargs.items(): - if not hasattr(config, k): - warn(f'Config {k!r} is not available for quart driver.') - continue - setattr(config, k, v) - config.bind.append( - f'{host or self.config.host}:{port or self.config.port}') - - serve_task = asyncio.run_coroutine_threadsafe( - coro=serve(self.server_app, config), - loop=asyncio.get_running_loop(), - ) - try: - serve_task.result() - finally: - serve_task.cancel() + """使用 ``uvicorn`` 启动 Quart""" + super().run(host, port, app, **kwargs) + LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "handlers": { + "default": { + "class": "nonebot.log.LoguruHandler", + }, + }, + "loggers": { + "uvicorn.error": { + "handlers": ["default"], + "level": "INFO" + }, + "uvicorn.access": { + "handlers": ["default"], + "level": "INFO", + }, + }, + } + uvicorn.run(app or self.server_app, + host=host or str(self.config.host), + port=port or self.config.port, + reload=bool(app) and self.config.debug, + debug=self.config.debug, + log_config=LOGGING_CONFIG, + **kwargs) @overrides(BaseDriver) async def _handle_http(self, adapter: str): request: Request = _request + try: data: Dict[str, Any] = await request.get_json() except Exception as e: raise exceptions.BadRequest() + if adapter not in self._adapters: logger.warning(f'Unknown adapter {adapter}. ' 'Please register the adapter before use.') raise exceptions.NotFound() + BotClass = self._adapters[adapter] - headers = dict(request.headers) + headers = {k: v for k, v in request.headers.items(lower=True)} + try: self_id = await BotClass.check_permission(self, 'http', headers, data) @@ -120,7 +149,6 @@ class Driver(BaseDriver): @overrides(BaseDriver) async def _handle_ws_reverse(self, adapter: str): websocket: QuartWebSocket = _websocket - if adapter not in self._adapters: logger.warning( f'Unknown adapter {adapter}. Please register the adapter before use.' @@ -128,10 +156,12 @@ class Driver(BaseDriver): raise exceptions.NotFound() BotClass = self._adapters[adapter] - headers = dict(websocket.headers) + headers = {k: v for k, v in websocket.headers.items(lower=True)} try: - self_id = await BotClass.check_permission(self, 'ws', headers, None) + self_id = await BotClass.check_permission(self, 'websocket', + headers, None) except RequestDenied as e: + print(e.reason) raise exceptions.HTTPException(status_code=e.status_code, description=e.reason, name='Request Denied') diff --git a/tests/.env.dev b/tests/.env.dev index 33e6f835..ef16df99 100644 --- a/tests/.env.dev +++ b/tests/.env.dev @@ -1,4 +1,4 @@ -DRIVER=nonebot.drivers.fastapi +DRIVER=nonebot.drivers.quart HOST=0.0.0.0 PORT=2333 DEBUG=true