Feature: 重构驱动器 lifespan 方法 (#1860)

This commit is contained in:
Ju4tCode 2023-03-29 15:59:54 +08:00 committed by GitHub
parent 0d0bc656c8
commit a8a76393a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 126 additions and 52 deletions

View File

@ -0,0 +1,45 @@
from typing import Any, List, Union, Callable, Awaitable, cast
from nonebot.utils import run_sync, is_coroutine_callable
SYNC_LIFESPAN_FUNC = Callable[[], Any]
ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]]
LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._shutdown_funcs.append(func)
return func
@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
await cast(ASYNC_LIFESPAN_FUNC, func)()
else:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()

View File

@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr
try: try:
import aiohttp import aiohttp
except ImportError as e: # pragma: no cover except ModuleNotFoundError as e: # pragma: no cover
raise ImportError( raise ImportError(
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`" "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
) from e ) from e

View File

@ -19,7 +19,7 @@ FrontMatter:
import logging import logging
import contextlib import contextlib
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Tuple, Union, Callable, Optional from typing import Any, Dict, List, Tuple, Union, Optional
from pydantic import BaseSettings from pydantic import BaseSettings
@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
from ._lifespan import LIFESPAN_FUNC, Lifespan
try: try:
import uvicorn import uvicorn
from fastapi.responses import Response from fastapi.responses import Response
from fastapi import FastAPI, Request, UploadFile, status from fastapi import FastAPI, Request, UploadFile, status
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
except ImportError as e: # pragma: no cover except ModuleNotFoundError as e: # pragma: no cover
raise ImportError( raise ImportError(
"Please install FastAPI by using `pip install nonebot2[fastapi]`" "Please install FastAPI by using `pip install nonebot2[fastapi]`"
) from e ) from e
@ -92,7 +94,10 @@ class Driver(ReverseDriver):
self.fastapi_config: Config = Config(**config.dict()) self.fastapi_config: Config = Config(**config.dict())
self._lifespan = Lifespan()
self._server_app = FastAPI( self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url, openapi_url=self.fastapi_config.fastapi_openapi_url,
docs_url=self.fastapi_config.fastapi_docs_url, docs_url=self.fastapi_config.fastapi_docs_url,
redoc_url=self.fastapi_config.fastapi_redoc_url, redoc_url=self.fastapi_config.fastapi_redoc_url,
@ -148,14 +153,20 @@ class Driver(ReverseDriver):
) )
@overrides(ReverseDriver) @overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable: def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" return self._lifespan.on_startup(func)
return self.server_app.on_event("startup")(func)
@overrides(ReverseDriver) @overrides(ReverseDriver)
def on_shutdown(self, func: Callable) -> Callable: def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_""" return self._lifespan.on_shutdown(func)
return self.server_app.on_event("shutdown")(func)
@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
try:
yield
finally:
await self._lifespan.shutdown()
@overrides(ReverseDriver) @overrides(ReverseDriver)
def run( def run(

View File

@ -31,7 +31,7 @@ from nonebot.drivers import (
try: try:
import httpx import httpx
except ImportError as e: # pragma: no cover except ModuleNotFoundError as e: # pragma: no cover
raise ImportError( raise ImportError(
"Please install httpx by using `pip install nonebot2[httpx]`" "Please install httpx by using `pip install nonebot2[httpx]`"
) from e ) from e

View File

@ -13,7 +13,6 @@ FrontMatter:
import signal import signal
import asyncio import asyncio
import threading import threading
from typing import Set, Union, Callable, Awaitable, cast
from nonebot.log import logger from nonebot.log import logger
from nonebot.consts import WINDOWS from nonebot.consts import WINDOWS
@ -22,7 +21,8 @@ from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import run_sync, is_coroutine_callable from nonebot.utils import run_sync, is_coroutine_callable
HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]] from ._lifespan import LIFESPAN_FUNC, Lifespan
HANDLED_SIGNALS = ( HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`. signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -36,8 +36,9 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(env, config) super().__init__(env, config)
self.startup_funcs: Set[HOOK_FUNC] = set()
self.shutdown_funcs: Set[HOOK_FUNC] = set() self._lifespan = Lifespan()
self.should_exit: asyncio.Event = asyncio.Event() self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False self.force_exit: bool = False
@ -54,20 +55,18 @@ class Driver(BaseDriver):
return logger return logger
@overrides(BaseDriver) @overrides(BaseDriver)
def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC: def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
""" """
注册一个启动时执行的函数 注册一个启动时执行的函数
""" """
self.startup_funcs.add(func) return self._lifespan.on_startup(func)
return func
@overrides(BaseDriver) @overrides(BaseDriver)
def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC: def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
""" """
注册一个停止时执行的函数 注册一个停止时执行的函数
""" """
self.shutdown_funcs.add(func) return self._lifespan.on_shutdown(func)
return func
@overrides(BaseDriver) @overrides(BaseDriver)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
@ -85,16 +84,8 @@ class Driver(BaseDriver):
await self._shutdown() await self._shutdown()
async def _startup(self): async def _startup(self):
# run startup
cors = [
cast(Callable[..., Awaitable[None]], startup)()
if is_coroutine_callable(startup)
else run_sync(startup)()
for startup in self.startup_funcs
]
if cors:
try: try:
await asyncio.gather(*cors) await self._lifespan.startup()
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. " "<r><bg #f8bbd0>Error when running startup function. "
@ -110,16 +101,9 @@ class Driver(BaseDriver):
logger.info("Shutting down") logger.info("Shutting down")
logger.info("Waiting for application shutdown.") logger.info("Waiting for application shutdown.")
# run shutdown
cors = [
cast(Callable[..., Awaitable[None]], shutdown)()
if is_coroutine_callable(shutdown)
else run_sync(shutdown)()
for shutdown in self.shutdown_funcs
]
if cors:
try: try:
await asyncio.gather(*cors) await self._lifespan.shutdown()
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. " "<r><bg #f8bbd0>Error when running shutdown function. "

View File

@ -37,7 +37,7 @@ try:
from quart import Quart, Request, Response from quart import Quart, Request, Response
from quart.datastructures import FileStorage from quart.datastructures import FileStorage
from quart import Websocket as QuartWebSocket from quart import Websocket as QuartWebSocket
except ImportError as e: # pragma: no cover except ModuleNotFoundError as e: # pragma: no cover
raise ImportError( raise ImportError(
"Please install Quart by using `pip install nonebot2[quart]`" "Please install Quart by using `pip install nonebot2[quart]`"
) from e ) from e

View File

@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
try: try:
from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError as e: # pragma: no cover except ModuleNotFoundError as e: # pragma: no cover
raise ImportError( raise ImportError(
"Please install websockets by using `pip install nonebot2[websockets]`" "Please install websockets by using `pip install nonebot2[websockets]`"
) from e ) from e

View File

@ -12,6 +12,7 @@ from nonebot.params import Depends
from nonebot import _resolve_combine_expr from nonebot import _resolve_combine_expr
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import ( from nonebot.drivers import (
URL, URL,
Driver, Driver,
@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
return DriverClass(Env(environment=global_driver.env), global_driver.config) return DriverClass(Env(environment=global_driver.env), global_driver.config)
@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
start_log = []
shutdown_log = []
@lifespan.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)
@lifespan.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)
@lifespan.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)
@lifespan.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)
async with lifespan:
assert start_log == [1, 2]
assert shutdown_log == [1, 2]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",