diff --git a/.gitignore b/.gitignore index a308d61..91d3bb2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ + +# Other Things +test.md + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8aef7b1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.typeCheckingMode": "standard" +} \ No newline at end of file diff --git a/nonebot_plugin_marshoai/__init__.py b/nonebot_plugin_marshoai/__init__.py index 5406a5f..b92a6ba 100644 --- a/nonebot_plugin_marshoai/__init__.py +++ b/nonebot_plugin_marshoai/__init__.py @@ -1,15 +1,19 @@ from nonebot.plugin import require + require("nonebot_plugin_alconna") require("nonebot_plugin_localstore") -from .azure import * -#from .hunyuan import * + from nonebot import get_driver, logger +import nonebot_plugin_localstore as store + +# from .hunyuan import * +from .azure import * from .config import config from .metadata import metadata -import nonebot_plugin_localstore as store __author__ = "Asankilp" __plugin_meta__ = metadata + driver = get_driver() diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index 458b860..f7ebc33 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -1,5 +1,6 @@ -import contextlib +import uuid import traceback +import contextlib from typing import Optional from pathlib import Path @@ -15,15 +16,21 @@ from azure.ai.inference.models import ( ChatCompletionsToolCall, ) from azure.core.credentials import AzureKeyCredential -from nonebot import on_command, on_message, logger +from nonebot import on_command, on_message, logger, get_driver from nonebot.adapters import Message, Event from nonebot.params import CommandArg from nonebot.permission import SUPERUSER from nonebot.rule import Rule, to_me -from nonebot_plugin_alconna import on_alconna, MsgTarget -from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg +from nonebot_plugin_alconna import ( + on_alconna, + MsgTarget, + UniMessage, + UniMsg, + Text as TextMsg, + Image as ImageMsg, +) import nonebot_plugin_localstore as store -from nonebot import get_driver + from .constants import * from .metadata import metadata @@ -37,15 +44,23 @@ async def at_enable(): driver = get_driver() -changemodel_cmd = on_command("changemodel", permission=SUPERUSER, priority=10, block=True) +changemodel_cmd = on_command( + "changemodel", permission=SUPERUSER, priority=10, block=True +) resetmem_cmd = on_command("reset", priority=10, block=True) # setprompt_cmd = on_command("prompt",permission=SUPERUSER) praises_cmd = on_command("praises", permission=SUPERUSER, priority=10, block=True) add_usermsg_cmd = on_command("usermsg", permission=SUPERUSER, priority=10, block=True) -add_assistantmsg_cmd = on_command("assistantmsg", permission=SUPERUSER, priority=10, block=True) +add_assistantmsg_cmd = on_command( + "assistantmsg", permission=SUPERUSER, priority=10, block=True +) contexts_cmd = on_command("contexts", permission=SUPERUSER, priority=10, block=True) -save_context_cmd = on_command("savecontext", permission=SUPERUSER, priority=10, block=True) -load_context_cmd = on_command("loadcontext", permission=SUPERUSER, priority=10, block=True) +save_context_cmd = on_command( + "savecontext", permission=SUPERUSER, priority=10, block=True +) +load_context_cmd = on_command( + "loadcontext", permission=SUPERUSER, priority=10, block=True +) marsho_cmd = on_alconna( Alconna( config.marshoai_default_name, @@ -53,18 +68,20 @@ marsho_cmd = on_alconna( ), aliases=config.marshoai_aliases, priority=10, - block=True + block=True, ) -marsho_at = on_message(rule=to_me()&at_enable, priority=11) +marsho_at = on_message(rule=to_me() & at_enable, priority=11) nickname_cmd = on_alconna( Alconna( "nickname", Args["name?", str], ), - priority = 10, - block = True + priority=10, + block=True, +) +refresh_data_cmd = on_command( + "refresh_data", permission=SUPERUSER, priority=10, block=True ) -refresh_data_cmd = on_command("refresh_data", permission=SUPERUSER, priority=10, block=True) command_start = driver.config.command_start model_name = config.marshoai_default_model @@ -86,7 +103,9 @@ async def _preload_tools(): tools.load_tools(store.get_plugin_data_dir() / "tools") for tool_dir in config.marshoai_toolset_dir: tools.load_tools(tool_dir) - logger.info("如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。") + logger.info( + "如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。" + ) @add_usermsg_cmd.handle() @@ -132,7 +151,9 @@ async def save_context(target: MsgTarget, arg: Message = CommandArg()): @load_context_cmd.handle() async def load_context(target: MsgTarget, arg: Message = CommandArg()): if msg := arg.extract_plain_text(): - await get_backup_context(target.id, target.private) # 为了将当前会话添加到"已恢复过备份"的列表而添加,防止上下文被覆盖(好奇怪QwQ + await get_backup_context( + target.id, target.private + ) # 为了将当前会话添加到"已恢复过备份"的列表而添加,防止上下文被覆盖(好奇怪QwQ context.set_context( await load_context_from_json(msg, "contexts"), target.id, target.private ) @@ -178,11 +199,86 @@ async def refresh_data(): await refresh_data_cmd.finish("已刷新数据") +""" +以下函数依照 Apache 2.0 协议授权 + +函数: get_back_uuidcodeblock、send_markdown + +版权所有 © 2024 金羿ELS +Copyright (R) 2024 Eilles(EillesWan@outlook.com) + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +async def get_back_uuidcodeblock(msg: str, code_blank_uuid_map: list[tuple[str, str]]): + + for torep, rep in code_blank_uuid_map: + msg = msg.replace(torep, rep) + + return msg + + +async def send_markdown(msg: str): + """ + 人工智能给出的回答一般不会包含 HTML 嵌入其中,但是包含图片或者 LaTeX 公式、代码块,都很正常。 + 这个函数会把这些都以图片形式嵌入消息体。 + """ + result_msg = UniMessage() + code_blank_uuid_map = [ + (uuid.uuid4().hex, cbp.group()) for cbp in CODE_BLOCK_PATTERN.finditer(msg) + ] + + # 代码块渲染麻烦,先不处理 + for rep, torep in code_blank_uuid_map: + msg = msg.replace(torep, rep) + + # 插入图片 + for each_img_tag in IMG_TAG_PATTERN.finditer(msg): + img_tag = await get_back_uuidcodeblock( + each_img_tag.group(), code_blank_uuid_map + ) + image_description = img_tag[2 : img_tag.find("]")] + image_url = img_tag[img_tag.find("(") + 1 : -1] + + result_msg.append( + TextMsg( + await get_back_uuidcodeblock( + msg[: msg.find(img_tag)], code_blank_uuid_map + ) + ) + ) + + result_msg.append(ImageMsg(url=image_url, name=image_description + ".png")) + + result_msg.append(TextMsg("({})".format(image_description))) + + await result_msg.send(reply_to=True) + + +""" +Apache 2.0 协议授权部分结束 +""" + + @marsho_at.handle() @marsho_cmd.handle() async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None): global target_list - if event.get_message().extract_plain_text() and (not text and event.get_message().extract_plain_text() != config.marshoai_default_name): + if event.get_message().extract_plain_text() and ( + not text + and event.get_message().extract_plain_text() != config.marshoai_default_name + ): text = event.get_message() if not text: # 发送说明 @@ -204,7 +300,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) "*你未设置自己的昵称。推荐使用'nickname [昵称]'命令设置昵称来获得个性化(可能)回答。" ).send() - is_support_image_model = model_name.lower() in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models + is_support_image_model = ( + model_name.lower() + in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models + ) is_reasoning_model = model_name.lower() in REASONING_MODELS usermsg = [] if is_support_image_model else "" for i in text: @@ -217,14 +316,18 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) if is_support_image_model: usermsg.append( ImageContentItem( - image_url=ImageUrl(url=str(await get_image_b64(i.data["url"]))) + image_url=ImageUrl( + url=str(await get_image_b64(i.data["url"])) + ) ) ) elif config.marshoai_enable_support_image_tip: await UniMessage("*此模型不支持图片处理。").send() backup_context = await get_backup_context(target.id, target.private) if backup_context: - context.set_context(backup_context, target.id, target.private) # 加载历史记录 + context.set_context( + backup_context, target.id, target.private + ) # 加载历史记录 logger.info(f"已恢复会话 {target.id} 的上下文备份~") context_msg = context.build(target.id, target.private) if not is_reasoning_model: @@ -234,46 +337,74 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) client=client, model_name=model_name, msg=context_msg + [UserMessage(content=usermsg)], - tools=tools.get_tools_list() + tools=tools.get_tools_list(), ) # await UniMessage(str(response)).send() choice = response.choices[0] - if (choice["finish_reason"] == CompletionsFinishReason.STOPPED): # 当对话成功时,将dict的上下文添加到上下文类中 + if choice["finish_reason"] == CompletionsFinishReason.STOPPED: + # 当对话成功时,将dict的上下文添加到上下文类中 context.append( UserMessage(content=usermsg).as_dict(), target.id, target.private ) context.append(choice.message.as_dict(), target.id, target.private) if [target.id, target.private] not in target_list: target_list.append([target.id, target.private]) - await UniMessage(str(choice.message.content)).send(reply_to=True) + + # 对话成功发送消息 + await send_markdown(str(choice.message.content)) elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED: - await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(reply_to=True) + + # 对话失败,消息过滤 + + await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send( + reply_to=True + ) return elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS: + + # 需要获取额外信息,调用函数工具 tool_msg = [] while choice.message.tool_calls != None: - tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls)) + tool_msg.append( + AssistantMessage(tool_calls=response.choices[0].message.tool_calls) + ) for tool_call in choice.message.tool_calls: - if isinstance(tool_call, ChatCompletionsToolCall): # 循环调用工具直到不需要调用 - function_args = json.loads(tool_call.function.arguments.replace("'", '"')) - logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}") - await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send() - func_return = await tools.call(tool_call.function.name, function_args) # 获取返回值 - tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return)) + if isinstance( + tool_call, ChatCompletionsToolCall + ): # 循环调用工具直到不需要调用 + function_args = json.loads( + tool_call.function.arguments.replace("'", '"') + ) + logger.info( + f"调用函数 {tool_call.function.name} ,参数为 {function_args}" + ) + await UniMessage( + f"调用函数 {tool_call.function.name} ,参数为 {function_args}" + ).send() + func_return = await tools.call( + tool_call.function.name, function_args + ) # 获取返回值 + tool_msg.append( + ToolMessage(tool_call_id=tool_call.id, content=func_return) + ) response = await make_chat( client=client, model_name=model_name, msg=context_msg + [UserMessage(content=usermsg)] + tool_msg, - tools=tools.get_tools_list() + tools=tools.get_tools_list(), ) choice = response.choices[0] if choice["finish_reason"] == CompletionsFinishReason.STOPPED: + + # 对话成功 添加上下文 context.append( UserMessage(content=usermsg).as_dict(), target.id, target.private ) - # context.append(tool_msg, target.id, target.private) + # context.append(tool_msg, target.id, target.private) context.append(choice.message.as_dict(), target.id, target.private) - await UniMessage(str(choice.message.content)).send(reply_to=True) + + # 发送消息 + await send_markdown(str(choice.message.content)) else: await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}") else: @@ -288,7 +419,6 @@ with contextlib.suppress(ImportError): # 优化先不做() import nonebot.adapters.onebot.v11 # type: ignore from .azure_onebot import poke_notify - @poke_notify.handle() async def poke(event: Event): @@ -327,5 +457,7 @@ async def auto_backup_context(): target_uid = "private_" + target_id else: target_uid = "group_" + target_id - await save_context_to_json(f"back_up_context_{target_uid}", contexts_data, "contexts/backup") + await save_context_to_json( + f"back_up_context_{target_uid}", contexts_data, "contexts/backup" + ) logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~") diff --git a/nonebot_plugin_marshoai/config.py b/nonebot_plugin_marshoai/config.py index e1d0490..bc42beb 100644 --- a/nonebot_plugin_marshoai/config.py +++ b/nonebot_plugin_marshoai/config.py @@ -55,19 +55,19 @@ destination_file = destination_folder / "config.yaml" def copy_config(source_template, destination_file): - ''' + """ 复制模板配置文件到config - ''' + """ shutil.copy(source_template, destination_file) def check_yaml_is_changed(source_template): - ''' + """ 检查配置文件是否需要更新 - ''' - with open(config_file_path, 'r', encoding="utf-8") as f: + """ + with open(config_file_path, "r", encoding="utf-8") as f: old = yaml.load(f) - with open(source_template, 'r', encoding="utf-8") as f: + with open(source_template, "r", encoding="utf-8") as f: example_ = yaml.load(f) keys1 = set(example_.keys()) keys2 = set(old.keys()) @@ -78,9 +78,9 @@ def check_yaml_is_changed(source_template): def merge_configs(old_config, new_config): - ''' + """ 合并配置文件 - ''' + """ for key, value in new_config.items(): if key in old_config: continue @@ -89,6 +89,7 @@ def merge_configs(old_config, new_config): old_config[key] = value return old_config + config: ConfigModel = get_plugin_config(ConfigModel) if config.marshoai_use_yaml_config: if not config_file_path.exists(): @@ -97,25 +98,27 @@ if config.marshoai_use_yaml_config: copy_config(source_template, destination_file) else: logger.info("配置文件存在,正在读取") - + if check_yaml_is_changed(source_template): yaml_2 = YAML() logger.info("插件新的配置已更新, 正在更新") - - with open(config_file_path, 'r', encoding="utf-8") as f: + + with open(config_file_path, "r", encoding="utf-8") as f: old_config = yaml_2.load(f) - - with open(source_template, 'r', encoding="utf-8") as f: + + with open(source_template, "r", encoding="utf-8") as f: new_config = yaml_2.load(f) - + merged_config = merge_configs(old_config, new_config) - - with open(destination_file, 'w', encoding="utf-8") as f: + + with open(destination_file, "w", encoding="utf-8") as f: yaml_2.dump(merged_config, f) - + with open(config_file_path, "r", encoding="utf-8") as f: yaml_config = yaml_.load(f, Loader=yaml_.FullLoader) - + config = ConfigModel(**yaml_config) else: - logger.info("MarshoAI 支持新的 YAML 配置系统,若要使用,请将 MARSHOAI_USE_YAML_CONFIG 配置项设置为 true。") + logger.info( + "MarshoAI 支持新的 YAML 配置系统,若要使用,请将 MARSHOAI_USE_YAML_CONFIG 配置项设置为 true。" + ) diff --git a/nonebot_plugin_marshoai/constants.py b/nonebot_plugin_marshoai/constants.py index 11b2810..9dac1a5 100644 --- a/nonebot_plugin_marshoai/constants.py +++ b/nonebot_plugin_marshoai/constants.py @@ -1,4 +1,6 @@ +import re from .config import config + USAGE: str = f"""MarshoAI-NoneBot Beta by Asankilp 用法: {config.marshoai_default_name} <聊天内容> : 与 Marsho 进行对话。当模型为 GPT-4o(-mini) 等时,可以带上图片进行对话。 @@ -15,9 +17,15 @@ USAGE: str = f"""MarshoAI-NoneBot Beta by Asankilp refresh_data : 从文件刷新已加载的昵称与夸赞名单。 ※本AI的回答"按原样"提供,不提供任何担保。AI也会犯错,请仔细甄别回答的准确性。""" -SUPPORT_IMAGE_MODELS: list = ["gpt-4o","gpt-4o-mini","phi-3.5-vision-instruct","llama-3.2-90b-vision-instruct","llama-3.2-11b-vision-instruct"] -REASONING_MODELS: list = ["o1-preview","o1-mini"] -INTRODUCTION: str = """你好喵~我是一只可爱的猫娘AI,名叫小棉~🐾! +SUPPORT_IMAGE_MODELS: list = [ + "gpt-4o", + "gpt-4o-mini", + "phi-3.5-vision-instruct", + "llama-3.2-90b-vision-instruct", + "llama-3.2-11b-vision-instruct", +] +REASONING_MODELS: list = ["o1-preview", "o1-mini"] +INTRODUCTION: str = """你好喵~我是一只可爱的猫娘AI,名叫小棉~🐾! 我的代码在这里哦~↓↓↓ https://github.com/LiteyukiStudio/nonebot-plugin-marshoai @@ -25,3 +33,19 @@ https://github.com/LiteyukiStudio/nonebot-plugin-marshoai https://github.com/Meloland/melobot 我与 Melobot 酱贴贴的代码在这里喵~↓↓↓ https://github.com/LiteyukiStudio/marshoai-melo""" + + +# 正则匹配代码块 +CODE_BLOCK_PATTERN = re.compile( + r"```(.*?)```|`(.*?)`", +) +# 正则匹配完整图片标签字段 +IMG_TAG_PATTERN = re.compile(r"!\[[^\]]*\]\([^()]*\)") +# # 正则匹配图片标签中的图片url字段 +# INTAG_URL_PATTERN = re.compile(r'\(([^)]*)') +# # 正则匹配图片标签中的文本描述字段 +# INTAG_TEXT_PATTERN = re.compile(r'!\[([^\]]*)\]') +# 正则匹配 LaTeX 公式内容 +LATEX_PATTERN = re.compile( + r"\\begin\{equation\}(.*?)\\\end\{equation\}|(? list: target_uid = f"group_{target_id}" if target_uid not in loaded_target_list: loaded_target_list.append(target_uid) - return await load_context_from_json(f"back_up_context_{target_uid}", "contexts/backup") + return await load_context_from_json( + f"back_up_context_{target_uid}", "contexts/backup" + ) return [] diff --git a/nonebot_plugin_marshoai/util_hunyuan.py b/nonebot_plugin_marshoai/util_hunyuan.py index 3d0f994..47e899f 100644 --- a/nonebot_plugin_marshoai/util_hunyuan.py +++ b/nonebot_plugin_marshoai/util_hunyuan.py @@ -3,11 +3,17 @@ import types from tencentcloud.common import credential from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( + TencentCloudSDKException, +) from tencentcloud.hunyuan.v20230901 import hunyuan_client, models from .config import config + + def generate_image(prompt: str): - cred = credential.Credential(config.marshoai_tencent_secretid, config.marshoai_tencent_secretkey) + cred = credential.Credential( + config.marshoai_tencent_secretid, config.marshoai_tencent_secretkey + ) # 实例化一个http选项,可选的,没有特殊需求可以跳过 httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -18,11 +24,7 @@ def generate_image(prompt: str): client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile) req = models.TextToImageLiteRequest() - params = { - "Prompt": prompt, - "RspImgType": "url", - "Resolution": "1080:1920" - } + params = {"Prompt": prompt, "RspImgType": "url", "Resolution": "1080:1920"} req.from_json_string(json.dumps(params)) # 返回的resp是一个TextToImageLiteResponse的实例,与请求对象对应 diff --git a/pyproject.toml b/pyproject.toml index 21cdc91..7f19d0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "httpx>=0.27.0", "ruamel.yaml>=0.18.6", "pyyaml>=6.0.2" - + ] license = { text = "MIT" }