Feature: 重构驱动器 lifespan 方法 (#1860)

This commit is contained in:
Ju4tCode 2023-03-29 15:59:54 +08:00 committed by GitHub
parent 0d0bc656c8
commit a8a76393a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 126 additions and 52 deletions

View File

@ -0,0 +1,45 @@
from typing import Any, List, Union, Callable, Awaitable, cast
from nonebot.utils import run_sync, is_coroutine_callable
SYNC_LIFESPAN_FUNC = Callable[[], Any]
ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]]
LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._shutdown_funcs.append(func)
return func
@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
await cast(ASYNC_LIFESPAN_FUNC, func)()
else:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()

View File

@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr
try:
import aiohttp
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
) from e

View File

@ -19,7 +19,7 @@ FrontMatter:
import logging
import contextlib
from functools import wraps
from typing import Any, Dict, List, Tuple, Union, Callable, Optional
from typing import Any, Dict, List, Tuple, Union, Optional
from pydantic import BaseSettings
@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
from ._lifespan import LIFESPAN_FUNC, Lifespan
try:
import uvicorn
from fastapi.responses import Response
from fastapi import FastAPI, Request, UploadFile, status
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install FastAPI by using `pip install nonebot2[fastapi]`"
) from e
@ -92,7 +94,10 @@ class Driver(ReverseDriver):
self.fastapi_config: Config = Config(**config.dict())
self._lifespan = Lifespan()
self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
docs_url=self.fastapi_config.fastapi_docs_url,
redoc_url=self.fastapi_config.fastapi_redoc_url,
@ -148,14 +153,20 @@ class Driver(ReverseDriver):
)
@overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
return self.server_app.on_event("startup")(func)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)
@overrides(ReverseDriver)
def on_shutdown(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_"""
return self.server_app.on_event("shutdown")(func)
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)
@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
try:
yield
finally:
await self._lifespan.shutdown()
@overrides(ReverseDriver)
def run(

View File

@ -31,7 +31,7 @@ from nonebot.drivers import (
try:
import httpx
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install httpx by using `pip install nonebot2[httpx]`"
) from e

View File

@ -13,7 +13,6 @@ FrontMatter:
import signal
import asyncio
import threading
from typing import Set, Union, Callable, Awaitable, cast
from nonebot.log import logger
from nonebot.consts import WINDOWS
@ -22,7 +21,8 @@ from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import run_sync, is_coroutine_callable
HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]]
from ._lifespan import LIFESPAN_FUNC, Lifespan
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -36,8 +36,9 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[HOOK_FUNC] = set()
self.shutdown_funcs: Set[HOOK_FUNC] = set()
self._lifespan = Lifespan()
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
@ -54,20 +55,18 @@ class Driver(BaseDriver):
return logger
@overrides(BaseDriver)
def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC:
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""
注册一个启动时执行的函数
"""
self.startup_funcs.add(func)
return func
return self._lifespan.on_startup(func)
@overrides(BaseDriver)
def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC:
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""
注册一个停止时执行的函数
"""
self.shutdown_funcs.add(func)
return func
return self._lifespan.on_shutdown(func)
@overrides(BaseDriver)
def run(self, *args, **kwargs):
@ -85,21 +84,13 @@ class Driver(BaseDriver):
await self._shutdown()
async def _startup(self):
# run startup
cors = [
cast(Callable[..., Awaitable[None]], startup)()
if is_coroutine_callable(startup)
else run_sync(startup)()
for startup in self.startup_funcs
]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
try:
await self._lifespan.startup()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
logger.info("Application startup completed.")
@ -110,21 +101,14 @@ class Driver(BaseDriver):
logger.info("Shutting down")
logger.info("Waiting for application shutdown.")
# run shutdown
cors = [
cast(Callable[..., Awaitable[None]], shutdown)()
if is_coroutine_callable(shutdown)
else run_sync(shutdown)()
for shutdown in self.shutdown_funcs
]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():

View File

@ -37,7 +37,7 @@ try:
from quart import Quart, Request, Response
from quart.datastructures import FileStorage
from quart import Websocket as QuartWebSocket
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install Quart by using `pip install nonebot2[quart]`"
) from e

View File

@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
try:
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install websockets by using `pip install nonebot2[websockets]`"
) from e

View File

@ -12,6 +12,7 @@ from nonebot.params import Depends
from nonebot import _resolve_combine_expr
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
return DriverClass(Env(environment=global_driver.env), global_driver.config)
@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
start_log = []
shutdown_log = []
@lifespan.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)
@lifespan.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)
@lifespan.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)
@lifespan.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)
async with lifespan:
assert start_log == [1, 2]
assert shutdown_log == [1, 2]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",