添加liteyuki.channel.Channel通道,可安全跨进程通信

This commit is contained in:
远野千束 2024-07-27 10:12:45 +08:00
parent 13692228c6
commit 39a9c39924
15 changed files with 436 additions and 139 deletions

View File

@ -2,15 +2,20 @@ from liteyuki.bot import (
LiteyukiBot, LiteyukiBot,
get_bot get_bot
) )
from liteyuki.comm import (
Channel,
chan,
Event
)
from liteyuki.plugin import ( from liteyuki.plugin import (
load_plugin, load_plugin,
load_plugins load_plugins
) )
# def get_bot_instance() -> LiteyukiBot | None: from liteyuki.log import (
# """ logger,
# 获取轻雪实例 init_log
# Returns:
# LiteyukiBot: 当前的轻雪实例 )
# """
# return _BOT_INSTANCE

View File

@ -1,20 +1,25 @@
import asyncio import asyncio
import multiprocessing import multiprocessing
import time
from typing import Any, Coroutine, Optional from typing import Any, Coroutine, Optional
import nonebot import nonebot
import liteyuki import liteyuki
from liteyuki.plugin.load import load_plugin, load_plugins from liteyuki.plugin.load import load_plugin, load_plugins
from liteyuki.utils import run_coroutine
from liteyuki.log import logger, init_log
from src.utils import ( from src.utils import (
adapter_manager, adapter_manager,
driver_manager, driver_manager,
) )
from src.utils.base.log import logger
from liteyuki.bot.lifespan import ( from liteyuki.bot.lifespan import (
Lifespan, Lifespan,
LIFESPAN_FUNC, LIFESPAN_FUNC,
) )
from liteyuki.core.spawn_process import nb_run, ProcessingManager from liteyuki.core.spawn_process import nb_run, ProcessingManager
__all__ = [ __all__ = [
@ -22,21 +27,18 @@ __all__ = [
"get_bot" "get_bot"
] ]
_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess" """是否为主进程"""
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
class LiteyukiBot: class LiteyukiBot:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
global _BOT_INSTANCE global _BOT_INSTANCE
_BOT_INSTANCE = self # 引用 _BOT_INSTANCE = self # 引用
self.running = False if not IS_MAIN_PROCESS:
self.config: dict[str, Any] = kwargs self.config: dict[str, Any] = kwargs
self.lifespan: Lifespan = Lifespan() self.lifespan: Lifespan = Lifespan()
self.init(**self.config) # 初始化 self.init(**self.config) # 初始化
if not _MAIN_PROCESS:
pass
else: else:
print("\033[34m" + r""" print("\033[34m" + r"""
__ ______ ________ ________ __ __ __ __ __ __ ______ __ ______ ________ ________ __ __ __ __ __ __ ______
@ -51,96 +53,57 @@ $$$$$$$$/ $$$$$$/ $$/ $$$$$$$$/ $$/ $$$$$$/ $$/ $$/ $$$$$$/
""" + "\033[0m") """ + "\033[0m")
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
if _MAIN_PROCESS: if IS_MAIN_PROCESS:
load_plugins("liteyuki/plugins")
asyncio.run(self.lifespan.before_start())
self._run_nb_in_spawn_process(*args, **kwargs) self._run_nb_in_spawn_process(*args, **kwargs)
else: else:
# 子进程启动 # 子进程启动
load_plugins("liteyuki/plugins") # 加载轻雪插件
driver_manager.init(config=self.config) driver_manager.init(config=self.config)
adapter_manager.init(self.config) adapter_manager.init(self.config)
adapter_manager.register() adapter_manager.register()
nonebot.load_plugin("src.liteyuki_main") nonebot.load_plugin("src.liteyuki_main")
run_coroutine(self.lifespan.after_start()) # 启动前
def _run_nb_in_spawn_process(self, *args, **kwargs): def _run_nb_in_spawn_process(self, *args, **kwargs):
""" """
在新的进程中运行nonebot.run方法 在新的进程中运行nonebot.run方法该函数在主进程中被调用
Args: Args:
*args: *args:
**kwargs: **kwargs:
Returns: Returns:
""" """
if IS_MAIN_PROCESS:
timeout_limit: int = 20 timeout_limit: int = 20
should_exit = False should_exit = False
while not should_exit:
ctx = multiprocessing.get_context("spawn")
event = ctx.Event()
ProcessingManager.event = event
process = ctx.Process(
target=nb_run,
args=(event,) + args,
kwargs=kwargs,
)
process.start() # 启动进程
asyncio.run(self.lifespan.after_start())
while not should_exit: while not should_exit:
if ProcessingManager.event.wait(1): ctx = multiprocessing.get_context("spawn")
logger.info("Receive reboot event") event = ctx.Event()
process.terminate() ProcessingManager.event = event
process.join(timeout_limit) process = ctx.Process(
if process.is_alive(): target=nb_run,
logger.warning( args=(event,) + args,
f"Process {process.pid} is still alive after {timeout_limit} seconds, force kill it." kwargs=kwargs,
) )
process.kill() process.start() # 启动进程
break
elif process.is_alive():
continue
else:
should_exit = True
@staticmethod while not should_exit:
def _run_coroutine(*coro: Coroutine): if ProcessingManager.event.wait(1):
""" logger.info("Receive reboot event")
运行协程 process.terminate()
Args: process.join(timeout_limit)
coro: if process.is_alive():
logger.warning(
Returns: f"Process {process.pid} is still alive after {timeout_limit} seconds, force kill it."
)
""" process.kill()
# 检测是否有现有的事件循环 break
new_loop = False elif process.is_alive():
try: liteyuki.chan.send("轻雪进程正常运行", "sub")
loop = asyncio.get_event_loop() continue
except RuntimeError: else:
new_loop = True should_exit = True
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if new_loop:
for c in coro:
loop.run_until_complete(c)
loop.close()
else:
for c in coro:
loop.create_task(c)
@property
def status(self) -> int:
"""
获取轻雪状态
Returns:
int: 0:未启动 1:运行中
"""
return 1 if self.running else 0
def restart(self): def restart(self):
""" """
@ -149,14 +112,12 @@ $$$$$$$$/ $$$$$$/ $$/ $$$$$$$$/ $$/ $$$$$$/ $$/ $$/ $$$$$$/
""" """
logger.info("Stopping LiteyukiBot...") logger.info("Stopping LiteyukiBot...")
logger.debug("Running before_restart functions...") logger.debug("Running before_restart functions...")
self._run_coroutine(self.lifespan.before_restart()) run_coroutine(self.lifespan.before_restart())
logger.debug("Running before_shutdown functions...") logger.debug("Running before_shutdown functions...")
self._run_coroutine(self.lifespan.before_shutdown()) run_coroutine(self.lifespan.before_shutdown())
ProcessingManager.restart() ProcessingManager.restart()
self.running = False
def init(self, *args, **kwargs): def init(self, *args, **kwargs):
""" """
@ -166,13 +127,12 @@ $$$$$$$$/ $$$$$$/ $$/ $$$$$$$$/ $$/ $$$$$$/ $$/ $$/ $$$$$$/
""" """
self.init_config() self.init_config()
self.init_logger() self.init_logger()
if not _MAIN_PROCESS: if not IS_MAIN_PROCESS:
nonebot.init(**kwargs) nonebot.init(**kwargs)
asyncio.run(self.lifespan.after_nonebot_init()) asyncio.run(self.lifespan.after_nonebot_init())
def init_logger(self): def init_logger(self):
from src.utils.base.log import init_log init_log(config=self.config)
init_log()
def init_config(self): def init_config(self):
pass pass

19
liteyuki/comm/__init__.py Normal file
View File

@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/26 下午10:36
@Author : snowykami
@Email : snowykami@outlook.com
@File : __init__.py
@Software: PyCharm
该模块用于轻雪主进程和Nonebot子进程之间的通信
"""
from liteyuki.comm.channel import Channel, chan
from liteyuki.comm.event import Event
__all__ = [
"Channel",
"chan",
"Event",
]

179
liteyuki/comm/channel.py Normal file
View File

@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/26 下午11:21
@Author : snowykami
@Email : snowykami@outlook.com
@File : channel.py
@Software: PyCharm
本模块定义了一个通用的通道类用于进程间通信
"""
import threading
from multiprocessing import Queue
from queue import Empty, Full
from typing import Any, Awaitable, Callable, List, Optional, TypeAlias
from nonebot import logger
from liteyuki.utils import is_coroutine_callable, run_coroutine
SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Any]
ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Awaitable[Any]]
ON_RECEIVE_FUNC: TypeAlias = SYNC_ON_RECEIVE_FUNC | ASYNC_ON_RECEIVE_FUNC
SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool]
ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]]
FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC
class Channel:
def __init__(self, buffer_size: int = 0):
self._queue = Queue(buffer_size)
self._closed = False
self._on_receive_funcs: List[ON_RECEIVE_FUNC] = []
self._on_receive_funcs_with_receiver: dict[str, List[ON_RECEIVE_FUNC]] = {}
self._receiving_thread = threading.Thread(target=self._start_receiver, daemon=True)
self._receiving_thread.start()
def send(
self,
data: Any,
receiver: Optional[str] = None,
block: bool = True,
timeout: Optional[float] = None
):
"""
发送数据
Args:
data: 数据
receiver: 接收者如果为None则广播
block: 是否阻塞
timeout: 超时时间
Returns:
"""
print(f"send {data} -> {receiver}")
if self._closed:
raise RuntimeError("Cannot send to a closed channel")
try:
self._queue.put((data, receiver), block, timeout)
except Full:
logger.warning("Channel buffer is full, send operation is blocked")
def receive(
self,
receiver: str = None,
block: bool = True,
timeout: Optional[float] = None
) -> Any:
"""
接收数据
Args:
receiver: 接收者如果为None则接收任意数据
block: 是否阻塞
timeout: 超时时间
Returns:
"""
if self._closed:
raise RuntimeError("Cannot receive from a closed channel")
try:
while True:
data, data_receiver = self._queue.get(block, timeout)
if receiver is None or receiver == data_receiver:
return data
except Empty:
if not block:
return None
raise
def close(self):
"""
关闭通道
Returns:
"""
self._closed = True
self._queue.close()
while not self._queue.empty():
self._queue.get()
def on_receive(
self,
filter_func: Optional[FILTER_FUNC] = None,
receiver: Optional[str] = None,
) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]:
"""
接收数据并执行函数
Args:
filter_func: 过滤函数为None则不过滤
receiver: 接收者, 为None则接收任意数据
Returns:
装饰器装饰一个函数在接收到数据后执行
"""
def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC:
async def wrapper(data: Any) -> Any:
if filter_func is not None:
if is_coroutine_callable(filter_func):
if not await filter_func(data):
return
else:
if not filter_func(data):
return
return await func(data)
if receiver is None:
self._on_receive_funcs.append(wrapper)
else:
if receiver not in self._on_receive_funcs_with_receiver:
self._on_receive_funcs_with_receiver[receiver] = []
self._on_receive_funcs_with_receiver[receiver].append(wrapper)
return func
return decorator
def _start_receiver(self):
"""
使用多线程启动接收循环在通道实例化时自动启动
Returns:
"""
while True:
data, receiver = self._queue.get(block=True, timeout=None)
self._run_on_receive_funcs(data, receiver)
def _run_on_receive_funcs(self, data: Any, receiver: Optional[str] = None):
"""
运行接收函数
Args:
data: 数据
Returns:
"""
if receiver is None:
for func in self._on_receive_funcs:
if is_coroutine_callable(func):
run_coroutine(func(data))
else:
func(data)
else:
for func in self._on_receive_funcs_with_receiver.get(receiver, []):
if is_coroutine_callable(func):
run_coroutine(func(data))
else:
func(data)
def __iter__(self):
return self
def __next__(self, timeout: Optional[float] = None) -> Any:
return self.receive(block=True, timeout=timeout)
"""默认通道实例,可直接从模块导入使用"""
chan = Channel()

21
liteyuki/comm/event.py Normal file
View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/26 下午10:47
@Author : snowykami
@Email : snowykami@outlook.com
@File : event.py
@Software: PyCharm
"""
from typing import Any
class Event:
"""
事件类
"""
def __init__(self, name: str, data: dict[str, Any]):
self.name = name
self.data = data

84
liteyuki/log.py Normal file
View File

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/27 上午9:12
@Author : snowykami
@Email : snowykami@outlook.com
@File : log.py
@Software: PyCharm
"""
import sys
import loguru
from typing import TYPE_CHECKING
logger = loguru.logger
if TYPE_CHECKING:
# avoid sphinx autodoc resolve annotation failed
# because loguru module do not have `Logger` class actually
from loguru import Record
def default_filter(record: "Record"):
"""默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。"""
log_level = record["extra"].get("nonebot_log_level", "INFO")
levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
return record["level"].no >= levelno
# DEBUG日志格式
debug_format: str = (
"<c>{time:YYYY-MM-DD HH:mm:ss}</c> "
"<lvl>[{level.icon}]</lvl> "
"<c><{name}.{module}.{function}:{line}></c> "
"{message}"
)
# 默认日志格式
default_format: str = (
"<c>{time:MM-DD HH:mm:ss}</c> "
"<lvl>[{level.icon}]</lvl> "
"<c><{name}></c> "
"{message}"
)
def get_format(level: str) -> str:
if level == "DEBUG":
return debug_format
else:
return default_format
logger = loguru.logger.bind()
def init_log(config: dict):
"""
在语言加载完成后执行
Returns:
"""
global logger
logger.remove()
logger.add(
sys.stdout,
level=0,
diagnose=False,
filter=default_filter,
format=get_format(config.get("log_level", "INFO")),
)
show_icon = config.get("log_icon", True)
# debug = lang.get("log.debug", default="==DEBUG")
# info = lang.get("log.info", default="===INFO")
# success = lang.get("log.success", default="SUCCESS")
# warning = lang.get("log.warning", default="WARNING")
# error = lang.get("log.error", default="==ERROR")
#
# logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}{debug}")
# logger.level("INFO", color="<normal>", icon=f"{'' if show_icon else ''}{info}")
# logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}{success}")
# logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}{warning}")
# logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}{error}")

View File

@ -1,8 +1,11 @@
import asyncio
import multiprocessing import multiprocessing
import time import time
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from liteyuki.plugin import PluginMetadata from liteyuki.plugin import PluginMetadata
from liteyuki import get_bot from liteyuki import get_bot, chan
__plugin_metadata__ = PluginMetadata( __plugin_metadata__ = PluginMetadata(
name="plugin_loader", name="plugin_loader",
@ -13,6 +16,7 @@ __plugin_metadata__ = PluginMetadata(
) )
from src.utils import TempConfig, common_db from src.utils import TempConfig, common_db
liteyuki = get_bot() liteyuki = get_bot()
@ -31,6 +35,19 @@ def _():
print("轻雪启动中") print("轻雪启动中")
@liteyuki.on_after_start
async def _():
print("轻雪启动完成")
chan.send("轻雪启动完成")
@liteyuki.on_after_nonebot_init @liteyuki.on_after_nonebot_init
async def _(): async def _():
print("NoneBot初始化完成") print("NoneBot初始化完成")
@chan.on_receive(receiver="main")
async def _(data):
print("收到消息", data)
await asyncio.sleep(5)

View File

@ -2,9 +2,11 @@
""" """
一些常用的工具类部分来源于 nonebot 并遵循其许可进行修改 一些常用的工具类部分来源于 nonebot 并遵循其许可进行修改
""" """
import asyncio
import inspect import inspect
import threading
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable, Coroutine
def is_coroutine_callable(call: Callable[..., Any]) -> bool: def is_coroutine_callable(call: Callable[..., Any]) -> bool:
@ -23,6 +25,39 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool:
return inspect.iscoroutinefunction(func_) return inspect.iscoroutinefunction(func_)
def run_coroutine(*coro: Coroutine):
"""
运行协程
Args:
coro:
Returns:
"""
# 检测是否有现有的事件循环
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建任务
for c in coro:
asyncio.ensure_future(c)
else:
# 如果事件循环未运行,运行直到完成
for c in coro:
loop.run_until_complete(c)
except RuntimeError:
# 如果没有找到事件循环,创建一个新的
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*coro))
loop.close()
except Exception as e:
# 捕获其他异常,防止协程被重复等待
print(f"Exception occurred: {e}")
def path_to_module_name(path: Path) -> str: def path_to_module_name(path: Path) -> str:
""" """
转换路径为模块名 转换路径为模块名

View File

@ -88,7 +88,7 @@ async def _(matcher: Matcher, bot: T_Bot, event: T_MessageEvent):
"reload_bot_id" : bot.self_id, "reload_bot_id" : bot.self_id,
"reload_session_type": event_utils.get_message_type(event), "reload_session_type": event_utils.get_message_type(event),
"reload_session_id" : (event.group_id if event.message_type == "group" else event.user_id) if not isinstance(event, "reload_session_id" : (event.group_id if event.message_type == "group" else event.user_id) if not isinstance(event,
satori.event.Event) else event.channel.id, satori.event.Event) else event.chan.id,
"delta_time" : 0 "delta_time" : 0
} }
) )

View File

@ -1,3 +1,5 @@
import asyncio
import nonebot.plugin import nonebot.plugin
from nonebot import get_driver from nonebot import get_driver
from src.utils import init_log from src.utils import init_log
@ -6,7 +8,9 @@ from src.utils.base.data_manager import InstalledPlugin, plugin_db
from src.utils.base.resource import load_resources from src.utils.base.resource import load_resources
from src.utils.message.tools import check_for_package from src.utils.message.tools import check_for_package
from liteyuki import get_bot from liteyuki import get_bot, chan
from nonebot_plugin_apscheduler import scheduler
load_resources() load_resources()
init_log() init_log()
@ -32,33 +36,3 @@ async def load_plugins():
nonebot.plugin.load_plugins("plugins") nonebot.plugin.load_plugins("plugins")
else: else:
nonebot.logger.info("Safe mode is on, no plugin loaded.") nonebot.logger.info("Safe mode is on, no plugin loaded.")
@liteyuki_bot.on_before_start
async def _():
print("启动前")
@liteyuki_bot.on_after_start
async def _():
print("启动后")
@liteyuki_bot.on_before_shutdown
async def _():
print("停止前")
@liteyuki_bot.on_after_shutdown
async def _():
print("停止后")
@liteyuki_bot.on_before_restart
async def _():
print("重启前")
@liteyuki_bot.on_after_restart
async def _():
print("重启后")

View File

@ -7,6 +7,9 @@ from pydantic import BaseModel
from src.utils.base.config import get_config from src.utils.base.config import get_config
from src.utils.io import fetch from src.utils.io import fetch
NONEBOT_PLUGIN_STORE_URL: str = "https://registry.nonebot.dev/plugins.json" # NoneBot商店地址
LITEYUKI_PLUGIN_STORE_URL: str = "https://bot.liteyuki.icu/assets/plugins.json" # 轻雪商店地址
class Session: class Session:
def __init__(self, session_type: str, session_id: int | str): def __init__(self, session_type: str, session_id: int | str):

View File

@ -1,12 +1,12 @@
import inspect
import os import os
import pickle import pickle
import sqlite3 import sqlite3
from types import NoneType from types import NoneType
from typing import Any, Callable from typing import Any, Callable
from packaging.version import parse
import inspect from nonebot import logger
import nonebot from nonebot.compat import PYDANTIC_V2
import pydantic
from pydantic import BaseModel from pydantic import BaseModel
@ -15,10 +15,10 @@ class LiteModel(BaseModel):
id: int = None id: int = None
def dump(self, *args, **kwargs): def dump(self, *args, **kwargs):
if parse(pydantic.__version__) < parse("2.0.0"): if PYDANTIC_V2:
return self.dict(*args, **kwargs)
else:
return self.model_dump(*args, **kwargs) return self.model_dump(*args, **kwargs)
else:
return self.dict(*args, **kwargs)
class Database: class Database:
@ -60,7 +60,7 @@ class Database:
""" """
table_name = model.TABLE_NAME table_name = model.TABLE_NAME
model_type = type(model) model_type = type(model)
nonebot.logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}") logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}")
if not table_name: if not table_name:
raise ValueError(f"数据模型{model_type.__name__}未提供表名") raise ValueError(f"数据模型{model_type.__name__}未提供表名")
@ -88,7 +88,7 @@ class Database:
""" """
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()] table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
for model in args: for model in args:
nonebot.logger.debug(f"Upserting {model}") logger.debug(f"Upserting {model}")
if not model.TABLE_NAME: if not model.TABLE_NAME:
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名") raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
elif model.TABLE_NAME not in table_list: elif model.TABLE_NAME not in table_list:
@ -206,7 +206,7 @@ class Database:
""" """
table_name = model.TABLE_NAME table_name = model.TABLE_NAME
nonebot.logger.debug(f"Deleting {model} WHERE {condition} {args}") logger.debug(f"Deleting {model} WHERE {condition} {args}")
if not table_name: if not table_name:
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名") raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
if model.id is not None: if model.id is not None: