diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index f1d96f2f..eaab98c7 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -34,8 +34,6 @@ from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup -from ._lifespan import LIFESPAN_FUNC, Lifespan - try: import uvicorn from fastapi.responses import Response @@ -97,8 +95,6 @@ class Driver(BaseDriver, ASGIMixin): self.fastapi_config: Config = Config(**config.dict()) - self._lifespan = Lifespan() - self._server_app = FastAPI( lifespan=self._lifespan_manager, openapi_url=self.fastapi_config.fastapi_openapi_url, @@ -155,14 +151,6 @@ class Driver(BaseDriver, ASGIMixin): name=setup.name, ) - @override - def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: - return self._lifespan.on_startup(func) - - @override - def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: - return self._lifespan.on_shutdown(func) - @contextlib.asynccontextmanager async def _lifespan_manager(self, app: FastAPI): await self._lifespan.startup() diff --git a/nonebot/drivers/none.py b/nonebot/drivers/none.py index 73829561..3a784c55 100644 --- a/nonebot/drivers/none.py +++ b/nonebot/drivers/none.py @@ -19,8 +19,6 @@ from nonebot.consts import WINDOWS from nonebot.config import Env, Config from nonebot.drivers import Driver as BaseDriver -from ._lifespan import LIFESPAN_FUNC, Lifespan - HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGTERM, # Unix signal 15. Sent by `kill `. @@ -35,8 +33,6 @@ class Driver(BaseDriver): def __init__(self, env: Env, config: Config): super().__init__(env, config) - self._lifespan = Lifespan() - self.should_exit: asyncio.Event = asyncio.Event() self.force_exit: bool = False @@ -52,16 +48,6 @@ class Driver(BaseDriver): """none driver 使用的 logger""" return logger - @override - def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: - """注册一个启动时执行的函数""" - return self._lifespan.on_startup(func) - - @override - def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: - """注册一个停止时执行的函数""" - return self._lifespan.on_shutdown(func) - @override def run(self, *args, **kwargs): """启动 none driver""" diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 8c90db71..3081b0a3 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -18,18 +18,7 @@ FrontMatter: import asyncio from functools import wraps from typing_extensions import override -from typing import ( - Any, - Dict, - List, - Tuple, - Union, - TypeVar, - Callable, - Optional, - Coroutine, - cast, -) +from typing import Any, Dict, List, Tuple, Union, Optional, cast from pydantic import BaseSettings @@ -57,8 +46,6 @@ except ModuleNotFoundError as e: # pragma: no cover "Install with pip: `pip install nonebot2[quart]`" ) from e -_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) - def catch_closed(func): @wraps(func) @@ -102,6 +89,8 @@ class Driver(BaseDriver, ASGIMixin): self._server_app = Quart( self.__class__.__qualname__, **self.quart_config.quart_extra ) + self._server_app.before_serving(self._lifespan.startup) + self._server_app.after_serving(self._lifespan.shutdown) @property @override @@ -150,16 +139,6 @@ class Driver(BaseDriver, ASGIMixin): view_func=_handle, ) - @override - def on_startup(self, func: _AsyncCallable) -> _AsyncCallable: - """参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)""" - return self.server_app.before_serving(func) # type: ignore - - @override - def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable: - """参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)""" - return self.server_app.after_serving(func) # type: ignore - @override def run( self, diff --git a/nonebot/internal/adapter/adapter.py b/nonebot/internal/adapter/adapter.py index 6bcfb1a8..897a67c8 100644 --- a/nonebot/internal/adapter/adapter.py +++ b/nonebot/internal/adapter/adapter.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager from typing import Any, Dict, AsyncGenerator from nonebot.config import Config +from nonebot.internal.driver._lifespan import LIFESPAN_FUNC from nonebot.internal.driver import ( Driver, Request, @@ -97,6 +98,9 @@ class Adapter(abc.ABC): async with self.driver.websocket(setup) as ws: yield ws + def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + return self.driver._lifespan.on_ready(func) + @abc.abstractmethod async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any: """`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。 diff --git a/nonebot/drivers/_lifespan.py b/nonebot/internal/driver/_lifespan.py similarity index 85% rename from nonebot/drivers/_lifespan.py rename to nonebot/internal/driver/_lifespan.py index d4b9f61a..5d04973a 100644 --- a/nonebot/drivers/_lifespan.py +++ b/nonebot/internal/driver/_lifespan.py @@ -11,6 +11,7 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC] class Lifespan: def __init__(self) -> None: self._startup_funcs: List[LIFESPAN_FUNC] = [] + self._ready_funcs: List[LIFESPAN_FUNC] = [] self._shutdown_funcs: List[LIFESPAN_FUNC] = [] def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: @@ -21,6 +22,10 @@ class Lifespan: self._shutdown_funcs.append(func) return func + def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + self._ready_funcs.append(func) + return func + @staticmethod async def _run_lifespan_func( funcs: List[LIFESPAN_FUNC], @@ -35,6 +40,9 @@ class Lifespan: if self._startup_funcs: await self._run_lifespan_func(self._startup_funcs) + if self._ready_funcs: + await self._run_lifespan_func(self._ready_funcs) + async def shutdown(self) -> None: if self._shutdown_funcs: await self._run_lifespan_func(self._shutdown_funcs) diff --git a/nonebot/internal/driver/abstract.py b/nonebot/internal/driver/abstract.py index d5fd8352..e7191495 100644 --- a/nonebot/internal/driver/abstract.py +++ b/nonebot/internal/driver/abstract.py @@ -2,7 +2,7 @@ import abc import asyncio from typing_extensions import TypeAlias from contextlib import AsyncExitStack, asynccontextmanager -from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator +from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator from nonebot.log import logger from nonebot.config import Env, Config @@ -16,6 +16,7 @@ from nonebot.typing import ( T_BotDisconnectionHook, ) +from ._lifespan import LIFESPAN_FUNC, Lifespan from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup if TYPE_CHECKING: @@ -49,6 +50,7 @@ class Driver(abc.ABC): """全局配置对象""" self._bots: Dict[str, "Bot"] = {} self._bot_tasks: Set[asyncio.Task] = set() + self._lifespan = Lifespan() def __repr__(self) -> str: return ( @@ -100,15 +102,13 @@ class Driver(abc.ABC): self.on_shutdown(self._cleanup) - @abc.abstractmethod - def on_startup(self, func: Callable) -> Callable: - """注册一个在驱动器启动时执行的函数""" - raise NotImplementedError + def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + """注册一个启动时执行的函数""" + return self._lifespan.on_startup(func) - @abc.abstractmethod - def on_shutdown(self, func: Callable) -> Callable: - """注册一个在驱动器停止时执行的函数""" - raise NotImplementedError + def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: + """注册一个停止时执行的函数""" + return self._lifespan.on_shutdown(func) @classmethod def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook: diff --git a/tests/test_driver.py b/tests/test_driver.py index 70fbea6b..cd9bc3a8 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -5,11 +5,11 @@ from typing import Any, Set, Optional import pytest from nonebug import App +from utils import FakeAdapter from nonebot.adapters import Bot from nonebot.params import Depends from nonebot.dependencies import Dependent from nonebot.exception import WebSocketClosed -from nonebot.drivers._lifespan import Lifespan from nonebot.drivers import ( URL, Driver, @@ -25,34 +25,50 @@ from nonebot.drivers import ( @pytest.mark.asyncio -async def test_lifespan(): - lifespan = Lifespan() +@pytest.mark.parametrize( + "driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True +) +async def test_lifespan(driver: Driver): + adapter = FakeAdapter(driver) start_log = [] + ready_log = [] shutdown_log = [] - @lifespan.on_startup + @driver.on_startup async def _startup1(): assert start_log == [] start_log.append(1) - @lifespan.on_startup + @driver.on_startup async def _startup2(): assert start_log == [1] start_log.append(2) - @lifespan.on_shutdown + @adapter.on_ready + def _ready1(): + assert start_log == [1, 2] + assert ready_log == [] + ready_log.append(1) + + @adapter.on_ready + def _ready2(): + assert ready_log == [1] + ready_log.append(2) + + @driver.on_shutdown async def _shutdown1(): assert shutdown_log == [] shutdown_log.append(1) - @lifespan.on_shutdown + @driver.on_shutdown async def _shutdown2(): assert shutdown_log == [1] shutdown_log.append(2) - async with lifespan: + async with driver._lifespan: assert start_log == [1, 2] + assert ready_log == [1, 2] assert shutdown_log == [1, 2]