🐛 fix bugs in quart driver

This commit is contained in:
Mix 2021-02-06 10:34:52 +08:00
parent 9e0862bc97
commit 496f64f103
2 changed files with 64 additions and 34 deletions

View File

@ -1,19 +1,21 @@
import asyncio import asyncio
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from logging import getLogger, warn from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional,
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar Type, TypeVar)
import uvicorn
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.config import Env from nonebot.config import Env
from nonebot.drivers import Driver as BaseDriver from nonebot.drivers import Driver as BaseDriver
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied from nonebot.exception import RequestDenied
from nonebot.log import LoguruHandler, logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
if TYPE_CHECKING:
from nonebot.adapters import Bot
try: try:
from hypercorn.asyncio import serve
from hypercorn.config import Config as HypercornConfig
from quart import Quart, Request, Response from quart import Quart, Request, Response
from quart import Websocket as QuartWebSocket from quart import Websocket as QuartWebSocket
from quart import exceptions from quart import exceptions
@ -32,11 +34,21 @@ class Driver(BaseDriver):
super().__init__(env, config) super().__init__(env, config)
self._server_app = Quart(self.__class__.__qualname__) self._server_app = Quart(self.__class__.__qualname__)
self._server_app.logger.handlers.clear()
self._server_app.logger.addHandler(LoguruHandler()) @overrides(BaseDriver)
self._server_app.route('/<adapter>/http', def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
methods=['POST'])(self._handle_http) if name in self._adapters:
self._server_app.websocket('/<adapter>/ws')(self._handle_ws_reverse) return
super().register_adapter(name, adapter, **kwargs)
@self.server_app.route(f'/{name}/http', endpoint=name + '_http')
async def _http_handler():
await self._handle_http(name)
@self.server_app.websocket(f'/{name}/ws', endpoint=name + '_ws')
async def _ws_handler():
await self._handle_ws_reverse(name)
@property @property
@overrides(BaseDriver) @overrides(BaseDriver)
@ -55,7 +67,7 @@ class Driver(BaseDriver):
@property @property
@overrides(BaseDriver) @overrides(BaseDriver)
def loggers(self): def logger(self):
return self._server_app.logger return self._server_app.logger
@overrides(BaseDriver) @overrides(BaseDriver)
@ -66,43 +78,60 @@ class Driver(BaseDriver):
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
return self.server_app.after_serving(func) # type: ignore return self.server_app.after_serving(func) # type: ignore
@overrides(BaseDriver)
@overrides(BaseDriver) @overrides(BaseDriver)
def run(self, def run(self,
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[int] = None, port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs): **kwargs):
super().run(host, port, **kwargs) """使用 ``uvicorn`` 启动 Quart"""
config = HypercornConfig() super().run(host, port, app, **kwargs)
for k, v in kwargs.items(): LOGGING_CONFIG = {
if not hasattr(config, k): "version": 1,
warn(f'Config {k!r} is not available for quart driver.') "disable_existing_loggers": False,
continue "handlers": {
setattr(config, k, v) "default": {
config.bind.append( "class": "nonebot.log.LoguruHandler",
f'{host or self.config.host}:{port or self.config.port}') },
},
serve_task = asyncio.run_coroutine_threadsafe( "loggers": {
coro=serve(self.server_app, config), "uvicorn.error": {
loop=asyncio.get_running_loop(), "handlers": ["default"],
) "level": "INFO"
try: },
serve_task.result() "uvicorn.access": {
finally: "handlers": ["default"],
serve_task.cancel() "level": "INFO",
},
},
}
uvicorn.run(app or self.server_app,
host=host or str(self.config.host),
port=port or self.config.port,
reload=bool(app) and self.config.debug,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs)
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_http(self, adapter: str): async def _handle_http(self, adapter: str):
request: Request = _request request: Request = _request
try: try:
data: Dict[str, Any] = await request.get_json() data: Dict[str, Any] = await request.get_json()
except Exception as e: except Exception as e:
raise exceptions.BadRequest() raise exceptions.BadRequest()
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning(f'Unknown adapter {adapter}. ' logger.warning(f'Unknown adapter {adapter}. '
'Please register the adapter before use.') 'Please register the adapter before use.')
raise exceptions.NotFound() raise exceptions.NotFound()
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
headers = dict(request.headers) headers = {k: v for k, v in request.headers.items(lower=True)}
try: try:
self_id = await BotClass.check_permission(self, 'http', headers, self_id = await BotClass.check_permission(self, 'http', headers,
data) data)
@ -120,7 +149,6 @@ class Driver(BaseDriver):
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_ws_reverse(self, adapter: str): async def _handle_ws_reverse(self, adapter: str):
websocket: QuartWebSocket = _websocket websocket: QuartWebSocket = _websocket
if adapter not in self._adapters: if adapter not in self._adapters:
logger.warning( logger.warning(
f'Unknown adapter {adapter}. Please register the adapter before use.' f'Unknown adapter {adapter}. Please register the adapter before use.'
@ -128,10 +156,12 @@ class Driver(BaseDriver):
raise exceptions.NotFound() raise exceptions.NotFound()
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
headers = dict(websocket.headers) headers = {k: v for k, v in websocket.headers.items(lower=True)}
try: try:
self_id = await BotClass.check_permission(self, 'ws', headers, None) self_id = await BotClass.check_permission(self, 'websocket',
headers, None)
except RequestDenied as e: except RequestDenied as e:
print(e.reason)
raise exceptions.HTTPException(status_code=e.status_code, raise exceptions.HTTPException(status_code=e.status_code,
description=e.reason, description=e.reason,
name='Request Denied') name='Request Denied')

View File

@ -1,4 +1,4 @@
DRIVER=nonebot.drivers.fastapi DRIVER=nonebot.drivers.quart
HOST=0.0.0.0 HOST=0.0.0.0
PORT=2333 PORT=2333
DEBUG=true DEBUG=true