🐛 修复计数缓存的bug

This commit is contained in:
Chenric 2024-09-28 09:04:01 +08:00
parent ea682ef18d
commit a765267cf1
2 changed files with 7 additions and 7 deletions

View File

@ -244,7 +244,6 @@ async def handle_rank(
await saa.Text("明明这个时间段都没有人说话怎么会有话痨榜呢?").finish() await saa.Text("明明这个时间段都没有人说话怎么会有话痨榜呢?").finish()
rank = got_rank(raw_rank) rank = got_rank(raw_rank)
logger.debug(rank)
ids = await persist_id2user_id([int(i[0]) for i in rank]) ids = await persist_id2user_id([int(i[0]) for i in rank])
for i in range(len(rank)): for i in range(len(rank)):
rank[i][0] = str(ids[i]) rank[i][0] = str(ids[i])

View File

@ -28,8 +28,9 @@ async def get_cache(time_start: datetime, time_stop: datetime, group_id: str):
sessions = (await db_session.scalars(statement)).all() sessions = (await db_session.scalars(statement)).all()
where = [ where = [
or_(MessageCountCache.session_id == session.id) for session in sessions or_(*[MessageCountCache.session_id == session.id for session in sessions])
] ]
statement = select(MessageCountCache).where(*where)
where.append(or_(MessageCountCache.time >= remove_timezone(time_start))) where.append(or_(MessageCountCache.time >= remove_timezone(time_start)))
where.append(or_(MessageCountCache.time <= remove_timezone(time_stop))) where.append(or_(MessageCountCache.time <= remove_timezone(time_stop)))
statement = select(MessageCountCache).where(*where) statement = select(MessageCountCache).where(*where)
@ -37,7 +38,7 @@ async def get_cache(time_start: datetime, time_stop: datetime, group_id: str):
user_caches = (await db_session.scalars(statement)).all() user_caches = (await db_session.scalars(statement)).all()
raw_rank = {} raw_rank = {}
for i in user_caches: for i in user_caches:
raw_rank[i.session_id] = i.session_bnum raw_rank[i.session_id] = raw_rank.get(i.session_id, 0) + i.session_bnum
return raw_rank return raw_rank
@ -111,13 +112,13 @@ async def _(bot: Bot, event: Event,session: Session = Depends(extract_session)):
async with get_session() as db_session: async with get_session() as db_session:
session_id = await get_session_persist_id(session) session_id = await get_session_persist_id(session)
logger.debug("session_id:"+str(session_id))
where = [or_(MessageCountCache.session_id == session_id)] where = [or_(MessageCountCache.session_id == session_id)]
where = [or_(MessageCountCache.time == remove_timezone(now))] where.append(or_(MessageCountCache.time == remove_timezone(now)))
statement = select(MessageCountCache).where(*where) statement = select(MessageCountCache).where(*where)
user_cache = (await db_session.scalars(statement)).all() user_cache = (await db_session.scalars(statement)).first()
if user_cache: if user_cache:
user_cache[0].session_bnum += 1 user_cache.session_bnum += 1
else: else:
user_cache = MessageCountCache( user_cache = MessageCountCache(
session_id=session_id, session_id=session_id,