🎨 重写基本完毕

This commit is contained in:
Asankilp 2025-02-23 14:50:35 +08:00
parent d9f22fa0f7
commit 5b315c46b1
5 changed files with 89 additions and 127 deletions

View File

@ -1,5 +1,5 @@
import json import json
from typing import Optional, Union from typing import Optional, Tuple, Union
from azure.ai.inference.models import ( from azure.ai.inference.models import (
CompletionsFinishReason, CompletionsFinishReason,
@ -21,11 +21,11 @@ from nonebot.matcher import (
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion, ChatCompletionMessage
from .config import config from .config import config
from .constants import SUPPORT_IMAGE_MODELS from .constants import SUPPORT_IMAGE_MODELS
from .instances import target_list, tools from .instances import target_list
from .models import MarshoContext from .models import MarshoContext
from .plugin.func_call.caller import get_function_calls from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext from .plugin.func_call.models import SessionContext
@ -50,7 +50,7 @@ class MarshoHandler:
self.context = context self.context = context
self.bot: Bot = current_bot.get() self.bot: Bot = current_bot.get()
self.event: Event = current_event.get() self.event: Event = current_event.get()
self.state: T_State = current_handler.get().state # self.state: T_State = current_handler.get().state
self.matcher: Matcher = current_matcher.get() self.matcher: Matcher = current_matcher.get()
self.message_id: str = UniMessage.get_message_id(self.event) self.message_id: str = UniMessage.get_message_id(self.event)
self.target = UniMessage.get_target(self.event) self.target = UniMessage.get_target(self.event)
@ -97,36 +97,35 @@ class MarshoHandler:
self, self,
user_message: Union[str, list], user_message: Union[str, list],
model_name: str, model_name: str,
tools: list, tools_list: list,
with_context: bool = True, tool_message: Optional[list] = None,
) -> ChatCompletion: ) -> ChatCompletion:
""" """
处理单条聊天 处理单条聊天
""" """
context_msg = ( context_msg = get_prompt(model_name) + (
get_prompt(model_name) self.context.build(self.target.id, self.target.private)
+ (self.context.build(self.target.id, self.target.private))
if with_context
else ""
) )
response = await make_chat_openai( response = await make_chat_openai(
client=self.client, client=self.client,
msg=context_msg + [UserMessage(content=user_message).as_dict()], # type: ignore msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore
model_name=model_name, model_name=model_name,
tools=tools, tools=tools_list if tools_list else None,
) )
return response return response
async def handle_function_call( async def handle_function_call(
self, self,
completion: ChatCompletion, completion: ChatCompletion,
user_message: Union[str, list],
model_name: str,
tools_list: list,
): ):
# function call # function call
# 需要获取额外信息,调用函数工具 # 需要获取额外信息,调用函数工具
tool_msg = [] tool_msg = []
choice = completion.choices[0] choice = completion.choices[0]
while choice.message.tool_calls is not None:
# await UniMessage(str(response)).send() # await UniMessage(str(response)).send()
tool_calls = choice.message.tool_calls tool_calls = choice.message.tool_calls
# try: # try:
@ -155,13 +154,6 @@ class MarshoHandler:
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
+ "\n".join([f"{k}={v}" for k, v in function_args.items()]) + "\n".join([f"{k}={v}" for k, v in function_args.items()])
).send() ).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(tool_call.function.name): if caller := get_function_calls().get(tool_call.function.name):
logger.debug(f"调用插件函数 {caller.full_name}") logger.debug(f"调用插件函数 {caller.full_name}")
# 权限检查,规则检查 TODO # 权限检查,规则检查 TODO
@ -170,65 +162,32 @@ class MarshoHandler:
SessionContext( SessionContext(
bot=self.bot, bot=self.bot,
event=self.event, event=self.event,
state=self.state,
matcher=self.matcher, matcher=self.matcher,
) )
).call(**function_args) ).call(**function_args)
else: else:
logger.error( logger.error(f"未找到函数 {tool_call.function.name.replace('-', '.')}")
f"未找到函数 {tool_call.function.name.replace('-', '.')}" func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
func_return = (
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
)
tool_msg.append( tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore
) )
# tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function"
# await UniMessage(str(tool_msg)).send() # await UniMessage(str(tool_msg)).send()
request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore return await self.handle_common_chat(
response = await make_chat_openai( user_message=user_message,
client=client,
model_name=model_name, model_name=model_name,
msg=request_msg, # type: ignore tools_list=tools_list,
tools=( tool_message=tool_msg,
tools_lists if tools_lists else None
), # TODO 临时追加函数,后期优化
) )
choice = response.choices[0]
# 当tool_calls非空时将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None:
choice.finish_reason = CompletionsFinishReason.TOOL_CALLS
if choice.finish_reason == CompletionsFinishReason.STOPPED:
# 对话成功 添加上下文
context.append(
UserMessage(content=usermsg).as_dict(), self.target.id, self.target.private # type: ignore
)
# context.append(tool_msg, self.target.id, self.target.private)
choice_msg_dict = choice.message.to_dict()
if "reasoning_content" in choice_msg_dict:
del choice_msg_dict["reasoning_content"]
context.append(choice_msg_dict, self.target.id, self.target.private)
# 发送消息
if config.marshoai_enable_richtext_parse:
await (await parse_richtext(str(choice.message.content))).send(
reply_to=True
)
else:
await UniMessage(str(choice.message.content)).send(reply_to=True)
else:
await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}")
async def handle_common_chat( async def handle_common_chat(
self, self,
user_message: Union[str, list], user_message: Union[str, list],
model_name: str, model_name: str,
tools: list, tools_list: list,
with_context: bool = True,
stream: bool = False, stream: bool = False,
) -> ChatCompletion: tool_message: Optional[list] = None,
) -> Union[Tuple[UserMessage, ChatCompletionMessage], None]:
""" """
处理一般聊天 处理一般聊天
""" """
@ -238,8 +197,8 @@ class MarshoHandler:
response = await self.handle_single_chat( response = await self.handle_single_chat(
user_message=user_message, user_message=user_message,
model_name=model_name, model_name=model_name,
tools=tools, tools_list=tools_list,
with_context=with_context, tool_message=tool_message,
) )
choice = response.choices[0] choice = response.choices[0]
# Sprint(choice) # Sprint(choice)
@ -267,7 +226,7 @@ class MarshoHandler:
) )
else: else:
await UniMessage(str(choice_msg_content)).send(reply_to=True) await UniMessage(str(choice_msg_content)).send(reply_to=True)
return (UserMessage(context=user_message), choice_msg_after) return UserMessage(content=user_message), choice_msg_after
elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED:
# 对话失败,消息过滤 # 对话失败,消息过滤
@ -277,7 +236,9 @@ class MarshoHandler:
) )
return None return None
elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
pass return await self.handle_function_call(
response, user_message, model_name, tools_list
)
else: else:
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send() await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
return None return None

View File

@ -3,7 +3,6 @@ from nonebot import get_driver
from openai import AsyncOpenAI from openai import AsyncOpenAI
from .config import config from .config import config
from .handler import MarshoHandler
from .models import MarshoContext, MarshoTools from .models import MarshoContext, MarshoTools
driver = get_driver() driver = get_driver()

View File

@ -263,11 +263,14 @@ async def marsho(
) )
logger.info(f"正在获取回答,模型:{model_name}") logger.info(f"正在获取回答,模型:{model_name}")
# logger.info(f"上下文:{context_msg}") # logger.info(f"上下文:{context_msg}")
response = await handler.handle_single_chat( response = await handler.handle_common_chat(usermsg, model_name, tools_lists)
usermsg, model_name, tools_lists, with_context=True
)
# await UniMessage(str(response)).send() # await UniMessage(str(response)).send()
if response is not None:
context_user, context_assistant = response
context.append(context_user.as_dict(), target.id, target.private)
context.append(context_assistant.to_dict(), target.id, target.private)
else:
await UniMessage("没有回答").send()
except Exception as e: except Exception as e:
await UniMessage(str(e) + suggest_solution(str(e))).send() await UniMessage(str(e) + suggest_solution(str(e))).send()
traceback.print_exc() traceback.print_exc()

View File

@ -70,11 +70,9 @@ class Caller:
): ):
return False, "告诉用户 Permission Denied 权限不足" return False, "告诉用户 Permission Denied 权限不足"
if self.ctx.state is None: # if self.ctx.state is None:
return False, "State is None" # return False, "State is None"
if self._rule and not await self._rule( if self._rule and not await self._rule(self.ctx.bot, self.ctx.event):
self.ctx.bot, self.ctx.event, self.ctx.state
):
return False, "告诉用户 Rule Denied 规则不匹配" return False, "告诉用户 Rule Denied 规则不匹配"
return True, "" return True, ""
@ -115,6 +113,10 @@ class Caller:
# 检查函数签名,确定依赖注入参数 # 检查函数签名,确定依赖注入参数
sig = inspect.signature(func) sig = inspect.signature(func)
for name, param in sig.parameters.items(): for name, param in sig.parameters.items():
if param.annotation == T_State:
self.di.state = name
continue # 防止后续判断T_State子类时报错
if issubclass(param.annotation, Event) or isinstance( if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event param.annotation, Event
): ):
@ -133,9 +135,6 @@ class Caller:
): ):
self.di.matcher = name self.di.matcher = name
if param.annotation == T_State:
self.di.state = name
# 检查默认值情况 # 检查默认值情况
for name, param in sig.parameters.items(): for name, param in sig.parameters.items():
if param.default is not inspect.Parameter.empty: if param.default is not inspect.Parameter.empty:

View File

@ -19,7 +19,7 @@ class SessionContext(BaseModel):
bot: Bot bot: Bot
event: Event event: Event
matcher: Matcher matcher: Matcher
state: T_State # state: T_State
caller: Any = None caller: Any = None
class Config: class Config:
@ -30,5 +30,5 @@ class SessionContextDepends(BaseModel):
bot: str | None = None bot: str | None = None
event: str | None = None event: str | None = None
matcher: str | None = None matcher: str | None = None
state: str | None = None # state: str | None = None
caller: str | None = None caller: str | None = None