diff --git a/liteyuki/comm/channel.py b/liteyuki/comm/channel.py index df4fa496..c5a422bb 100644 --- a/liteyuki/comm/channel.py +++ b/liteyuki/comm/channel.py @@ -10,11 +10,9 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved 本模块定义了一个通用的通道类,用于进程间通信 """ -import functools -import multiprocessing import threading from multiprocessing import Pipe -from typing import Any, Optional, Callable, Awaitable, List, TypeAlias +from typing import Any, Awaitable, Callable, Optional, TypeAlias, TypeVar, Generic, get_args from uuid import uuid4 from liteyuki.utils import IS_MAIN_PROCESS, is_coroutine_callable, run_coroutine @@ -30,18 +28,21 @@ FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC _channel: dict[str, "Channel"] = {} _callback_funcs: dict[str, ON_RECEIVE_FUNC] = {} -"""子进程可用的主动和被动通道""" -active_channel: Optional["Channel"] = None -passive_channel: Optional["Channel"] = None +T = TypeVar("T") -class Channel: +class Channel(Generic[T]): """ 通道类,可以在进程间和进程内通信,双向但同时只能有一个发送者和一个接收者 有两种接收工作方式,但是只能选择一种,主动接收和被动接收,主动接收使用 `receive` 方法,被动接收使用 `on_receive` 装饰器 """ - def __init__(self, _id: str): + def __init__(self, _id: str, type_check: bool = False): + """ + 初始化通道 + Args: + _id: 通道ID + """ self.conn_send, self.conn_recv = Pipe() self._closed = False self._on_main_receive_funcs: list[str] = [] @@ -51,20 +52,67 @@ class Channel: self.is_main_receive_loop_running = False self.is_sub_receive_loop_running = False + if type_check: + if self._get_generic_type() is None: + raise TypeError("Type hint is required for enforcing type check.") + self.type_check = type_check + + def _get_generic_type(self) -> Optional[type]: + """ + 获取通道传递泛型类型 + + Returns: + Optional[type]: 泛型类型 + """ + if hasattr(self, '__orig_class__'): + return get_args(self.__orig_class__)[0] + return None + + def _validate_structure(self, data: any, structure: type | tuple | list | dict) -> bool: + """ + 验证数据结构 + Args: + data: 数据 + structure: 结构 + + Returns: + bool: 是否通过验证 + """ + if isinstance(structure, type): + return isinstance(data, structure) + elif isinstance(structure, tuple): + if not isinstance(data, tuple) or len(data) != len(structure): + return False + return all(self._validate_structure(d, s) for d, s in zip(data, structure)) + elif isinstance(structure, list): + if not isinstance(data, list): + return False + return all(self._validate_structure(d, structure[0]) for d in data) + elif isinstance(structure, dict): + if not isinstance(data, dict): + return False + return all(k in data and self._validate_structure(data[k], structure[k]) for k in structure) + return False + def __str__(self): return f"Channel({self.name})" - def send(self, data: Any): + def send(self, data: T): """ 发送数据 Args: data: 数据 """ + if self.type_check: + _type = self._get_generic_type() + if not self._validate_structure(data, _type): + raise TypeError(f"Data must be an instance of {_type}, {type(data)} found") + if self._closed: raise RuntimeError("Cannot send to a closed channel") self.conn_send.send(data) - def receive(self) -> Any: + def receive(self) -> T: """ 接收数据 Args: @@ -168,6 +216,38 @@ class Channel: return self.receive() +"""子进程可用的主动和被动通道""" +active_channel: Optional["Channel"] = None +passive_channel: Optional["Channel"] = None +if not IS_MAIN_PROCESS: + """sub process only""" + active_channel: Optional["Channel"] = None + passive_channel: Optional["Channel"] = None + +"""通道传递通道,主进程单例,子进程初始化时实例化""" +channel_deliver_active_channel: Optional["Channel"] +channel_deliver_passive_channel: Optional["Channel"] +if IS_MAIN_PROCESS: + channel_deliver_active_channel: Optional["Channel"] = Channel(_id="channel_deliver_active_channel") + channel_deliver_passive_channel: Optional["Channel"] = Channel(_id="channel_deliver_passive_channel") + + + @channel_deliver_passive_channel.on_receive(filter_func=lambda data: data[0] == "set_channel") + def on_set_channel(data: tuple[str, str, Channel, Channel]): + name, channel, temp_channel = data[1:] + temp_channel.send(set_channel(name, channel)) + + + @channel_deliver_active_channel.on_receive(filter_func=lambda data: data[0] == "get_channel") + def on_get_channel(data: tuple[str, Channel]): + name = data[1:] + channel = get_channel() + return channel +else: + channel_deliver_active_channel = None + channel_deliver_passive_channel = None + + def set_channel(name: str, channel: Channel): """ 设置通道实例 @@ -175,12 +255,16 @@ def set_channel(name: str, channel: Channel): name: 通道名称 channel: 通道实例 """ - if not IS_MAIN_PROCESS: - raise RuntimeError(f"Function {__name__} should only be called in the main process.") - if not isinstance(channel, Channel): raise TypeError(f"channel must be an instance of Channel, {type(channel)} found") - _channel[name] = channel + + if IS_MAIN_PROCESS: + _channel[name] = channel + else: + # 请求主进程设置通道 + temp_channel = Channel(_id="temp_channel") + channel_deliver_passive_channel.send(("set_channel", name, channel, temp_channel)) + return temp_channel.receive() def set_channels(channels: dict[str, Channel]): diff --git a/liteyuki/comm/storage.py b/liteyuki/comm/storage.py index b01b2083..5c8982ee 100644 --- a/liteyuki/comm/storage.py +++ b/liteyuki/comm/storage.py @@ -197,7 +197,6 @@ shared_memory: Optional[KeyValueStore] = None if IS_MAIN_PROCESS: shared_memory = GlobalKeyValueStore.get_instance() - @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get") def on_get(data: tuple[str, str, any, Channel]): data[3].send(shared_memory.get(data[1], data[2])) diff --git a/liteyuki/core/manager.py b/liteyuki/core/manager.py index 9fc52f29..e5ebe01c 100644 --- a/liteyuki/core/manager.py +++ b/liteyuki/core/manager.py @@ -16,17 +16,19 @@ import threading from multiprocessing import Process from typing import Any, Callable, TYPE_CHECKING, TypeAlias -from liteyuki.comm import Channel, get_channel, set_channels +from liteyuki.comm.channel import Channel, get_channel, set_channels from liteyuki.comm.storage import shared_memory from liteyuki.log import logger from liteyuki.utils import IS_MAIN_PROCESS -TARGET_FUNC: TypeAlias = Callable[..., Any] +if IS_MAIN_PROCESS: + from liteyuki.comm.channel import channel_deliver_active_channel, channel_deliver_passive_channel if TYPE_CHECKING: from liteyuki.bot import LiteyukiBot from liteyuki.comm.storage import KeyValueStore +TARGET_FUNC: TypeAlias = Callable[..., Any] TIMEOUT = 10 __all__ = [ @@ -34,18 +36,21 @@ __all__ = [ ] -# Update the delivery_channel_wrapper function to return the top-level wrapper +# 函数处理一些跨进程通道的 def _delivery_channel_wrapper(func: TARGET_FUNC, chan_active: Channel, chan_passive: Channel, sm: "KeyValueStore", *args, **kwargs): """ 子进程入口函数 + 处理一些操作 """ # 给子进程设置通道 if IS_MAIN_PROCESS: raise RuntimeError("Function should only be called in a sub process.") - from liteyuki.comm import channel - channel.active_channel = chan_active - channel.passive_channel = chan_passive + from liteyuki.comm import channel # type Module + channel.active_channel = chan_active # 子进程主动通道 + channel.passive_channel = chan_passive # 子进程被动通道 + channel.channel_deliver_active_channel = channel_deliver_active_channel # 子进程通道传递主动通道 + channel.channel_deliver_passive_channel = channel_deliver_passive_channel # 子进程通道传递被动通道 # 给子进程创建共享内存实例 from liteyuki.comm import storage