重构代码,准备将聊天请求逻辑移入MarshoHandler

This commit is contained in:
Asankilp 2025-02-22 20:39:03 +08:00
parent aaa4056482
commit 17551885f5
6 changed files with 130 additions and 56 deletions

View File

@ -1,4 +1,6 @@
from .instances import cache from ..models import Cache
cache = Cache()
def from_cache(key): def from_cache(key):

View File

@ -0,0 +1,109 @@
from typing import Optional, Union
from azure.ai.inference.models import (
AssistantMessage,
ImageContentItem,
ImageUrl,
TextContentItem,
ToolMessage,
UserMessage,
)
from nonebot.adapters import Event
from nonebot.log import logger
from nonebot.matcher import Matcher, current_event, current_matcher
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from .config import config
from .constants import SUPPORT_IMAGE_MODELS
from .models import MarshoContext
from .util import (
get_backup_context,
get_image_b64,
get_nickname_by_user_id,
get_prompt,
make_chat_openai,
)
class MarshoHandler:
def __init__(
self,
client: AsyncOpenAI,
context: MarshoContext,
):
self.client = client
self.context = context
self.event: Event = current_event.get()
self.matcher: Matcher = current_matcher.get()
self.message_id: str = UniMessage.get_message_id(self.event)
self.target = UniMessage.get_target(self.event)
async def process_user_input(
self, user_input: UniMsg, model_name: str
) -> Union[str, list]:
"""
处理用户输入
"""
is_support_image_model = (
model_name.lower()
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
)
usermsg = [] if is_support_image_model else ""
user_nickname = await get_nickname_by_user_id(self.event.get_user_id())
if user_nickname:
nickname_prompt = f"此消息的说话者为: {user_nickname}"
else:
nickname_prompt = ""
for i in user_input: # type: ignore
if i.type == "text":
if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
else:
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
elif i.type == "image":
if is_support_image_model:
usermsg.append( # type: ignore
ImageContentItem(
image_url=ImageUrl( # type: ignore
url=str(await get_image_b64(i.data["url"])) # type: ignore
) # type: ignore
).as_dict() # type: ignore
) # type: ignore
logger.info(f"输入图片 {i.data['url']}")
elif config.marshoai_enable_support_image_tip:
await UniMessage(
"*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。"
).send()
return usermsg # type: ignore
async def handle_single_chat(
self,
user_message: Union[str, list],
model_name: str,
tools: list,
with_context: bool = True,
) -> ChatCompletion:
"""
处理单条聊天
"""
backup_context = await get_backup_context(self.target.id, self.target.private)
if backup_context:
self.context.set_context(
backup_context, self.target.id, self.target.private
) # 加载历史记录
logger.info(f"已恢复会话 {self.target.id} 的上下文备份~")
context_msg = (
get_prompt(model_name)
+ (self.context.build(self.target.id, self.target.private))
if with_context
else ""
)
response = await make_chat_openai(
client=self.client,
msg=context_msg + [UserMessage(content=user_message).as_dict()], # type: ignore
model_name=model_name,
tools=tools,
)
return response

View File

@ -6,7 +6,7 @@ import nonebot_plugin_localstore as store
from nonebot import logger from nonebot import logger
from .config import config from .config import config
from .instances import * from .instances import context, driver, target_list, tools
from .plugin import load_plugin, load_plugins from .plugin import load_plugin, load_plugins
from .util import get_backup_context, save_context_to_json from .util import get_backup_context, save_context_to_json

View File

@ -3,7 +3,8 @@ from nonebot import get_driver
from openai import AsyncOpenAI from openai import AsyncOpenAI
from .config import config from .config import config
from .models import Cache, MarshoContext, MarshoTools from .handler import MarshoHandler
from .models import MarshoContext, MarshoTools
driver = get_driver() driver = get_driver()
@ -11,7 +12,6 @@ command_start = driver.config.command_start
model_name = config.marshoai_default_model model_name = config.marshoai_default_model
context = MarshoContext() context = MarshoContext()
tools = MarshoTools() tools = MarshoTools()
cache = Cache()
token = config.marshoai_token token = config.marshoai_token
endpoint = config.marshoai_azure_endpoint endpoint = config.marshoai_azure_endpoint
# client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) # client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))

View File

@ -24,6 +24,7 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from .config import config from .config import config
from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS
from .handler import MarshoHandler
from .hooks import * from .hooks import *
from .instances import client, context, model_name, target_list, tools from .instances import client, context, model_name, target_list, tools
from .metadata import metadata from .metadata import metadata
@ -232,16 +233,10 @@ async def marsho(
# 发送说明 # 发送说明
# await UniMessage(metadata.usage + "\n当前使用的模型" + model_name).send() # await UniMessage(metadata.usage + "\n当前使用的模型" + model_name).send()
await marsho_cmd.finish(INTRODUCTION) await marsho_cmd.finish(INTRODUCTION)
handler = MarshoHandler(client, context)
try: try:
user_id = event.get_user_id() user_nickname = await get_nickname_by_user_id(event.get_user_id())
nicknames = await get_nicknames() if not user_nickname:
user_nickname = nicknames.get(user_id, "")
if user_nickname != "":
nickname_prompt = (
f"\n*此消息的说话者id为:{user_id},名字为:{user_nickname}*"
)
else:
nickname_prompt = ""
# 用户名无法获取,暂时注释 # 用户名无法获取,暂时注释
# user_nickname = event.sender.nickname # 未设置昵称时获取用户名 # user_nickname = event.sender.nickname # 未设置昵称时获取用户名
# nickname_prompt = f"\n*此消息的说话者:{user_nickname}" # nickname_prompt = f"\n*此消息的说话者:{user_nickname}"
@ -255,49 +250,15 @@ async def marsho(
"※你未设置自己的昵称。推荐使用「nickname [昵称]」命令设置昵称来获得个性化(可能)回答。" "※你未设置自己的昵称。推荐使用「nickname [昵称]」命令设置昵称来获得个性化(可能)回答。"
).send() ).send()
is_support_image_model = ( usermsg = await handler.process_user_input(text, model_name)
model_name.lower()
in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models
)
usermsg = [] if is_support_image_model else ""
for i in text: # type: ignore
if i.type == "text":
if is_support_image_model:
usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore
else:
usermsg += str(i.data["text"] + nickname_prompt) # type: ignore
elif i.type == "image":
if is_support_image_model:
usermsg.append( # type: ignore
ImageContentItem(
image_url=ImageUrl( # type: ignore
url=str(await get_image_b64(i.data["url"])) # type: ignore
) # type: ignore
).as_dict() # type: ignore
) # type: ignore
logger.info(f"输入图片 {i.data['url']}")
elif config.marshoai_enable_support_image_tip:
await UniMessage(
"*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。"
).send()
backup_context = await get_backup_context(target.id, target.private)
if backup_context:
context.set_context(
backup_context, target.id, target.private
) # 加载历史记录
logger.info(f"已恢复会话 {target.id} 的上下文备份~")
context_msg = get_prompt(model_name) + context.build(target.id, target.private)
tools_lists = tools.tools_list + list( tools_lists = tools.tools_list + list(
map(lambda v: v.data(), get_function_calls().values()) map(lambda v: v.data(), get_function_calls().values())
) )
logger.info(f"正在获取回答,模型:{model_name}") logger.info(f"正在获取回答,模型:{model_name}")
# logger.info(f"上下文:{context_msg}") # logger.info(f"上下文:{context_msg}")
response = await make_chat_openai( response = await handler.handle_single_chat(
client=client, usermsg, model_name, tools_lists, with_context=True
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore
tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化
) )
# await UniMessage(str(response)).send() # await UniMessage(str(response)).send()
choice = response.choices[0] choice = response.choices[0]
@ -451,12 +412,10 @@ with contextlib.suppress(ImportError): # 优化先不做()
@poke_notify.handle() @poke_notify.handle()
async def poke(event: Event): async def poke(event: Event):
user_id = event.get_user_id() user_nickname = await get_nickname_by_user_id(event.get_user_id())
nicknames = await get_nicknames()
user_nickname = nicknames.get(user_id, "")
try: try:
if config.marshoai_poke_suffix != "": if config.marshoai_poke_suffix != "":
logger.info(f"收到戳一戳,用户昵称:{user_nickname}用户ID{user_id}") logger.info(f"收到戳一戳,用户昵称:{user_nickname}")
response = await make_chat_openai( response = await make_chat_openai(
client=client, client=client,
model_name=model_name, model_name=model_name,

View File

@ -20,11 +20,10 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
from zhDateTime import DateTime from zhDateTime import DateTime
from ._types import DeveloperMessage from ._types import DeveloperMessage
from .cache.decos import *
from .config import config from .config import config
from .constants import CODE_BLOCK_PATTERN, IMG_LATEX_PATTERN, OPENAI_NEW_MODELS from .constants import CODE_BLOCK_PATTERN, IMG_LATEX_PATTERN, OPENAI_NEW_MODELS
from .deal_latex import ConvertLatex from .deal_latex import ConvertLatex
from .decos import from_cache, update_to_cache
from .instances import cache
# nickname_json = None # 记录昵称 # nickname_json = None # 记录昵称
# praises_json = None # 记录夸赞名单 # praises_json = None # 记录夸赞名单
@ -240,6 +239,11 @@ async def set_nickname(user_id: str, name: str):
return data return data
async def get_nickname_by_user_id(user_id: str):
nickname_json = await get_nicknames()
return nickname_json.get(user_id, "")
@update_to_cache("nickname") @update_to_cache("nickname")
async def refresh_nickname_json(): async def refresh_nickname_json():
"""强制刷新nickname_json""" """强制刷新nickname_json"""