From 780df08a6565be7261ba151d0812748719a23c45 Mon Sep 17 00:00:00 2001 From: Asankilp Date: Fri, 7 Mar 2025 13:32:30 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=81=E5=BC=8F=E8=B0=83=E7=94=A8=2030%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_marshoai/handler.py | 50 ++++++++++++++++++++++++++---- nonebot_plugin_marshoai/util.py | 33 ++++---------------- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/nonebot_plugin_marshoai/handler.py b/nonebot_plugin_marshoai/handler.py index cccd02fa..a802c45b 100644 --- a/nonebot_plugin_marshoai/handler.py +++ b/nonebot_plugin_marshoai/handler.py @@ -18,8 +18,8 @@ from nonebot.matcher import ( current_matcher, ) from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg -from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai import AsyncOpenAI, AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from .config import config from .constants import SUPPORT_IMAGE_MODELS @@ -96,7 +96,8 @@ class MarshoHandler: model_name: str, tools_list: list, tool_message: Optional[list] = None, - ) -> ChatCompletion: + stream: bool = False, + ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]: """ 处理单条聊天 """ @@ -109,12 +110,13 @@ class MarshoHandler: msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore model_name=model_name, tools=tools_list if tools_list else None, + stream=stream, ) return response async def handle_function_call( self, - completion: ChatCompletion, + completion: Union[ChatCompletion, AsyncStream[ChatCompletionChunk]], user_message: Union[str, list], model_name: str, tools_list: list, @@ -122,7 +124,10 @@ class MarshoHandler: # function call # 需要获取额外信息,调用函数工具 tool_msg = [] - choice = completion.choices[0] + if isinstance(completion, ChatCompletion): + choice = completion.choices[0] + else: + raise ValueError("Unexpected completion type") # await UniMessage(str(response)).send() tool_calls = choice.message.tool_calls # try: @@ -198,7 +203,10 @@ class MarshoHandler: tools_list=tools_list, tool_message=tool_message, ) - choice = response.choices[0] + if isinstance(response, ChatCompletion): + choice = response.choices[0] + else: + raise ValueError("Unexpected response type") # Sprint(choice) # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls: @@ -240,3 +248,33 @@ class MarshoHandler: else: await UniMessage(f"意外的完成原因:{choice.finish_reason}").send() return None + + async def handle_stream_request( + self, user_message: Union[str, list], model_name: str, tools_list: list + ): + """ + 处理流式请求 + """ + response = await self.handle_single_chat( + user_message=user_message, + model_name=model_name, + tools_list=tools_list, + stream=True, + ) + + if isinstance(response, AsyncStream): + reasoning_contents = "" + answer_contents = "" + async for chunk in response: + if not chunk.choices: + logger.info("Usage:", chunk.usage) + else: + delta = chunk.choices[0].delta + if ( + hasattr(delta, "reasoning_content") + and delta.reasoning_content is not None + ): + reasoning_contents += delta.reasoning_content + else: + if delta.content is not None: + answer_contents += delta.content diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 1a35320f..7efa8129 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -3,7 +3,7 @@ import json import mimetypes import re import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import aiofiles # type: ignore import httpx @@ -15,8 +15,8 @@ from nonebot.log import logger from nonebot_plugin_alconna import Image as ImageMsg from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import UniMessage -from openai import AsyncOpenAI, NotGiven -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai import AsyncOpenAI, AsyncStream, NotGiven +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from zhDateTime import DateTime from ._types import DeveloperMessage @@ -109,35 +109,13 @@ async def get_image_b64(url: str, timeout: int = 10) -> Optional[str]: return None -async def make_chat( - client: ChatCompletionsClient, - msg: list, - model_name: str, - tools: Optional[list] = None, -): - """ - 调用ai获取回复 - - 参数: - client: 用于与AI模型进行通信 - msg: 消息内容 - model_name: 指定AI模型名 - tools: 工具列表 - """ - return await client.complete( - messages=msg, - model=model_name, - tools=tools, - **config.marshoai_model_args, - ) - - async def make_chat_openai( client: AsyncOpenAI, msg: list, model_name: str, tools: Optional[list] = None, -) -> ChatCompletion: + stream: bool = False, +) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]: """ 使用 Openai SDK 调用ai获取回复 @@ -152,6 +130,7 @@ async def make_chat_openai( model=model_name, tools=tools or NOT_GIVEN, timeout=config.marshoai_timeout, + stream=stream, **config.marshoai_model_args, )