From e290a94cf38a3af22871e6109215c96616f8e4ce Mon Sep 17 00:00:00 2001 From: Chenric <91937041+ChenXu233@users.noreply.github.com> Date: Tue, 16 Jul 2024 13:56:50 +0800 Subject: [PATCH] =?UTF-8?q?:children=5Fcrossing:=20=E5=85=BC=E5=AE=B9pydan?= =?UTF-8?q?tic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_dialectlist/__init__.py | 12 ++++++++---- nonebot_plugin_dialectlist/utils.py | 27 +++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/nonebot_plugin_dialectlist/__init__.py b/nonebot_plugin_dialectlist/__init__.py index 0b08168..b50a7ba 100644 --- a/nonebot_plugin_dialectlist/__init__.py +++ b/nonebot_plugin_dialectlist/__init__.py @@ -19,6 +19,7 @@ from arclet.alconna.arparma import Arparma from nonebot.log import logger from nonebot.params import Arg, Depends from nonebot.typing import T_State +from nonebot.compat import model_dump from nonebot.adapters import Bot, Event, Message from nonebot.params import Arg, Depends from nonebot.permission import SUPERUSER @@ -54,6 +55,7 @@ from .utils import ( got_rank, msg_counter, persist_id2user_id, + user_id2persist_id, get_rank_image ) @@ -208,18 +210,20 @@ async def _group_message( prompt="请输入你要查询的群号。" ) async def handle_rank( + state: T_State, bot: Bot, event: Event, session: Session = Depends(extract_session), start: datetime = Arg(), stop: datetime = Arg(), - group_id: str = Arg(), + # group_id: str = Arg(), ): - if group_id: - id = group_id + if id := state["group_id"]: + # id = await user_id2persist_id(id) logger.debug(f"group_id: {id}") else: id = session.id2 + logger.debug(f"group_id: {id}") if not id: await saa.Text("没有指定群哦").finish() @@ -259,7 +263,7 @@ async def handle_rank( user_avatar = await user_info.user_avatar.get_image()\ if user_info.user_avatar\ else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read() - user = UserRankInfo(**user_info.model_dump(), + user = UserRankInfo(**model_dump(user_info), user_bnum=i[1], user_proportion= round(i[1] / total * 100, 2), user_index= rank.index(i) + 1, diff --git a/nonebot_plugin_dialectlist/utils.py b/nonebot_plugin_dialectlist/utils.py index 070f110..45360fc 100644 --- a/nonebot_plugin_dialectlist/utils.py +++ b/nonebot_plugin_dialectlist/utils.py @@ -57,17 +57,38 @@ async def persist_id2user_id(ids: List) -> List[str]: return [i.id1 for i in records] -async def user_id2persist_id(id: str) -> int: +async def user_id2persist_id(ids: List[str]) -> List[int]: whereclause: List[ColumnElement[bool]] = [] - whereclause.append(or_(*[SessionModel.id2 == id])) + whereclause.append(or_(*[SessionModel.id1 == id for id in ids])) statement = ( select(SessionModel).where(*whereclause) # .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) ) async with get_session() as db_session: records = (await db_session.scalars(statement)).all() - return records[0].id + return [i.id for i in records] +async def group_id2persist_id(ids: List[str]) -> List[int]: + whereclause: List[ColumnElement[bool]] = [] + whereclause.append(or_(*[SessionModel.id2 == id for id in ids])) + statement = ( + select(SessionModel).where(*whereclause) + # .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) + ) + async with get_session() as db_session: + records = (await db_session.scalars(statement)).all() + return [i.id for i in records] + +async def persist_id2group_id(ids: List[str]) -> List[str]: + whereclause: List[ColumnElement[bool]] = [] + whereclause.append(or_(*[SessionModel.id == id for id in ids])) + statement = ( + select(SessionModel).where(*whereclause) + # .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) + ) + async with get_session() as db_session: + records = (await db_session.scalars(statement)).all() + return [i.id2 for i in records] def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]: """### 计算每个人的消息量