nonebot-plugin-marshoai/nonebot_plugin_marshoai/azure.py

491 lines
18 KiB
Python
Raw Normal View History

import uuid
import traceback
import contextlib
from typing import Optional
from pathlib import Path
from arclet.alconna import Alconna, Args, AllParam
from azure.ai.inference.models import (
UserMessage,
AssistantMessage,
ToolMessage,
TextContentItem,
ImageContentItem,
ImageUrl,
CompletionsFinishReason,
ChatCompletionsToolCall,
)
2024-09-17 20:20:31 +08:00
from azure.core.credentials import AzureKeyCredential
from nonebot import on_command, on_message, logger, get_driver
from nonebot.adapters import Message, Event
from nonebot.params import CommandArg
from nonebot.permission import SUPERUSER
2024-11-24 15:00:44 +08:00
from nonebot.rule import Rule, to_me
from nonebot_plugin_alconna import (
on_alconna,
MsgTarget,
UniMessage,
UniMsg,
Text as TextMsg,
Image as ImageMsg,
)
import nonebot_plugin_localstore as store
from .constants import *
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .util import *
2024-11-24 15:00:44 +08:00
async def at_enable():
return config.marshoai_at
driver = get_driver()
changemodel_cmd = on_command(
"changemodel", permission=SUPERUSER, priority=10, block=True
)
resetmem_cmd = on_command("reset", priority=10, block=True)
# setprompt_cmd = on_command("prompt",permission=SUPERUSER)
praises_cmd = on_command("praises", permission=SUPERUSER, priority=10, block=True)
add_usermsg_cmd = on_command("usermsg", permission=SUPERUSER, priority=10, block=True)
add_assistantmsg_cmd = on_command(
"assistantmsg", permission=SUPERUSER, priority=10, block=True
)
contexts_cmd = on_command("contexts", permission=SUPERUSER, priority=10, block=True)
save_context_cmd = on_command(
"savecontext", permission=SUPERUSER, priority=10, block=True
)
load_context_cmd = on_command(
"loadcontext", permission=SUPERUSER, priority=10, block=True
)
marsho_cmd = on_alconna(
Alconna(
config.marshoai_default_name,
Args["text?", AllParam],
),
aliases=config.marshoai_aliases,
priority=10,
block=True,
)
marsho_at = on_message(rule=to_me() & at_enable, priority=11)
nickname_cmd = on_alconna(
Alconna(
"nickname",
Args["name?", str],
),
priority=10,
block=True,
)
refresh_data_cmd = on_command(
"refresh_data", permission=SUPERUSER, priority=10, block=True
)
command_start = driver.config.command_start
model_name = config.marshoai_default_model
context = MarshoContext()
tools = MarshoTools()
token = config.marshoai_token
endpoint = config.marshoai_azure_endpoint
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
target_list = [] # 记录需保存历史上下文的列表
2024-11-24 15:00:44 +08:00
@driver.on_startup
async def _preload_tools():
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
if config.marshoai_load_builtin_tools:
tools.load_tools(Path(__file__).parent / "tools")
tools.load_tools(store.get_plugin_data_dir() / "tools")
for tool_dir in config.marshoai_toolset_dir:
tools.load_tools(tool_dir)
logger.info(
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
)
2024-11-24 15:00:44 +08:00
@add_usermsg_cmd.handle()
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text():
2024-10-03 15:16:32 +08:00
context.append(UserMessage(content=msg).as_dict(), target.id, target.private)
await add_usermsg_cmd.finish("已添加用户消息")
@add_assistantmsg_cmd.handle()
async def add_assistantmsg(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text():
context.append(
AssistantMessage(content=msg).as_dict(), target.id, target.private
)
2024-10-03 15:16:32 +08:00
await add_assistantmsg_cmd.finish("已添加助手消息")
2024-09-28 12:24:20 +08:00
@praises_cmd.handle()
async def praises():
2024-11-24 15:00:44 +08:00
# await UniMessage(await tools.call("marshoai-weather.get_weather", {"location":"杭州"})).send()
await praises_cmd.finish(build_praises())
2024-09-17 20:20:31 +08:00
@contexts_cmd.handle()
async def contexts(target: MsgTarget):
backup_context = await get_backup_context(target.id, target.private)
if backup_context:
context.set_context(backup_context, target.id, target.private) # 加载历史记录
await contexts_cmd.finish(str(context.build(target.id, target.private)))
2024-10-03 15:16:32 +08:00
@save_context_cmd.handle()
async def save_context(target: MsgTarget, arg: Message = CommandArg()):
2024-11-17 02:38:30 +08:00
contexts_data = context.build(target.id, target.private)
if not context:
await save_context_cmd.finish("暂无上下文可以保存")
2024-10-03 15:16:32 +08:00
if msg := arg.extract_plain_text():
await save_context_to_json(msg, contexts_data, "contexts")
2024-10-03 15:16:32 +08:00
await save_context_cmd.finish("已保存上下文")
2024-09-17 20:20:31 +08:00
2024-10-03 15:16:32 +08:00
@load_context_cmd.handle()
async def load_context(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text():
await get_backup_context(
target.id, target.private
) # 为了将当前会话添加到"已恢复过备份"的列表而添加防止上下文被覆盖好奇怪QwQ
context.set_context(
await load_context_from_json(msg, "contexts"), target.id, target.private
)
2024-10-03 15:16:32 +08:00
await load_context_cmd.finish("已加载并覆盖上下文")
2024-09-17 20:20:31 +08:00
@resetmem_cmd.handle()
async def resetmem(target: MsgTarget):
if [target.id, target.private] not in target_list:
target_list.append([target.id, target.private])
context.reset(target.id, target.private)
await resetmem_cmd.finish("上下文已重置")
@changemodel_cmd.handle()
async def changemodel(arg: Message = CommandArg()):
2024-09-17 20:20:31 +08:00
global model_name
if model := arg.extract_plain_text():
model_name = model
await changemodel_cmd.finish("已切换")
2024-10-03 15:16:32 +08:00
@nickname_cmd.handle()
async def nickname(event: Event, name=None):
nicknames = await get_nicknames()
user_id = event.get_user_id()
if not name:
if user_id not in nicknames:
await nickname_cmd.finish("你未设置昵称")
await nickname_cmd.finish("你的昵称为:" + str(nicknames[user_id]))
if name == "reset":
await set_nickname(user_id, "")
await nickname_cmd.finish("已重置昵称")
else:
await set_nickname(user_id, name)
await nickname_cmd.finish("已设置昵称为:" + name)
@refresh_data_cmd.handle()
async def refresh_data():
await refresh_nickname_json()
await refresh_praises_json()
await refresh_data_cmd.finish("已刷新数据")
"""
以下函数依照 Apache 2.0 协议授权
函数 get_back_uuidcodeblocksend_markdown
版权所有 © 2024 金羿ELS
Copyright (R) 2024 Eilles(EillesWan@outlook.com)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
async def get_back_uuidcodeblock(msg: str, code_blank_uuid_map: list[tuple[str, str]]):
for torep, rep in code_blank_uuid_map:
msg = msg.replace(torep, rep)
return msg
async def send_markdown(msg: str):
"""
人工智能给出的回答一般不会包含 HTML 嵌入其中但是包含图片或者 LaTeX 公式代码块都很正常
这个函数会把这些都以图片形式嵌入消息体
"""
2024-11-30 18:03:50 +08:00
if not IMG_TAG_PATTERN.search(msg): # 没有图片标签
await UniMessage(msg).send(reply_to=True)
return
result_msg = UniMessage()
code_blank_uuid_map = [
(uuid.uuid4().hex, cbp.group()) for cbp in CODE_BLOCK_PATTERN.finditer(msg)
]
2024-11-30 18:03:50 +08:00
last_tag_index = 0
# 代码块渲染麻烦,先不处理
for rep, torep in code_blank_uuid_map:
msg = msg.replace(torep, rep)
2024-11-30 18:03:50 +08:00
# for to_rep in CODE_SINGLE_PATTERN.finditer(msg):
# code_blank_uuid_map.append((rep := uuid.uuid4().hex, to_rep.group()))
# msg = msg.replace(to_rep.group(), rep)
print("#####################\n", msg, "\n\n")
# 插入图片
for each_img_tag in IMG_TAG_PATTERN.finditer(msg):
img_tag = await get_back_uuidcodeblock(
each_img_tag.group(), code_blank_uuid_map
)
image_description = img_tag[2 : img_tag.find("]")]
image_url = img_tag[img_tag.find("(") + 1 : -1]
result_msg.append(
TextMsg(
await get_back_uuidcodeblock(
2024-11-30 18:03:50 +08:00
msg[last_tag_index : msg.find(img_tag)], code_blank_uuid_map
)
)
)
2024-11-30 18:03:50 +08:00
last_tag_index = msg.find(img_tag) + len(img_tag)
if image_ := await get_image_raw_and_type(image_url):
result_msg.append(
ImageMsg(
raw=image_[0], mimetype=image_[1], name=image_description + ".png"
)
)
result_msg.append(TextMsg("{}".format(image_description)))
else:
result_msg.append(TextMsg(img_tag))
2024-11-30 18:03:50 +08:00
result_msg.append(
TextMsg(await get_back_uuidcodeblock(msg[last_tag_index:], code_blank_uuid_map))
)
await result_msg.send(reply_to=True)
"""
Apache 2.0 协议授权部分结束
"""
2024-11-24 15:00:44 +08:00
@marsho_at.handle()
@marsho_cmd.handle()
async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None):
global target_list
if event.get_message().extract_plain_text() and (
not text
and event.get_message().extract_plain_text() != config.marshoai_default_name
):
2024-11-24 15:00:44 +08:00
text = event.get_message()
if not text:
# 发送说明
await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send()
await marsho_cmd.finish(INTRODUCTION)
try:
user_id = event.get_user_id()
nicknames = await get_nicknames()
user_nickname = nicknames.get(user_id, "")
2024-11-18 00:52:41 +08:00
if user_nickname != "":
nickname_prompt = f"\n*此消息的说话者:{user_nickname}*"
else:
nickname_prompt = ""
2024-11-24 15:00:44 +08:00
# 用户名无法获取,暂时注释
# user_nickname = event.sender.nickname # 未设置昵称时获取用户名
# nickname_prompt = f"\n*此消息的说话者:{user_nickname}"
if config.marshoai_enable_nickname_tip:
await UniMessage(
"*你未设置自己的昵称。推荐使用'nickname [昵称]'命令设置昵称来获得个性化(可能)回答。"
).send()
is_support_image_model = (
model_name.lower()
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
)
is_reasoning_model = model_name.lower() in REASONING_MODELS
usermsg = [] if is_support_image_model else ""
for i in text:
if i.type == "text":
if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt)]
else:
usermsg += str(i.data["text"] + nickname_prompt)
elif i.type == "image":
if is_support_image_model:
usermsg.append(
ImageContentItem(
image_url=ImageUrl(
url=str(await get_image_b64(i.data["url"]))
)
)
)
elif config.marshoai_enable_support_image_tip:
await UniMessage("*此模型不支持图片处理。").send()
backup_context = await get_backup_context(target.id, target.private)
if backup_context:
context.set_context(
backup_context, target.id, target.private
) # 加载历史记录
logger.info(f"已恢复会话 {target.id} 的上下文备份~")
context_msg = context.build(target.id, target.private)
2024-11-17 02:59:35 +08:00
if not is_reasoning_model:
context_msg = [get_prompt()] + context_msg
# o1等推理模型不支持系统提示词, 故不添加
response = await make_chat(
client=client,
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg)],
tools=tools.get_tools_list(),
)
# await UniMessage(str(response)).send()
choice = response.choices[0]
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
# 当对话成功时将dict的上下文添加到上下文类中
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private
)
context.append(choice.message.as_dict(), target.id, target.private)
if [target.id, target.private] not in target_list:
target_list.append([target.id, target.private])
# 对话成功发送消息
await send_markdown(str(choice.message.content))
elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(
reply_to=True
)
2024-09-17 20:20:31 +08:00
return
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
# 需要获取额外信息,调用函数工具
tool_msg = []
while choice.message.tool_calls != None:
tool_msg.append(
AssistantMessage(tool_calls=response.choices[0].message.tool_calls)
)
for tool_call in choice.message.tool_calls:
if isinstance(
tool_call, ChatCompletionsToolCall
): # 循环调用工具直到不需要调用
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
logger.info(
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
)
await UniMessage(
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
).send()
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return)
)
response = await make_chat(
client=client,
model_name=model_name,
2024-11-24 15:00:44 +08:00
msg=context_msg + [UserMessage(content=usermsg)] + tool_msg,
tools=tools.get_tools_list(),
2024-11-24 15:00:44 +08:00
)
choice = response.choices[0]
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private
)
# context.append(tool_msg, target.id, target.private)
context.append(choice.message.as_dict(), target.id, target.private)
# 发送消息
await send_markdown(str(choice.message.content))
else:
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
else:
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
except Exception as e:
await UniMessage(str(e) + suggest_solution(str(e))).send()
traceback.print_exc()
return
with contextlib.suppress(ImportError): # 优化先不做()
import nonebot.adapters.onebot.v11 # type: ignore
from .azure_onebot import poke_notify
@poke_notify.handle()
async def poke(event: Event):
user_id = event.get_user_id()
nicknames = await get_nicknames()
user_nickname = nicknames.get(user_id, "")
2024-09-17 20:20:31 +08:00
try:
if config.marshoai_poke_suffix != "":
response = await make_chat(
client=client,
model_name=model_name,
msg=[
get_prompt(),
UserMessage(
content=f"*{user_nickname}{config.marshoai_poke_suffix}"
),
],
)
choice = response.choices[0]
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
await UniMessage(" " + str(choice.message.content)).send(
at_sender=True
)
except Exception as e:
await UniMessage(str(e) + suggest_solution(str(e))).send()
2024-09-17 20:20:31 +08:00
traceback.print_exc()
return
@driver.on_shutdown
async def auto_backup_context():
for target_info in target_list:
target_id, target_private = target_info
2024-11-17 02:38:30 +08:00
contexts_data = context.build(target_id, target_private)
if target_private:
target_uid = "private_" + target_id
else:
target_uid = "group_" + target_id
await save_context_to_json(
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
)
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")