from collections.abc import Awaitable
from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast

from nonebot.utils import run_sync, is_coroutine_callable

SYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Any]
ASYNC_LIFESPAN_FUNC: TypeAlias = Callable[[], Awaitable[Any]]
LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]


class Lifespan:
    def __init__(self) -> None:
        self._startup_funcs: list[LIFESPAN_FUNC] = []
        self._ready_funcs: list[LIFESPAN_FUNC] = []
        self._shutdown_funcs: list[LIFESPAN_FUNC] = []

    def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
        self._startup_funcs.append(func)
        return func

    def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
        self._shutdown_funcs.append(func)
        return func

    def on_ready(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
        self._ready_funcs.append(func)
        return func

    @staticmethod
    async def _run_lifespan_func(
        funcs: list[LIFESPAN_FUNC],
    ) -> None:
        for func in funcs:
            if is_coroutine_callable(func):
                await cast(ASYNC_LIFESPAN_FUNC, func)()
            else:
                await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()

    async def startup(self) -> None:
        if self._startup_funcs:
            await self._run_lifespan_func(self._startup_funcs)

        if self._ready_funcs:
            await self._run_lifespan_func(self._ready_funcs)

    async def shutdown(self) -> None:
        if self._shutdown_funcs:
            await self._run_lifespan_func(self._shutdown_funcs)

    async def __aenter__(self) -> None:
        await self.startup()

    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
        await self.shutdown()