🐛 Fix: 新增 Lifespan._on_ready() 供适配器使用 (#2483)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Bryan不可思议 2023-12-10 18:12:10 +08:00 committed by GitHub
parent 915274081d
commit 8f3f385cb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 67 deletions

View File

@ -34,8 +34,6 @@ from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup
from ._lifespan import LIFESPAN_FUNC, Lifespan
try: try:
import uvicorn import uvicorn
from fastapi.responses import Response from fastapi.responses import Response
@ -97,8 +95,6 @@ class Driver(BaseDriver, ASGIMixin):
self.fastapi_config: Config = Config(**config.dict()) self.fastapi_config: Config = Config(**config.dict())
self._lifespan = Lifespan()
self._server_app = FastAPI( self._server_app = FastAPI(
lifespan=self._lifespan_manager, lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url, openapi_url=self.fastapi_config.fastapi_openapi_url,
@ -155,14 +151,6 @@ class Driver(BaseDriver, ASGIMixin):
name=setup.name, name=setup.name,
) )
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI): async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup() await self._lifespan.startup()

View File

@ -19,8 +19,6 @@ from nonebot.consts import WINDOWS
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver from nonebot.drivers import Driver as BaseDriver
from ._lifespan import LIFESPAN_FUNC, Lifespan
HANDLED_SIGNALS = ( HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`. signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(env, config) super().__init__(env, config)
self._lifespan = Lifespan()
self.should_exit: asyncio.Event = asyncio.Event() self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False self.force_exit: bool = False
@ -52,16 +48,6 @@ class Driver(BaseDriver):
"""none driver 使用的 logger""" """none driver 使用的 logger"""
return logger return logger
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@override @override
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
"""启动 none driver""" """启动 none driver"""

View File

@ -18,18 +18,7 @@ FrontMatter:
import asyncio import asyncio
from functools import wraps from functools import wraps
from typing_extensions import override from typing_extensions import override
from typing import ( from typing import Any, Dict, List, Tuple, Union, Optional, cast
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from pydantic import BaseSettings from pydantic import BaseSettings
@ -57,8 +46,6 @@ except ModuleNotFoundError as e: # pragma: no cover
"Install with pip: `pip install nonebot2[quart]`" "Install with pip: `pip install nonebot2[quart]`"
) from e ) from e
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
def catch_closed(func): def catch_closed(func):
@wraps(func) @wraps(func)
@ -102,6 +89,8 @@ class Driver(BaseDriver, ASGIMixin):
self._server_app = Quart( self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra self.__class__.__qualname__, **self.quart_config.quart_extra
) )
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)
@property @property
@override @override
@ -150,16 +139,6 @@ class Driver(BaseDriver, ASGIMixin):
view_func=_handle, view_func=_handle,
) )
@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore
@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore
@override @override
def run( def run(
self, self,

View File

@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
from typing import Any, Dict, AsyncGenerator from typing import Any, Dict, AsyncGenerator
from nonebot.config import Config from nonebot.config import Config
from nonebot.internal.driver._lifespan import LIFESPAN_FUNC
from nonebot.internal.driver import ( from nonebot.internal.driver import (
Driver, Driver,
Request, Request,
@ -97,6 +98,9 @@ class Adapter(abc.ABC):
async with self.driver.websocket(setup) as ws: async with self.driver.websocket(setup) as ws:
yield ws yield ws
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self.driver._lifespan.on_ready(func)
@abc.abstractmethod @abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any: async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。 """`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。

View File

@ -11,6 +11,7 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan: class Lifespan:
def __init__(self) -> None: def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = [] self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = [] self._shutdown_funcs: List[LIFESPAN_FUNC] = []
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
@ -21,6 +22,10 @@ class Lifespan:
self._shutdown_funcs.append(func) self._shutdown_funcs.append(func)
return func return func
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._ready_funcs.append(func)
return func
@staticmethod @staticmethod
async def _run_lifespan_func( async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC], funcs: List[LIFESPAN_FUNC],
@ -35,6 +40,9 @@ class Lifespan:
if self._startup_funcs: if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs) await self._run_lifespan_func(self._startup_funcs)
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self._shutdown_funcs: if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs) await self._run_lifespan_func(self._shutdown_funcs)

View File

@ -2,7 +2,7 @@ import abc
import asyncio import asyncio
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
@ -16,6 +16,7 @@ from nonebot.typing import (
T_BotDisconnectionHook, T_BotDisconnectionHook,
) )
from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
if TYPE_CHECKING: if TYPE_CHECKING:
@ -49,6 +50,7 @@ class Driver(abc.ABC):
"""全局配置对象""" """全局配置对象"""
self._bots: Dict[str, "Bot"] = {} self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set() self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -100,15 +102,13 @@ class Driver(abc.ABC):
self.on_shutdown(self._cleanup) self.on_shutdown(self._cleanup)
@abc.abstractmethod def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
def on_startup(self, func: Callable) -> Callable: """注册一个启动时执行的函数"""
"""注册一个在驱动器启动时执行的函数""" return self._lifespan.on_startup(func)
raise NotImplementedError
@abc.abstractmethod def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
def on_shutdown(self, func: Callable) -> Callable: """注册一个停止时执行的函数"""
"""注册一个在驱动器停止时执行的函数""" return self._lifespan.on_shutdown(func)
raise NotImplementedError
@classmethod @classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook: def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:

View File

@ -5,11 +5,11 @@ from typing import Any, Set, Optional
import pytest import pytest
from nonebug import App from nonebug import App
from utils import FakeAdapter
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.params import Depends from nonebot.params import Depends
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import ( from nonebot.drivers import (
URL, URL,
Driver, Driver,
@ -25,34 +25,50 @@ from nonebot.drivers import (
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lifespan(): @pytest.mark.parametrize(
lifespan = Lifespan() "driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
)
async def test_lifespan(driver: Driver):
adapter = FakeAdapter(driver)
start_log = [] start_log = []
ready_log = []
shutdown_log = [] shutdown_log = []
@lifespan.on_startup @driver.on_startup
async def _startup1(): async def _startup1():
assert start_log == [] assert start_log == []
start_log.append(1) start_log.append(1)
@lifespan.on_startup @driver.on_startup
async def _startup2(): async def _startup2():
assert start_log == [1] assert start_log == [1]
start_log.append(2) start_log.append(2)
@lifespan.on_shutdown @adapter.on_ready
def _ready1():
assert start_log == [1, 2]
assert ready_log == []
ready_log.append(1)
@adapter.on_ready
def _ready2():
assert ready_log == [1]
ready_log.append(2)
@driver.on_shutdown
async def _shutdown1(): async def _shutdown1():
assert shutdown_log == [] assert shutdown_log == []
shutdown_log.append(1) shutdown_log.append(1)
@lifespan.on_shutdown @driver.on_shutdown
async def _shutdown2(): async def _shutdown2():
assert shutdown_log == [1] assert shutdown_log == [1]
shutdown_log.append(2) shutdown_log.append(2)
async with lifespan: async with driver._lifespan:
assert start_log == [1, 2] assert start_log == [1, 2]
assert ready_log == [1, 2]
assert shutdown_log == [1, 2] assert shutdown_log == [1, 2]