添加liteyuki.channel.Channel通道,可安全跨进程通信

This commit is contained in:
snowy 2024-07-27 10:12:45 +08:00
parent 13692228c6
commit 39a9c39924
15 changed files with 436 additions and 139 deletions

View File

@ -2,15 +2,20 @@ from liteyuki.bot import (
LiteyukiBot,
get_bot
)
from liteyuki.comm import (
Channel,
chan,
Event
)
from liteyuki.plugin import (
load_plugin,
load_plugins
)
# def get_bot_instance() -> LiteyukiBot | None:
# """
# 获取轻雪实例
# Returns:
# LiteyukiBot: 当前的轻雪实例
# """
# return _BOT_INSTANCE
from liteyuki.log import (
logger,
init_log
)

View File

@ -1,20 +1,25 @@
import asyncio
import multiprocessing
import time
from typing import Any, Coroutine, Optional
import nonebot
import liteyuki
from liteyuki.plugin.load import load_plugin, load_plugins
from liteyuki.utils import run_coroutine
from liteyuki.log import logger, init_log
from src.utils import (
adapter_manager,
driver_manager,
)
from src.utils.base.log import logger
from liteyuki.bot.lifespan import (
Lifespan,
LIFESPAN_FUNC,
)
from liteyuki.core.spawn_process import nb_run, ProcessingManager
__all__ = [
@ -22,21 +27,18 @@ __all__ = [
"get_bot"
]
_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
"""是否为主进程"""
IS_MAIN_PROCESS = multiprocessing.current_process().name == "MainProcess"
class LiteyukiBot:
def __init__(self, *args, **kwargs):
global _BOT_INSTANCE
_BOT_INSTANCE = self # 引用
self.running = False
self.config: dict[str, Any] = kwargs
self.lifespan: Lifespan = Lifespan()
self.init(**self.config) # 初始化
if not _MAIN_PROCESS:
pass
if not IS_MAIN_PROCESS:
self.config: dict[str, Any] = kwargs
self.lifespan: Lifespan = Lifespan()
self.init(**self.config) # 初始化
else:
print("\033[34m" + r"""
__ ______ ________ ________ __ __ __ __ __ __ ______
@ -51,96 +53,57 @@ $$$$$$$$/ $$$$$$/ $$/ $$$$$$$$/ $$/ $$$$$$/ $$/ $$/ $$$$$$/
""" + "\033[0m")
def run(self, *args, **kwargs):
if _MAIN_PROCESS:
load_plugins("liteyuki/plugins")
asyncio.run(self.lifespan.before_start())
if IS_MAIN_PROCESS:
self._run_nb_in_spawn_process(*args, **kwargs)
else:
# 子进程启动
load_plugins("liteyuki/plugins") # 加载轻雪插件
driver_manager.init(config=self.config)
adapter_manager.init(self.config)
adapter_manager.register()
nonebot.load_plugin("src.liteyuki_main")
run_coroutine(self.lifespan.after_start()) # 启动前
def _run_nb_in_spawn_process(self, *args, **kwargs):
"""
在新的进程中运行nonebot.run方法
在新的进程中运行nonebot.run方法该函数在主进程中被调用
Args:
*args:
**kwargs:
Returns:
"""
timeout_limit: int = 20
should_exit = False
while not should_exit:
ctx = multiprocessing.get_context("spawn")
event = ctx.Event()
ProcessingManager.event = event
process = ctx.Process(
target=nb_run,
args=(event,) + args,
kwargs=kwargs,
)
process.start() # 启动进程
asyncio.run(self.lifespan.after_start())
if IS_MAIN_PROCESS:
timeout_limit: int = 20
should_exit = False
while not should_exit:
if ProcessingManager.event.wait(1):
logger.info("Receive reboot event")
process.terminate()
process.join(timeout_limit)
if process.is_alive():
logger.warning(
f"Process {process.pid} is still alive after {timeout_limit} seconds, force kill it."
)
process.kill()
break
elif process.is_alive():
continue
else:
should_exit = True
ctx = multiprocessing.get_context("spawn")
event = ctx.Event()
ProcessingManager.event = event
process = ctx.Process(
target=nb_run,
args=(event,) + args,
kwargs=kwargs,
)
process.start() # 启动进程
@staticmethod
def _run_coroutine(*coro: Coroutine):
"""
运行协程
Args:
coro:
Returns:
"""
# 检测是否有现有的事件循环
new_loop = False
try:
loop = asyncio.get_event_loop()
except RuntimeError:
new_loop = True
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if new_loop:
for c in coro:
loop.run_until_complete(c)
loop.close()
else:
for c in coro:
loop.create_task(c)
@property
def status(self) -> int:
"""
获取轻雪状态
Returns:
int: 0:未启动 1:运行中
"""
return 1 if self.running else 0
while not should_exit:
if ProcessingManager.event.wait(1):
logger.info("Receive reboot event")
process.terminate()
process.join(timeout_limit)
if process.is_alive():
logger.warning(
f"Process {process.pid} is still alive after {timeout_limit} seconds, force kill it."
)
process.kill()
break
elif process.is_alive():
liteyuki.chan.send("轻雪进程正常运行", "sub")
continue
else:
should_exit = True
def restart(self):
"""
@ -149,14 +112,12 @@ $$$$$$$$/ $$$$$$/ $$/ $$$$$$$$/ $$/ $$$$$$/ $$/ $$/ $$$$$$/
"""
logger.info("Stopping LiteyukiBot...")
logger.debug("Running before_restart functions...")
self._run_coroutine(self.lifespan.before_restart())
run_coroutine(self.lifespan.before_restart())
logger.debug("Running before_shutdown functions...")
self._run_coroutine(self.lifespan.before_shutdown())
run_coroutine(self.lifespan.before_shutdown())
ProcessingManager.restart()
self.running = False
def init(self, *args, **kwargs):
"""
@ -166,13 +127,12 @@ $$$$$$$$/ $$$$$$/ $$/ $$$$$$$$/ $$/ $$$$$$/ $$/ $$/ $$$$$$/
"""
self.init_config()
self.init_logger()
if not _MAIN_PROCESS:
if not IS_MAIN_PROCESS:
nonebot.init(**kwargs)
asyncio.run(self.lifespan.after_nonebot_init())
def init_logger(self):
from src.utils.base.log import init_log
init_log()
init_log(config=self.config)
def init_config(self):
pass

19
liteyuki/comm/__init__.py Normal file
View File

@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/26 下午10:36
@Author : snowykami
@Email : snowykami@outlook.com
@File : __init__.py
@Software: PyCharm
该模块用于轻雪主进程和Nonebot子进程之间的通信
"""
from liteyuki.comm.channel import Channel, chan
from liteyuki.comm.event import Event
__all__ = [
"Channel",
"chan",
"Event",
]

179
liteyuki/comm/channel.py Normal file
View File

@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/26 下午11:21
@Author : snowykami
@Email : snowykami@outlook.com
@File : channel.py
@Software: PyCharm
本模块定义了一个通用的通道类用于进程间通信
"""
import threading
from multiprocessing import Queue
from queue import Empty, Full
from typing import Any, Awaitable, Callable, List, Optional, TypeAlias
from nonebot import logger
from liteyuki.utils import is_coroutine_callable, run_coroutine
SYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Any]
ASYNC_ON_RECEIVE_FUNC: TypeAlias = Callable[[Any], Awaitable[Any]]
ON_RECEIVE_FUNC: TypeAlias = SYNC_ON_RECEIVE_FUNC | ASYNC_ON_RECEIVE_FUNC
SYNC_FILTER_FUNC: TypeAlias = Callable[[Any], bool]
ASYNC_FILTER_FUNC: TypeAlias = Callable[[Any], Awaitable[bool]]
FILTER_FUNC: TypeAlias = SYNC_FILTER_FUNC | ASYNC_FILTER_FUNC
class Channel:
def __init__(self, buffer_size: int = 0):
self._queue = Queue(buffer_size)
self._closed = False
self._on_receive_funcs: List[ON_RECEIVE_FUNC] = []
self._on_receive_funcs_with_receiver: dict[str, List[ON_RECEIVE_FUNC]] = {}
self._receiving_thread = threading.Thread(target=self._start_receiver, daemon=True)
self._receiving_thread.start()
def send(
self,
data: Any,
receiver: Optional[str] = None,
block: bool = True,
timeout: Optional[float] = None
):
"""
发送数据
Args:
data: 数据
receiver: 接收者如果为None则广播
block: 是否阻塞
timeout: 超时时间
Returns:
"""
print(f"send {data} -> {receiver}")
if self._closed:
raise RuntimeError("Cannot send to a closed channel")
try:
self._queue.put((data, receiver), block, timeout)
except Full:
logger.warning("Channel buffer is full, send operation is blocked")
def receive(
self,
receiver: str = None,
block: bool = True,
timeout: Optional[float] = None
) -> Any:
"""
接收数据
Args:
receiver: 接收者如果为None则接收任意数据
block: 是否阻塞
timeout: 超时时间
Returns:
"""
if self._closed:
raise RuntimeError("Cannot receive from a closed channel")
try:
while True:
data, data_receiver = self._queue.get(block, timeout)
if receiver is None or receiver == data_receiver:
return data
except Empty:
if not block:
return None
raise
def close(self):
"""
关闭通道
Returns:
"""
self._closed = True
self._queue.close()
while not self._queue.empty():
self._queue.get()
def on_receive(
self,
filter_func: Optional[FILTER_FUNC] = None,
receiver: Optional[str] = None,
) -> Callable[[ON_RECEIVE_FUNC], ON_RECEIVE_FUNC]:
"""
接收数据并执行函数
Args:
filter_func: 过滤函数为None则不过滤
receiver: 接收者, 为None则接收任意数据
Returns:
装饰器装饰一个函数在接收到数据后执行
"""
def decorator(func: ON_RECEIVE_FUNC) -> ON_RECEIVE_FUNC:
async def wrapper(data: Any) -> Any:
if filter_func is not None:
if is_coroutine_callable(filter_func):
if not await filter_func(data):
return
else:
if not filter_func(data):
return
return await func(data)
if receiver is None:
self._on_receive_funcs.append(wrapper)
else:
if receiver not in self._on_receive_funcs_with_receiver:
self._on_receive_funcs_with_receiver[receiver] = []
self._on_receive_funcs_with_receiver[receiver].append(wrapper)
return func
return decorator
def _start_receiver(self):
"""
使用多线程启动接收循环在通道实例化时自动启动
Returns:
"""
while True:
data, receiver = self._queue.get(block=True, timeout=None)
self._run_on_receive_funcs(data, receiver)
def _run_on_receive_funcs(self, data: Any, receiver: Optional[str] = None):
"""
运行接收函数
Args:
data: 数据
Returns:
"""
if receiver is None:
for func in self._on_receive_funcs:
if is_coroutine_callable(func):
run_coroutine(func(data))
else:
func(data)
else:
for func in self._on_receive_funcs_with_receiver.get(receiver, []):
if is_coroutine_callable(func):
run_coroutine(func(data))
else:
func(data)
def __iter__(self):
return self
def __next__(self, timeout: Optional[float] = None) -> Any:
return self.receive(block=True, timeout=timeout)
"""默认通道实例,可直接从模块导入使用"""
chan = Channel()

