新增线程安全共享内存储存器

This commit is contained in:
远野千束 2024-08-16 21:38:22 +08:00
parent adc9b76688
commit 222250bc41
15 changed files with 246 additions and 66 deletions

View File

@ -7,7 +7,6 @@ from liteyuki.bot import (
from liteyuki.comm import ( from liteyuki.comm import (
Channel, Channel,
chan,
Event Event
) )
@ -17,7 +16,19 @@ from liteyuki.plugin import (
) )
from liteyuki.log import ( from liteyuki.log import (
logger, init_log,
init_log logger
) )
__all__ = [
"LiteyukiBot",
"get_bot",
"get_config",
"get_config_with_compat",
"Channel",
"Event",
"load_plugin",
"load_plugins",
"init_log",
"logger"
]

View File

@ -47,9 +47,9 @@ class LiteyukiBot:
""" """
启动逻辑 启动逻辑
""" """
self.lifespan.before_start() # 启动前钩子 self.lifespan.before_start() # 启动前钩子
self.process_manager.start_all() self.process_manager.start_all()
self.lifespan.after_start() # 启动后钩子 self.lifespan.after_start() # 启动后钩子
self.keep_alive() self.keep_alive()
def keep_alive(self): def keep_alive(self):
@ -98,10 +98,9 @@ class LiteyukiBot:
Args: Args:
name: 进程名称, 默认为None, 所有进程 name: 进程名称, 默认为None, 所有进程
Returns: Returns:
""" """
self.loop.create_task(self.lifespan.before_process_shutdown()) # 重启前钩子 self.lifespan.before_process_shutdown() # 重启前钩子
self.loop.create_task(self.lifespan.before_process_shutdown()) # 停止前钩子 self.lifespan.before_process_shutdown() # 停止前钩子
if name is not None: if name is not None:
chan_active = get_channel(f"{name}-active") chan_active = get_channel(f"{name}-active")

View File

@ -11,20 +11,31 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
""" """
from liteyuki.comm.channel import ( from liteyuki.comm.channel import (
Channel, Channel,
chan,
get_channel, get_channel,
set_channel, set_channel,
set_channels, set_channels,
get_channels get_channels,
active_channel,
passive_channel
) )
from liteyuki.comm.event import Event from liteyuki.comm.event import Event
__all__ = [ __all__ = [
"Channel", "Channel",
"chan",
"Event", "Event",
"get_channel", "get_channel",
"set_channel", "set_channel",
"set_channels", "set_channels",
"get_channels" "get_channels",
"active_channel",
"passive_channel"
] ]
from liteyuki.utils import IS_MAIN_PROCESS
# 第一次引用必定为赋值
_ref_count = 0
if not IS_MAIN_PROCESS:
if (active_channel is None or passive_channel is None) and _ref_count > 0:
raise RuntimeError("Error: Channel not initialized in sub process")
_ref_count += 1

View File

