更新OpenAI模型列表,重构获取系统提示词逻辑,添加开发者消息类型,兼容 OpenAI o1 以上模型的系统提示词

This commit is contained in:
Asankilp 2025-02-15 00:31:20 +08:00
parent 0c57ace798
commit 50567e1f57
4 changed files with 59 additions and 18 deletions

View File

@ -0,0 +1,33 @@
# source: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py
from typing import Any, Literal, Mapping, Optional, overload
from azure.ai.inference._model_base import rest_discriminator, rest_field
from azure.ai.inference.models import ChatRequestMessage
class DeveloperMessage(ChatRequestMessage, discriminator="developer"):
role: Literal["developer"] = rest_discriminator(name="role") # type: ignore
"""The chat role associated with this message, which is always 'developer' for developer messages.
Required."""
content: Optional[str] = rest_field()
"""The content of the message."""
@overload
def __init__(
self,
*,
content: Optional[str] = None,
): ...
@overload
def __init__(self, mapping: Mapping[str, Any]):
"""
:param mapping: raw JSON to initialize the model.
:type mapping: Mapping[str, Any]
"""
def __init__(
self, *args: Any, **kwargs: Any
) -> None: # pylint: disable=useless-super-delegation
super().__init__(*args, role="developer", **kwargs)

View File

@ -26,7 +26,14 @@ SUPPORT_IMAGE_MODELS: list = [
"llama-3.2-11b-vision-instruct", "llama-3.2-11b-vision-instruct",
"gemini-2.0-flash-exp", "gemini-2.0-flash-exp",
] ]
NO_SYSPROMPT_MODELS: list = ["o1", "o1-preview", "o1-mini"] OPENAI_NEW_MODELS: list = [
"o1",
"o1-preview",
"o1-mini",
"o3",
"o3-mini",
"o3-mini-large",
]
INTRODUCTION: str = f"""MarshoAI-NoneBot by LiteyukiStudio INTRODUCTION: str = f"""MarshoAI-NoneBot by LiteyukiStudio
你好喵~我是一只可爱的猫娘AI名叫小棉~🐾 你好喵~我是一只可爱的猫娘AI名叫小棉~🐾
我的主页在这里哦~ 我的主页在这里哦~

View File

@ -257,7 +257,7 @@ async def marsho(
model_name.lower() model_name.lower()
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
) )
is_reasoning_model = model_name.lower() in NO_SYSPROMPT_MODELS is_openai_new_model = model_name.lower() in OPENAI_NEW_MODELS
usermsg = [] if is_support_image_model else "" usermsg = [] if is_support_image_model else ""
for i in text: # type: ignore for i in text: # type: ignore
if i.type == "text": if i.type == "text":
@ -285,14 +285,13 @@ async def marsho(
backup_context, target.id, target.private backup_context, target.id, target.private
) # 加载历史记录 ) # 加载历史记录
logger.info(f"已恢复会话 {target.id} 的上下文备份~") logger.info(f"已恢复会话 {target.id} 的上下文备份~")
context_msg = context.build(target.id, target.private) context_msg = get_prompt(model_name) + context.build(target.id, target.private)
if not is_reasoning_model:
context_msg = [get_prompt()] + context_msg
# o1等推理模型不支持系统提示词, 故不添加
tools_lists = tools.tools_list + list( tools_lists = tools.tools_list + list(
map(lambda v: v.data(), get_function_calls().values()) map(lambda v: v.data(), get_function_calls().values())
) )
logger.info(f"正在获取回答,模型:{model_name}") logger.info(f"正在获取回答,模型:{model_name}")
# logger.info(f"上下文:{context_msg}")
response = await make_chat_openai( response = await make_chat_openai(
client=client, client=client,
model_name=model_name, model_name=model_name,
@ -460,12 +459,8 @@ with contextlib.suppress(ImportError): # 优化先不做()
response = await make_chat_openai( response = await make_chat_openai(
client=client, client=client,
model_name=model_name, model_name=model_name,
msg=[ msg=get_prompt(model_name)
( + [
get_prompt()
if model_name.lower() not in NO_SYSPROMPT_MODELS
else None
),
UserMessage( UserMessage(
content=f"*{user_nickname}{config.marshoai_poke_suffix}" content=f"*{user_nickname}{config.marshoai_poke_suffix}"
), ),

View File

@ -3,7 +3,7 @@ import json
import mimetypes import mimetypes
import re import re
import uuid import uuid
from typing import Any, Optional from typing import Any, Dict, List, Optional
import aiofiles # type: ignore import aiofiles # type: ignore
import httpx import httpx
@ -16,9 +16,10 @@ from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import Text as TextMsg
from nonebot_plugin_alconna import UniMessage from nonebot_plugin_alconna import UniMessage
from openai import AsyncOpenAI, NotGiven from openai import AsyncOpenAI, NotGiven
from openai.types.chat import ChatCompletionMessage from openai.types.chat import ChatCompletion, ChatCompletionMessage
from zhDateTime import DateTime from zhDateTime import DateTime
from ._types import DeveloperMessage
from .config import config from .config import config
from .constants import * from .constants import *
from .deal_latex import ConvertLatex from .deal_latex import ConvertLatex
@ -135,7 +136,7 @@ async def make_chat_openai(
msg: list, msg: list,
model_name: str, model_name: str,
tools: Optional[list] = None, tools: Optional[list] = None,
): ) -> ChatCompletion:
""" """
使用 Openai SDK 调用ai获取回复 使用 Openai SDK 调用ai获取回复
@ -252,7 +253,7 @@ async def refresh_nickname_json():
logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件") logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件")
def get_prompt(): def get_prompt(model: str) -> List[Dict[str, Any]]:
"""获取系统提示词""" """获取系统提示词"""
prompts = config.marshoai_additional_prompt prompts = config.marshoai_additional_prompt
if config.marshoai_enable_praises: if config.marshoai_enable_praises:
@ -271,8 +272,13 @@ def get_prompt():
) )
marsho_prompt = config.marshoai_prompt marsho_prompt = config.marshoai_prompt
spell = SystemMessage(content=marsho_prompt + prompts).as_dict() sysprompt_content = marsho_prompt + prompts
return spell prompt_list: List[Dict[str, Any]] = []
if model not in OPENAI_NEW_MODELS:
prompt_list += [SystemMessage(content=sysprompt_content).as_dict()]
else:
prompt_list += [DeveloperMessage(content=sysprompt_content).as_dict()]
return prompt_list
def suggest_solution(errinfo: str) -> str: def suggest_solution(errinfo: str) -> str: