unfinished

This commit is contained in:
Asankilp 2025-02-23 11:37:29 +08:00
parent 091e88fe81
commit d9f22fa0f7
2 changed files with 170 additions and 147 deletions

View File

@ -1,29 +1,42 @@
import json
from typing import Optional, Union from typing import Optional, Union
from azure.ai.inference.models import ( from azure.ai.inference.models import (
AssistantMessage, CompletionsFinishReason,
ImageContentItem, ImageContentItem,
ImageUrl, ImageUrl,
TextContentItem, TextContentItem,
ToolMessage, ToolMessage,
UserMessage, UserMessage,
) )
from nonebot.adapters import Event from nonebot.adapters import Bot, Event
from nonebot.log import logger from nonebot.log import logger
from nonebot.matcher import Matcher, current_event, current_matcher from nonebot.matcher import (
Matcher,
current_bot,
current_event,
current_handler,
current_matcher,
)
from nonebot.typing import T_State
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from .config import config from .config import config
from .constants import SUPPORT_IMAGE_MODELS from .constants import SUPPORT_IMAGE_MODELS
from .instances import target_list, tools
from .models import MarshoContext from .models import MarshoContext
from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext
from .util import ( from .util import (
extract_content_and_think,
get_backup_context, get_backup_context,
get_image_b64, get_image_b64,
get_nickname_by_user_id, get_nickname_by_user_id,
get_prompt, get_prompt,
make_chat_openai, make_chat_openai,
parse_richtext,
) )
@ -35,7 +48,9 @@ class MarshoHandler:
): ):
self.client = client self.client = client
self.context = context self.context = context
self.bot: Bot = current_bot.get()
self.event: Event = current_event.get() self.event: Event = current_event.get()
self.state: T_State = current_handler.get().state
self.matcher: Matcher = current_matcher.get() self.matcher: Matcher = current_matcher.get()
self.message_id: str = UniMessage.get_message_id(self.event) self.message_id: str = UniMessage.get_message_id(self.event)
self.target = UniMessage.get_target(self.event) self.target = UniMessage.get_target(self.event)
@ -44,7 +59,7 @@ class MarshoHandler:
self, user_input: UniMsg, model_name: str self, user_input: UniMsg, model_name: str
) -> Union[str, list]: ) -> Union[str, list]:
""" """
处理用户输入 处理用户输入为可输入 API 的格式并添加昵称提示
""" """
is_support_image_model = ( is_support_image_model = (
model_name.lower() model_name.lower()
@ -88,12 +103,7 @@ class MarshoHandler:
""" """
处理单条聊天 处理单条聊天
""" """
backup_context = await get_backup_context(self.target.id, self.target.private)
if backup_context:
self.context.set_context(
backup_context, self.target.id, self.target.private
) # 加载历史记录
logger.info(f"已恢复会话 {self.target.id} 的上下文备份~")
context_msg = ( context_msg = (
get_prompt(model_name) get_prompt(model_name)
+ (self.context.build(self.target.id, self.target.private)) + (self.context.build(self.target.id, self.target.private))
@ -108,6 +118,109 @@ class MarshoHandler:
) )
return response return response
async def handle_function_call(
self,
completion: ChatCompletion,
):
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
choice = completion.choices[0]
while choice.message.tool_calls is not None:
# await UniMessage(str(response)).send()
tool_calls = choice.message.tool_calls
# try:
# if tool_calls[0]["function"]["name"].startswith("$"):
# choice.message.tool_calls[0][
# "type"
# ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
# except:
# pass
tool_msg.append(choice.message)
for tool_call in tool_calls:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
# 删除args的placeholder参数
if "placeholder" in function_args:
del function_args["placeholder"]
logger.info(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
)
await UniMessage(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(tool_call.function.name):
logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
func_return = await caller.with_ctx(
SessionContext(
bot=self.bot,
event=self.event,
state=self.state,
matcher=self.matcher,
)
).call(**function_args)
else:
logger.error(
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
func_return = (
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
)
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send()
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
response = await make_chat_openai(
client=client,
model_name=model_name,
msg=request_msg, # type: ignore
tools=(
tools_lists if tools_lists else None
), # TODO 临时追加函数,后期优化
)
choice = response.choices[0]
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None:
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文
context.append(
UserMessage(content=usermsg).as_dict(), self.target.id, self.target.private # type: ignore
)
# context.append(tool_msg, self.target.id, self.target.private)
choice_msg_dict = choice.message.to_dict()
if "reasoning_content" in choice_msg_dict:
del choice_msg_dict["reasoning_content"]
context.append(choice_msg_dict, self.target.id, self.target.private)
# 发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice.message.content))).send(
reply_to=True
)
else:
await UniMessage(str(choice.message.content)).send(reply_to=True)
else:
await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}")
async def handle_common_chat( async def handle_common_chat(
self, self,
user_message: Union[str, list], user_message: Union[str, list],
@ -119,6 +232,7 @@ class MarshoHandler:
""" """
处理一般聊天 处理一般聊天
""" """
global target_list
if stream: if stream:
raise NotImplementedError raise NotImplementedError
response = await self.handle_single_chat( response = await self.handle_single_chat(
@ -127,4 +241,43 @@ class MarshoHandler:
tools=tools, tools=tools,
with_context=with_context, with_context=with_context,
) )
return response choice = response.choices[0]
# Sprint(choice)
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
choice.finish_reason = "tool_calls"
logger.info(f"完成原因:{choice.finish_reason}")
if choice.finish_reason == CompletionsFinishReason.STOPPED:
##### DeepSeek-R1 兼容部分 #####
choice_msg_content, choice_msg_thinking, choice_msg_after = (
extract_content_and_think(choice.message)
)
if choice_msg_thinking and config.marshoai_send_thinking:
await UniMessage("思维链:\n" + choice_msg_thinking).send()
##### 兼容部分结束 #####
if [self.target.id, self.target.private] not in target_list:
target_list.append([self.target.id, self.target.private])
# 对话成功发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice_msg_content))).send(
reply_to=True
)
else:
await UniMessage(str(choice_msg_content)).send(reply_to=True)
return (UserMessage(context=user_message), choice_msg_after)
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(
reply_to=True
)
return None
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
pass
else:
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
return None

View File

@ -233,6 +233,12 @@ async def marsho(
# 发送说明 # 发送说明
# await UniMessage(metadata.usage + "\n当前使用的模型" + model_name).send() # await UniMessage(metadata.usage + "\n当前使用的模型" + model_name).send()
await marsho_cmd.finish(INTRODUCTION) await marsho_cmd.finish(INTRODUCTION)
backup_context = await get_backup_context(target.id, target.private)
if backup_context:
context.set_context(
backup_context, target.id, target.private
) # 加载历史记录
logger.info(f"已恢复会话 {target.id} 的上下文备份~")
handler = MarshoHandler(client, context) handler = MarshoHandler(client, context)
try: try:
user_nickname = await get_nickname_by_user_id(event.get_user_id()) user_nickname = await get_nickname_by_user_id(event.get_user_id())
@ -261,143 +267,7 @@ async def marsho(
usermsg, model_name, tools_lists, with_context=True usermsg, model_name, tools_lists, with_context=True
) )
# await UniMessage(str(response)).send() # await UniMessage(str(response)).send()
choice = response.choices[0]
# Sprint(choice)
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
choice.finish_reason = "tool_calls"
logger.info(f"完成原因:{choice.finish_reason}")
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 当对话成功时将dict的上下文添加到上下文类中
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
)
##### DeepSeek-R1 兼容部分 #####
choice_msg_content, choice_msg_thinking, choice_msg_after = (
extract_content_and_think(choice.message)
)
if choice_msg_thinking and config.marshoai_send_thinking:
await UniMessage("思维链:\n" + choice_msg_thinking).send()
##### 兼容部分结束 #####
context.append(choice_msg_after.to_dict(), target.id, target.private)
if [target.id, target.private] not in target_list:
target_list.append([target.id, target.private])
# 对话成功发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice_msg_content))).send(
reply_to=True
)
else:
await UniMessage(str(choice_msg_content)).send(reply_to=True)
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤
await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send(
reply_to=True
)
return
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
while choice.message.tool_calls is not None:
# await UniMessage(str(response)).send()
tool_calls = choice.message.tool_calls
# try:
# if tool_calls[0]["function"]["name"].startswith("$"):
# choice.message.tool_calls[0][
# "type"
# ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
# except:
# pass
tool_msg.append(choice.message)
for tool_call in tool_calls:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
# 删除args的placeholder参数
if "placeholder" in function_args:
del function_args["placeholder"]
logger.info(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
)
await UniMessage(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(tool_call.function.name):
logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
func_return = await caller.with_ctx(
SessionContext(
bot=bot,
event=event,
state=state,
matcher=matcher,
)
).call(**function_args)
else:
logger.error(
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
)
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send()
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
response = await make_chat_openai(
client=client,
model_name=model_name,
msg=request_msg, # type: ignore
tools=(
tools_lists if tools_lists else None
), # TODO 临时追加函数,后期优化
)
choice = response.choices[0]
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None:
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
)
# context.append(tool_msg, target.id, target.private)
choice_msg_dict = choice.message.to_dict()
if "reasoning_content" in choice_msg_dict:
del choice_msg_dict["reasoning_content"]
context.append(choice_msg_dict, target.id, target.private)
# 发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice.message.content))).send(
reply_to=True
)
else:
await UniMessage(str(choice.message.content)).send(reply_to=True)
else:
await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}")
else:
await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}")
except Exception as e: except Exception as e:
await UniMessage(str(e) + suggest_solution(str(e))).send() await UniMessage(str(e) + suggest_solution(str(e))).send()
traceback.print_exc() traceback.print_exc()