diff --git a/liteyuki/comm/storage.py b/liteyuki/comm/storage.py index 5ef62d2b..b94ec9d4 100644 --- a/liteyuki/comm/storage.py +++ b/liteyuki/comm/storage.py @@ -155,29 +155,36 @@ class GlobalKeyValueStore: if IS_MAIN_PROCESS: shared_memory: KeyValueStore = GlobalKeyValueStore.get_instance() + @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get") - def on_get(): - # TODO - pass + def on_get(data: tuple[str, dict[str, Any]]): + key = data[1]["key"] + default = data[1]["default"] + recv_chan = data[1]["recv_chan"] + recv_chan.send(shared_memory.get(key, default)) @shared_memory.passive_chan.on_receive(lambda d: d[0] == "set") - def on_set(data: tuple[str, str, Any]): - shared_memory.set(data[1], data[2]) + def on_set(data: tuple[str, dict[str, Any]]): + key = data[1]["key"] + value = data[1]["value"] + shared_memory.set(key, value) @shared_memory.passive_chan.on_receive(lambda d: d[0] == "delete") - def on_delete(data: tuple[str, str]): - shared_memory.delete(data[1]) + def on_delete(data: tuple[str, dict[str, Any]]): + key = data[1]["key"] + shared_memory.delete(key) @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all") - def on_get_all(data: tuple[str, Channel]): - if data[0] == "get_all": - data[1].send(shared_memory.get_all()) + def on_get_all(data: tuple[str, dict[str, Any]]): + recv_chan = data[1]["recv_chan"] + recv_chan.send(shared_memory.get_all()) + else: # 子进程在入口函数中对shared_memory进行初始化 - shared_memory: Optional[KeyValueStore] = None # type: ignore + shared_memory: Optional[KeyValueStore] = None # type: ignore _ref_count = 0 # import 引用计数, 防止获取空指针 if not IS_MAIN_PROCESS: