2024-07-31 02:28:25 +08:00
# -*- coding: utf-8 -*-
"""
Copyright ( C ) 2020 - 2024 LiteyukiStudio . All Rights Reserved
@Time : 2024 / 7 / 27 上午11 : 12
@Author : snowykami
@Email : snowykami @outlook.com
@File : manager . py
@Software : PyCharm
"""
2024-08-18 11:51:18 +08:00
import multiprocessing
2024-08-16 21:38:22 +08:00
import threading
2024-07-31 02:28:25 +08:00
from multiprocessing import Process
2024-08-16 21:38:22 +08:00
from typing import Any , Callable , TYPE_CHECKING , TypeAlias
2024-07-31 02:28:25 +08:00
2024-08-19 23:47:39 +08:00
from liteyuki . comm . channel import Channel , get_channel , set_channels , publish_channel
2024-08-16 21:38:22 +08:00
from liteyuki . comm . storage import shared_memory
2024-07-31 02:28:25 +08:00
from liteyuki . log import logger
2024-08-16 21:38:22 +08:00
from liteyuki . utils import IS_MAIN_PROCESS
2024-07-31 02:28:25 +08:00
2024-08-08 18:06:03 +08:00
if TYPE_CHECKING :
2024-08-17 23:46:43 +08:00
from liteyuki . bot . lifespan import Lifespan
2024-08-16 21:38:22 +08:00
from liteyuki . comm . storage import KeyValueStore
2024-08-08 18:06:03 +08:00
2024-08-17 19:10:03 +08:00
if IS_MAIN_PROCESS :
from liteyuki . comm . channel import channel_deliver_active_channel , channel_deliver_passive_channel
else :
from liteyuki . comm import channel
2024-08-17 17:41:56 +08:00
TARGET_FUNC : TypeAlias = Callable [ . . . , Any ]
2024-07-31 02:28:25 +08:00
TIMEOUT = 10
__all__ = [
" ProcessManager "
]
2024-08-18 11:51:18 +08:00
multiprocessing . set_start_method ( " spawn " , force = True )
2024-07-31 02:28:25 +08:00
2024-08-17 19:10:03 +08:00
class ChannelDeliver :
def __init__ (
self ,
active : Channel [ Any ] ,
passive : Channel [ Any ] ,
channel_deliver_active : Channel [ Channel [ Any ] ] ,
2024-08-19 23:47:39 +08:00
channel_deliver_passive : Channel [ tuple [ str , dict ] ] ,
publish : Channel [ tuple [ str , Any ] ] ,
2024-08-17 19:10:03 +08:00
) :
self . active = active
self . passive = passive
self . channel_deliver_active = channel_deliver_active
self . channel_deliver_passive = channel_deliver_passive
2024-08-19 23:47:39 +08:00
self . publish = publish
2024-08-17 19:10:03 +08:00
2024-08-17 17:41:56 +08:00
# 函数处理一些跨进程通道的
2024-08-17 19:10:03 +08:00
def _delivery_channel_wrapper ( func : TARGET_FUNC , cd : ChannelDeliver , sm : " KeyValueStore " , * args , * * kwargs ) :
2024-08-16 21:38:22 +08:00
"""
子进程入口函数
2024-08-17 17:41:56 +08:00
处理一些操作
2024-08-16 21:38:22 +08:00
"""
# 给子进程设置通道
if IS_MAIN_PROCESS :
raise RuntimeError ( " Function should only be called in a sub process. " )
2024-08-17 19:10:03 +08:00
channel . active_channel = cd . active # 子进程主动通道
channel . passive_channel = cd . passive # 子进程被动通道
channel . channel_deliver_active_channel = cd . channel_deliver_active # 子进程通道传递主动通道
channel . channel_deliver_passive_channel = cd . channel_deliver_passive # 子进程通道传递被动通道
2024-08-19 23:47:39 +08:00
channel . publish_channel = cd . publish # 子进程发布通道
2024-08-16 21:38:22 +08:00
# 给子进程创建共享内存实例
from liteyuki . comm import storage
storage . shared_memory = sm
func ( * args , * * kwargs )
2024-07-31 02:28:25 +08:00
class ProcessManager :
"""
2024-08-16 21:38:22 +08:00
进程管理器
2024-07-31 02:28:25 +08:00
"""
2024-08-17 23:46:43 +08:00
def __init__ ( self , lifespan : " Lifespan " ) :
self . lifespan = lifespan
2024-08-17 19:10:03 +08:00
self . targets : dict [ str , tuple [ Callable , tuple , dict ] ] = { }
2024-08-08 18:06:03 +08:00
self . processes : dict [ str , Process ] = { }
2024-07-31 02:28:25 +08:00
2024-08-15 16:40:29 +08:00
def start ( self , name : str ) :
2024-07-31 02:28:25 +08:00
"""
2024-08-08 18:06:03 +08:00
开启后自动监控进程 , 并添加到进程字典中
2024-07-31 02:28:25 +08:00
Args :
name :
Returns :
"""
2024-08-08 18:06:03 +08:00
if name not in self . targets :
2024-07-31 02:28:25 +08:00
raise KeyError ( f " Process { name } not found. " )
2024-08-15 16:40:29 +08:00
chan_active = get_channel ( f " { name } -active " )
def _start_process ( ) :
process = Process ( target = self . targets [ name ] [ 0 ] , args = self . targets [ name ] [ 1 ] ,
2024-08-16 23:43:43 +08:00
kwargs = self . targets [ name ] [ 2 ] , daemon = True )
2024-08-15 16:40:29 +08:00
self . processes [ name ] = process
process . start ( )
# 启动进程并监听信号
_start_process ( )
2024-08-16 23:43:43 +08:00
while True :
data = chan_active . receive ( )
if data == 0 :
# 停止
logger . info ( f " Stopping process { name } " )
2024-08-17 23:46:43 +08:00
self . lifespan . before_process_shutdown ( )
2024-08-16 23:43:43 +08:00
self . terminate ( name )
break
elif data == 1 :
# 重启
logger . info ( f " Restarting process { name } " )
2024-08-17 23:46:43 +08:00
self . lifespan . before_process_shutdown ( )
self . lifespan . before_process_restart ( )
2024-08-16 23:43:43 +08:00
self . terminate ( name )
_start_process ( )
continue
else :
logger . warning ( " Unknown data received, ignored. " )
2024-08-16 02:56:50 +08:00
2024-08-15 16:40:29 +08:00
def start_all ( self ) :
"""
启动所有进程
"""
for name in self . targets :
2024-08-15 17:32:02 +08:00
threading . Thread ( target = self . start , args = ( name , ) , daemon = True ) . start ( )
def add_target ( self , name : str , target : TARGET_FUNC , args : tuple = ( ) , kwargs = None ) :
2024-08-12 02:40:51 +08:00
"""
添加进程
Args :
name : 进程名 , 用于获取和唯一标识
target : 进程函数
args : 进程函数参数
kwargs : 进程函数关键字参数 , 通常会默认传入chan_active和chan_passive
"""
if kwargs is None :
kwargs = { }
2024-08-17 19:10:03 +08:00
chan_active : Channel = Channel ( _id = f " { name } -active " )
chan_passive : Channel = Channel ( _id = f " { name } -passive " )
channel_deliver = ChannelDeliver (
active = chan_active ,
passive = chan_passive ,
channel_deliver_active = channel_deliver_active_channel ,
2024-08-19 23:47:39 +08:00
channel_deliver_passive = channel_deliver_passive_channel ,
publish = publish_channel
2024-08-17 19:10:03 +08:00
)
2024-08-16 21:38:22 +08:00
2024-08-17 19:10:03 +08:00
self . targets [ name ] = ( _delivery_channel_wrapper , ( target , channel_deliver , shared_memory , * args ) , kwargs )
2024-08-16 21:38:22 +08:00
# 主进程通道
2024-08-12 02:40:51 +08:00
set_channels (
{
f " { name } -active " : chan_active ,
f " { name } -passive " : chan_passive
}
)
2024-08-16 21:38:22 +08:00
def join_all ( self ) :
2024-08-08 18:06:03 +08:00
for name , process in self . targets :
2024-07-31 02:28:25 +08:00
process . join ( )
2024-08-08 18:06:03 +08:00
def terminate ( self , name : str ) :
"""
终止进程并从进程字典中删除
Args :
name :
Returns :
"""
2024-08-16 23:43:43 +08:00
if name not in self . processes :
logger . warning ( f " Process { name } not found. " )
return
2024-08-08 18:06:03 +08:00
process = self . processes [ name ]
process . terminate ( )
process . join ( TIMEOUT )
if process . is_alive ( ) :
process . kill ( )
2024-08-15 16:40:29 +08:00
logger . success ( f " Process { name } terminated. " )
2024-08-10 22:25:41 +08:00
def terminate_all ( self ) :
for name in self . targets :
self . terminate ( name )
2024-08-12 04:45:59 +08:00
def is_process_alive ( self , name : str ) - > bool :
"""
检查进程是否存活
Args :
name :
Returns :
"""
if name not in self . targets :
2024-08-17 19:10:03 +08:00
logger . warning ( f " Process { name } not found. " )
2024-08-15 16:40:29 +08:00
return self . processes [ name ] . is_alive ( )