mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
✨ Feature: 重构驱动器 lifespan 方法 (#1860)
This commit is contained in:
parent
0d0bc656c8
commit
a8a76393a5
45
nonebot/drivers/_lifespan.py
Normal file
45
nonebot/drivers/_lifespan.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Any, List, Union, Callable, Awaitable, cast
|
||||
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
|
||||
SYNC_LIFESPAN_FUNC = Callable[[], Any]
|
||||
ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]]
|
||||
LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
||||
|
||||
|
||||
class Lifespan:
|
||||
def __init__(self) -> None:
|
||||
self._startup_funcs: List[LIFESPAN_FUNC] = []
|
||||
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
|
||||
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
self._startup_funcs.append(func)
|
||||
return func
|
||||
|
||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
self._shutdown_funcs.append(func)
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
async def _run_lifespan_func(
|
||||
funcs: List[LIFESPAN_FUNC],
|
||||
) -> None:
|
||||
for func in funcs:
|
||||
if is_coroutine_callable(func):
|
||||
await cast(ASYNC_LIFESPAN_FUNC, func)()
|
||||
else:
|
||||
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
||||
|
||||
async def startup(self) -> None:
|
||||
if self._startup_funcs:
|
||||
await self._run_lifespan_func(self._startup_funcs)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._shutdown_funcs:
|
||||
await self._run_lifespan_func(self._shutdown_funcs)
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.startup()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
await self.shutdown()
|
@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError as e: # pragma: no cover
|
||||
except ModuleNotFoundError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
|
||||
) from e
|
||||
|
@ -19,7 +19,7 @@ FrontMatter:
|
||||
import logging
|
||||
import contextlib
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Tuple, Union, Callable, Optional
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
from fastapi.responses import Response
|
||||
from fastapi import FastAPI, Request, UploadFile, status
|
||||
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
|
||||
except ImportError as e: # pragma: no cover
|
||||
except ModuleNotFoundError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Please install FastAPI by using `pip install nonebot2[fastapi]`"
|
||||
) from e
|
||||
@ -92,7 +94,10 @@ class Driver(ReverseDriver):
|
||||
|
||||
self.fastapi_config: Config = Config(**config.dict())
|
||||
|
||||
self._lifespan = Lifespan()
|
||||
|
||||
self._server_app = FastAPI(
|
||||
lifespan=self._lifespan_manager,
|
||||
openapi_url=self.fastapi_config.fastapi_openapi_url,
|
||||
docs_url=self.fastapi_config.fastapi_docs_url,
|
||||
redoc_url=self.fastapi_config.fastapi_redoc_url,
|
||||
@ -148,14 +153,20 @@ class Driver(ReverseDriver):
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
|
||||
return self.server_app.on_event("startup")(func)
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
return self._lifespan.on_startup(func)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def on_shutdown(self, func: Callable) -> Callable:
|
||||
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_"""
|
||||
return self.server_app.on_event("shutdown")(func)
|
||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
return self._lifespan.on_shutdown(func)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lifespan_manager(self, app: FastAPI):
|
||||
await self._lifespan.startup()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await self._lifespan.shutdown()
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def run(
|
||||
|
@ -31,7 +31,7 @@ from nonebot.drivers import (
|
||||
|
||||
try:
|
||||
import httpx
|
||||
except ImportError as e: # pragma: no cover
|
||||
except ModuleNotFoundError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Please install httpx by using `pip install nonebot2[httpx]`"
|
||||
) from e
|
||||
|
@ -13,7 +13,6 @@ FrontMatter:
|
||||
import signal
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Set, Union, Callable, Awaitable, cast
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.consts import WINDOWS
|
||||
@ -22,7 +21,8 @@ from nonebot.config import Env, Config
|
||||
from nonebot.drivers import Driver as BaseDriver
|
||||
from nonebot.utils import run_sync, is_coroutine_callable
|
||||
|
||||
HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]]
|
||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
@ -36,8 +36,9 @@ class Driver(BaseDriver):
|
||||
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
self.startup_funcs: Set[HOOK_FUNC] = set()
|
||||
self.shutdown_funcs: Set[HOOK_FUNC] = set()
|
||||
|
||||
self._lifespan = Lifespan()
|
||||
|
||||
self.should_exit: asyncio.Event = asyncio.Event()
|
||||
self.force_exit: bool = False
|
||||
|
||||
@ -54,20 +55,18 @@ class Driver(BaseDriver):
|
||||
return logger
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC:
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""
|
||||
注册一个启动时执行的函数
|
||||
"""
|
||||
self.startup_funcs.add(func)
|
||||
return func
|
||||
return self._lifespan.on_startup(func)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC:
|
||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""
|
||||
注册一个停止时执行的函数
|
||||
"""
|
||||
self.shutdown_funcs.add(func)
|
||||
return func
|
||||
return self._lifespan.on_shutdown(func)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def run(self, *args, **kwargs):
|
||||
@ -85,21 +84,13 @@ class Driver(BaseDriver):
|
||||
await self._shutdown()
|
||||
|
||||
async def _startup(self):
|
||||
# run startup
|
||||
cors = [
|
||||
cast(Callable[..., Awaitable[None]], startup)()
|
||||
if is_coroutine_callable(startup)
|
||||
else run_sync(startup)()
|
||||
for startup in self.startup_funcs
|
||||
]
|
||||
if cors:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running startup function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
try:
|
||||
await self._lifespan.startup()
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running startup function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
logger.info("Application startup completed.")
|
||||
|
||||
@ -110,21 +101,14 @@ class Driver(BaseDriver):
|
||||
logger.info("Shutting down")
|
||||
|
||||
logger.info("Waiting for application shutdown.")
|
||||
# run shutdown
|
||||
cors = [
|
||||
cast(Callable[..., Awaitable[None]], shutdown)()
|
||||
if is_coroutine_callable(shutdown)
|
||||
else run_sync(shutdown)()
|
||||
for shutdown in self.shutdown_funcs
|
||||
]
|
||||
if cors:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
try:
|
||||
await self._lifespan.shutdown()
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
for task in asyncio.all_tasks():
|
||||
if task is not asyncio.current_task() and not task.done():
|
||||
|
@ -37,7 +37,7 @@ try:
|
||||
from quart import Quart, Request, Response
|
||||
from quart.datastructures import FileStorage
|
||||
from quart import Websocket as QuartWebSocket
|
||||
except ImportError as e: # pragma: no cover
|
||||
except ModuleNotFoundError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Please install Quart by using `pip install nonebot2[quart]`"
|
||||
) from e
|
||||
|
@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
|
||||
try:
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
except ImportError as e: # pragma: no cover
|
||||
except ModuleNotFoundError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"Please install websockets by using `pip install nonebot2[websockets]`"
|
||||
) from e
|
||||
|
@ -12,6 +12,7 @@ from nonebot.params import Depends
|
||||
from nonebot import _resolve_combine_expr
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers._lifespan import Lifespan
|
||||
from nonebot.drivers import (
|
||||
URL,
|
||||
Driver,
|
||||
@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
|
||||
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan():
|
||||
lifespan = Lifespan()
|
||||
|
||||
start_log = []
|
||||
shutdown_log = []
|
||||
|
||||
@lifespan.on_startup
|
||||
async def _startup1():
|
||||
assert start_log == []
|
||||
start_log.append(1)
|
||||
|
||||
@lifespan.on_startup
|
||||
async def _startup2():
|
||||
assert start_log == [1]
|
||||
start_log.append(2)
|
||||
|
||||
@lifespan.on_shutdown
|
||||
async def _shutdown1():
|
||||
assert shutdown_log == []
|
||||
shutdown_log.append(1)
|
||||
|
||||
@lifespan.on_shutdown
|
||||
async def _shutdown2():
|
||||
assert shutdown_log == [1]
|
||||
shutdown_log.append(2)
|
||||
|
||||
async with lifespan:
|
||||
assert start_log == [1, 2]
|
||||
|
||||
assert shutdown_log == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
|
Loading…
Reference in New Issue
Block a user