mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-03-10 18:31:03 +08:00
流式调用 30%
This commit is contained in:
parent
a61d13426e
commit
780df08a65
@ -18,8 +18,8 @@ from nonebot.matcher import (
|
||||
current_matcher,
|
||||
)
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai import AsyncOpenAI, AsyncStream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
|
||||
|
||||
from .config import config
|
||||
from .constants import SUPPORT_IMAGE_MODELS
|
||||
@ -96,7 +96,8 @@ class MarshoHandler:
|
||||
model_name: str,
|
||||
tools_list: list,
|
||||
tool_message: Optional[list] = None,
|
||||
) -> ChatCompletion:
|
||||
stream: bool = False,
|
||||
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
|
||||
"""
|
||||
处理单条聊天
|
||||
"""
|
||||
@ -109,12 +110,13 @@ class MarshoHandler:
|
||||
msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore
|
||||
model_name=model_name,
|
||||
tools=tools_list if tools_list else None,
|
||||
stream=stream,
|
||||
)
|
||||
return response
|
||||
|
||||
async def handle_function_call(
|
||||
self,
|
||||
completion: ChatCompletion,
|
||||
completion: Union[ChatCompletion, AsyncStream[ChatCompletionChunk]],
|
||||
user_message: Union[str, list],
|
||||
model_name: str,
|
||||
tools_list: list,
|
||||
@ -122,7 +124,10 @@ class MarshoHandler:
|
||||
# function call
|
||||
# 需要获取额外信息,调用函数工具
|
||||
tool_msg = []
|
||||
choice = completion.choices[0]
|
||||
if isinstance(completion, ChatCompletion):
|
||||
choice = completion.choices[0]
|
||||
else:
|
||||
raise ValueError("Unexpected completion type")
|
||||
# await UniMessage(str(response)).send()
|
||||
tool_calls = choice.message.tool_calls
|
||||
# try:
|
||||
@ -198,7 +203,10 @@ class MarshoHandler:
|
||||
tools_list=tools_list,
|
||||
tool_message=tool_message,
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if isinstance(response, ChatCompletion):
|
||||
choice = response.choices[0]
|
||||
else:
|
||||
raise ValueError("Unexpected response type")
|
||||
# Sprint(choice)
|
||||
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
|
||||
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
|
||||
@ -240,3 +248,33 @@ class MarshoHandler:
|
||||
else:
|
||||
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
|
||||
return None
|
||||
|
||||
async def handle_stream_request(
|
||||
self, user_message: Union[str, list], model_name: str, tools_list: list
|
||||
):
|
||||
"""
|
||||
处理流式请求
|
||||
"""
|
||||
response = await self.handle_single_chat(
|
||||
user_message=user_message,
|
||||
model_name=model_name,
|
||||
tools_list=tools_list,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if isinstance(response, AsyncStream):
|
||||
reasoning_contents = ""
|
||||
answer_contents = ""
|
||||
async for chunk in response:
|
||||
if not chunk.choices:
|
||||
logger.info("Usage:", chunk.usage)
|
||||
else:
|
||||
delta = chunk.choices[0].delta
|
||||
if (
|
||||
hasattr(delta, "reasoning_content")
|
||||
and delta.reasoning_content is not None
|
||||
):
|
||||
reasoning_contents += delta.reasoning_content
|
||||
else:
|
||||
if delta.content is not None:
|
||||
answer_contents += delta.content
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
import mimetypes
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import aiofiles # type: ignore
|
||||
import httpx
|
||||
@ -15,8 +15,8 @@ from nonebot.log import logger
|
||||
from nonebot_plugin_alconna import Image as ImageMsg
|
||||
from nonebot_plugin_alconna import Text as TextMsg
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from openai import AsyncOpenAI, NotGiven
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai import AsyncOpenAI, AsyncStream, NotGiven
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
|
||||
from zhDateTime import DateTime
|
||||
|
||||
from ._types import DeveloperMessage
|
||||
@ -109,35 +109,13 @@ async def get_image_b64(url: str, timeout: int = 10) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
async def make_chat(
|
||||
client: ChatCompletionsClient,
|
||||
msg: list,
|
||||
model_name: str,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
调用ai获取回复
|
||||
|
||||
参数:
|
||||
client: 用于与AI模型进行通信
|
||||
msg: 消息内容
|
||||
model_name: 指定AI模型名
|
||||
tools: 工具列表
|
||||
"""
|
||||
return await client.complete(
|
||||
messages=msg,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
**config.marshoai_model_args,
|
||||
)
|
||||
|
||||
|
||||
async def make_chat_openai(
|
||||
client: AsyncOpenAI,
|
||||
msg: list,
|
||||
model_name: str,
|
||||
tools: Optional[list] = None,
|
||||
) -> ChatCompletion:
|
||||
stream: bool = False,
|
||||
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
|
||||
"""
|
||||
使用 Openai SDK 调用ai获取回复
|
||||
|
||||
@ -152,6 +130,7 @@ async def make_chat_openai(
|
||||
model=model_name,
|
||||
tools=tools or NOT_GIVEN,
|
||||
timeout=config.marshoai_timeout,
|
||||
stream=stream,
|
||||
**config.marshoai_model_args,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user