mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-03-10 21:41:03 +08:00
更新OpenAI模型列表,重构获取系统提示词逻辑,添加开发者消息类型,兼容 OpenAI o1 以上模型的系统提示词
This commit is contained in:
parent
0c57ace798
commit
50567e1f57
33
nonebot_plugin_marshoai/_types.py
Normal file
33
nonebot_plugin_marshoai/_types.py
Normal file
@ -0,0 +1,33 @@
|
||||
# source: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py
|
||||
from typing import Any, Literal, Mapping, Optional, overload
|
||||
|
||||
from azure.ai.inference._model_base import rest_discriminator, rest_field
|
||||
from azure.ai.inference.models import ChatRequestMessage
|
||||
|
||||
|
||||
class DeveloperMessage(ChatRequestMessage, discriminator="developer"):
|
||||
|
||||
role: Literal["developer"] = rest_discriminator(name="role") # type: ignore
|
||||
"""The chat role associated with this message, which is always 'developer' for developer messages.
|
||||
Required."""
|
||||
content: Optional[str] = rest_field()
|
||||
"""The content of the message."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
content: Optional[str] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(self, mapping: Mapping[str, Any]):
|
||||
"""
|
||||
:param mapping: raw JSON to initialize the model.
|
||||
:type mapping: Mapping[str, Any]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> None: # pylint: disable=useless-super-delegation
|
||||
super().__init__(*args, role="developer", **kwargs)
|
@ -26,7 +26,14 @@ SUPPORT_IMAGE_MODELS: list = [
|
||||
"llama-3.2-11b-vision-instruct",
|
||||
"gemini-2.0-flash-exp",
|
||||
]
|
||||
NO_SYSPROMPT_MODELS: list = ["o1", "o1-preview", "o1-mini"]
|
||||
OPENAI_NEW_MODELS: list = [
|
||||
"o1",
|
||||
"o1-preview",
|
||||
"o1-mini",
|
||||
"o3",
|
||||
"o3-mini",
|
||||
"o3-mini-large",
|
||||
]
|
||||
INTRODUCTION: str = f"""MarshoAI-NoneBot by LiteyukiStudio
|
||||
你好喵~我是一只可爱的猫娘AI,名叫小棉~🐾!
|
||||
我的主页在这里哦~↓↓↓
|
||||
|
@ -257,7 +257,7 @@ async def marsho(
|
||||
model_name.lower()
|
||||
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
|
||||
)
|
||||
is_reasoning_model = model_name.lower() in NO_SYSPROMPT_MODELS
|
||||
is_openai_new_model = model_name.lower() in OPENAI_NEW_MODELS
|
||||
usermsg = [] if is_support_image_model else ""
|
||||
for i in text: # type: ignore
|
||||
if i.type == "text":
|
||||
@ -285,14 +285,13 @@ async def marsho(
|
||||
backup_context, target.id, target.private
|
||||
) # 加载历史记录
|
||||
logger.info(f"已恢复会话 {target.id} 的上下文备份~")
|
||||
context_msg = context.build(target.id, target.private)
|
||||
if not is_reasoning_model:
|
||||
context_msg = [get_prompt()] + context_msg
|
||||
# o1等推理模型不支持系统提示词, 故不添加
|
||||
context_msg = get_prompt(model_name) + context.build(target.id, target.private)
|
||||
|
||||
tools_lists = tools.tools_list + list(
|
||||
map(lambda v: v.data(), get_function_calls().values())
|
||||
)
|
||||
logger.info(f"正在获取回答,模型:{model_name}")
|
||||
# logger.info(f"上下文:{context_msg}")
|
||||
response = await make_chat_openai(
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
@ -460,12 +459,8 @@ with contextlib.suppress(ImportError): # 优化先不做()
|
||||
response = await make_chat_openai(
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
msg=[
|
||||
(
|
||||
get_prompt()
|
||||
if model_name.lower() not in NO_SYSPROMPT_MODELS
|
||||
else None
|
||||
),
|
||||
msg=get_prompt(model_name)
|
||||
+ [
|
||||
UserMessage(
|
||||
content=f"*{user_nickname}{config.marshoai_poke_suffix}"
|
||||
),
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
import mimetypes
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiofiles # type: ignore
|
||||
import httpx
|
||||
@ -16,9 +16,10 @@ from nonebot_plugin_alconna import Image as ImageMsg
|
||||
from nonebot_plugin_alconna import Text as TextMsg
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from openai import AsyncOpenAI, NotGiven
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from zhDateTime import DateTime
|
||||
|
||||
from ._types import DeveloperMessage
|
||||
from .config import config
|
||||
from .constants import *
|
||||
from .deal_latex import ConvertLatex
|
||||
@ -135,7 +136,7 @@ async def make_chat_openai(
|
||||
msg: list,
|
||||
model_name: str,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
使用 Openai SDK 调用ai获取回复
|
||||
|
||||
@ -252,7 +253,7 @@ async def refresh_nickname_json():
|
||||
logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件")
|
||||
|
||||
|
||||
def get_prompt():
|
||||
def get_prompt(model: str) -> List[Dict[str, Any]]:
|
||||
"""获取系统提示词"""
|
||||
prompts = config.marshoai_additional_prompt
|
||||
if config.marshoai_enable_praises:
|
||||
@ -271,8 +272,13 @@ def get_prompt():
|
||||
)
|
||||
|
||||
marsho_prompt = config.marshoai_prompt
|
||||
spell = SystemMessage(content=marsho_prompt + prompts).as_dict()
|
||||
return spell
|
||||
sysprompt_content = marsho_prompt + prompts
|
||||
prompt_list: List[Dict[str, Any]] = []
|
||||
if model not in OPENAI_NEW_MODELS:
|
||||
prompt_list += [SystemMessage(content=sysprompt_content).as_dict()]
|
||||
else:
|
||||
prompt_list += [DeveloperMessage(content=sysprompt_content).as_dict()]
|
||||
return prompt_list
|
||||
|
||||
|
||||
def suggest_solution(errinfo: str) -> str:
|
||||
|
Loading…
x
Reference in New Issue
Block a user