add block driver startup/shutdown sync support (#1104)

Feature: 正向驱动器 startup/shutdown hook 支持同步函数
This commit is contained in:
synodriver 2022-07-15 10:11:19 +08:00 committed by GitHub
parent fe5cf5624c
commit 9bd07b9ced
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,15 +1,16 @@
import signal import signal
import asyncio import asyncio
import threading import threading
from typing import Set, Callable, Awaitable from typing import Set, Union, Callable, Awaitable
from nonebot.log import logger from nonebot.log import logger
from nonebot.drivers import Driver from nonebot.drivers import Driver
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.utils import run_sync, is_coroutine_callable
STARTUP_FUNC = Callable[[], Awaitable[None]] STARTUP_FUNC = Callable[[], Union[None, Awaitable[None]]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Union[None, Awaitable[None]]]
HANDLED_SIGNALS = ( HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`. signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
@ -69,7 +70,10 @@ class BlockDriver(Driver):
async def startup(self): async def startup(self):
# run startup # run startup
cors = [startup() for startup in self.startup_funcs] cors = [
startup() if is_coroutine_callable(startup) else run_sync(startup)()
for startup in self.startup_funcs
]
if cors: if cors:
try: try:
await asyncio.gather(*cors) await asyncio.gather(*cors)
@ -89,7 +93,10 @@ class BlockDriver(Driver):
logger.info("Waiting for application shutdown.") logger.info("Waiting for application shutdown.")
# run shutdown # run shutdown
cors = [shutdown() for shutdown in self.shutdown_funcs] cors = [
shutdown() if is_coroutine_callable(shutdown) else run_sync(shutdown)()
for shutdown in self.shutdown_funcs
]
if cors: if cors:
try: try:
await asyncio.gather(*cors) await asyncio.gather(*cors)