🐛 Fix: 新增 Lifespan._on_ready() 供适配器使用 (#2483)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Bryan不可思议 2023-12-10 18:12:10 +08:00 committed by GitHub
parent 915274081d
commit 8f3f385cb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 67 deletions

View File

@ -34,8 +34,6 @@ from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup
from ._lifespan import LIFESPAN_FUNC, Lifespan
try:
import uvicorn
from fastapi.responses import Response
@ -97,8 +95,6 @@ class Driver(BaseDriver, ASGIMixin):
self.fastapi_config: Config = Config(**config.dict())
self._lifespan = Lifespan()
self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
@ -155,14 +151,6 @@ class Driver(BaseDriver, ASGIMixin):
name=setup.name,
)
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)
@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()

View File

@ -19,8 +19,6 @@ from nonebot.consts import WINDOWS
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from ._lifespan import LIFESPAN_FUNC, Lifespan
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -35,8 +33,6 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self._lifespan = Lifespan()
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
@ -52,16 +48,6 @@ class Driver(BaseDriver):
"""none driver 使用的 logger"""
return logger
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@override
def run(self, *args, **kwargs):
"""启动 none driver"""

View File

@ -18,18 +18,7 @@ FrontMatter:
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from typing import Any, Dict, List, Tuple, Union, Optional, cast
from pydantic import BaseSettings
@ -57,8 +46,6 @@ except ModuleNotFoundError as e: # pragma: no cover
"Install with pip: `pip install nonebot2[quart]`"
) from e
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
def catch_closed(func):
@wraps(func)
@ -102,6 +89,8 @@ class Driver(BaseDriver, ASGIMixin):
self._server_app = Quart(
self.__class__.__qualname__, **self.quart_config.quart_extra
)
self._server_app.before_serving(self._lifespan.startup)
self._server_app.after_serving(self._lifespan.shutdown)
@property
@override
@ -150,16 +139,6 @@ class Driver(BaseDriver, ASGIMixin):
view_func=_handle,
)
@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore
@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore
@override
def run(
self,

View File

@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
from typing import Any, Dict, AsyncGenerator
from nonebot.config import Config
from nonebot.internal.driver._lifespan import LIFESPAN_FUNC
from nonebot.internal.driver import (
Driver,
Request,
@ -97,6 +98,9 @@ class Adapter(abc.ABC):
async with self.driver.websocket(setup) as ws:
yield ws
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self.driver._lifespan.on_ready(func)
@abc.abstractmethod
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。

View File

@ -11,6 +11,7 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._ready_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
@ -21,6 +22,10 @@ class Lifespan:
self._shutdown_funcs.append(func)
return func
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._ready_funcs.append(func)
return func
@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
@ -35,6 +40,9 @@ class Lifespan:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)

View File

@ -2,7 +2,7 @@ import abc
import asyncio
from typing_extensions import TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator
from nonebot.log import logger
from nonebot.config import Env, Config
@ -16,6 +16,7 @@ from nonebot.typing import (
T_BotDisconnectionHook,
)
from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
if TYPE_CHECKING:
@ -49,6 +50,7 @@ class Driver(abc.ABC):
"""全局配置对象"""
self._bots: Dict[str, "Bot"] = {}
self._bot_tasks: Set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
return (
@ -100,15 +102,13 @@ class Driver(abc.ABC):
self.on_shutdown(self._cleanup)
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
"""注册一个在驱动器启动时执行的函数"""
raise NotImplementedError
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
"""注册一个在驱动器停止时执行的函数"""
raise NotImplementedError
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@classmethod
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:

View File

@ -5,11 +5,11 @@ from typing import Any, Set, Optional
import pytest
from nonebug import App
from utils import FakeAdapter
from nonebot.adapters import Bot
from nonebot.params import Depends
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
@ -25,34 +25,50 @@ from nonebot.drivers import (
@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()
@pytest.mark.parametrize(
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
)
async def test_lifespan(driver: Driver):
adapter = FakeAdapter(driver)
start_log = []
ready_log = []
shutdown_log = []
@lifespan.on_startup
@driver.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)
@lifespan.on_startup
@driver.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)
@lifespan.on_shutdown
@adapter.on_ready
def _ready1():
assert start_log == [1, 2]
assert ready_log == []
ready_log.append(1)
@adapter.on_ready
def _ready2():
assert ready_log == [1]
ready_log.append(2)
@driver.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)
@lifespan.on_shutdown
@driver.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)
async with lifespan:
async with driver._lifespan:
assert start_log == [1, 2]
assert ready_log == [1, 2]
assert shutdown_log == [1, 2]