diff --git a/nonebot/drivers/_block_driver.py b/nonebot/drivers/_block_driver.py index bf87659b..0c47caa3 100644 --- a/nonebot/drivers/_block_driver.py +++ b/nonebot/drivers/_block_driver.py @@ -1,15 +1,16 @@ import signal import asyncio import threading -from typing import Set, Callable, Awaitable +from typing import Set, Union, Callable, Awaitable from nonebot.log import logger from nonebot.drivers import Driver from nonebot.typing import overrides from nonebot.config import Env, Config +from nonebot.utils import run_sync, is_coroutine_callable -STARTUP_FUNC = Callable[[], Awaitable[None]] -SHUTDOWN_FUNC = Callable[[], Awaitable[None]] +STARTUP_FUNC = Callable[[], Union[None, Awaitable[None]]] +SHUTDOWN_FUNC = Callable[[], Union[None, Awaitable[None]]] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGTERM, # Unix signal 15. Sent by `kill `. @@ -69,7 +70,10 @@ class BlockDriver(Driver): async def startup(self): # run startup - cors = [startup() for startup in self.startup_funcs] + cors = [ + startup() if is_coroutine_callable(startup) else run_sync(startup)() + for startup in self.startup_funcs + ] if cors: try: await asyncio.gather(*cors) @@ -89,7 +93,10 @@ class BlockDriver(Driver): logger.info("Waiting for application shutdown.") # run shutdown - cors = [shutdown() for shutdown in self.shutdown_funcs] + cors = [ + shutdown() if is_coroutine_callable(shutdown) else run_sync(shutdown)() + for shutdown in self.shutdown_funcs + ] if cors: try: await asyncio.gather(*cors)