diff --git a/nonebot_plugin_marshoai/_types.py b/nonebot_plugin_marshoai/_types.py new file mode 100644 index 00000000..a26a31ba --- /dev/null +++ b/nonebot_plugin_marshoai/_types.py @@ -0,0 +1,33 @@ +# source: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py +from typing import Any, Literal, Mapping, Optional, overload + +from azure.ai.inference._model_base import rest_discriminator, rest_field +from azure.ai.inference.models import ChatRequestMessage + + +class DeveloperMessage(ChatRequestMessage, discriminator="developer"): + + role: Literal["developer"] = rest_discriminator(name="role") # type: ignore + """The chat role associated with this message, which is always 'developer' for developer messages. + Required.""" + content: Optional[str] = rest_field() + """The content of the message.""" + + @overload + def __init__( + self, + *, + content: Optional[str] = None, + ): ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__( + self, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=useless-super-delegation + super().__init__(*args, role="developer", **kwargs) diff --git a/nonebot_plugin_marshoai/constants.py b/nonebot_plugin_marshoai/constants.py index 594bcf11..9f09513f 100755 --- a/nonebot_plugin_marshoai/constants.py +++ b/nonebot_plugin_marshoai/constants.py @@ -26,7 +26,14 @@ SUPPORT_IMAGE_MODELS: list = [ "llama-3.2-11b-vision-instruct", "gemini-2.0-flash-exp", ] -NO_SYSPROMPT_MODELS: list = ["o1", "o1-preview", "o1-mini"] +OPENAI_NEW_MODELS: list = [ + "o1", + "o1-preview", + "o1-mini", + "o3", + "o3-mini", + "o3-mini-large", +] INTRODUCTION: str = f"""MarshoAI-NoneBot by LiteyukiStudio 你好喵~我是一只可爱的猫娘AI,名叫小棉~🐾! 我的主页在这里哦~↓↓↓ diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index ae2337bf..1295336a 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -257,7 +257,7 @@ async def marsho( model_name.lower() in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models ) - is_reasoning_model = model_name.lower() in NO_SYSPROMPT_MODELS + is_openai_new_model = model_name.lower() in OPENAI_NEW_MODELS usermsg = [] if is_support_image_model else "" for i in text: # type: ignore if i.type == "text": @@ -285,14 +285,13 @@ async def marsho( backup_context, target.id, target.private ) # 加载历史记录 logger.info(f"已恢复会话 {target.id} 的上下文备份~") - context_msg = context.build(target.id, target.private) - if not is_reasoning_model: - context_msg = [get_prompt()] + context_msg - # o1等推理模型不支持系统提示词, 故不添加 + context_msg = get_prompt(model_name) + context.build(target.id, target.private) + tools_lists = tools.tools_list + list( map(lambda v: v.data(), get_function_calls().values()) ) logger.info(f"正在获取回答,模型:{model_name}") + # logger.info(f"上下文:{context_msg}") response = await make_chat_openai( client=client, model_name=model_name, @@ -460,12 +459,8 @@ with contextlib.suppress(ImportError): # 优化先不做() response = await make_chat_openai( client=client, model_name=model_name, - msg=[ - ( - get_prompt() - if model_name.lower() not in NO_SYSPROMPT_MODELS - else None - ), + msg=get_prompt(model_name) + + [ UserMessage( content=f"*{user_nickname}{config.marshoai_poke_suffix}" ), diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 066144cf..79ab6fb3 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -3,7 +3,7 @@ import json import mimetypes import re import uuid -from typing import Any, Optional +from typing import Any, Dict, List, Optional import aiofiles # type: ignore import httpx @@ -16,9 +16,10 @@ from nonebot_plugin_alconna import Image as ImageMsg from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import UniMessage from openai import AsyncOpenAI, NotGiven -from openai.types.chat import ChatCompletionMessage +from openai.types.chat import ChatCompletion, ChatCompletionMessage from zhDateTime import DateTime +from ._types import DeveloperMessage from .config import config from .constants import * from .deal_latex import ConvertLatex @@ -135,7 +136,7 @@ async def make_chat_openai( msg: list, model_name: str, tools: Optional[list] = None, -): +) -> ChatCompletion: """ 使用 Openai SDK 调用ai获取回复 @@ -252,7 +253,7 @@ async def refresh_nickname_json(): logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件") -def get_prompt(): +def get_prompt(model: str) -> List[Dict[str, Any]]: """获取系统提示词""" prompts = config.marshoai_additional_prompt if config.marshoai_enable_praises: @@ -271,8 +272,13 @@ def get_prompt(): ) marsho_prompt = config.marshoai_prompt - spell = SystemMessage(content=marsho_prompt + prompts).as_dict() - return spell + sysprompt_content = marsho_prompt + prompts + prompt_list: List[Dict[str, Any]] = [] + if model not in OPENAI_NEW_MODELS: + prompt_list += [SystemMessage(content=sysprompt_content).as_dict()] + else: + prompt_list += [DeveloperMessage(content=sysprompt_content).as_dict()] + return prompt_list def suggest_solution(errinfo: str) -> str: