mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-30 17:15:08 +08:00
🐛 fix bugs in quart driver
This commit is contained in:
parent
9e0862bc97
commit
496f64f103
@ -1,19 +1,21 @@
|
||||
import asyncio
|
||||
from json.decoder import JSONDecodeError
|
||||
from logging import getLogger, warn
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional,
|
||||
Type, TypeVar)
|
||||
|
||||
import uvicorn
|
||||
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.config import Env
|
||||
from nonebot.drivers import Driver as BaseDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.log import LoguruHandler, logger
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot
|
||||
try:
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config as HypercornConfig
|
||||
from quart import Quart, Request, Response
|
||||
from quart import Websocket as QuartWebSocket
|
||||
from quart import exceptions
|
||||
@ -32,11 +34,21 @@ class Driver(BaseDriver):
|
||||
super().__init__(env, config)
|
||||
|
||||
self._server_app = Quart(self.__class__.__qualname__)
|
||||
self._server_app.logger.handlers.clear()
|
||||
self._server_app.logger.addHandler(LoguruHandler())
|
||||
self._server_app.route('/<adapter>/http',
|
||||
methods=['POST'])(self._handle_http)
|
||||
self._server_app.websocket('/<adapter>/ws')(self._handle_ws_reverse)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
|
||||
if name in self._adapters:
|
||||
return
|
||||
|
||||
super().register_adapter(name, adapter, **kwargs)
|
||||
|
||||
@self.server_app.route(f'/{name}/http', endpoint=name + '_http')
|
||||
async def _http_handler():
|
||||
await self._handle_http(name)
|
||||
|
||||
@self.server_app.websocket(f'/{name}/ws', endpoint=name + '_ws')
|
||||
async def _ws_handler():
|
||||
await self._handle_ws_reverse(name)
|
||||
|
||||
@property
|
||||
@overrides(BaseDriver)
|
||||
@ -55,7 +67,7 @@ class Driver(BaseDriver):
|
||||
|
||||
@property
|
||||
@overrides(BaseDriver)
|
||||
def loggers(self):
|
||||
def logger(self):
|
||||
return self._server_app.logger
|
||||
|
||||
@overrides(BaseDriver)
|
||||
@ -66,43 +78,60 @@ class Driver(BaseDriver):
|
||||
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
|
||||
return self.server_app.after_serving(func) # type: ignore
|
||||
|
||||
@overrides(BaseDriver)
|
||||
@overrides(BaseDriver)
|
||||
def run(self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
*,
|
||||
app: Optional[str] = None,
|
||||
**kwargs):
|
||||
super().run(host, port, **kwargs)
|
||||
config = HypercornConfig()
|
||||
for k, v in kwargs.items():
|
||||
if not hasattr(config, k):
|
||||
warn(f'Config {k!r} is not available for quart driver.')
|
||||
continue
|
||||
setattr(config, k, v)
|
||||
config.bind.append(
|
||||
f'{host or self.config.host}:{port or self.config.port}')
|
||||
|
||||
serve_task = asyncio.run_coroutine_threadsafe(
|
||||
coro=serve(self.server_app, config),
|
||||
loop=asyncio.get_running_loop(),
|
||||
)
|
||||
try:
|
||||
serve_task.result()
|
||||
finally:
|
||||
serve_task.cancel()
|
||||
"""使用 ``uvicorn`` 启动 Quart"""
|
||||
super().run(host, port, app, **kwargs)
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"handlers": {
|
||||
"default": {
|
||||
"class": "nonebot.log.LoguruHandler",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO"
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
},
|
||||
},
|
||||
}
|
||||
uvicorn.run(app or self.server_app,
|
||||
host=host or str(self.config.host),
|
||||
port=port or self.config.port,
|
||||
reload=bool(app) and self.config.debug,
|
||||
debug=self.config.debug,
|
||||
log_config=LOGGING_CONFIG,
|
||||
**kwargs)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_http(self, adapter: str):
|
||||
request: Request = _request
|
||||
|
||||
try:
|
||||
data: Dict[str, Any] = await request.get_json()
|
||||
except Exception as e:
|
||||
raise exceptions.BadRequest()
|
||||
|
||||
if adapter not in self._adapters:
|
||||
logger.warning(f'Unknown adapter {adapter}. '
|
||||
'Please register the adapter before use.')
|
||||
raise exceptions.NotFound()
|
||||
|
||||
BotClass = self._adapters[adapter]
|
||||
headers = dict(request.headers)
|
||||
headers = {k: v for k, v in request.headers.items(lower=True)}
|
||||
|
||||
try:
|
||||
self_id = await BotClass.check_permission(self, 'http', headers,
|
||||
data)
|
||||
@ -120,7 +149,6 @@ class Driver(BaseDriver):
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_ws_reverse(self, adapter: str):
|
||||
websocket: QuartWebSocket = _websocket
|
||||
|
||||
if adapter not in self._adapters:
|
||||
logger.warning(
|
||||
f'Unknown adapter {adapter}. Please register the adapter before use.'
|
||||
@ -128,10 +156,12 @@ class Driver(BaseDriver):
|
||||
raise exceptions.NotFound()
|
||||
|
||||
BotClass = self._adapters[adapter]
|
||||
headers = dict(websocket.headers)
|
||||
headers = {k: v for k, v in websocket.headers.items(lower=True)}
|
||||
try:
|
||||
self_id = await BotClass.check_permission(self, 'ws', headers, None)
|
||||
self_id = await BotClass.check_permission(self, 'websocket',
|
||||
headers, None)
|
||||
except RequestDenied as e:
|
||||
print(e.reason)
|
||||
raise exceptions.HTTPException(status_code=e.status_code,
|
||||
description=e.reason,
|
||||
name='Request Denied')
|
||||
|
@ -1,4 +1,4 @@
|
||||
DRIVER=nonebot.drivers.fastapi
|
||||
DRIVER=nonebot.drivers.quart
|
||||
HOST=0.0.0.0
|
||||
PORT=2333
|
||||
DEBUG=true
|
||||
|
Loading…
Reference in New Issue
Block a user