🐛 Ctrl+C 无法终止Channel接收线程的问题

This commit is contained in:
远野千束 2024-08-16 23:43:43 +08:00
parent 0417805e46
commit 85a13251a5
6 changed files with 123 additions and 51 deletions

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import atexit
import os import os
import platform import platform
import signal
import sys import sys
import threading import threading
import time import time
@ -38,11 +40,16 @@ class LiteyukiBot:
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
self.stop_event = threading.Event() self.stop_event = threading.Event()
self.call_restart_count = 0 self.call_restart_count = 0
load_plugins("liteyuki/plugins") # 加载轻雪插件 load_plugins("liteyuki/plugins") # 加载轻雪插件
signal.signal(signal.SIGINT, self._handle_exit)
signal.signal(signal.SIGTERM, self._handle_exit)
atexit.register(self.process_manager.terminate_all) # 注册退出时的函数
def run(self): def run(self):
""" """
启动逻辑 启动逻辑
@ -60,12 +67,24 @@ class LiteyukiBot:
""" """
try: try:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
time.sleep(1) time.sleep(0.5)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Liteyuki is stopping...") logger.info("Liteyuki is stopping...")
self.stop() self.stop()
finally:
self.lifespan.after_shutdown() def _handle_exit(self, signum, frame):
"""
信号处理
Args:
signum:
frame:
Returns:
"""
logger.info("Received signal, stopping all processes.")
self.stop()
sys.exit(0)
def restart(self, delay: int = 0): def restart(self, delay: int = 0):
""" """
@ -116,16 +135,12 @@ class LiteyukiBot:
Returns: Returns:
""" """
self.init_config()
self.init_logger() self.init_logger()
def init_logger(self): def init_logger(self):
# 修改nonebot的日志配置 # 修改nonebot的日志配置
init_log(config=self.config) init_log(config=self.config)
def init_config(self):
pass
def stop(self): def stop(self):
""" """
停止轻雪 停止轻雪

View File

@ -93,10 +93,10 @@ class Channel:
装饰器装饰一个函数在接收到数据后执行 装饰器装饰一个函数在接收到数据后执行
""" """
if (not self.is_sub_receive_loop_running) and not IS_MAIN_PROCESS: if (not self.is_sub_receive_loop_running) and not IS_MAIN_PROCESS:
threading.Thread(target=self._start_sub_receive_loop).start() threading.Thread(target=self._start_sub_receive_loop, daemon=True).start()
if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS: if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS:
threading.Thread(target=self._start_main_receive_loop).start() threading.Thread(target=self._start_main_receive_loop, daemon=True).start()
def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC: def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC:
async def wrapper(data: Any) -> Any: async def wrapper(data: Any) -> Any:

View File

@ -102,6 +102,77 @@ class KeyValueStore:
return self.active_chan.receive() return self.active_chan.receive()
class KeyValueStoreNoLock:
def __init__(self):
self._store = {}
self.active_chan = Channel(_id="shared_memory-active")
self.passive_chan = Channel(_id="shared_memory-passive")
def set(self, key: str, value: any) -> None:
"""
设置键值对
Args:
key:
value:
"""
if IS_MAIN_PROCESS:
self._store[key] = value
else:
# 向主进程发送请求拿取
self.passive_chan.send(("set", key, value))
def get(self, key: str, default: Optional[any] = None) -> any:
"""
获取键值对
Args:
key:
default: 默认值
Returns:
any:
"""
if IS_MAIN_PROCESS:
return self._store.get(key, default)
else:
self.passive_chan.send(("get", key, default))
return self.active_chan.receive()
def delete(self, key: str, ignore_key_error: bool = True) -> None:
"""
删除键值对
Args:
key:
ignore_key_error: 是否忽略键不存在的错误
Returns:
"""
if IS_MAIN_PROCESS:
if key in self._store:
try:
del self._store[key]
del _locks[key]
except KeyError as e:
if not ignore_key_error:
raise e
else:
# 向主进程发送请求删除
self.passive_chan.send(("delete", key))
def get_all(self) -> dict[str, any]:
"""
获取所有键值对
Returns:
dict[str, any]: 键值对
"""
if IS_MAIN_PROCESS:
return self._store
else:
self.passive_chan.send(("get_all",))
return self.active_chan.receive()
class GlobalKeyValueStore: class GlobalKeyValueStore:
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
@ -142,6 +213,7 @@ if IS_MAIN_PROCESS:
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all") @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all")
def on_get_all(d): def on_get_all(d):
if d[0] == "get_all":
shared_memory.active_chan.send(shared_memory.get_all()) shared_memory.active_chan.send(shared_memory.get_all())
else: else:
shared_memory = None shared_memory = None

View File

@ -11,6 +11,7 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
import atexit import atexit
import signal import signal
import sys
import threading import threading
from multiprocessing import Process from multiprocessing import Process
from typing import Any, Callable, TYPE_CHECKING, TypeAlias from typing import Any, Callable, TYPE_CHECKING, TypeAlias
@ -63,10 +64,6 @@ class ProcessManager:
self.targets: dict[str, tuple[callable, tuple, dict]] = {} self.targets: dict[str, tuple[callable, tuple, dict]] = {}
self.processes: dict[str, Process] = {} self.processes: dict[str, Process] = {}
atexit.register(self.terminate_all)
signal.signal(signal.SIGINT, self._handle_exit)
signal.signal(signal.SIGTERM, self._handle_exit)
def start(self, name: str): def start(self, name: str):
""" """
开启后自动监控进程并添加到进程字典中 开启后自动监控进程并添加到进程字典中
@ -82,14 +79,13 @@ class ProcessManager:
def _start_process(): def _start_process():
process = Process(target=self.targets[name][0], args=self.targets[name][1], process = Process(target=self.targets[name][0], args=self.targets[name][1],
kwargs=self.targets[name][2]) kwargs=self.targets[name][2], daemon=True)
self.processes[name] = process self.processes[name] = process
process.start() process.start()
# 启动进程并监听信号 # 启动进程并监听信号
_start_process() _start_process()
def _start_monitor():
while True: while True:
data = chan_active.receive() data = chan_active.receive()
if data == 0: if data == 0:
@ -109,8 +105,6 @@ class ProcessManager:
else: else:
logger.warning("Unknown data received, ignored.") logger.warning("Unknown data received, ignored.")
_start_monitor()
def start_all(self): def start_all(self):
""" """
启动所有进程 启动所有进程
@ -118,11 +112,6 @@ class ProcessManager:
for name in self.targets: for name in self.targets:
threading.Thread(target=self.start, args=(name,), daemon=True).start() threading.Thread(target=self.start, args=(name,), daemon=True).start()
def _handle_exit(self, signum, frame):
logger.info("Received signal, stopping all processes.")
self.terminate_all()
exit(0)
def add_target(self, name: str, target: TARGET_FUNC, args: tuple = (), kwargs=None): def add_target(self, name: str, target: TARGET_FUNC, args: tuple = (), kwargs=None):
""" """
添加进程 添加进程
@ -159,8 +148,9 @@ class ProcessManager:
Returns: Returns:
""" """
if name not in self.targets: if name not in self.processes:
raise logger.warning(f"Process {name} not found.") logger.warning(f"Process {name} not found.")
return
process = self.processes[name] process = self.processes[name]
process.terminate() process.terminate()
process.join(TIMEOUT) process.join(TIMEOUT)

View File

@ -10,7 +10,6 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
""" """
import nonebot import nonebot
from liteyuki.utils import IS_MAIN_PROCESS from liteyuki.utils import IS_MAIN_PROCESS
from liteyuki.plugin import PluginMetadata from liteyuki.plugin import PluginMetadata
from .nb_utils import adapter_manager, driver_manager from .nb_utils import adapter_manager, driver_manager

View File

@ -1,7 +1,7 @@
import threading import threading
from nonebot import logger from nonebot import logger
from liteyuki.comm.channel import get_channel from liteyuki.comm.channel import active_channel
def reload(delay: float = 0.0, receiver: str = "nonebot"): def reload(delay: float = 0.0, receiver: str = "nonebot"):
@ -14,13 +14,9 @@ def reload(delay: float = 0.0, receiver: str = "nonebot"):
Returns: Returns:
""" """
chan = get_channel(receiver + "-active")
if chan is None:
logger.error(f"Channel {receiver}-active not found, cannot reload.")
return
if delay > 0: if delay > 0:
threading.Timer(delay, chan.send, args=(1,)).start() threading.Timer(delay, active_channel.send, args=(1,)).start()
return return
else: else:
chan.send(1) active_channel.send(1)