更新实例和工具模块,更换为OpenAI异步客户端进行聊天请求

This commit is contained in:
Asankilp 2025-01-26 00:48:55 +08:00
parent 132d219c59
commit 736a881071
3 changed files with 86 additions and 63 deletions

View File

@ -2,6 +2,7 @@
from azure.ai.inference.aio import ChatCompletionsClient from azure.ai.inference.aio import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver from nonebot import get_driver
from openai import AsyncOpenAI
from .config import config from .config import config
from .models import MarshoContext, MarshoTools from .models import MarshoContext, MarshoTools
@ -14,5 +15,6 @@ context = MarshoContext()
tools = MarshoTools() tools = MarshoTools()
token = config.marshoai_token token = config.marshoai_token
endpoint = config.marshoai_azure_endpoint endpoint = config.marshoai_azure_endpoint
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) # client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
client = AsyncOpenAI(base_url=endpoint, api_key=token)
target_list: list[list] = [] # 记录需保存历史上下文的列表 target_list: list[list] = [] # 记录需保存历史上下文的列表

View File

@ -2,6 +2,7 @@ import contextlib
import traceback import traceback
from typing import Optional from typing import Optional
import openai
from arclet.alconna import Alconna, AllParam, Args from arclet.alconna import Alconna, AllParam, Args
from azure.ai.inference.models import ( from azure.ai.inference.models import (
AssistantMessage, AssistantMessage,
@ -21,6 +22,7 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me from nonebot.rule import Rule, to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from openai import AsyncOpenAI
from .hooks import * from .hooks import *
from .instances import * from .instances import *
@ -253,7 +255,7 @@ async def marsho(
for i in text: # type: ignore for i in text: # type: ignore
if i.type == "text": if i.type == "text":
if is_support_image_model: if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt)] # type: ignore usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
else: else:
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
elif i.type == "image": elif i.type == "image":
@ -263,7 +265,7 @@ async def marsho(
image_url=ImageUrl( # type: ignore image_url=ImageUrl( # type: ignore
url=str(await get_image_b64(i.data["url"])) # type: ignore url=str(await get_image_b64(i.data["url"])) # type: ignore
) # type: ignore ) # type: ignore
) # type: ignore ).as_dict() # type: ignore
) # type: ignore ) # type: ignore
elif config.marshoai_enable_support_image_tip: elif config.marshoai_enable_support_image_tip:
await UniMessage( await UniMessage(
@ -282,10 +284,10 @@ async def marsho(
tools_lists = tools.tools_list + list( tools_lists = tools.tools_list + list(
map(lambda v: v.data(), get_function_calls().values()) map(lambda v: v.data(), get_function_calls().values())
) )
response = await make_chat( response = await make_chat_openai(
client=client, client=client,
model_name=model_name, model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore
tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化 tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化
) )
# await UniMessage(str(response)).send() # await UniMessage(str(response)).send()
@ -293,13 +295,13 @@ async def marsho(
# Sprint(choice) # Sprint(choice)
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS # 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls != None and config.marshoai_fix_toolcalls: if choice.message.tool_calls != None and config.marshoai_fix_toolcalls:
choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice["finish_reason"] == CompletionsFinishReason.STOPPED: if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 当对话成功时将dict的上下文添加到上下文类中 # 当对话成功时将dict的上下文添加到上下文类中
context.append( context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
) )
context.append(choice.message.as_dict(), target.id, target.private) context.append(choice.message, target.id, target.private)
if [target.id, target.private] not in target_list: if [target.id, target.private] not in target_list:
target_list.append([target.id, target.private]) target_list.append([target.id, target.private])
@ -310,7 +312,7 @@ async def marsho(
) )
else: else:
await UniMessage(str(choice.message.content)).send(reply_to=True) await UniMessage(str(choice.message.content)).send(reply_to=True)
elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED: elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤 # 对话失败,消息过滤
@ -318,7 +320,7 @@ async def marsho(
reply_to=True reply_to=True
) )
return return
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS: elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
# function call # function call
# 需要获取额外信息,调用函数工具 # 需要获取额外信息,调用函数工具
tool_msg = [] tool_msg = []
@ -332,11 +334,8 @@ async def marsho(
] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
except: except:
pass pass
tool_msg.append(choice.message.as_dict()) tool_msg.append(choice.message)
for tool_call in tool_calls: for tool_call in tool_calls:
if isinstance(
tool_call, ChatCompletionsToolCall
): # 循环调用工具直到不需要调用
try: try:
function_args = json.loads(tool_call.function.arguments) function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError: except json.JSONDecodeError:
@ -361,9 +360,7 @@ async def marsho(
tool_call.function.name, function_args tool_call.function.name, function_args
) # 获取返回值 ) # 获取返回值
else: else:
if caller := get_function_calls().get( if caller := get_function_calls().get(tool_call.function.name):
tool_call.function.name
):
logger.debug(f"调用插件函数 {caller.full_name}") logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO # 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入 # 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
@ -386,7 +383,7 @@ async def marsho(
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send() # await UniMessage(str(tool_msg)).send()
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
response = await make_chat( response = await make_chat_openai(
client=client, client=client,
model_name=model_name, model_name=model_name,
msg=request_msg, # type: ignore msg=request_msg, # type: ignore
@ -397,15 +394,15 @@ async def marsho(
choice = response.choices[0] choice = response.choices[0]
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS # 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls != None: if choice.message.tool_calls != None:
choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice["finish_reason"] == CompletionsFinishReason.STOPPED: if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文 # 对话成功 添加上下文
context.append( context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
) )
# context.append(tool_msg, target.id, target.private) # context.append(tool_msg, target.id, target.private)
context.append(choice.message.as_dict(), target.id, target.private) context.append(choice.message, target.id, target.private)
# 发送消息 # 发送消息
if config.marshoai_enable_richtext_parse: if config.marshoai_enable_richtext_parse:
@ -448,7 +445,7 @@ with contextlib.suppress(ImportError): # 优化先不做()
], ],
) )
choice = response.choices[0] choice = response.choices[0]
if choice["finish_reason"] == CompletionsFinishReason.STOPPED: if choice.finish_reason == CompletionsFinishReason.STOPPED:
await UniMessage(" " + str(choice.message.content)).send( await UniMessage(" " + str(choice.message.content)).send(
at_sender=True at_sender=True
) )

View File

@ -16,6 +16,7 @@ from nonebot.log import logger
from nonebot_plugin_alconna import Image as ImageMsg from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import Text as TextMsg
from nonebot_plugin_alconna import UniMessage from nonebot_plugin_alconna import UniMessage
from openai import AsyncOpenAI
from .config import config from .config import config
from .constants import * from .constants import *
@ -102,6 +103,29 @@ async def make_chat(
) )
async def make_chat_openai(
client: AsyncOpenAI,
msg: list,
model_name: str,
tools: Optional[list] = None,
):
"""使用 Openai SDK 调用ai获取回复
参数:
client: 用于与AI模型进行通信
msg: 消息内容
model_name: 指定AI模型名
tools: 工具列表"""
return await client.chat.completions.create(
messages=msg,
model=model_name,
tools=tools,
temperature=config.marshoai_temperature,
max_tokens=config.marshoai_max_tokens,
top_p=config.marshoai_top_p,
)
def get_praises(): def get_praises():
global praises_json global praises_json
if praises_json is None: if praises_json is None: