mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 09:05:04 +08:00
🐛 Fix: 新增 Lifespan._on_ready()
供适配器使用 (#2483)
Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
915274081d
commit
8f3f385cb6
@ -34,8 +34,6 @@ from nonebot.drivers import Request as BaseRequest
|
|||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup
|
from nonebot.drivers import HTTPServerSetup, WebSocketServerSetup
|
||||||
|
|
||||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
@ -97,8 +95,6 @@ class Driver(BaseDriver, ASGIMixin):
|
|||||||
|
|
||||||
self.fastapi_config: Config = Config(**config.dict())
|
self.fastapi_config: Config = Config(**config.dict())
|
||||||
|
|
||||||
self._lifespan = Lifespan()
|
|
||||||
|
|
||||||
self._server_app = FastAPI(
|
self._server_app = FastAPI(
|
||||||
lifespan=self._lifespan_manager,
|
lifespan=self._lifespan_manager,
|
||||||
openapi_url=self.fastapi_config.fastapi_openapi_url,
|
openapi_url=self.fastapi_config.fastapi_openapi_url,
|
||||||
@ -155,14 +151,6 @@ class Driver(BaseDriver, ASGIMixin):
|
|||||||
name=setup.name,
|
name=setup.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
|
||||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
||||||
return self._lifespan.on_startup(func)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
||||||
return self._lifespan.on_shutdown(func)
|
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _lifespan_manager(self, app: FastAPI):
|
async def _lifespan_manager(self, app: FastAPI):
|
||||||
await self._lifespan.startup()
|
await self._lifespan.startup()
|
||||||
|
@ -19,8 +19,6 @@ from nonebot.consts import WINDOWS
|
|||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.drivers import Driver as BaseDriver
|
from nonebot.drivers import Driver as BaseDriver
|
||||||
|
|
||||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
|
||||||
|
|
||||||
HANDLED_SIGNALS = (
|
HANDLED_SIGNALS = (
|
||||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||||
@ -35,8 +33,6 @@ class Driver(BaseDriver):
|
|||||||
def __init__(self, env: Env, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
super().__init__(env, config)
|
super().__init__(env, config)
|
||||||
|
|
||||||
self._lifespan = Lifespan()
|
|
||||||
|
|
||||||
self.should_exit: asyncio.Event = asyncio.Event()
|
self.should_exit: asyncio.Event = asyncio.Event()
|
||||||
self.force_exit: bool = False
|
self.force_exit: bool = False
|
||||||
|
|
||||||
@ -52,16 +48,6 @@ class Driver(BaseDriver):
|
|||||||
"""none driver 使用的 logger"""
|
"""none driver 使用的 logger"""
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
@override
|
|
||||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
||||||
"""注册一个启动时执行的函数"""
|
|
||||||
return self._lifespan.on_startup(func)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
||||||
"""注册一个停止时执行的函数"""
|
|
||||||
return self._lifespan.on_shutdown(func)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
"""启动 none driver"""
|
"""启动 none driver"""
|
||||||
|
@ -18,18 +18,7 @@ FrontMatter:
|
|||||||
import asyncio
|
import asyncio
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from typing import (
|
from typing import Any, Dict, List, Tuple, Union, Optional, cast
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
TypeVar,
|
|
||||||
Callable,
|
|
||||||
Optional,
|
|
||||||
Coroutine,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
@ -57,8 +46,6 @@ except ModuleNotFoundError as e: # pragma: no cover
|
|||||||
"Install with pip: `pip install nonebot2[quart]`"
|
"Install with pip: `pip install nonebot2[quart]`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
|
|
||||||
|
|
||||||
|
|
||||||
def catch_closed(func):
|
def catch_closed(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@ -102,6 +89,8 @@ class Driver(BaseDriver, ASGIMixin):
|
|||||||
self._server_app = Quart(
|
self._server_app = Quart(
|
||||||
self.__class__.__qualname__, **self.quart_config.quart_extra
|
self.__class__.__qualname__, **self.quart_config.quart_extra
|
||||||
)
|
)
|
||||||
|
self._server_app.before_serving(self._lifespan.startup)
|
||||||
|
self._server_app.after_serving(self._lifespan.shutdown)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
@ -150,16 +139,6 @@ class Driver(BaseDriver, ASGIMixin):
|
|||||||
view_func=_handle,
|
view_func=_handle,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
|
||||||
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
|
|
||||||
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
|
|
||||||
return self.server_app.before_serving(func) # type: ignore
|
|
||||||
|
|
||||||
@override
|
|
||||||
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
|
|
||||||
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
|
|
||||||
return self.server_app.after_serving(func) # type: ignore
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
|
@ -3,6 +3,7 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import Any, Dict, AsyncGenerator
|
from typing import Any, Dict, AsyncGenerator
|
||||||
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
|
from nonebot.internal.driver._lifespan import LIFESPAN_FUNC
|
||||||
from nonebot.internal.driver import (
|
from nonebot.internal.driver import (
|
||||||
Driver,
|
Driver,
|
||||||
Request,
|
Request,
|
||||||
@ -97,6 +98,9 @@ class Adapter(abc.ABC):
|
|||||||
async with self.driver.websocket(setup) as ws:
|
async with self.driver.websocket(setup) as ws:
|
||||||
yield ws
|
yield ws
|
||||||
|
|
||||||
|
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
|
return self.driver._lifespan.on_ready(func)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
|
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
|
||||||
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
"""`Adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||||
|
@ -11,6 +11,7 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
|||||||
class Lifespan:
|
class Lifespan:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._startup_funcs: List[LIFESPAN_FUNC] = []
|
self._startup_funcs: List[LIFESPAN_FUNC] = []
|
||||||
|
self._ready_funcs: List[LIFESPAN_FUNC] = []
|
||||||
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
|
self._shutdown_funcs: List[LIFESPAN_FUNC] = []
|
||||||
|
|
||||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
@ -21,6 +22,10 @@ class Lifespan:
|
|||||||
self._shutdown_funcs.append(func)
|
self._shutdown_funcs.append(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
|
self._ready_funcs.append(func)
|
||||||
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _run_lifespan_func(
|
async def _run_lifespan_func(
|
||||||
funcs: List[LIFESPAN_FUNC],
|
funcs: List[LIFESPAN_FUNC],
|
||||||
@ -35,6 +40,9 @@ class Lifespan:
|
|||||||
if self._startup_funcs:
|
if self._startup_funcs:
|
||||||
await self._run_lifespan_func(self._startup_funcs)
|
await self._run_lifespan_func(self._startup_funcs)
|
||||||
|
|
||||||
|
if self._ready_funcs:
|
||||||
|
await self._run_lifespan_func(self._ready_funcs)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self._shutdown_funcs:
|
if self._shutdown_funcs:
|
||||||
await self._run_lifespan_func(self._shutdown_funcs)
|
await self._run_lifespan_func(self._shutdown_funcs)
|
@ -2,7 +2,7 @@ import abc
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
|
from typing import TYPE_CHECKING, Any, Set, Dict, Type, AsyncGenerator
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
@ -16,6 +16,7 @@ from nonebot.typing import (
|
|||||||
T_BotDisconnectionHook,
|
T_BotDisconnectionHook,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||||
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -49,6 +50,7 @@ class Driver(abc.ABC):
|
|||||||
"""全局配置对象"""
|
"""全局配置对象"""
|
||||||
self._bots: Dict[str, "Bot"] = {}
|
self._bots: Dict[str, "Bot"] = {}
|
||||||
self._bot_tasks: Set[asyncio.Task] = set()
|
self._bot_tasks: Set[asyncio.Task] = set()
|
||||||
|
self._lifespan = Lifespan()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
@ -100,15 +102,13 @@ class Driver(abc.ABC):
|
|||||||
|
|
||||||
self.on_shutdown(self._cleanup)
|
self.on_shutdown(self._cleanup)
|
||||||
|
|
||||||
@abc.abstractmethod
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
def on_startup(self, func: Callable) -> Callable:
|
"""注册一个启动时执行的函数"""
|
||||||
"""注册一个在驱动器启动时执行的函数"""
|
return self._lifespan.on_startup(func)
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
def on_shutdown(self, func: Callable) -> Callable:
|
"""注册一个停止时执行的函数"""
|
||||||
"""注册一个在驱动器停止时执行的函数"""
|
return self._lifespan.on_shutdown(func)
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
|
def on_bot_connect(cls, func: T_BotConnectionHook) -> T_BotConnectionHook:
|
||||||
|
@ -5,11 +5,11 @@ from typing import Any, Set, Optional
|
|||||||
import pytest
|
import pytest
|
||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
|
from utils import FakeAdapter
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
from nonebot.params import Depends
|
from nonebot.params import Depends
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
from nonebot.drivers._lifespan import Lifespan
|
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
URL,
|
URL,
|
||||||
Driver,
|
Driver,
|
||||||
@ -25,34 +25,50 @@ from nonebot.drivers import (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_lifespan():
|
@pytest.mark.parametrize(
|
||||||
lifespan = Lifespan()
|
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
|
||||||
|
)
|
||||||
|
async def test_lifespan(driver: Driver):
|
||||||
|
adapter = FakeAdapter(driver)
|
||||||
|
|
||||||
start_log = []
|
start_log = []
|
||||||
|
ready_log = []
|
||||||
shutdown_log = []
|
shutdown_log = []
|
||||||
|
|
||||||
@lifespan.on_startup
|
@driver.on_startup
|
||||||
async def _startup1():
|
async def _startup1():
|
||||||
assert start_log == []
|
assert start_log == []
|
||||||
start_log.append(1)
|
start_log.append(1)
|
||||||
|
|
||||||
@lifespan.on_startup
|
@driver.on_startup
|
||||||
async def _startup2():
|
async def _startup2():
|
||||||
assert start_log == [1]
|
assert start_log == [1]
|
||||||
start_log.append(2)
|
start_log.append(2)
|
||||||
|
|
||||||
@lifespan.on_shutdown
|
@adapter.on_ready
|
||||||
|
def _ready1():
|
||||||
|
assert start_log == [1, 2]
|
||||||
|
assert ready_log == []
|
||||||
|
ready_log.append(1)
|
||||||
|
|
||||||
|
@adapter.on_ready
|
||||||
|
def _ready2():
|
||||||
|
assert ready_log == [1]
|
||||||
|
ready_log.append(2)
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
async def _shutdown1():
|
async def _shutdown1():
|
||||||
assert shutdown_log == []
|
assert shutdown_log == []
|
||||||
shutdown_log.append(1)
|
shutdown_log.append(1)
|
||||||
|
|
||||||
@lifespan.on_shutdown
|
@driver.on_shutdown
|
||||||
async def _shutdown2():
|
async def _shutdown2():
|
||||||
assert shutdown_log == [1]
|
assert shutdown_log == [1]
|
||||||
shutdown_log.append(2)
|
shutdown_log.append(2)
|
||||||
|
|
||||||
async with lifespan:
|
async with driver._lifespan:
|
||||||
assert start_log == [1, 2]
|
assert start_log == [1, 2]
|
||||||
|
assert ready_log == [1, 2]
|
||||||
|
|
||||||
assert shutdown_log == [1, 2]
|
assert shutdown_log == [1, 2]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user