From 9bd07b9cede18179775dda00cb965afc63b0813c Mon Sep 17 00:00:00 2001 From: synodriver <624805065@qq.com> Date: Fri, 15 Jul 2022 10:11:19 +0800 Subject: [PATCH] :sparkles: add block driver startup/shutdown sync support (#1104) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Feature: 正向驱动器 startup/shutdown hook 支持同步函数 --- nonebot/drivers/_block_driver.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/nonebot/drivers/_block_driver.py b/nonebot/drivers/_block_driver.py index bf87659b..0c47caa3 100644 --- a/nonebot/drivers/_block_driver.py +++ b/nonebot/drivers/_block_driver.py @@ -1,15 +1,16 @@ import signal import asyncio import threading -from typing import Set, Callable, Awaitable +from typing import Set, Union, Callable, Awaitable from nonebot.log import logger from nonebot.drivers import Driver from nonebot.typing import overrides from nonebot.config import Env, Config +from nonebot.utils import run_sync, is_coroutine_callable -STARTUP_FUNC = Callable[[], Awaitable[None]] -SHUTDOWN_FUNC = Callable[[], Awaitable[None]] +STARTUP_FUNC = Callable[[], Union[None, Awaitable[None]]] +SHUTDOWN_FUNC = Callable[[], Union[None, Awaitable[None]]] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGTERM, # Unix signal 15. Sent by `kill `. @@ -69,7 +70,10 @@ class BlockDriver(Driver): async def startup(self): # run startup - cors = [startup() for startup in self.startup_funcs] + cors = [ + startup() if is_coroutine_callable(startup) else run_sync(startup)() + for startup in self.startup_funcs + ] if cors: try: await asyncio.gather(*cors) @@ -89,7 +93,10 @@ class BlockDriver(Driver): logger.info("Waiting for application shutdown.") # run shutdown - cors = [shutdown() for shutdown in self.shutdown_funcs] + cors = [ + shutdown() if is_coroutine_callable(shutdown) else run_sync(shutdown)() + for shutdown in self.shutdown_funcs + ] if cors: try: await asyncio.gather(*cors)