mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-01 01:25:07 +08:00
101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
from types import TracebackType
|
|
from typing_extensions import TypeAlias
|
|
from collections.abc import Iterable, Awaitable
|
|
from typing import Any, Union, Callable, Optional, cast
|
|
|
|
import anyio
|
|
from anyio.abc import TaskGroup
|
|
from exceptiongroup import suppress
|
|
|
|
from nonebot.utils import run_sync, is_coroutine_callable
|
|
|
|
SYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Any]
|
|
ASYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Awaitable[Any]]
|
|
LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
|
|
|
|
|
class Lifespan:
|
|
def __init__(self) -> None:
|
|
self._task_group: Optional[TaskGroup] = None
|
|
|
|
self._startup_funcs: list[LIFESPAN_FUNC] = []
|
|
self._ready_funcs: list[LIFESPAN_FUNC] = []
|
|
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
|
|
|
|
@property
|
|
def task_group(self) -> TaskGroup:
|
|
if self._task_group is None:
|
|
raise RuntimeError("Lifespan not started")
|
|
return self._task_group
|
|
|
|
@task_group.setter
|
|
def task_group(self, task_group: TaskGroup) -> None:
|
|
if self._task_group is not None:
|
|
raise RuntimeError("Lifespan already started")
|
|
self._task_group = task_group
|
|
|
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
self._startup_funcs.append(func)
|
|
return func
|
|
|
|
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
self._shutdown_funcs.append(func)
|
|
return func
|
|
|
|
def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
|
self._ready_funcs.append(func)
|
|
return func
|
|
|
|
@staticmethod
|
|
async def _run_lifespan_func(
|
|
funcs: Iterable[LIFESPAN_FUNC],
|
|
) -> None:
|
|
for func in funcs:
|
|
if is_coroutine_callable(func):
|
|
await cast(ASYNC_LIFESPAN_FUNC, func)()
|
|
else:
|
|
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
|
|
|
async def startup(self) -> None:
|
|
# create background task group
|
|
self.task_group = anyio.create_task_group()
|
|
await self.task_group.__aenter__()
|
|
|
|
# run startup funcs
|
|
if self._startup_funcs:
|
|
await self._run_lifespan_func(self._startup_funcs)
|
|
|
|
# run ready funcs
|
|
if self._ready_funcs:
|
|
await self._run_lifespan_func(self._ready_funcs)
|
|
|
|
async def shutdown(
|
|
self,
|
|
*,
|
|
exc_type: Optional[type[BaseException]] = None,
|
|
exc_val: Optional[BaseException] = None,
|
|
exc_tb: Optional[TracebackType] = None,
|
|
) -> None:
|
|
if self._shutdown_funcs:
|
|
# reverse shutdown funcs to ensure stack order
|
|
await self._run_lifespan_func(reversed(self._shutdown_funcs))
|
|
|
|
# shutdown background task group
|
|
self.task_group.cancel_scope.cancel()
|
|
|
|
with suppress(Exception):
|
|
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
|
|
|
|
self._task_group = None
|
|
|
|
async def __aenter__(self) -> None:
|
|
await self.startup()
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[type[BaseException]],
|
|
exc_val: Optional[BaseException],
|
|
exc_tb: Optional[TracebackType],
|
|
) -> None:
|
|
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)
|