更新实例和工具模块,更换为OpenAI异步客户端进行聊天请求

This commit is contained in:
Asankilp 2025-01-26 00:48:55 +08:00
parent 132d219c59
commit 736a881071
3 changed files with 86 additions and 63 deletions

View File

@ -2,6 +2,7 @@
from azure.ai.inference.aio import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver
from openai import AsyncOpenAI
from .config import config
from .models import MarshoContext, MarshoTools
@ -14,5 +15,6 @@ context = MarshoContext()
tools = MarshoTools()
token = config.marshoai_token
endpoint = config.marshoai_azure_endpoint
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
# client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
client = AsyncOpenAI(base_url=endpoint, api_key=token)
target_list: list[list] = [] # 记录需保存历史上下文的列表

View File

@ -2,6 +2,7 @@ import contextlib
import traceback
from typing import Optional
import openai
from arclet.alconna import Alconna, AllParam, Args
from azure.ai.inference.models import (
AssistantMessage,
@ -21,6 +22,7 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me
from nonebot.typing import T_State
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from openai import AsyncOpenAI
from .hooks import *
from .instances import *
@ -253,7 +255,7 @@ async def marsho(
for i in text: # type: ignore
if i.type == "text":
if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt)] # type: ignore
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
else:
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
elif i.type == "image":
@ -263,7 +265,7 @@ async def marsho(
image_url=ImageUrl( # type: ignore
url=str(await get_image_b64(i.data["url"])) # type: ignore
) # type: ignore
) # type: ignore
).as_dict() # type: ignore
) # type: ignore
elif config.marshoai_enable_support_image_tip:
await UniMessage(
@ -282,10 +284,10 @@ async def marsho(
tools_lists = tools.tools_list + list(
map(lambda v: v.data(), get_function_calls().values())
)
response = await make_chat(
response = await make_chat_openai(
client=client,
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore
tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化
)
# await UniMessage(str(response)).send()
@ -293,13 +295,13 @@ async def marsho(
# Sprint(choice)
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls != None and config.marshoai_fix_toolcalls:
choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 当对话成功时将dict的上下文添加到上下文类中
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
)
context.append(choice.message.as_dict(), target.id, target.private)
context.append(choice.message, target.id, target.private)
if [target.id, target.private] not in target_list:
target_list.append([target.id, target.private])
@ -310,7 +312,7 @@ async def marsho(
)
else:
await UniMessage(str(choice.message.content)).send(reply_to=True)
elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED:
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤
@ -318,7 +320,7 @@ async def marsho(
reply_to=True
)
return
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
@ -332,61 +334,56 @@ async def marsho(
] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
except:
pass
tool_msg.append(choice.message.as_dict())
tool_msg.append(choice.message)
for tool_call in tool_calls:
if isinstance(
tool_call, ChatCompletionsToolCall
): # 循环调用工具直到不需要调用
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
# 删除args的placeholder参数
if "placeholder" in function_args:
del function_args["placeholder"]
logger.info(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
function_args = json.loads(
tool_call.function.arguments.replace("'", '"')
)
await UniMessage(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(
tool_call.function.name
):
logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
func_return = await caller.with_ctx(
SessionContext(
bot=bot,
event=event,
state=state,
matcher=matcher,
)
).call(**function_args)
else:
logger.error(
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
# 删除args的placeholder参数
if "placeholder" in function_args:
del function_args["placeholder"]
logger.info(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
)
await UniMessage(
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(tool_call.function.name):
logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
func_return = await caller.with_ctx(
SessionContext(
bot=bot,
event=event,
state=state,
matcher=matcher,
)
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
)
).call(**function_args)
else:
logger.error(
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
)
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send()
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
response = await make_chat(
response = await make_chat_openai(
client=client,
model_name=model_name,
msg=request_msg, # type: ignore
@ -397,15 +394,15 @@ async def marsho(
choice = response.choices[0]
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls != None:
choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
)
# context.append(tool_msg, target.id, target.private)
context.append(choice.message.as_dict(), target.id, target.private)
context.append(choice.message, target.id, target.private)
# 发送消息
if config.marshoai_enable_richtext_parse:
@ -448,7 +445,7 @@ with contextlib.suppress(ImportError): # 优化先不做()
],
)
choice = response.choices[0]
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
if choice.finish_reason == CompletionsFinishReason.STOPPED:
await UniMessage(" " + str(choice.message.content)).send(
at_sender=True
)

View File

@ -16,6 +16,7 @@ from nonebot.log import logger
from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg
from nonebot_plugin_alconna import UniMessage
from openai import AsyncOpenAI
from .config import config
from .constants import *
@ -102,6 +103,29 @@ async def make_chat(
)
async def make_chat_openai(
client: AsyncOpenAI,
msg: list,
model_name: str,
tools: Optional[list] = None,
):
"""使用 Openai SDK 调用ai获取回复
参数:
client: 用于与AI模型进行通信
msg: 消息内容
model_name: 指定AI模型名
tools: 工具列表"""
return await client.chat.completions.create(
messages=msg,
model=model_name,
tools=tools,
temperature=config.marshoai_temperature,
max_tokens=config.marshoai_max_tokens,
top_p=config.marshoai_top_p,
)
def get_praises():
global praises_json
if praises_json is None: