LiteyukiBot/liteyuki/comm/storage.py
2024-08-17 00:18:06 +08:00

229 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
共享内存模块。类似于redis但是更加轻量级并且线程安全
"""
import threading
from typing import Optional
from liteyuki.comm.channel import Channel
from liteyuki.utils import IS_MAIN_PROCESS
if IS_MAIN_PROCESS:
_locks = {}
def _get_lock(key) -> threading.Lock:
"""
获取锁
"""
if IS_MAIN_PROCESS:
if key not in _locks:
_locks[key] = threading.Lock()
return _locks[key]
else:
raise RuntimeError("Cannot get lock in sub process.")
class KeyValueStore:
def __init__(self):
self._store = {}
self.active_chan = Channel(_id="shared_memory-active")
self.passive_chan = Channel(_id="shared_memory-passive")
def set(self, key: str, value: any) -> None:
"""
设置键值对
Args:
key: 键
value: 值
"""
if IS_MAIN_PROCESS:
lock = _get_lock(key)
with lock:
self._store[key] = value
else:
# 向主进程发送请求拿取
self.passive_chan.send(("set", key, value))
def get(self, key: str, default: Optional[any] = None) -> any:
"""
获取键值对
Args:
key: 键
default: 默认值
Returns:
any: 值
"""
if IS_MAIN_PROCESS:
lock = _get_lock(key)
with lock:
return self._store.get(key, default)
else:
self.passive_chan.send(("get", key, default))
return self.active_chan.receive()
def delete(self, key: str, ignore_key_error: bool = True) -> None:
"""
删除键值对
Args:
key: 键
ignore_key_error: 是否忽略键不存在的错误
Returns:
"""
if IS_MAIN_PROCESS:
lock = _get_lock(key)
with lock:
if key in self._store:
try:
del self._store[key]
del _locks[key]
except KeyError as e:
if not ignore_key_error:
raise e
else:
# 向主进程发送请求删除
self.passive_chan.send(("delete", key))
def get_all(self) -> dict[str, any]:
"""
获取所有键值对
Returns:
dict[str, any]: 键值对
"""
if IS_MAIN_PROCESS:
return self._store
else:
self.passive_chan.send(("get_all",))
return self.active_chan.receive()
class KeyValueStoreNoLock:
def __init__(self):
self._store = {}
self.active_chan = Channel(_id="shared_memory-active")
self.passive_chan = Channel(_id="shared_memory-passive")
def set(self, key: str, value: any) -> None:
"""
设置键值对
Args:
key: 键
value: 值
"""
if IS_MAIN_PROCESS:
self._store[key] = value
else:
# 向主进程发送请求拿取
self.passive_chan.send(("set", key, value))
def get(self, key: str, default: Optional[any] = None) -> any:
"""
获取键值对
Args:
key: 键
default: 默认值
Returns:
any: 值
"""
if IS_MAIN_PROCESS:
return self._store.get(key, default)
else:
temp_chan = Channel("temp_chan")
self.passive_chan.send(("get", key, default, temp_chan))
return temp_chan.receive()
def delete(self, key: str, ignore_key_error: bool = True) -> None:
"""
删除键值对
Args:
key: 键
ignore_key_error: 是否忽略键不存在的错误
Returns:
"""
if IS_MAIN_PROCESS:
if key in self._store:
try:
del self._store[key]
del _locks[key]
except KeyError as e:
if not ignore_key_error:
raise e
else:
# 向主进程发送请求删除
self.passive_chan.send(("delete", key))
def get_all(self) -> dict[str, any]:
"""
获取所有键值对
Returns:
dict[str, any]: 键值对
"""
if IS_MAIN_PROCESS:
return self._store
else:
temp_chan = Channel("temp_chan")
self.passive_chan.send(("get_all", temp_chan))
return temp_chan.receive()
class GlobalKeyValueStore:
_instance = None
_lock = threading.Lock()
@classmethod
def get_instance(cls):
if IS_MAIN_PROCESS:
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = KeyValueStore()
return cls._instance
else:
raise RuntimeError("Cannot get instance in sub process.")
shared_memory: Optional[KeyValueStore] = None
# 全局单例访问点
if IS_MAIN_PROCESS:
shared_memory = GlobalKeyValueStore.get_instance()
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get")
def on_get(data: tuple[str, str, any, Channel]):
data[3].send(shared_memory.get(data[1], data[2]))
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "set")
def on_set(data: tuple[str, str, any]):
shared_memory.set(data[1], data[2])
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "delete")
def on_delete(data: tuple[str, str]):
shared_memory.delete(data[1])
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get_all")
def on_get_all(data: tuple[str, Channel]):
if data[0] == "get_all":
data[1].send(shared_memory.get_all())
else:
# 子进程在入口函数中对shared_memory进行初始化
shared_memory = None
_ref_count = 0 # import 引用计数, 防止获取空指针
if not IS_MAIN_PROCESS:
if (shared_memory is None) and _ref_count > 1:
raise RuntimeError("Shared memory not initialized.")
_ref_count += 1