mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-01-26 18:12:47 +08:00
✨ 更新实例和工具模块,更换为OpenAI异步客户端进行聊天请求
This commit is contained in:
parent
132d219c59
commit
736a881071
@ -2,6 +2,7 @@
|
|||||||
from azure.ai.inference.aio import ChatCompletionsClient
|
from azure.ai.inference.aio import ChatCompletionsClient
|
||||||
from azure.core.credentials import AzureKeyCredential
|
from azure.core.credentials import AzureKeyCredential
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .models import MarshoContext, MarshoTools
|
from .models import MarshoContext, MarshoTools
|
||||||
@ -14,5 +15,6 @@ context = MarshoContext()
|
|||||||
tools = MarshoTools()
|
tools = MarshoTools()
|
||||||
token = config.marshoai_token
|
token = config.marshoai_token
|
||||||
endpoint = config.marshoai_azure_endpoint
|
endpoint = config.marshoai_azure_endpoint
|
||||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
# client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
||||||
|
client = AsyncOpenAI(base_url=endpoint, api_key=token)
|
||||||
target_list: list[list] = [] # 记录需保存历史上下文的列表
|
target_list: list[list] = [] # 记录需保存历史上下文的列表
|
||||||
|
@ -2,6 +2,7 @@ import contextlib
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
from arclet.alconna import Alconna, AllParam, Args
|
from arclet.alconna import Alconna, AllParam, Args
|
||||||
from azure.ai.inference.models import (
|
from azure.ai.inference.models import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
@ -21,6 +22,7 @@ from nonebot.permission import SUPERUSER
|
|||||||
from nonebot.rule import Rule, to_me
|
from nonebot.rule import Rule, to_me
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from .hooks import *
|
from .hooks import *
|
||||||
from .instances import *
|
from .instances import *
|
||||||
@ -253,7 +255,7 @@ async def marsho(
|
|||||||
for i in text: # type: ignore
|
for i in text: # type: ignore
|
||||||
if i.type == "text":
|
if i.type == "text":
|
||||||
if is_support_image_model:
|
if is_support_image_model:
|
||||||
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt)] # type: ignore
|
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
|
||||||
else:
|
else:
|
||||||
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
|
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
|
||||||
elif i.type == "image":
|
elif i.type == "image":
|
||||||
@ -263,7 +265,7 @@ async def marsho(
|
|||||||
image_url=ImageUrl( # type: ignore
|
image_url=ImageUrl( # type: ignore
|
||||||
url=str(await get_image_b64(i.data["url"])) # type: ignore
|
url=str(await get_image_b64(i.data["url"])) # type: ignore
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
) # type: ignore
|
).as_dict() # type: ignore
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
elif config.marshoai_enable_support_image_tip:
|
elif config.marshoai_enable_support_image_tip:
|
||||||
await UniMessage(
|
await UniMessage(
|
||||||
@ -282,10 +284,10 @@ async def marsho(
|
|||||||
tools_lists = tools.tools_list + list(
|
tools_lists = tools.tools_list + list(
|
||||||
map(lambda v: v.data(), get_function_calls().values())
|
map(lambda v: v.data(), get_function_calls().values())
|
||||||
)
|
)
|
||||||
response = await make_chat(
|
response = await make_chat_openai(
|
||||||
client=client,
|
client=client,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
|
msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore
|
||||||
tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化
|
tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化
|
||||||
)
|
)
|
||||||
# await UniMessage(str(response)).send()
|
# await UniMessage(str(response)).send()
|
||||||
@ -293,13 +295,13 @@ async def marsho(
|
|||||||
# Sprint(choice)
|
# Sprint(choice)
|
||||||
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
||||||
if choice.message.tool_calls != None and config.marshoai_fix_toolcalls:
|
if choice.message.tool_calls != None and config.marshoai_fix_toolcalls:
|
||||||
choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS
|
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
|
||||||
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
|
if choice.finish_reason == CompletionsFinishReason.STOPPED:
|
||||||
# 当对话成功时,将dict的上下文添加到上下文类中
|
# 当对话成功时,将dict的上下文添加到上下文类中
|
||||||
context.append(
|
context.append(
|
||||||
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
|
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
|
||||||
)
|
)
|
||||||
context.append(choice.message.as_dict(), target.id, target.private)
|
context.append(choice.message, target.id, target.private)
|
||||||
if [target.id, target.private] not in target_list:
|
if [target.id, target.private] not in target_list:
|
||||||
target_list.append([target.id, target.private])
|
target_list.append([target.id, target.private])
|
||||||
|
|
||||||
@ -310,7 +312,7 @@ async def marsho(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||||
elif choice["finish_reason"] == CompletionsFinishReason.CONTENT_FILTERED:
|
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
|
||||||
|
|
||||||
# 对话失败,消息过滤
|
# 对话失败,消息过滤
|
||||||
|
|
||||||
@ -318,7 +320,7 @@ async def marsho(
|
|||||||
reply_to=True
|
reply_to=True
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
|
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
|
||||||
# function call
|
# function call
|
||||||
# 需要获取额外信息,调用函数工具
|
# 需要获取额外信息,调用函数工具
|
||||||
tool_msg = []
|
tool_msg = []
|
||||||
@ -332,11 +334,8 @@ async def marsho(
|
|||||||
] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
|
] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
tool_msg.append(choice.message.as_dict())
|
tool_msg.append(choice.message)
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
if isinstance(
|
|
||||||
tool_call, ChatCompletionsToolCall
|
|
||||||
): # 循环调用工具直到不需要调用
|
|
||||||
try:
|
try:
|
||||||
function_args = json.loads(tool_call.function.arguments)
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
@ -361,9 +360,7 @@ async def marsho(
|
|||||||
tool_call.function.name, function_args
|
tool_call.function.name, function_args
|
||||||
) # 获取返回值
|
) # 获取返回值
|
||||||
else:
|
else:
|
||||||
if caller := get_function_calls().get(
|
if caller := get_function_calls().get(tool_call.function.name):
|
||||||
tool_call.function.name
|
|
||||||
):
|
|
||||||
logger.debug(f"调用插件函数 {caller.full_name}")
|
logger.debug(f"调用插件函数 {caller.full_name}")
|
||||||
# 权限检查,规则检查 TODO
|
# 权限检查,规则检查 TODO
|
||||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||||
@ -386,7 +383,7 @@ async def marsho(
|
|||||||
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
|
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
|
||||||
# await UniMessage(str(tool_msg)).send()
|
# await UniMessage(str(tool_msg)).send()
|
||||||
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
|
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore
|
||||||
response = await make_chat(
|
response = await make_chat_openai(
|
||||||
client=client,
|
client=client,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
msg=request_msg, # type: ignore
|
msg=request_msg, # type: ignore
|
||||||
@ -397,15 +394,15 @@ async def marsho(
|
|||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
||||||
if choice.message.tool_calls != None:
|
if choice.message.tool_calls != None:
|
||||||
choice["finish_reason"] = CompletionsFinishReason.TOOL_CALLS
|
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
|
||||||
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
|
if choice.finish_reason == CompletionsFinishReason.STOPPED:
|
||||||
|
|
||||||
# 对话成功 添加上下文
|
# 对话成功 添加上下文
|
||||||
context.append(
|
context.append(
|
||||||
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
|
UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore
|
||||||
)
|
)
|
||||||
# context.append(tool_msg, target.id, target.private)
|
# context.append(tool_msg, target.id, target.private)
|
||||||
context.append(choice.message.as_dict(), target.id, target.private)
|
context.append(choice.message, target.id, target.private)
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
if config.marshoai_enable_richtext_parse:
|
if config.marshoai_enable_richtext_parse:
|
||||||
@ -448,7 +445,7 @@ with contextlib.suppress(ImportError): # 优化先不做()
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
|
if choice.finish_reason == CompletionsFinishReason.STOPPED:
|
||||||
await UniMessage(" " + str(choice.message.content)).send(
|
await UniMessage(" " + str(choice.message.content)).send(
|
||||||
at_sender=True
|
at_sender=True
|
||||||
)
|
)
|
||||||
|
@ -16,6 +16,7 @@ from nonebot.log import logger
|
|||||||
from nonebot_plugin_alconna import Image as ImageMsg
|
from nonebot_plugin_alconna import Image as ImageMsg
|
||||||
from nonebot_plugin_alconna import Text as TextMsg
|
from nonebot_plugin_alconna import Text as TextMsg
|
||||||
from nonebot_plugin_alconna import UniMessage
|
from nonebot_plugin_alconna import UniMessage
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .constants import *
|
from .constants import *
|
||||||
@ -102,6 +103,29 @@ async def make_chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def make_chat_openai(
|
||||||
|
client: AsyncOpenAI,
|
||||||
|
msg: list,
|
||||||
|
model_name: str,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
):
|
||||||
|
"""使用 Openai SDK 调用ai获取回复
|
||||||
|
|
||||||
|
参数:
|
||||||
|
client: 用于与AI模型进行通信
|
||||||
|
msg: 消息内容
|
||||||
|
model_name: 指定AI模型名
|
||||||
|
tools: 工具列表"""
|
||||||
|
return await client.chat.completions.create(
|
||||||
|
messages=msg,
|
||||||
|
model=model_name,
|
||||||
|
tools=tools,
|
||||||
|
temperature=config.marshoai_temperature,
|
||||||
|
max_tokens=config.marshoai_max_tokens,
|
||||||
|
top_p=config.marshoai_top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_praises():
|
def get_praises():
|
||||||
global praises_json
|
global praises_json
|
||||||
if praises_json is None:
|
if praises_json is None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user