mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-01 01:25:07 +08:00
✨ Feature: 重构驱动器 lifespan 方法 (#1860)
This commit is contained in:
parent
0d0bc656c8
commit
a8a76393a5
45
nonebot/drivers/_lifespan.py
Normal file
45
nonebot/drivers/_lifespan.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Any, List, Union, Callable, Awaitable, cast
|
||||||
|
|
||||||
|
from nonebot.utils import run_sync, is_coroutine_callable
|
||||||
|
|
||||||
|
SYNC_LIFESPAN_FUNC = Callable[[], Any]
|
||||||
|
ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]]
|
||||||
|
LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
||||||
|
|
||||||
|
|
||||||
|
class Lifespan:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._startup_funcs: List[LIFESPAN_FUNC] = []
|
||||||
|
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
|
||||||
|
|
||||||
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
|
self._startup_funcs.append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
|
self._shutdown_funcs.append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _run_lifespan_func(
|
||||||
|
funcs: List[LIFESPAN_FUNC],
|
||||||
|
) -> None:
|
||||||
|
for func in funcs:
|
||||||
|
if is_coroutine_callable(func):
|
||||||
|
await cast(ASYNC_LIFESPAN_FUNC, func)()
|
||||||
|
else:
|
||||||
|
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
||||||
|
|
||||||
|
async def startup(self) -> None:
|
||||||
|
if self._startup_funcs:
|
||||||
|
await self._run_lifespan_func(self._startup_funcs)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
if self._shutdown_funcs:
|
||||||
|
await self._run_lifespan_func(self._shutdown_funcs)
|
||||||
|
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
await self.startup()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
await self.shutdown()
|
@ -27,7 +27,7 @@ from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_dr
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
except ImportError as e: # pragma: no cover
|
except ModuleNotFoundError as e: # pragma: no cover
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
|
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
|
||||||
) from e
|
) from e
|
||||||
|
@ -19,7 +19,7 @@ FrontMatter:
|
|||||||
import logging
|
import logging
|
||||||
import contextlib
|
import contextlib
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, List, Tuple, Union, Callable, Optional
|
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||||
|
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
@ -32,12 +32,14 @@ from nonebot.drivers import Request as BaseRequest
|
|||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||||
|
|
||||||
|
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from fastapi import FastAPI, Request, UploadFile, status
|
from fastapi import FastAPI, Request, UploadFile, status
|
||||||
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
|
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
|
||||||
except ImportError as e: # pragma: no cover
|
except ModuleNotFoundError as e: # pragma: no cover
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install FastAPI by using `pip install nonebot2[fastapi]`"
|
"Please install FastAPI by using `pip install nonebot2[fastapi]`"
|
||||||
) from e
|
) from e
|
||||||
@ -92,7 +94,10 @@ class Driver(ReverseDriver):
|
|||||||
|
|
||||||
self.fastapi_config: Config = Config(**config.dict())
|
self.fastapi_config: Config = Config(**config.dict())
|
||||||
|
|
||||||
|
self._lifespan = Lifespan()
|
||||||
|
|
||||||
self._server_app = FastAPI(
|
self._server_app = FastAPI(
|
||||||
|
lifespan=self._lifespan_manager,
|
||||||
openapi_url=self.fastapi_config.fastapi_openapi_url,
|
openapi_url=self.fastapi_config.fastapi_openapi_url,
|
||||||
docs_url=self.fastapi_config.fastapi_docs_url,
|
docs_url=self.fastapi_config.fastapi_docs_url,
|
||||||
redoc_url=self.fastapi_config.fastapi_redoc_url,
|
redoc_url=self.fastapi_config.fastapi_redoc_url,
|
||||||
@ -148,14 +153,20 @@ class Driver(ReverseDriver):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
def on_startup(self, func: Callable) -> Callable:
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
|
return self._lifespan.on_startup(func)
|
||||||
return self.server_app.on_event("startup")(func)
|
|
||||||
|
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
def on_shutdown(self, func: Callable) -> Callable:
|
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_"""
|
return self._lifespan.on_shutdown(func)
|
||||||
return self.server_app.on_event("shutdown")(func)
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _lifespan_manager(self, app: FastAPI):
|
||||||
|
await self._lifespan.startup()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await self._lifespan.shutdown()
|
||||||
|
|
||||||
@overrides(ReverseDriver)
|
@overrides(ReverseDriver)
|
||||||
def run(
|
def run(
|
||||||
|
@ -31,7 +31,7 @@ from nonebot.drivers import (
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import httpx
|
import httpx
|
||||||
except ImportError as e: # pragma: no cover
|
except ModuleNotFoundError as e: # pragma: no cover
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install httpx by using `pip install nonebot2[httpx]`"
|
"Please install httpx by using `pip install nonebot2[httpx]`"
|
||||||
) from e
|
) from e
|
||||||
|
@ -13,7 +13,6 @@ FrontMatter:
|
|||||||
import signal
|
import signal
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
from typing import Set, Union, Callable, Awaitable, cast
|
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.consts import WINDOWS
|
from nonebot.consts import WINDOWS
|
||||||
@ -22,7 +21,8 @@ from nonebot.config import Env, Config
|
|||||||
from nonebot.drivers import Driver as BaseDriver
|
from nonebot.drivers import Driver as BaseDriver
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable
|
from nonebot.utils import run_sync, is_coroutine_callable
|
||||||
|
|
||||||
HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]]
|
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||||
|
|
||||||
HANDLED_SIGNALS = (
|
HANDLED_SIGNALS = (
|
||||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||||
@ -36,8 +36,9 @@ class Driver(BaseDriver):
|
|||||||
|
|
||||||
def __init__(self, env: Env, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
super().__init__(env, config)
|
super().__init__(env, config)
|
||||||
self.startup_funcs: Set[HOOK_FUNC] = set()
|
|
||||||
self.shutdown_funcs: Set[HOOK_FUNC] = set()
|
self._lifespan = Lifespan()
|
||||||
|
|
||||||
self.should_exit: asyncio.Event = asyncio.Event()
|
self.should_exit: asyncio.Event = asyncio.Event()
|
||||||
self.force_exit: bool = False
|
self.force_exit: bool = False
|
||||||
|
|
||||||
@ -54,20 +55,18 @@ class Driver(BaseDriver):
|
|||||||
return logger
|
return logger
|
||||||
|
|
||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC:
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
"""
|
"""
|
||||||
注册一个启动时执行的函数
|
注册一个启动时执行的函数
|
||||||
"""
|
"""
|
||||||
self.startup_funcs.add(func)
|
return self._lifespan.on_startup(func)
|
||||||
return func
|
|
||||||
|
|
||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC:
|
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
"""
|
"""
|
||||||
注册一个停止时执行的函数
|
注册一个停止时执行的函数
|
||||||
"""
|
"""
|
||||||
self.shutdown_funcs.add(func)
|
return self._lifespan.on_shutdown(func)
|
||||||
return func
|
|
||||||
|
|
||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
@ -85,21 +84,13 @@ class Driver(BaseDriver):
|
|||||||
await self._shutdown()
|
await self._shutdown()
|
||||||
|
|
||||||
async def _startup(self):
|
async def _startup(self):
|
||||||
# run startup
|
try:
|
||||||
cors = [
|
await self._lifespan.startup()
|
||||||
cast(Callable[..., Awaitable[None]], startup)()
|
except Exception as e:
|
||||||
if is_coroutine_callable(startup)
|
logger.opt(colors=True, exception=e).error(
|
||||||
else run_sync(startup)()
|
"<r><bg #f8bbd0>Error when running startup function. "
|
||||||
for startup in self.startup_funcs
|
"Ignored!</bg #f8bbd0></r>"
|
||||||
]
|
)
|
||||||
if cors:
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*cors)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running startup function. "
|
|
||||||
"Ignored!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Application startup completed.")
|
logger.info("Application startup completed.")
|
||||||
|
|
||||||
@ -110,21 +101,14 @@ class Driver(BaseDriver):
|
|||||||
logger.info("Shutting down")
|
logger.info("Shutting down")
|
||||||
|
|
||||||
logger.info("Waiting for application shutdown.")
|
logger.info("Waiting for application shutdown.")
|
||||||
# run shutdown
|
|
||||||
cors = [
|
try:
|
||||||
cast(Callable[..., Awaitable[None]], shutdown)()
|
await self._lifespan.shutdown()
|
||||||
if is_coroutine_callable(shutdown)
|
except Exception as e:
|
||||||
else run_sync(shutdown)()
|
logger.opt(colors=True, exception=e).error(
|
||||||
for shutdown in self.shutdown_funcs
|
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||||
]
|
"Ignored!</bg #f8bbd0></r>"
|
||||||
if cors:
|
)
|
||||||
try:
|
|
||||||
await asyncio.gather(*cors)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
|
||||||
"Ignored!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
for task in asyncio.all_tasks():
|
for task in asyncio.all_tasks():
|
||||||
if task is not asyncio.current_task() and not task.done():
|
if task is not asyncio.current_task() and not task.done():
|
||||||
|
@ -37,7 +37,7 @@ try:
|
|||||||
from quart import Quart, Request, Response
|
from quart import Quart, Request, Response
|
||||||
from quart.datastructures import FileStorage
|
from quart.datastructures import FileStorage
|
||||||
from quart import Websocket as QuartWebSocket
|
from quart import Websocket as QuartWebSocket
|
||||||
except ImportError as e: # pragma: no cover
|
except ModuleNotFoundError as e: # pragma: no cover
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install Quart by using `pip install nonebot2[quart]`"
|
"Please install Quart by using `pip install nonebot2[quart]`"
|
||||||
) from e
|
) from e
|
||||||
|
@ -30,7 +30,7 @@ from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
|
|||||||
try:
|
try:
|
||||||
from websockets.exceptions import ConnectionClosed
|
from websockets.exceptions import ConnectionClosed
|
||||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||||
except ImportError as e: # pragma: no cover
|
except ModuleNotFoundError as e: # pragma: no cover
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install websockets by using `pip install nonebot2[websockets]`"
|
"Please install websockets by using `pip install nonebot2[websockets]`"
|
||||||
) from e
|
) from e
|
||||||
|
@ -12,6 +12,7 @@ from nonebot.params import Depends
|
|||||||
from nonebot import _resolve_combine_expr
|
from nonebot import _resolve_combine_expr
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
|
from nonebot.drivers._lifespan import Lifespan
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
URL,
|
URL,
|
||||||
Driver,
|
Driver,
|
||||||
@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
|
|||||||
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan():
|
||||||
|
lifespan = Lifespan()
|
||||||
|
|
||||||
|
start_log = []
|
||||||
|
shutdown_log = []
|
||||||
|
|
||||||
|
@lifespan.on_startup
|
||||||
|
async def _startup1():
|
||||||
|
assert start_log == []
|
||||||
|
start_log.append(1)
|
||||||
|
|
||||||
|
@lifespan.on_startup
|
||||||
|
async def _startup2():
|
||||||
|
assert start_log == [1]
|
||||||
|
start_log.append(2)
|
||||||
|
|
||||||
|
@lifespan.on_shutdown
|
||||||
|
async def _shutdown1():
|
||||||
|
assert shutdown_log == []
|
||||||
|
shutdown_log.append(1)
|
||||||
|
|
||||||
|
@lifespan.on_shutdown
|
||||||
|
async def _shutdown2():
|
||||||
|
assert shutdown_log == [1]
|
||||||
|
shutdown_log.append(2)
|
||||||
|
|
||||||
|
async with lifespan:
|
||||||
|
assert start_log == [1, 2]
|
||||||
|
|
||||||
|
assert shutdown_log == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
|
Loading…
Reference in New Issue
Block a user