21
liteyuki/comm/event.py Normal file
View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/26 下午10:47
@Author : snowykami
@Email : snowykami@outlook.com
@File : event.py
@Software: PyCharm
"""
from typing import Any
class Event:
"""
事件类
"""
def __init__(self, name: str, data: dict[str, Any]):
self.name = name
self.data = data

84
liteyuki/log.py Normal file
View File

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/7/27 上午9:12
@Author : snowykami
@Email : snowykami@outlook.com
@File : log.py
@Software: PyCharm
"""
import sys
import loguru
from typing import TYPE_CHECKING
logger = loguru.logger
if TYPE_CHECKING:
# avoid sphinx autodoc resolve annotation failed
# because loguru module do not have `Logger` class actually
from loguru import Record
def default_filter(record: "Record"):
"""默认的日志过滤器,根据 `config.log_level` 配置改变日志等级。"""
log_level = record["extra"].get("nonebot_log_level", "INFO")
levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
return record["level"].no >= levelno
# DEBUG日志格式
debug_format: str = (
"<c>{time:YYYY-MM-DD HH:mm:ss}</c> "
"<lvl>[{level.icon}]</lvl> "
"<c><{name}.{module}.{function}:{line}></c> "
"{message}"
)
# 默认日志格式
default_format: str = (
"<c>{time:MM-DD HH:mm:ss}</c> "
"<lvl>[{level.icon}]</lvl> "
"<c><{name}></c> "
"{message}"
)
def get_format(level: str) -> str:
if level == "DEBUG":
return debug_format
else:
return default_format
logger = loguru.logger.bind()
def init_log(config: dict):
"""
在语言加载完成后执行
Returns:
"""
global logger
logger.remove()
logger.add(
sys.stdout,
level=0,
diagnose=False,
filter=default_filter,
format=get_format(config.get("log_level", "INFO")),
)
show_icon = config.get("log_icon", True)
# debug = lang.get("log.debug", default="==DEBUG")
# info = lang.get("log.info", default="===INFO")
# success = lang.get("log.success", default="SUCCESS")
# warning = lang.get("log.warning", default="WARNING")
# error = lang.get("log.error", default="==ERROR")
#
# logger.level("DEBUG", color="<blue>", icon=f"{'🐛' if show_icon else ''}{debug}")
# logger.level("INFO", color="<normal>", icon=f"{'' if show_icon else ''}{info}")
# logger.level("SUCCESS", color="<green>", icon=f"{'✅' if show_icon else ''}{success}")
# logger.level("WARNING", color="<yellow>", icon=f"{'⚠️' if show_icon else ''}{warning}")
# logger.level("ERROR", color="<red>", icon=f"{'⭕' if show_icon else ''}{error}")