@ -17,7 +17,7 @@ from multiprocessing import Pipe
from typing import Any, Optional, Callable, Awaitable, List, TypeAlias from typing import Any, Optional, Callable, Awaitable, List, TypeAlias
from uuid import uuid4 from uuid import uuid4
from liteyuki.utils import is_coroutine_callable, run_coroutine from liteyuki.utils import IS_MAIN_PROCESS, is_coroutine_callable, run_coroutine
SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Any] SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Any]
ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Awaitable[Any]] ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Awaitable[Any]]
@ -27,11 +27,13 @@ SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool]
ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]] ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]]
FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC
_channel: dict[str, "Channel"] = {} _channel: dict[str, "Channel"] = {}
_callback_funcs: dict[str, ON_RECEIVE_FUNC] = {} _callback_funcs: dict[str, ON_RECEIVE_FUNC] = {}
"""子进程可用的主动和被动通道"""
active_channel: Optional["Channel"] = None
passive_channel: Optional["Channel"] = None
class Channel: class Channel:
""" """
@ -40,8 +42,6 @@ class Channel:
""" """
def __init__(self, _id: str): def __init__(self, _id: str):
# self.main_send_conn, self.sub_receive_conn = Pipe()
# self.sub_send_conn, self.main_receive_conn = Pipe()
self.conn_send, self.conn_recv = Pipe() self.conn_send, self.conn_recv = Pipe()
self._closed = False self._closed = False
self._on_main_receive_funcs: list[str] = [] self._on_main_receive_funcs: list[str] = []
@ -102,12 +102,16 @@ class Channel:
async def wrapper(data: Any) -> Any: async def wrapper(data: Any) -> Any:
if filter_func is not None: if filter_func is not None:
if is_coroutine_callable(filter_func): if is_coroutine_callable(filter_func):
if not await filter_func(data): if not (await filter_func(data)):
return return
else: else:
if not filter_func(data): if not filter_func(data):
return return
return await func(data)
if is_coroutine_callable(func):
return await func(data)
else:
return func(data)
function_id = str(uuid4()) function_id = str(uuid4())
_callback_funcs[function_id] = wrapper _callback_funcs[function_id] = wrapper
@ -164,10 +168,6 @@ class Channel:
return self.receive() return self.receive()
"""默认通道实例,可直接从模块导入使用"""
chan = Channel("default")
def set_channel(name: str, channel: Channel): def set_channel(name: str, channel: Channel):
""" """
设置通道实例 设置通道实例
@ -175,6 +175,9 @@ def set_channel(name: str, channel: Channel):
name: 通道名称 name: 通道名称
channel: 通道实例 channel: 通道实例
""" """
if not IS_MAIN_PROCESS:
raise RuntimeError(f"Function {__name__} should only be called in the main process.")
if not isinstance(channel, Channel): if not isinstance(channel, Channel):
raise TypeError(f"channel must be an instance of Channel, {type(channel)} found") raise TypeError(f"channel must be an instance of Channel, {type(channel)} found")
_channel[name] = channel _channel[name] = channel
@ -186,6 +189,9 @@ def set_channels(channels: dict[str, Channel]):
Args: Args:
channels: 通道名称 channels: 通道名称
""" """
if not IS_MAIN_PROCESS:
raise RuntimeError(f"Function {__name__} should only be called in the main process.")
for name, channel in channels.items(): for name, channel in channels.items():
set_channel(name, channel) set_channel(name, channel)
@ -197,6 +203,9 @@ def get_channel(name: str) -> Optional[Channel]:
name: 通道名称 name: 通道名称
Returns: Returns:
""" """
if not IS_MAIN_PROCESS:
raise RuntimeError(f"Function {__name__} should only be called in the main process.")
return _channel.get(name, None) return _channel.get(name, None)
@ -205,4 +214,7 @@ def get_channels() -> dict[str, Channel]:
获取通道实例 获取通道实例
Returns: Returns:
""" """
if not IS_MAIN_PROCESS:
raise RuntimeError(f"Function {__name__} should only be called in the main process.")
return _channel return _channel

View File

@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
"""
共享内存模块类似于redis但是更加轻量级
"""
memory_database = {}
def set_memory(key: str, value: any) -> None:
pass
def get_mem_data(key: str) -> any:
pass

114
liteyuki/comm/storage.py Normal file
View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""
共享内存模块类似于redis但是更加轻量级并且线程安全
"""
import threading
from typing import Any, Optional
from liteyuki.utils import IS_MAIN_PROCESS
from liteyuki.comm.channel import Channel
if IS_MAIN_PROCESS:
_locks = {}
def _get_lock(key):
if IS_MAIN_PROCESS:
if key not in _locks:
_locks[key] = threading.Lock()
return _locks[key]
else:
raise RuntimeError("Cannot get lock in sub process.")
class KeyValueStore:
def __init__(self):
self._store = {}
self.active_chan = Channel(_id="shared_memory-active")
self.passive_chan = Channel(_id="shared_memory-passive")
def set(self, key: str, value: any) -> None:
if IS_MAIN_PROCESS:
lock = _get_lock(key)
with lock:
self._store[key] = value
else:
# 向主进程发送请求拿取
self.passive_chan.send(("set", key, value))
def get(self, key: str, default: Optional[any] = None) -> any:
if IS_MAIN_PROCESS:
lock = _get_lock(key)
with lock:
return self._store.get(key, default)
else:
self.passive_chan.send(("get", key, default))
return self.active_chan.receive()
def delete(self, key: str) -> None:
if IS_MAIN_PROCESS:
lock = _get_lock(key)
with lock:
if key in self._store:
del self._store[key]
del _locks[key]
else:
# 向主进程发送请求删除
self.passive_chan.send(("delete", key))
def get_all(self) -> dict[str, any]:
if IS_MAIN_PROCESS:
return self._store
else:
self.passive_chan.send(("get_all",))
return self.active_chan.receive()
class GlobalKeyValueStore:
_instance = None
_lock = threading.Lock()
@classmethod
def get_instance(cls):
if IS_MAIN_PROCESS:
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = KeyValueStore()
return cls._instance
else:
raise RuntimeError("Cannot get instance in sub process.")
shared_memory: Optional[KeyValueStore] = None
# 全局单例访问点
if IS_MAIN_PROCESS:
shared_memory = GlobalKeyValueStore.get_instance()
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "get")
def on_get(d):
print(shared_memory.get_all())
shared_memory.active_chan.send(shared_memory.get(d[1], d[2]))
print("发送数据:", shared_memory.get(d[1], d[2]))
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "set")
def on_set(d):
shared_memory.set(d[1], d[2])
@shared_memory.passive_chan.on_receive(lambda d: d[0] == "delete")
def on_delete(d):
shared_memory.delete(d[1])
else:
shared_memory = None
_ref_count = 0 # 引用计数
if not IS_MAIN_PROCESS:
if (shared_memory is None) and _ref_count > 1:
raise RuntimeError("Shared memory not initialized.")
_ref_count += 1

View File

@ -16,7 +16,7 @@ import toml
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from liteyuki import logger from liteyuki.log import logger
_SUPPORTED_CONFIG_FORMATS = (".yaml", ".yml", ".json", ".toml") _SUPPORTED_CONFIG_FORMATS = (".yaml", ".yml", ".json", ".toml")

View File

@ -10,18 +10,21 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
""" """
import atexit import atexit
import threading
import signal import signal
import threading
from multiprocessing import Process from multiprocessing import Process
from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, TypeAlias from typing import Any, Callable, TYPE_CHECKING, TypeAlias
from liteyuki.comm import Channel, get_channel, set_channels from liteyuki.comm import Channel, get_channel, set_channels
from liteyuki.comm.storage import shared_memory
from liteyuki.log import logger from liteyuki.log import logger
from liteyuki.utils import IS_MAIN_PROCESS
TARGET_FUNC: TypeAlias = Callable[[Channel, Channel, ...], Any] TARGET_FUNC: TypeAlias = Callable[..., Any]
if TYPE_CHECKING: if TYPE_CHECKING:
from liteyuki.bot import LiteyukiBot from liteyuki.bot import LiteyukiBot
from liteyuki.comm.storage import KeyValueStore
TIMEOUT = 10 TIMEOUT = 10
@ -30,9 +33,29 @@ __all__ = [
] ]
# Update the delivery_channel_wrapper function to return the top-level wrapper
def _delivery_channel_wrapper(func: TARGET_FUNC, chan_active: Channel, chan_passive: Channel, sm: "KeyValueStore", *args, **kwargs):
"""
子进程入口函数
"""
# 给子进程设置通道
if IS_MAIN_PROCESS:
raise RuntimeError("Function should only be called in a sub process.")
from liteyuki.comm import channel
channel.active_channel = chan_active
channel.passive_channel = chan_passive
# 给子进程创建共享内存实例
from liteyuki.comm import storage
storage.shared_memory = sm
func(*args, **kwargs)
class ProcessManager: class ProcessManager:
""" """
在主进程中被调用 进程管理器
""" """
def __init__(self, bot: "LiteyukiBot"): def __init__(self, bot: "LiteyukiBot"):
@ -61,7 +84,6 @@ class ProcessManager:
process = Process(target=self.targets[name][0], args=self.targets[name][1], process = Process(target=self.targets[name][0], args=self.targets[name][1],
kwargs=self.targets[name][2]) kwargs=self.targets[name][2])
self.processes[name] = process self.processes[name] = process
process.start() process.start()
# 启动进程并监听信号 # 启动进程并监听信号
@ -114,9 +136,9 @@ class ProcessManager:
kwargs = {} kwargs = {}
chan_active = Channel(_id=f"{name}-active") chan_active = Channel(_id=f"{name}-active")
chan_passive = Channel(_id=f"{name}-passive") chan_passive = Channel(_id=f"{name}-passive")
kwargs["chan_active"] = chan_active
kwargs["chan_passive"] = chan_passive self.targets[name] = (_delivery_channel_wrapper, (target, chan_active, chan_passive, shared_memory, *args), kwargs)
self.targets[name] = (target, args, kwargs) # 主进程通道
set_channels( set_channels(
{ {
f"{name}-active" : chan_active, f"{name}-active" : chan_active,
@ -124,7 +146,7 @@ class ProcessManager:
} }
) )
def join(self): def join_all(self):
for name, process in self.targets: for name, process in self.targets:
process.join() process.join()

View File

@ -8,7 +8,7 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@File : __init__.py.py @File : __init__.py.py
@Software: PyCharm @Software: PyCharm
""" """
from liteyuki import get_config_with_compat, load_plugin from liteyuki import get_config, load_plugin, get_bot
from liteyuki.plugin import PluginMetadata, load_plugins from liteyuki.plugin import PluginMetadata, load_plugins
__plugin_meta__ = PluginMetadata( __plugin_meta__ = PluginMetadata(
@ -18,9 +18,17 @@ __plugin_meta__ = PluginMetadata(
description="插件加载器,用于加载轻雪原生插件" description="插件加载器,用于加载轻雪原生插件"
) )
load_plugins("src/liteyuki_plugins")
for plugin in get_config_with_compat("liteyuki.plugins", ("plugins", ), []):
load_plugin(plugin)
for plugin_dir in get_config_with_compat("liteyuki.plugin_dirs", ("plugins_dirs", ), []): def default_plugins_loader():
load_plugins(plugin_dir) """
默认插件加载器应在初始化时调用
"""
load_plugins("src/liteyuki_plugins")
for plugin in get_config("liteyuki.plugins", []):
load_plugin(plugin)
for plugin_dir in get_config("liteyuki.plugin_dirs", []):
load_plugins(plugin_dir)
default_plugins_loader()

View File

@ -12,6 +12,7 @@ from liteyuki.log import logger
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess" IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_coroutine_callable(call: Callable[..., Any]) -> bool:
""" """
判断是否为协程可调用对象 判断是否为协程可调用对象
@ -87,5 +88,6 @@ def async_wrapper(func: Callable[..., Any]) -> Callable[..., Coroutine]:
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
wrapper.__signature__ = inspect.signature(func) wrapper.__signature__ = inspect.signature(func)
return wrapper return wrapper

View File

@ -8,7 +8,7 @@ from src.utils.base.data_manager import InstalledPlugin, plugin_db
from src.utils.base.resource import load_resources from src.utils.base.resource import load_resources
from src.utils.message.tools import check_for_package from src.utils.message.tools import check_for_package
from liteyuki import get_bot, chan from liteyuki import get_bot
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler

View File

@ -8,12 +8,9 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@File : __init__.py.py @File : __init__.py.py
@Software: PyCharm @Software: PyCharm
""" """
import asyncio
import time
import nonebot import nonebot
from liteyuki.comm import Channel, set_channel
from liteyuki.core import IS_MAIN_PROCESS from liteyuki.core import IS_MAIN_PROCESS
from liteyuki.plugin import PluginMetadata from liteyuki.plugin import PluginMetadata
from .nb_utils import adapter_manager, driver_manager from .nb_utils import adapter_manager, driver_manager
@ -23,20 +20,15 @@ __plugin_meta__ = PluginMetadata(
) )
def nb_run(chan_active: "Channel", chan_passive: "Channel", *args, **kwargs): def nb_run(*args, **kwargs):
""" """
初始化NoneBot并运行在子进程 初始化NoneBot并运行在子进程
Args: Args:
chan_active:
chan_passive:
**kwargs: **kwargs:
Returns: Returns:
""" """
# 给子进程传递通道对象 # 给子进程传递通道对象
set_channel("nonebot-active", chan_active)
set_channel("nonebot-passive", chan_passive)
kwargs.update(kwargs.get("nonebot", {})) # nonebot配置优先 kwargs.update(kwargs.get("nonebot", {})) # nonebot配置优先
nonebot.init(**kwargs) nonebot.init(**kwargs)

View File

@ -7,4 +7,5 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Email : snowykami@outlook.com @Email : snowykami@outlook.com
@File : __init__.py @File : __init__.py
@Software: PyCharm @Software: PyCharm
""" """
from .after_start import *

View File

@ -11,13 +11,15 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
import time import time
from liteyuki import get_bot from liteyuki import get_bot
from liteyuki.comm.storage import shared_memory
liteyuki = get_bot() liteyuki = get_bot()
@liteyuki.on_after_start @liteyuki.on_before_start
def save_startup_timestamp(): def save_startup_timestamp():
""" """
储存启动的时间戳 储存启动的时间戳
""" """
startup_timestamp = time.time() startup_timestamp = time.time()
shared_memory.set("startup_timestamp", startup_timestamp)

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/8/16 下午8:30
@Author : snowykami
@Email : snowykami@outlook.com
@File : ts_ly_comm.py
@Software: PyCharm
"""
from nonebot.plugin import PluginMetadata
from liteyuki.comm.storage import shared_memory
__plugin_meta__ = PluginMetadata(
name="轻雪通信测试",
description="用于测试轻雪插件通信",
usage="不面向用户",
)
print("共享内存数据:", shared_memory.get("startup_timestamp", default=None))