From 065619f0a8d0fc461cefa231e374b2ee04a2b3a9 Mon Sep 17 00:00:00 2001 From: Chen_Xu233 Date: Sun, 28 Jul 2024 09:00:37 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20=E4=BF=AE=E5=A4=8D=E6=8E=92=E5=90=8D?= =?UTF-8?q?=E9=94=99=E4=BD=8D=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_dialectlist/utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/nonebot_plugin_dialectlist/utils.py b/nonebot_plugin_dialectlist/utils.py index 45360fc..4a0da0c 100644 --- a/nonebot_plugin_dialectlist/utils.py +++ b/nonebot_plugin_dialectlist/utils.py @@ -46,15 +46,12 @@ async def ensure_group(matcher: Matcher, session: Session = Depends(extract_sess async def persist_id2user_id(ids: List) -> List[str]: - whereclause: List[ColumnElement[bool]] = [] - whereclause.append(or_(*[SessionModel.id == id for id in ids])) - statement = ( - select(SessionModel).where(*whereclause) - # .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) - ) + user_ids = [] async with get_session() as db_session: - records = (await db_session.scalars(statement)).all() - return [i.id1 for i in records] + for i in ids: + user_id = (await db_session.scalar(select(SessionModel).where(or_(*[SessionModel.id == i])))).id1 # type: ignore + user_ids.append(user_id) + return user_ids async def user_id2persist_id(ids: List[str]) -> List[int]: