diff --git a/nonebot_plugin_marshoai/decos.py b/nonebot_plugin_marshoai/cache/decos.py similarity index 94% rename from nonebot_plugin_marshoai/decos.py rename to nonebot_plugin_marshoai/cache/decos.py index ac59531e..dd15e79a 100644 --- a/nonebot_plugin_marshoai/decos.py +++ b/nonebot_plugin_marshoai/cache/decos.py @@ -1,4 +1,6 @@ -from .instances import cache +from ..models import Cache + +cache = Cache() def from_cache(key): diff --git a/nonebot_plugin_marshoai/handler.py b/nonebot_plugin_marshoai/handler.py new file mode 100644 index 00000000..b13512a6 --- /dev/null +++ b/nonebot_plugin_marshoai/handler.py @@ -0,0 +1,109 @@ +from typing import Optional, Union + +from azure.ai.inference.models import ( + AssistantMessage, + ImageContentItem, + ImageUrl, + TextContentItem, + ToolMessage, + UserMessage, +) +from nonebot.adapters import Event +from nonebot.log import logger +from nonebot.matcher import Matcher, current_event, current_matcher +from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion + +from .config import config +from .constants import SUPPORT_IMAGE_MODELS +from .models import MarshoContext +from .util import ( + get_backup_context, + get_image_b64, + get_nickname_by_user_id, + get_prompt, + make_chat_openai, +) + + +class MarshoHandler: + def __init__( + self, + client: AsyncOpenAI, + context: MarshoContext, + ): + self.client = client + self.context = context + self.event: Event = current_event.get() + self.matcher: Matcher = current_matcher.get() + self.message_id: str = UniMessage.get_message_id(self.event) + self.target = UniMessage.get_target(self.event) + + async def process_user_input( + self, user_input: UniMsg, model_name: str + ) -> Union[str, list]: + """ + 处理用户输入 + """ + is_support_image_model = ( + model_name.lower() + in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models + ) + usermsg = [] if is_support_image_model else "" + user_nickname = await get_nickname_by_user_id(self.event.get_user_id()) + if user_nickname: + nickname_prompt = f"此消息的说话者为: {user_nickname}" + else: + nickname_prompt = "" + for i in user_input: # type: ignore + if i.type == "text": + if is_support_image_model: + usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore + else: + usermsg += str(i.data["text"] + nickname_prompt) # type: ignore + elif i.type == "image": + if is_support_image_model: + usermsg.append( # type: ignore + ImageContentItem( + image_url=ImageUrl( # type: ignore + url=str(await get_image_b64(i.data["url"])) # type: ignore + ) # type: ignore + ).as_dict() # type: ignore + ) # type: ignore + logger.info(f"输入图片 {i.data['url']}") + elif config.marshoai_enable_support_image_tip: + await UniMessage( + "*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。" + ).send() + return usermsg # type: ignore + + async def handle_single_chat( + self, + user_message: Union[str, list], + model_name: str, + tools: list, + with_context: bool = True, + ) -> ChatCompletion: + """ + 处理单条聊天 + """ + backup_context = await get_backup_context(self.target.id, self.target.private) + if backup_context: + self.context.set_context( + backup_context, self.target.id, self.target.private + ) # 加载历史记录 + logger.info(f"已恢复会话 {self.target.id} 的上下文备份~") + context_msg = ( + get_prompt(model_name) + + (self.context.build(self.target.id, self.target.private)) + if with_context + else "" + ) + response = await make_chat_openai( + client=self.client, + msg=context_msg + [UserMessage(content=user_message).as_dict()], # type: ignore + model_name=model_name, + tools=tools, + ) + return response diff --git a/nonebot_plugin_marshoai/hooks.py b/nonebot_plugin_marshoai/hooks.py index 8b6ddd0e..6e6f439e 100644 --- a/nonebot_plugin_marshoai/hooks.py +++ b/nonebot_plugin_marshoai/hooks.py @@ -6,7 +6,7 @@ import nonebot_plugin_localstore as store from nonebot import logger from .config import config -from .instances import * +from .instances import context, driver, target_list, tools from .plugin import load_plugin, load_plugins from .util import get_backup_context, save_context_to_json diff --git a/nonebot_plugin_marshoai/instances.py b/nonebot_plugin_marshoai/instances.py index d450d2d2..b4cba256 100644 --- a/nonebot_plugin_marshoai/instances.py +++ b/nonebot_plugin_marshoai/instances.py @@ -3,7 +3,8 @@ from nonebot import get_driver from openai import AsyncOpenAI from .config import config -from .models import Cache, MarshoContext, MarshoTools +from .handler import MarshoHandler +from .models import MarshoContext, MarshoTools driver = get_driver() @@ -11,7 +12,6 @@ command_start = driver.config.command_start model_name = config.marshoai_default_model context = MarshoContext() tools = MarshoTools() -cache = Cache() token = config.marshoai_token endpoint = config.marshoai_azure_endpoint # client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index ce20fba6..bf1b0fe1 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -24,6 +24,7 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna from .config import config from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS +from .handler import MarshoHandler from .hooks import * from .instances import client, context, model_name, target_list, tools from .metadata import metadata @@ -232,16 +233,10 @@ async def marsho( # 发送说明 # await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send() await marsho_cmd.finish(INTRODUCTION) + handler = MarshoHandler(client, context) try: - user_id = event.get_user_id() - nicknames = await get_nicknames() - user_nickname = nicknames.get(user_id, "") - if user_nickname != "": - nickname_prompt = ( - f"\n*此消息的说话者id为:{user_id},名字为:{user_nickname}*" - ) - else: - nickname_prompt = "" + user_nickname = await get_nickname_by_user_id(event.get_user_id()) + if not user_nickname: # 用户名无法获取,暂时注释 # user_nickname = event.sender.nickname # 未设置昵称时获取用户名 # nickname_prompt = f"\n*此消息的说话者:{user_nickname}" @@ -255,49 +250,15 @@ async def marsho( "※你未设置自己的昵称。推荐使用「nickname [昵称]」命令设置昵称来获得个性化(可能)回答。" ).send() - is_support_image_model = ( - model_name.lower() - in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models - ) - usermsg = [] if is_support_image_model else "" - for i in text: # type: ignore - if i.type == "text": - if is_support_image_model: - usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore - else: - usermsg += str(i.data["text"] + nickname_prompt) # type: ignore - elif i.type == "image": - if is_support_image_model: - usermsg.append( # type: ignore - ImageContentItem( - image_url=ImageUrl( # type: ignore - url=str(await get_image_b64(i.data["url"])) # type: ignore - ) # type: ignore - ).as_dict() # type: ignore - ) # type: ignore - logger.info(f"输入图片 {i.data['url']}") - elif config.marshoai_enable_support_image_tip: - await UniMessage( - "*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。" - ).send() - backup_context = await get_backup_context(target.id, target.private) - if backup_context: - context.set_context( - backup_context, target.id, target.private - ) # 加载历史记录 - logger.info(f"已恢复会话 {target.id} 的上下文备份~") - context_msg = get_prompt(model_name) + context.build(target.id, target.private) + usermsg = await handler.process_user_input(text, model_name) tools_lists = tools.tools_list + list( map(lambda v: v.data(), get_function_calls().values()) ) logger.info(f"正在获取回答,模型:{model_name}") # logger.info(f"上下文:{context_msg}") - response = await make_chat_openai( - client=client, - model_name=model_name, - msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore - tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化 + response = await handler.handle_single_chat( + usermsg, model_name, tools_lists, with_context=True ) # await UniMessage(str(response)).send() choice = response.choices[0] @@ -451,12 +412,10 @@ with contextlib.suppress(ImportError): # 优化先不做() @poke_notify.handle() async def poke(event: Event): - user_id = event.get_user_id() - nicknames = await get_nicknames() - user_nickname = nicknames.get(user_id, "") + user_nickname = await get_nickname_by_user_id(event.get_user_id()) try: if config.marshoai_poke_suffix != "": - logger.info(f"收到戳一戳,用户昵称:{user_nickname},用户ID:{user_id}") + logger.info(f"收到戳一戳,用户昵称:{user_nickname}") response = await make_chat_openai( client=client, model_name=model_name, diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 33f8e4fe..1a35320f 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -20,11 +20,10 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage from zhDateTime import DateTime from ._types import DeveloperMessage +from .cache.decos import * from .config import config from .constants import CODE_BLOCK_PATTERN, IMG_LATEX_PATTERN, OPENAI_NEW_MODELS from .deal_latex import ConvertLatex -from .decos import from_cache, update_to_cache -from .instances import cache # nickname_json = None # 记录昵称 # praises_json = None # 记录夸赞名单 @@ -240,6 +239,11 @@ async def set_nickname(user_id: str, name: str): return data +async def get_nickname_by_user_id(user_id: str): + nickname_json = await get_nicknames() + return nickname_json.get(user_id, "") + + @update_to_cache("nickname") async def refresh_nickname_json(): """强制刷新nickname_json"""