diff --git a/nonebot_plugin_marshoai/instances.py b/nonebot_plugin_marshoai/instances.py index ead05cd5..3120c6e4 100644 --- a/nonebot_plugin_marshoai/instances.py +++ b/nonebot_plugin_marshoai/instances.py @@ -2,6 +2,7 @@ from azure.ai.inference.aio import ChatCompletionsClient from azure.core.credentials import AzureKeyCredential from nonebot import get_driver +from openai import AsyncOpenAI from .config import config from .models import MarshoContext, MarshoTools @@ -14,5 +15,6 @@ context = MarshoContext() tools = MarshoTools() token = config.marshoai_token endpoint = config.marshoai_azure_endpoint -client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) +# client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) +client = AsyncOpenAI(base_url=endpoint, api_key=token) target_list: list[list] = [] # 记录需保存历史上下文的列表 diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index 1534b646..44742b2a 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -2,6 +2,7 @@ import contextlib import traceback from typing import Optional +import openai from arclet.alconna import Alconna, AllParam, Args from azure.ai.inference.models import ( AssistantMessage, @@ -21,6 +22,7 @@ from nonebot.permission import SUPERUSER from nonebot.rule import Rule, to_me from nonebot.typing import T_State from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna +from openai import AsyncOpenAI from .hooks import * from .instances import * @@ -253,7 +255,7 @@ async def marsho( for i in text: # type: ignore if i.type == "text": if is_support_image_model: - usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt)] # type: ignore + usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore else: usermsg += str(i.data["text"] + nickname_prompt) # type: ignore elif i.type == "image": @@ -263,7 +265,7 @@ async def marsho( image_url=ImageUrl( # type: ignore url=str(await get_image_b64(i.data["url"])) # type: ignore ) # type: ignore - ) # type: ignore + ).as_dict() # type: ignore ) # type: ignore elif config.marshoai_enable_support_image_tip: await UniMessage( @@ -282,10 +284,10 @@ async def marsho( tools_lists = tools.tools_list + list( map(lambda v: v.data(), get_function_calls().values()) ) - response = await make_chat( + response = await make_chat_openai( client=client, model_name=model_name, - msg=context_msg + [UserMessage(content=usermsg)], # type: ignore + msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化 ) # await UniMessage(str(response)).send() @@ -293,13 +295,13 @@ async def marsho( # Sprint(choice) # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS if choice.message.tool_calls != None and config.marshoai_fix_toolcalls: - choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS - if choice["finish_reason"] == CompletionsFinishReason.STOPPED: + choice.finish_reason = CompletionsFinishReason.TOOL_CALLS + if choice.finish_reason == CompletionsFinishReason.STOPPED: # 当对话成功时,将dict的上下文添加到上下文类中 context.append( UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore ) - context.append(choice.message.as_dict(), target.id, target.private) + context.append(choice.message, target.id, target.private) if [target.id, target.private] not in target_list: target_list.append([target.id, target.private]) @@ -310,7 +312,7 @@ async def marsho( ) else: await UniMessage(str(choice.message.content)).send(reply_to=True) - elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED: + elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: # 对话失败,消息过滤 @@ -318,7 +320,7 @@ async def marsho( reply_to=True ) return - elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS: + elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: # function call # 需要获取额外信息,调用函数工具 tool_msg = [] @@ -332,61 +334,56 @@ async def marsho( ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 except: pass - tool_msg.append(choice.message.as_dict()) + tool_msg.append(choice.message) for tool_call in tool_calls: - if isinstance( - tool_call, ChatCompletionsToolCall - ): # 循环调用工具直到不需要调用 - try: - function_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - function_args = json.loads( - tool_call.function.arguments.replace("'", '"') - ) - # 删除args的placeholder参数 - if "placeholder" in function_args: - del function_args["placeholder"] - logger.info( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + try: + function_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + function_args = json.loads( + tool_call.function.arguments.replace("'", '"') ) - await UniMessage( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ).send() - # TODO 临时追加插件函数,若工具中没有则调用插件函数 - if tools.has_function(tool_call.function.name): - logger.debug(f"调用工具函数 {tool_call.function.name}") - func_return = await tools.call( - tool_call.function.name, function_args - ) # 获取返回值 - else: - if caller := get_function_calls().get( - tool_call.function.name - ): - logger.debug(f"调用插件函数 {caller.full_name}") - # 权限检查,规则检查 TODO - # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 - func_return = await caller.with_ctx( - SessionContext( - bot=bot, - event=event, - state=state, - matcher=matcher, - ) - ).call(**function_args) - else: - logger.error( - f"未找到函数 {tool_call.function.name.replace('-', '.')}" + # 删除args的placeholder参数 + if "placeholder" in function_args: + del function_args["placeholder"] + logger.info( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ) + await UniMessage( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ).send() + # TODO 临时追加插件函数,若工具中没有则调用插件函数 + if tools.has_function(tool_call.function.name): + logger.debug(f"调用工具函数 {tool_call.function.name}") + func_return = await tools.call( + tool_call.function.name, function_args + ) # 获取返回值 + else: + if caller := get_function_calls().get(tool_call.function.name): + logger.debug(f"调用插件函数 {caller.full_name}") + # 权限检查,规则检查 TODO + # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 + func_return = await caller.with_ctx( + SessionContext( + bot=bot, + event=event, + state=state, + matcher=matcher, ) - func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" - tool_msg.append( - ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore - ) + ).call(**function_args) + else: + logger.error( + f"未找到函数 {tool_call.function.name.replace('-', '.')}" + ) + func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" + tool_msg.append( + ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore + ) # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" # await UniMessage(str(tool_msg)).send() request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore - response = await make_chat( + response = await make_chat_openai( client=client, model_name=model_name, msg=request_msg, # type: ignore @@ -397,15 +394,15 @@ async def marsho( choice = response.choices[0] # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS if choice.message.tool_calls != None: - choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS - if choice["finish_reason"] == CompletionsFinishReason.STOPPED: + choice.finish_reason = CompletionsFinishReason.TOOL_CALLS + if choice.finish_reason == CompletionsFinishReason.STOPPED: # 对话成功 添加上下文 context.append( UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore ) # context.append(tool_msg, target.id, target.private) - context.append(choice.message.as_dict(), target.id, target.private) + context.append(choice.message, target.id, target.private) # 发送消息 if config.marshoai_enable_richtext_parse: @@ -448,7 +445,7 @@ with contextlib.suppress(ImportError): # 优化先不做() ], ) choice = response.choices[0] - if choice["finish_reason"] == CompletionsFinishReason.STOPPED: + if choice.finish_reason == CompletionsFinishReason.STOPPED: await UniMessage(" " + str(choice.message.content)).send( at_sender=True ) diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 08df81cc..b26e5e27 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -16,6 +16,7 @@ from nonebot.log import logger from nonebot_plugin_alconna import Image as ImageMsg from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import UniMessage +from openai import AsyncOpenAI from .config import config from .constants import * @@ -102,6 +103,29 @@ async def make_chat( ) +async def make_chat_openai( + client: AsyncOpenAI, + msg: list, + model_name: str, + tools: Optional[list] = None, +): + """使用 Openai SDK 调用ai获取回复 + + 参数: + client: 用于与AI模型进行通信 + msg: 消息内容 + model_name: 指定AI模型名 + tools: 工具列表""" + return await client.chat.completions.create( + messages=msg, + model=model_name, + tools=tools, + temperature=config.marshoai_temperature, + max_tokens=config.marshoai_max_tokens, + top_p=config.marshoai_top_p, + ) + + def get_praises(): global praises_json if praises_json is None: