From 85a13251a5cc97e6351f60efcba8b547021694d2 Mon Sep 17 00:00:00 2001 From: snowy Date: Fri, 16 Aug 2024 23:43:43 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20Ctrl+C=20=E6=97=A0=E6=B3=95=E7=BB=88?= =?UTF-8?q?=E6=AD=A2Channel=E6=8E=A5=E6=94=B6=E7=BA=BF=E7=A8=8B=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- liteyuki/bot/__init__.py | 29 ++++++-- liteyuki/comm/channel.py | 4 +- liteyuki/comm/storage.py | 74 ++++++++++++++++++- liteyuki/core/manager.py | 56 ++++++-------- .../nonebot_launcher/__init__.py | 1 - src/utils/base/__init__.py | 10 +-- 6 files changed, 123 insertions(+), 51 deletions(-) diff --git a/liteyuki/bot/__init__.py b/liteyuki/bot/__init__.py index d7a2ce35..49a85570 100644 --- a/liteyuki/bot/__init__.py +++ b/liteyuki/bot/__init__.py @@ -1,6 +1,8 @@ import asyncio +import atexit import os import platform +import signal import sys import threading import time @@ -38,11 +40,16 @@ class LiteyukiBot: self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) + self.stop_event = threading.Event() self.call_restart_count = 0 load_plugins("liteyuki/plugins") # 加载轻雪插件 + signal.signal(signal.SIGINT, self._handle_exit) + signal.signal(signal.SIGTERM, self._handle_exit) + atexit.register(self.process_manager.terminate_all) # 注册退出时的函数 + def run(self): """ 启动逻辑 @@ -60,12 +67,24 @@ class LiteyukiBot: """ try: while not self.stop_event.is_set(): - time.sleep(1) + time.sleep(0.5) except KeyboardInterrupt: logger.info("Liteyuki is stopping...") self.stop() - finally: - self.lifespan.after_shutdown() + + def _handle_exit(self, signum, frame): + """ + 信号处理 + Args: + signum: + frame: + + Returns: + + """ + logger.info("Received signal, stopping all processes.") + self.stop() + sys.exit(0) def restart(self, delay: int = 0): """ @@ -116,16 +135,12 @@ class LiteyukiBot: Returns: """ - self.init_config() self.init_logger() def init_logger(self): # 修改nonebot的日志配置 init_log(config=self.config) - def init_config(self): - pass - def stop(self): """ 停止轻雪 diff --git a/liteyuki/comm/channel.py b/liteyuki/comm/channel.py index b7ae95d2..df4fa496 100644 --- a/liteyuki/comm/channel.py +++ b/liteyuki/comm/channel.py @@ -93,10 +93,10 @@ class Channel: 装饰器,装饰一个函数在接收到数据后执行 """ if (not self.is_sub_receive_loop_running) and not IS_MAIN_PROCESS: - threading.Thread(target=self._start_sub_receive_loop).start() + threading.Thread(target=self._start_sub_receive_loop, daemon=True).start() if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS: - threading.Thread(target=self._start_main_receive_loop).start() + threading.Thread(target=self._start_main_receive_loop, daemon=True).start() def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC: async def wrapper(data: Any) -> Any: diff --git a/liteyuki/comm/storage.py b/liteyuki/comm/storage.py index 4821d470..d4fa5ee3 100644 --- a/liteyuki/comm/storage.py +++ b/liteyuki/comm/storage.py @@ -102,6 +102,77 @@ class KeyValueStore: return self.active_chan.receive() +class KeyValueStoreNoLock: + def __init__(self): + self._store = {} + + self.active_chan = Channel(_id="shared_memory-active") + self.passive_chan = Channel(_id="shared_memory-passive") + + def set(self, key: str, value: any) -> None: + """ + 设置键值对 + Args: + key: 键 + value: 值 + + """ + if IS_MAIN_PROCESS: + self._store[key] = value + else: + # 向主进程发送请求拿取 + self.passive_chan.send(("set", key, value)) + + def get(self, key: str, default: Optional[any] = None) -> any: + """ + 获取键值对 + Args: + key: 键 + default: 默认值 + + Returns: + any: 值 + """ + if IS_MAIN_PROCESS: + return self._store.get(key, default) + else: + self.passive_chan.send(("get", key, default)) + return self.active_chan.receive() + + def delete(self, key: str, ignore_key_error: bool = True) -> None: + """ + 删除键值对 + Args: + key: 键 + ignore_key_error: 是否忽略键不存在的错误 + + Returns: + """ + if IS_MAIN_PROCESS: + if key in self._store: + try: + del self._store[key] + del _locks[key] + except KeyError as e: + if not ignore_key_error: + raise e + else: + # 向主进程发送请求删除 + self.passive_chan.send(("delete", key)) + + def get_all(self) -> dict[str, any]: + """ + 获取所有键值对 + Returns: + dict[str, any]: 键值对 + """ + if IS_MAIN_PROCESS: + return self._store + else: + self.passive_chan.send(("get_all",)) + return self.active_chan.receive() + + class GlobalKeyValueStore: _instance = None _lock = threading.Lock() @@ -142,7 +213,8 @@ if IS_MAIN_PROCESS: @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all") def on_get_all(d): - shared_memory.active_chan.send(shared_memory.get_all()) + if d[0] == "get_all": + shared_memory.active_chan.send(shared_memory.get_all()) else: shared_memory = None diff --git a/liteyuki/core/manager.py b/liteyuki/core/manager.py index c28346a5..9fc52f29 100644 --- a/liteyuki/core/manager.py +++ b/liteyuki/core/manager.py @@ -11,6 +11,7 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved import atexit import signal +import sys import threading from multiprocessing import Process from typing import Any, Callable, TYPE_CHECKING, TypeAlias @@ -63,10 +64,6 @@ class ProcessManager: self.targets: dict[str, tuple[callable, tuple, dict]] = {} self.processes: dict[str, Process] = {} - atexit.register(self.terminate_all) - signal.signal(signal.SIGINT, self._handle_exit) - signal.signal(signal.SIGTERM, self._handle_exit) - def start(self, name: str): """ 开启后自动监控进程,并添加到进程字典中 @@ -82,34 +79,31 @@ class ProcessManager: def _start_process(): process = Process(target=self.targets[name][0], args=self.targets[name][1], - kwargs=self.targets[name][2]) + kwargs=self.targets[name][2], daemon=True) self.processes[name] = process process.start() # 启动进程并监听信号 _start_process() - def _start_monitor(): - while True: - data = chan_active.receive() - if data == 0: - # 停止 - logger.info(f"Stopping process {name}") - self.bot.lifespan.before_process_shutdown() - self.terminate(name) - break - elif data == 1: - # 重启 - logger.info(f"Restarting process {name}") - self.bot.lifespan.before_process_shutdown() - self.bot.lifespan.before_process_restart() - self.terminate(name) - _start_process() - continue - else: - logger.warning("Unknown data received, ignored.") - - _start_monitor() + while True: + data = chan_active.receive() + if data == 0: + # 停止 + logger.info(f"Stopping process {name}") + self.bot.lifespan.before_process_shutdown() + self.terminate(name) + break + elif data == 1: + # 重启 + logger.info(f"Restarting process {name}") + self.bot.lifespan.before_process_shutdown() + self.bot.lifespan.before_process_restart() + self.terminate(name) + _start_process() + continue + else: + logger.warning("Unknown data received, ignored.") def start_all(self): """ @@ -118,11 +112,6 @@ class ProcessManager: for name in self.targets: threading.Thread(target=self.start, args=(name,), daemon=True).start() - def _handle_exit(self, signum, frame): - logger.info("Received signal, stopping all processes.") - self.terminate_all() - exit(0) - def add_target(self, name: str, target: TARGET_FUNC, args: tuple = (), kwargs=None): """ 添加进程 @@ -159,8 +148,9 @@ class ProcessManager: Returns: """ - if name not in self.targets: - raise logger.warning(f"Process {name} not found.") + if name not in self.processes: + logger.warning(f"Process {name} not found.") + return process = self.processes[name] process.terminate() process.join(TIMEOUT) diff --git a/src/liteyuki_plugins/nonebot_launcher/__init__.py b/src/liteyuki_plugins/nonebot_launcher/__init__.py index 9330bd81..31ff07ae 100644 --- a/src/liteyuki_plugins/nonebot_launcher/__init__.py +++ b/src/liteyuki_plugins/nonebot_launcher/__init__.py @@ -10,7 +10,6 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved """ import nonebot - from liteyuki.utils import IS_MAIN_PROCESS from liteyuki.plugin import PluginMetadata from .nb_utils import adapter_manager, driver_manager diff --git a/src/utils/base/__init__.py b/src/utils/base/__init__.py index 184e58d1..6f15eca5 100644 --- a/src/utils/base/__init__.py +++ b/src/utils/base/__init__.py @@ -1,7 +1,7 @@ import threading from nonebot import logger -from liteyuki.comm.channel import get_channel +from liteyuki.comm.channel import active_channel def reload(delay: float = 0.0, receiver: str = "nonebot"): @@ -14,13 +14,9 @@ def reload(delay: float = 0.0, receiver: str = "nonebot"): Returns: """ - chan = get_channel(receiver + "-active") - if chan is None: - logger.error(f"Channel {receiver}-active not found, cannot reload.") - return if delay > 0: - threading.Timer(delay, chan.send, args=(1,)).start() + threading.Timer(delay, active_channel.send, args=(1,)).start() return else: - chan.send(1) + active_channel.send(1)