对通道类添加类型检查和泛型

This commit is contained in:
snowy 2024-08-17 17:41:56 +08:00
parent 03057c8ef9
commit f980e77a4a
3 changed files with 109 additions and 21 deletions

View File

@ -10,11 +10,9 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
本模块定义了一个通用的通道类用于进程间通信 本模块定义了一个通用的通道类用于进程间通信
""" """
import functools
import multiprocessing
import threading import threading
from multiprocessing import Pipe from multiprocessing import Pipe
from typing import Any, Optional, Callable, Awaitable, List, TypeAlias from typing import Any, Awaitable, Callable, Optional, TypeAlias, TypeVar, Generic, get_args
from uuid import uuid4 from uuid import uuid4
from liteyuki.utils import IS_MAIN_PROCESS, is_coroutine_callable, run_coroutine from liteyuki.utils import IS_MAIN_PROCESS, is_coroutine_callable, run_coroutine
@ -30,18 +28,21 @@ FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC
_channel: dict[str, "Channel"] = {} _channel: dict[str, "Channel"] = {}
_callback_funcs: dict[str, ON_RECEIVE_FUNC] = {} _callback_funcs: dict[str, ON_RECEIVE_FUNC] = {}
"""子进程可用的主动和被动通道""" T = TypeVar("T")
active_channel: Optional["Channel"] = None
passive_channel: Optional["Channel"] = None
class Channel: class Channel(Generic[T]):
""" """
通道类可以在进程间和进程内通信双向但同时只能有一个发送者和一个接收者 通道类可以在进程间和进程内通信双向但同时只能有一个发送者和一个接收者
有两种接收工作方式但是只能选择一种主动接收和被动接收主动接收使用 `receive` 方法被动接收使用 `on_receive` 装饰器 有两种接收工作方式但是只能选择一种主动接收和被动接收主动接收使用 `receive` 方法被动接收使用 `on_receive` 装饰器
""" """
def __init__(self, _id: str): def __init__(self, _id: str, type_check: bool = False):
"""
初始化通道
Args:
_id: 通道ID
"""
self.conn_send, self.conn_recv = Pipe() self.conn_send, self.conn_recv = Pipe()
self._closed = False self._closed = False
self._on_main_receive_funcs: list[str] = [] self._on_main_receive_funcs: list[str] = []
@ -51,20 +52,67 @@ class Channel:
self.is_main_receive_loop_running = False self.is_main_receive_loop_running = False
self.is_sub_receive_loop_running = False self.is_sub_receive_loop_running = False
if type_check:
if self._get_generic_type() is None:
raise TypeError("Type hint is required for enforcing type check.")
self.type_check = type_check
def _get_generic_type(self) -> Optional[type]:
"""
获取通道传递泛型类型
Returns:
Optional[type]: 泛型类型
"""
if hasattr(self, '__orig_class__'):
return get_args(self.__orig_class__)[0]
return None
def _validate_structure(self, data: any, structure: type | tuple | list | dict) -> bool:
"""
验证数据结构
Args:
data: 数据
structure: 结构
Returns:
bool: 是否通过验证
"""
if isinstance(structure, type):
return isinstance(data, structure)
elif isinstance(structure, tuple):
if not isinstance(data, tuple) or len(data) != len(structure):
return False
return all(self._validate_structure(d, s) for d, s in zip(data, structure))
elif isinstance(structure, list):
if not isinstance(data, list):
return False
return all(self._validate_structure(d, structure[0]) for d in data)
elif isinstance(structure, dict):
if not isinstance(data, dict):
return False
return all(k in data and self._validate_structure(data[k], structure[k]) for k in structure)
return False
def __str__(self): def __str__(self):
return f"Channel({self.name})" return f"Channel({self.name})"
def send(self, data: Any): def send(self, data: T):
""" """
发送数据 发送数据
Args: Args:
data: 数据 data: 数据
""" """
if self.type_check:
_type = self._get_generic_type()
if not self._validate_structure(data, _type):
raise TypeError(f"Data must be an instance of {_type}, {type(data)} found")
if self._closed: if self._closed:
raise RuntimeError("Cannot send to a closed channel") raise RuntimeError("Cannot send to a closed channel")
self.conn_send.send(data) self.conn_send.send(data)
def receive(self) -> Any: def receive(self) -> T:
""" """
接收数据 接收数据
Args: Args:
@ -168,6 +216,38 @@ class Channel:
return self.receive() return self.receive()
"""子进程可用的主动和被动通道"""
active_channel: Optional["Channel"] = None
passive_channel: Optional["Channel"] = None
if not IS_MAIN_PROCESS:
"""sub process only"""
active_channel: Optional["Channel"] = None
passive_channel: Optional["Channel"] = None
"""通道传递通道,主进程单例,子进程初始化时实例化"""
channel_deliver_active_channel: Optional["Channel"]
channel_deliver_passive_channel: Optional["Channel"]
if IS_MAIN_PROCESS:
channel_deliver_active_channel: Optional["Channel"] = Channel(_id="channel_deliver_active_channel")
channel_deliver_passive_channel: Optional["Channel"] = Channel(_id="channel_deliver_passive_channel")
@channel_deliver_passive_channel.on_receive(filter_func=lambda data: data[0] == "set_channel")
def on_set_channel(data: tuple[str, str, Channel, Channel]):
name, channel, temp_channel = data[1:]
temp_channel.send(set_channel(name, channel))
@channel_deliver_active_channel.on_receive(filter_func=lambda data: data[0] == "get_channel")
def on_get_channel(data: tuple[str, Channel]):
name = data[1:]
channel = get_channel()
return channel
else:
channel_deliver_active_channel = None
channel_deliver_passive_channel = None
def set_channel(name: str, channel: Channel): def set_channel(name: str, channel: Channel):
""" """
设置通道实例 设置通道实例
@ -175,12 +255,16 @@ def set_channel(name: str, channel: Channel):
name: 通道名称 name: 通道名称
channel: 通道实例 channel: 通道实例
""" """
if not IS_MAIN_PROCESS:
raise RuntimeError(f"Function {__name__} should only be called in the main process.")
if not isinstance(channel, Channel): if not isinstance(channel, Channel):
raise TypeError(f"channel must be an instance of Channel, {type(channel)} found") raise TypeError(f"channel must be an instance of Channel, {type(channel)} found")
_channel[name] = channel
if IS_MAIN_PROCESS:
_channel[name] = channel
else:
# 请求主进程设置通道
temp_channel = Channel(_id="temp_channel")
channel_deliver_passive_channel.send(("set_channel", name, channel, temp_channel))
return temp_channel.receive()
def set_channels(channels: dict[str, Channel]): def set_channels(channels: dict[str, Channel]):

View File

@ -197,7 +197,6 @@ shared_memory: Optional[KeyValueStore] = None
if IS_MAIN_PROCESS: if IS_MAIN_PROCESS:
shared_memory = GlobalKeyValueStore.get_instance() shared_memory = GlobalKeyValueStore.get_instance()
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get") @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get")
def on_get(data: tuple[str, str, any, Channel]): def on_get(data: tuple[str, str, any, Channel]):
data[3].send(shared_memory.get(data[1], data[2])) data[3].send(shared_memory.get(data[1], data[2]))

View File

@ -16,17 +16,19 @@ import threading
from multiprocessing import Process from multiprocessing import Process
from typing import Any, Callable, TYPE_CHECKING, TypeAlias from typing import Any, Callable, TYPE_CHECKING, TypeAlias
from liteyuki.comm import Channel, get_channel, set_channels from liteyuki.comm.channel import Channel, get_channel, set_channels
from liteyuki.comm.storage import shared_memory from liteyuki.comm.storage import shared_memory
from liteyuki.log import logger from liteyuki.log import logger
from liteyuki.utils import IS_MAIN_PROCESS from liteyuki.utils import IS_MAIN_PROCESS
TARGET_FUNC: TypeAlias = Callable[..., Any] if IS_MAIN_PROCESS:
from liteyuki.comm.channel import channel_deliver_active_channel, channel_deliver_passive_channel
if TYPE_CHECKING: if TYPE_CHECKING:
from liteyuki.bot import LiteyukiBot from liteyuki.bot import LiteyukiBot
from liteyuki.comm.storage import KeyValueStore from liteyuki.comm.storage import KeyValueStore
TARGET_FUNC: TypeAlias = Callable[..., Any]
TIMEOUT = 10 TIMEOUT = 10
__all__ = [ __all__ = [
@ -34,18 +36,21 @@ __all__ = [
] ]
# Update the delivery_channel_wrapper function to return the top-level wrapper # 函数处理一些跨进程通道的
def _delivery_channel_wrapper(func: TARGET_FUNC, chan_active: Channel, chan_passive: Channel, sm: "KeyValueStore", *args, **kwargs): def _delivery_channel_wrapper(func: TARGET_FUNC, chan_active: Channel, chan_passive: Channel, sm: "KeyValueStore", *args, **kwargs):
""" """
子进程入口函数 子进程入口函数
处理一些操作
""" """
# 给子进程设置通道 # 给子进程设置通道
if IS_MAIN_PROCESS: if IS_MAIN_PROCESS:
raise RuntimeError("Function should only be called in a sub process.") raise RuntimeError("Function should only be called in a sub process.")
from liteyuki.comm import channel from liteyuki.comm import channel # type Module
channel.active_channel = chan_active channel.active_channel = chan_active # 子进程主动通道
channel.passive_channel = chan_passive channel.passive_channel = chan_passive # 子进程被动通道
channel.channel_deliver_active_channel = channel_deliver_active_channel # 子进程通道传递主动通道
channel.channel_deliver_passive_channel = channel_deliver_passive_channel # 子进程通道传递被动通道
# 给子进程创建共享内存实例 # 给子进程创建共享内存实例
from liteyuki.comm import storage from liteyuki.comm import storage