View File

@ -1,8 +1,11 @@
import asyncio
import multiprocessing
import time
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from liteyuki.plugin import PluginMetadata
from liteyuki import get_bot
from liteyuki import get_bot, chan
__plugin_metadata__ = PluginMetadata(
name="plugin_loader",
@ -13,6 +16,7 @@ __plugin_metadata__ = PluginMetadata(
)
from src.utils import TempConfig, common_db
liteyuki = get_bot()
@ -31,6 +35,19 @@ def _():
print("轻雪启动中")
@liteyuki.on_after_start
async def _():
print("轻雪启动完成")
chan.send("轻雪启动完成")
@liteyuki.on_after_nonebot_init
async def _():
print("NoneBot初始化完成")
@chan.on_receive(receiver="main")
async def _(data):
print("收到消息", data)
await asyncio.sleep(5)

View File

@ -2,9 +2,11 @@
"""
一些常用的工具类部分来源于 nonebot 并遵循其许可进行修改
"""
import asyncio
import inspect
import threading
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, Coroutine
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
@ -23,6 +25,39 @@ def is_coroutine_callable(call: Callable[..., Any]) -> bool:
return inspect.iscoroutinefunction(func_)
def run_coroutine(*coro: Coroutine):
"""
运行协程
Args:
coro:
Returns:
"""
# 检测是否有现有的事件循环
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建任务
for c in coro:
asyncio.ensure_future(c)
else:
# 如果事件循环未运行,运行直到完成
for c in coro:
loop.run_until_complete(c)
except RuntimeError:
# 如果没有找到事件循环,创建一个新的
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.gather(*coro))
loop.close()
except Exception as e:
# 捕获其他异常,防止协程被重复等待
print(f"Exception occurred: {e}")
def path_to_module_name(path: Path) -> str:
"""
转换路径为模块名

View File

@ -88,7 +88,7 @@ async def _(matcher: Matcher, bot: T_Bot, event: T_MessageEvent):
"reload_bot_id" : bot.self_id,
"reload_session_type": event_utils.get_message_type(event),
"reload_session_id" : (event.group_id if event.message_type == "group" else event.user_id) if not isinstance(event,
satori.event.Event) else event.channel.id,
satori.event.Event) else event.chan.id,
"delta_time" : 0
}
)

View File

@ -1,3 +1,5 @@
import asyncio
import nonebot.plugin
from nonebot import get_driver
from src.utils import init_log
@ -6,7 +8,9 @@ from src.utils.base.data_manager import InstalledPlugin, plugin_db
from src.utils.base.resource import load_resources
from src.utils.message.tools import check_for_package
from liteyuki import get_bot
from liteyuki import get_bot, chan
from nonebot_plugin_apscheduler import scheduler
load_resources()
init_log()
@ -32,33 +36,3 @@ async def load_plugins():
nonebot.plugin.load_plugins("plugins")
else:
nonebot.logger.info("Safe mode is on, no plugin loaded.")
@liteyuki_bot.on_before_start
async def _():
print("启动前")
@liteyuki_bot.on_after_start
async def _():
print("启动后")
@liteyuki_bot.on_before_shutdown
async def _():
print("停止前")
@liteyuki_bot.on_after_shutdown
async def _():
print("停止后")
@liteyuki_bot.on_before_restart
async def _():
print("重启前")
@liteyuki_bot.on_after_restart
async def _():
print("重启后")

View File

@ -7,6 +7,9 @@ from pydantic import BaseModel
from src.utils.base.config import get_config
from src.utils.io import fetch
NONEBOT_PLUGIN_STORE_URL: str = "https://registry.nonebot.dev/plugins.json" # NoneBot商店地址
LITEYUKI_PLUGIN_STORE_URL: str = "https://bot.liteyuki.icu/assets/plugins.json" # 轻雪商店地址
class Session:
def __init__(self, session_type: str, session_id: int | str):

View File

@ -1,12 +1,12 @@
import inspect
import os
import pickle
import sqlite3
from types import NoneType
from typing import Any, Callable
from packaging.version import parse
import inspect
import nonebot
import pydantic
from nonebot import logger
from nonebot.compat import PYDANTIC_V2
from pydantic import BaseModel
@ -15,10 +15,10 @@ class LiteModel(BaseModel):
id: int = None
def dump(self, *args, **kwargs):
if parse(pydantic.__version__) < parse("2.0.0"):
return self.dict(*args, **kwargs)
else:
if PYDANTIC_V2:
return self.model_dump(*args, **kwargs)
else:
return self.dict(*args, **kwargs)
class Database:
@ -60,7 +60,7 @@ class Database:
"""
table_name = model.TABLE_NAME
model_type = type(model)
nonebot.logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}")
logger.debug(f"Selecting {model.TABLE_NAME} WHERE {condition.replace('?', '%s') % args}")
if not table_name:
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
@ -88,7 +88,7 @@ class Database:
"""
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
for model in args:
nonebot.logger.debug(f"Upserting {model}")
logger.debug(f"Upserting {model}")
if not model.TABLE_NAME:
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
elif model.TABLE_NAME not in table_list:
@ -206,7 +206,7 @@ class Database:
"""
table_name = model.TABLE_NAME
nonebot.logger.debug(f"Deleting {model} WHERE {condition} {args}")
logger.debug(f"Deleting {model} WHERE {condition} {args}")
if not table_name:
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
if model.id is not None: