mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-03-10 21:41:03 +08:00
流式调用 30%
This commit is contained in:
parent
a61d13426e
commit
780df08a65
@ -18,8 +18,8 @@ from nonebot.matcher import (
|
|||||||
current_matcher,
|
current_matcher,
|
||||||
)
|
)
|
||||||
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
|
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI, AsyncStream
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .constants import SUPPORT_IMAGE_MODELS
|
from .constants import SUPPORT_IMAGE_MODELS
|
||||||
@ -96,7 +96,8 @@ class MarshoHandler:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
tools_list: list,
|
tools_list: list,
|
||||||
tool_message: Optional[list] = None,
|
tool_message: Optional[list] = None,
|
||||||
) -> ChatCompletion:
|
stream: bool = False,
|
||||||
|
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
|
||||||
"""
|
"""
|
||||||
处理单条聊天
|
处理单条聊天
|
||||||
"""
|
"""
|
||||||
@ -109,12 +110,13 @@ class MarshoHandler:
|
|||||||
msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore
|
msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
tools=tools_list if tools_list else None,
|
tools=tools_list if tools_list else None,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def handle_function_call(
|
async def handle_function_call(
|
||||||
self,
|
self,
|
||||||
completion: ChatCompletion,
|
completion: Union[ChatCompletion, AsyncStream[ChatCompletionChunk]],
|
||||||
user_message: Union[str, list],
|
user_message: Union[str, list],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
tools_list: list,
|
tools_list: list,
|
||||||
@ -122,7 +124,10 @@ class MarshoHandler:
|
|||||||
# function call
|
# function call
|
||||||
# 需要获取额外信息,调用函数工具
|
# 需要获取额外信息,调用函数工具
|
||||||
tool_msg = []
|
tool_msg = []
|
||||||
choice = completion.choices[0]
|
if isinstance(completion, ChatCompletion):
|
||||||
|
choice = completion.choices[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected completion type")
|
||||||
# await UniMessage(str(response)).send()
|
# await UniMessage(str(response)).send()
|
||||||
tool_calls = choice.message.tool_calls
|
tool_calls = choice.message.tool_calls
|
||||||
# try:
|
# try:
|
||||||
@ -198,7 +203,10 @@ class MarshoHandler:
|
|||||||
tools_list=tools_list,
|
tools_list=tools_list,
|
||||||
tool_message=tool_message,
|
tool_message=tool_message,
|
||||||
)
|
)
|
||||||
choice = response.choices[0]
|
if isinstance(response, ChatCompletion):
|
||||||
|
choice = response.choices[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected response type")
|
||||||
# Sprint(choice)
|
# Sprint(choice)
|
||||||
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
||||||
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
|
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
|
||||||
@ -240,3 +248,33 @@ class MarshoHandler:
|
|||||||
else:
|
else:
|
||||||
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
|
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def handle_stream_request(
|
||||||
|
self, user_message: Union[str, list], model_name: str, tools_list: list
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
处理流式请求
|
||||||
|
"""
|
||||||
|
response = await self.handle_single_chat(
|
||||||
|
user_message=user_message,
|
||||||
|
model_name=model_name,
|
||||||
|
tools_list=tools_list,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, AsyncStream):
|
||||||
|
reasoning_contents = ""
|
||||||
|
answer_contents = ""
|
||||||
|
async for chunk in response:
|
||||||
|
if not chunk.choices:
|
||||||
|
logger.info("Usage:", chunk.usage)
|
||||||
|
else:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if (
|
||||||
|
hasattr(delta, "reasoning_content")
|
||||||
|
and delta.reasoning_content is not None
|
||||||
|
):
|
||||||
|
reasoning_contents += delta.reasoning_content
|
||||||
|
else:
|
||||||
|
if delta.content is not None:
|
||||||
|
answer_contents += delta.content
|
||||||
|
@ -3,7 +3,7 @@ import json
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import aiofiles # type: ignore
|
import aiofiles # type: ignore
|
||||||
import httpx
|
import httpx
|
||||||
@ -15,8 +15,8 @@ from nonebot.log import logger
|
|||||||
from nonebot_plugin_alconna import Image as ImageMsg
|
from nonebot_plugin_alconna import Image as ImageMsg
|
||||||
from nonebot_plugin_alconna import Text as TextMsg
|
from nonebot_plugin_alconna import Text as TextMsg
|
||||||
from nonebot_plugin_alconna import UniMessage
|
from nonebot_plugin_alconna import UniMessage
|
||||||
from openai import AsyncOpenAI, NotGiven
|
from openai import AsyncOpenAI, AsyncStream, NotGiven
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
|
||||||
from zhDateTime import DateTime
|
from zhDateTime import DateTime
|
||||||
|
|
||||||
from ._types import DeveloperMessage
|
from ._types import DeveloperMessage
|
||||||
@ -109,35 +109,13 @@ async def get_image_b64(url: str, timeout: int = 10) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def make_chat(
|
|
||||||
client: ChatCompletionsClient,
|
|
||||||
msg: list,
|
|
||||||
model_name: str,
|
|
||||||
tools: Optional[list] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
调用ai获取回复
|
|
||||||
|
|
||||||
参数:
|
|
||||||
client: 用于与AI模型进行通信
|
|
||||||
msg: 消息内容
|
|
||||||
model_name: 指定AI模型名
|
|
||||||
tools: 工具列表
|
|
||||||
"""
|
|
||||||
return await client.complete(
|
|
||||||
messages=msg,
|
|
||||||
model=model_name,
|
|
||||||
tools=tools,
|
|
||||||
**config.marshoai_model_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def make_chat_openai(
|
async def make_chat_openai(
|
||||||
client: AsyncOpenAI,
|
client: AsyncOpenAI,
|
||||||
msg: list,
|
msg: list,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
tools: Optional[list] = None,
|
tools: Optional[list] = None,
|
||||||
) -> ChatCompletion:
|
stream: bool = False,
|
||||||
|
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
|
||||||
"""
|
"""
|
||||||
使用 Openai SDK 调用ai获取回复
|
使用 Openai SDK 调用ai获取回复
|
||||||
|
|
||||||
@ -152,6 +130,7 @@ async def make_chat_openai(
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
tools=tools or NOT_GIVEN,
|
tools=tools or NOT_GIVEN,
|
||||||
timeout=config.marshoai_timeout,
|
timeout=config.marshoai_timeout,
|
||||||
|
stream=stream,
|
||||||
**config.marshoai_model_args,
|
**config.marshoai_model_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user