From 8e27f6b9b03a7d57ce4fa1b85a75a13c9e0273af Mon Sep 17 00:00:00 2001 From: snowy Date: Sat, 17 Aug 2024 19:10:03 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E5=AF=B9=E9=80=9A=E9=81=93?= =?UTF-8?q?=E7=B1=BB=E6=B7=BB=E5=8A=A0=E7=B1=BB=E5=9E=8B=E6=A3=80=E6=9F=A5?= =?UTF-8?q?=E5=92=8C=E6=B3=9B=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- liteyuki/comm/channel.py | 87 +++++++++++------------ liteyuki/comm/storage.py | 149 ++++++++++++++------------------------- liteyuki/core/manager.py | 53 +++++++++----- 3 files changed, 130 insertions(+), 159 deletions(-) diff --git a/liteyuki/comm/channel.py b/liteyuki/comm/channel.py index c5a422bb..b1ffa88a 100644 --- a/liteyuki/comm/channel.py +++ b/liteyuki/comm/channel.py @@ -12,23 +12,23 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved """ import threading from multiprocessing import Pipe -from typing import Any, Awaitable, Callable, Optional, TypeAlias, TypeVar, Generic, get_args -from uuid import uuid4 +from typing import Any, Callable, Coroutine, Generic, Optional, TypeAlias, TypeVar, get_args from liteyuki.utils import IS_MAIN_PROCESS, is_coroutine_callable, run_coroutine -SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Any] -ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Awaitable[Any]] +T = TypeVar("T") + +SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[T], Any] +ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[T], Coroutine[Any, Any, Any]] ON_RECEIVE_FUNC: TypeAlias = SYNC_ON_RECEIVE_FUNC | ASYNC_ON_RECEIVE_FUNC -SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool] -ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]] +SYNC_FILTER_FUNC: TypeAlias = Callable[[T], bool] +ASYNC_FILTER_FUNC: TypeAlias = Callable[[T], Coroutine[Any, Any, bool]] FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC +_func_id: int = 0 _channel: dict[str, "Channel"] = {} -_callback_funcs: dict[str, ON_RECEIVE_FUNC] = {} - -T = TypeVar("T") +_callback_funcs: dict[int, ON_RECEIVE_FUNC] = {} class Channel(Generic[T]): @@ -45,8 +45,8 @@ class Channel(Generic[T]): """ self.conn_send, self.conn_recv = Pipe() self._closed = False - self._on_main_receive_funcs: list[str] = [] - self._on_sub_receive_funcs: list[str] = [] + self._on_main_receive_funcs: list[int] = [] + self._on_sub_receive_funcs: list[int] = [] self.name: str = _id self.is_main_receive_loop_running = False @@ -68,7 +68,7 @@ class Channel(Generic[T]): return get_args(self.__orig_class__)[0] return None - def _validate_structure(self, data: any, structure: type | tuple | list | dict) -> bool: + def _validate_structure(self, data: Any, structure: type) -> bool: """ 验证数据结构 Args: @@ -105,7 +105,7 @@ class Channel(Generic[T]): """ if self.type_check: _type = self._get_generic_type() - if not self._validate_structure(data, _type): + if _type is not None and not self._validate_structure(data, _type): raise TypeError(f"Data must be an instance of {_type}, {type(data)} found") if self._closed: @@ -132,7 +132,7 @@ class Channel(Generic[T]): self.conn_send.close() self.conn_recv.close() - def on_receive(self, filter_func: Optional[FILTER_FUNC] = None) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]: + def on_receive(self, filter_func: Optional[FILTER_FUNC] = None) -> Callable[[Callable[[T], Any]], Callable[[T], Any]]: """ 接收数据并执行函数 Args: @@ -146,11 +146,13 @@ class Channel(Generic[T]): if (not self.is_main_receive_loop_running) and IS_MAIN_PROCESS: threading.Thread(target=self._start_main_receive_loop, daemon=True).start() - def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC: - async def wrapper(data: Any) -> Any: + def decorator(func: Callable[[T], Any]) -> Callable[[T], Any]: + global _func_id + + async def wrapper(data: T) -> Any: if filter_func is not None: if is_coroutine_callable(filter_func): - if not (await filter_func(data)): + if not (await filter_func(data)): # type: ignore return else: if not filter_func(data): @@ -161,12 +163,12 @@ class Channel(Generic[T]): else: return func(data) - function_id = str(uuid4()) - _callback_funcs[function_id] = wrapper + _callback_funcs[_func_id] = wrapper if IS_MAIN_PROCESS: - self._on_main_receive_funcs.append(function_id) + self._on_main_receive_funcs.append(_func_id) else: - self._on_sub_receive_funcs.append(function_id) + self._on_sub_receive_funcs.append(_func_id) + _func_id += 1 return func return decorator @@ -219,35 +221,21 @@ class Channel(Generic[T]): """子进程可用的主动和被动通道""" active_channel: Optional["Channel"] = None passive_channel: Optional["Channel"] = None -if not IS_MAIN_PROCESS: - """sub process only""" - active_channel: Optional["Channel"] = None - passive_channel: Optional["Channel"] = None -"""通道传递通道,主进程单例,子进程初始化时实例化""" -channel_deliver_active_channel: Optional["Channel"] -channel_deliver_passive_channel: Optional["Channel"] +"""通道传递通道,主进程创建单例,子进程初始化时实例化""" +channel_deliver_active_channel: Channel[Channel[Any]] +channel_deliver_passive_channel: Channel[tuple[str, dict[str, Any]]] if IS_MAIN_PROCESS: - channel_deliver_active_channel: Optional["Channel"] = Channel(_id="channel_deliver_active_channel") - channel_deliver_passive_channel: Optional["Channel"] = Channel(_id="channel_deliver_passive_channel") + channel_deliver_active_channel = Channel(_id="channel_deliver_active_channel") + channel_deliver_passive_channel = Channel(_id="channel_deliver_passive_channel") @channel_deliver_passive_channel.on_receive(filter_func=lambda data: data[0] == "set_channel") - def on_set_channel(data: tuple[str, str, Channel, Channel]): - name, channel, temp_channel = data[1:] + def on_set_channel(data: tuple[str, dict[str, Any]]): + name, channel, temp_channel = data[1]["name"], data[1]["channel"], _channel[data[0]] temp_channel.send(set_channel(name, channel)) - @channel_deliver_active_channel.on_receive(filter_func=lambda data: data[0] == "get_channel") - def on_get_channel(data: tuple[str, Channel]): - name = data[1:] - channel = get_channel() - return channel -else: - channel_deliver_active_channel = None - channel_deliver_passive_channel = None - - def set_channel(name: str, channel: Channel): """ 设置通道实例 @@ -262,9 +250,14 @@ def set_channel(name: str, channel: Channel): _channel[name] = channel else: # 请求主进程设置通道 - temp_channel = Channel(_id="temp_channel") - channel_deliver_passive_channel.send(("set_channel", name, channel, temp_channel)) - return temp_channel.receive() + channel_deliver_passive_channel.send( + ( + "set_channel", { + "name" : name, + "channel": channel, + } + ) + ) def set_channels(channels: dict[str, Channel]): @@ -280,7 +273,7 @@ def set_channels(channels: dict[str, Channel]): set_channel(name, channel) -def get_channel(name: str) -> Optional[Channel]: +def get_channel(name: str) -> Channel: """ 获取通道实例 Args: @@ -290,7 +283,7 @@ def get_channel(name: str) -> Optional[Channel]: if not IS_MAIN_PROCESS: raise RuntimeError(f"Function {__name__} should only be called in the main process.") - return _channel.get(name, None) + return _channel[name] def get_channels() -> dict[str, Channel]: diff --git a/liteyuki/comm/storage.py b/liteyuki/comm/storage.py index 5c8982ee..5ef62d2b 100644 --- a/liteyuki/comm/storage.py +++ b/liteyuki/comm/storage.py @@ -4,7 +4,7 @@ """ import threading -from typing import Optional +from typing import Any, Optional from liteyuki.comm.channel import Channel from liteyuki.utils import IS_MAIN_PROCESS @@ -28,11 +28,10 @@ def _get_lock(key) -> threading.Lock: class KeyValueStore: def __init__(self): self._store = {} + self.active_chan = Channel[tuple[str, Optional[dict[str, Any]]]](_id="shared_memory-active") + self.passive_chan = Channel[tuple[str, Optional[dict[str, Any]]]](_id="shared_memory-passive") - self.active_chan = Channel(_id="shared_memory-active") - self.passive_chan = Channel(_id="shared_memory-passive") - - def set(self, key: str, value: any) -> None: + def set(self, key: str, value: Any) -> None: """ 设置键值对 Args: @@ -46,9 +45,17 @@ class KeyValueStore: self._store[key] = value else: # 向主进程发送请求拿取 - self.passive_chan.send(("set", key, value)) + self.passive_chan.send( + ( + "set", + { + "key" : key, + "value": value + } + ) + ) - def get(self, key: str, default: Optional[any] = None) -> any: + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: """ 获取键值对 Args: @@ -56,15 +63,26 @@ class KeyValueStore: default: 默认值 Returns: - any: 值 + Any: 值 """ if IS_MAIN_PROCESS: lock = _get_lock(key) with lock: return self._store.get(key, default) else: - self.passive_chan.send(("get", key, default)) - return self.active_chan.receive() + recv_chan = Channel[Optional[Any]]("recv_chan") + self.passive_chan.send( + ( + "get", + { + "key" : key, + "default" : default, + "recv_chan": recv_chan + } + + ) + ) + return recv_chan.receive() def delete(self, key: str, ignore_key_error: bool = True) -> None: """ @@ -87,92 +105,34 @@ class KeyValueStore: raise e else: # 向主进程发送请求删除 - self.passive_chan.send(("delete", key)) + self.passive_chan.send( + ( + "delete", + { + "key": key + } + ) + ) - def get_all(self) -> dict[str, any]: + def get_all(self) -> dict[str, Any]: """ 获取所有键值对 Returns: - dict[str, any]: 键值对 + dict[str, Any]: 键值对 """ if IS_MAIN_PROCESS: return self._store else: - self.passive_chan.send(("get_all",)) - return self.active_chan.receive() - - -class KeyValueStoreNoLock: - def __init__(self): - self._store = {} - - self.active_chan = Channel(_id="shared_memory-active") - self.passive_chan = Channel(_id="shared_memory-passive") - - def set(self, key: str, value: any) -> None: - """ - 设置键值对 - Args: - key: 键 - value: 值 - - """ - if IS_MAIN_PROCESS: - self._store[key] = value - else: - # 向主进程发送请求拿取 - self.passive_chan.send(("set", key, value)) - - def get(self, key: str, default: Optional[any] = None) -> any: - """ - 获取键值对 - Args: - key: 键 - default: 默认值 - - Returns: - any: 值 - """ - if IS_MAIN_PROCESS: - return self._store.get(key, default) - else: - temp_chan = Channel("temp_chan") - self.passive_chan.send(("get", key, default, temp_chan)) - return temp_chan.receive() - - def delete(self, key: str, ignore_key_error: bool = True) -> None: - """ - 删除键值对 - Args: - key: 键 - ignore_key_error: 是否忽略键不存在的错误 - - Returns: - """ - if IS_MAIN_PROCESS: - if key in self._store: - try: - del self._store[key] - del _locks[key] - except KeyError as e: - if not ignore_key_error: - raise e - else: - # 向主进程发送请求删除 - self.passive_chan.send(("delete", key)) - - def get_all(self) -> dict[str, any]: - """ - 获取所有键值对 - Returns: - dict[str, any]: 键值对 - """ - if IS_MAIN_PROCESS: - return self._store - else: - temp_chan = Channel("temp_chan") - self.passive_chan.send(("get_all", temp_chan)) - return temp_chan.receive() + recv_chan = Channel[dict[str, Any]]("recv_chan") + self.passive_chan.send( + ( + "get_all", + { + "recv_chan": recv_chan + } + ) + ) + return recv_chan.receive() class GlobalKeyValueStore: @@ -191,19 +151,18 @@ class GlobalKeyValueStore: raise RuntimeError("Cannot get instance in sub process.") -shared_memory: Optional[KeyValueStore] = None - # 全局单例访问点 if IS_MAIN_PROCESS: - shared_memory = GlobalKeyValueStore.get_instance() + shared_memory: KeyValueStore = GlobalKeyValueStore.get_instance() @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get") - def on_get(data: tuple[str, str, any, Channel]): - data[3].send(shared_memory.get(data[1], data[2])) + def on_get(): + # TODO + pass @shared_memory.passive_chan.on_receive(lambda d: d[0] == "set") - def on_set(data: tuple[str, str, any]): + def on_set(data: tuple[str, str, Any]): shared_memory.set(data[1], data[2]) @@ -218,7 +177,7 @@ if IS_MAIN_PROCESS: data[1].send(shared_memory.get_all()) else: # 子进程在入口函数中对shared_memory进行初始化 - shared_memory = None + shared_memory: Optional[KeyValueStore] = None # type: ignore _ref_count = 0 # import 引用计数, 防止获取空指针 if not IS_MAIN_PROCESS: diff --git a/liteyuki/core/manager.py b/liteyuki/core/manager.py index e5ebe01c..1e445b03 100644 --- a/liteyuki/core/manager.py +++ b/liteyuki/core/manager.py @@ -9,9 +9,6 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved @Software: PyCharm """ -import atexit -import signal -import sys import threading from multiprocessing import Process from typing import Any, Callable, TYPE_CHECKING, TypeAlias @@ -21,13 +18,15 @@ from liteyuki.comm.storage import shared_memory from liteyuki.log import logger from liteyuki.utils import IS_MAIN_PROCESS -if IS_MAIN_PROCESS: - from liteyuki.comm.channel import channel_deliver_active_channel, channel_deliver_passive_channel - if TYPE_CHECKING: from liteyuki.bot import LiteyukiBot from liteyuki.comm.storage import KeyValueStore +if IS_MAIN_PROCESS: + from liteyuki.comm.channel import channel_deliver_active_channel, channel_deliver_passive_channel +else: + from liteyuki.comm import channel + TARGET_FUNC: TypeAlias = Callable[..., Any] TIMEOUT = 10 @@ -36,8 +35,22 @@ __all__ = [ ] +class ChannelDeliver: + def __init__( + self, + active: Channel[Any], + passive: Channel[Any], + channel_deliver_active: Channel[Channel[Any]], + channel_deliver_passive: Channel[tuple[str, dict]] + ): + self.active = active + self.passive = passive + self.channel_deliver_active = channel_deliver_active + self.channel_deliver_passive = channel_deliver_passive + + # 函数处理一些跨进程通道的 -def _delivery_channel_wrapper(func: TARGET_FUNC, chan_active: Channel, chan_passive: Channel, sm: "KeyValueStore", *args, **kwargs): +def _delivery_channel_wrapper(func: TARGET_FUNC, cd: ChannelDeliver, sm: "KeyValueStore", *args, **kwargs): """ 子进程入口函数 处理一些操作 @@ -46,11 +59,10 @@ def _delivery_channel_wrapper(func: TARGET_FUNC, chan_active: Channel, chan_pass if IS_MAIN_PROCESS: raise RuntimeError("Function should only be called in a sub process.") - from liteyuki.comm import channel # type Module - channel.active_channel = chan_active # 子进程主动通道 - channel.passive_channel = chan_passive # 子进程被动通道 - channel.channel_deliver_active_channel = channel_deliver_active_channel # 子进程通道传递主动通道 - channel.channel_deliver_passive_channel = channel_deliver_passive_channel # 子进程通道传递被动通道 + channel.active_channel = cd.active # 子进程主动通道 + channel.passive_channel = cd.passive # 子进程被动通道 + channel.channel_deliver_active_channel = cd.channel_deliver_active # 子进程通道传递主动通道 + channel.channel_deliver_passive_channel = cd.channel_deliver_passive # 子进程通道传递被动通道 # 给子进程创建共享内存实例 from liteyuki.comm import storage @@ -66,7 +78,7 @@ class ProcessManager: def __init__(self, bot: "LiteyukiBot"): self.bot = bot - self.targets: dict[str, tuple[callable, tuple, dict]] = {} + self.targets: dict[str, tuple[Callable, tuple, dict]] = {} self.processes: dict[str, Process] = {} def start(self, name: str): @@ -128,10 +140,17 @@ class ProcessManager: """ if kwargs is None: kwargs = {} - chan_active = Channel(_id=f"{name}-active") - chan_passive = Channel(_id=f"{name}-passive") + chan_active: Channel = Channel(_id=f"{name}-active") + chan_passive: Channel = Channel(_id=f"{name}-passive") - self.targets[name] = (_delivery_channel_wrapper, (target, chan_active, chan_passive, shared_memory, *args), kwargs) + channel_deliver = ChannelDeliver( + active=chan_active, + passive=chan_passive, + channel_deliver_active=channel_deliver_active_channel, + channel_deliver_passive=channel_deliver_passive_channel + ) + + self.targets[name] = (_delivery_channel_wrapper, (target, channel_deliver, shared_memory, *args), kwargs) # 主进程通道 set_channels( { @@ -177,5 +196,5 @@ class ProcessManager: """ if name not in self.targets: - raise logger.warning(f"Process {name} not found.") + logger.warning(f"Process {name} not found.") return self.processes[name].is_alive()