新增线程安全共享内存储存器

This commit is contained in:
snowy 2024-08-16 21:43:29 +08:00
parent dd00e6ecec
commit 1b692dd13f
4 changed files with 43 additions and 13 deletions

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
from liteyuki.bot.lifespan import (LIFESPAN_FUNC, Lifespan) from liteyuki.bot.lifespan import (LIFESPAN_FUNC, Lifespan)
from liteyuki.comm import get_channel from liteyuki.comm import get_channel
from liteyuki.core import IS_MAIN_PROCESS from liteyuki.utils import IS_MAIN_PROCESS
from liteyuki.core.manager import ProcessManager from liteyuki.core.manager import ProcessManager
from liteyuki.log import init_log, logger from liteyuki.log import init_log, logger
from liteyuki.plugin import load_plugins from liteyuki.plugin import load_plugins

View File

@ -30,6 +30,13 @@ class KeyValueStore:
self.passive_chan = Channel(_id="shared_memory-passive") self.passive_chan = Channel(_id="shared_memory-passive")
def set(self, key: str, value: any) -> None: def set(self, key: str, value: any) -> None:
"""
设置键值对
Args:
key:
value:
"""
if IS_MAIN_PROCESS: if IS_MAIN_PROCESS:
lock = _get_lock(key) lock = _get_lock(key)
with lock: with lock:
@ -39,6 +46,15 @@ class KeyValueStore:
self.passive_chan.send(("set", key, value)) self.passive_chan.send(("set", key, value))
def get(self, key: str, default: Optional[any] = None) -> any: def get(self, key: str, default: Optional[any] = None) -> any:
"""
获取键值对
Args:
key:
default: 默认值
Returns:
any:
"""
if IS_MAIN_PROCESS: if IS_MAIN_PROCESS:
lock = _get_lock(key) lock = _get_lock(key)
with lock: with lock:
@ -47,18 +63,35 @@ class KeyValueStore:
self.passive_chan.send(("get", key, default)) self.passive_chan.send(("get", key, default))
return self.active_chan.receive() return self.active_chan.receive()
def delete(self, key: str) -> None: def delete(self, key: str, ignore_key_error: bool = True) -> None:
"""
删除键值对
Args:
key:
ignore_key_error: 是否忽略键不存在的错误
Returns:
"""
if IS_MAIN_PROCESS: if IS_MAIN_PROCESS:
lock = _get_lock(key) lock = _get_lock(key)
with lock: with lock:
if key in self._store: if key in self._store:
del self._store[key] try:
del _locks[key] del self._store[key]
del _locks[key]
except KeyError as e:
if not ignore_key_error:
raise e
else: else:
# 向主进程发送请求删除 # 向主进程发送请求删除
self.passive_chan.send(("delete", key)) self.passive_chan.send(("delete", key))
def get_all(self) -> dict[str, any]: def get_all(self) -> dict[str, any]:
"""
获取所有键值对
Returns:
dict[str, any]: 键值对
"""
if IS_MAIN_PROCESS: if IS_MAIN_PROCESS:
return self._store return self._store
else: else:
@ -91,9 +124,7 @@ if IS_MAIN_PROCESS:
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get") @shared_memory.passive_chan.on_receive(lambda d: d[0] == "get")
def on_get(d): def on_get(d):
print(shared_memory.get_all())
shared_memory.active_chan.send(shared_memory.get(d[1], d[2])) shared_memory.active_chan.send(shared_memory.get(d[1], d[2]))
print("发送数据:", shared_memory.get(d[1], d[2]))
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "set") @shared_memory.passive_chan.on_receive(lambda d: d[0] == "set")
@ -104,6 +135,11 @@ if IS_MAIN_PROCESS:
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "delete") @shared_memory.passive_chan.on_receive(lambda d: d[0] == "delete")
def on_delete(d): def on_delete(d):
shared_memory.delete(d[1]) shared_memory.delete(d[1])
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all")
def on_get_all(d):
shared_memory.active_chan.send(shared_memory.get_all())
else: else:
shared_memory = None shared_memory = None

View File

@ -2,9 +2,3 @@ import multiprocessing
from .manager import * from .manager import *
__all__ = [
"IS_MAIN_PROCESS"
]
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"

View File

@ -11,7 +11,7 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
import nonebot import nonebot
from liteyuki.core import IS_MAIN_PROCESS from liteyuki.utils import IS_MAIN_PROCESS
from liteyuki.plugin import PluginMetadata from liteyuki.plugin import PluginMetadata
from .nb_utils import adapter_manager, driver_manager from .nb_utils import adapter_manager, driver_manager