🚸 兼容pydantic

This commit is contained in:
Chenric 2024-07-16 13:56:50 +08:00
parent 86fd356f46
commit e290a94cf3
2 changed files with 32 additions and 7 deletions

View File

@ -19,6 +19,7 @@ from arclet.alconna.arparma import Arparma
from nonebot.log import logger
from nonebot.params import Arg, Depends
from nonebot.typing import T_State
from nonebot.compat import model_dump
from nonebot.adapters import Bot, Event, Message
from nonebot.params import Arg, Depends
from nonebot.permission import SUPERUSER
@ -54,6 +55,7 @@ from .utils import (
got_rank,
msg_counter,
persist_id2user_id,
user_id2persist_id,
get_rank_image
)
@ -208,18 +210,20 @@ async def _group_message(
prompt="请输入你要查询的群号。"
)
async def handle_rank(
state: T_State,
bot: Bot,
event: Event,
session: Session = Depends(extract_session),
start: datetime = Arg(),
stop: datetime = Arg(),
group_id: str = Arg(),
# group_id: str = Arg(),
):
if group_id:
id = group_id
if id := state["group_id"]:
# id = await user_id2persist_id(id)
logger.debug(f"group_id: {id}")
else:
id = session.id2
logger.debug(f"group_id: {id}")
if not id:
await saa.Text("没有指定群哦").finish()
@ -259,7 +263,7 @@ async def handle_rank(
user_avatar = await user_info.user_avatar.get_image()\
if user_info.user_avatar\
else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read()
user = UserRankInfo(**user_info.model_dump(),
user = UserRankInfo(**model_dump(user_info),
user_bnum=i[1],
user_proportion= round(i[1] / total * 100, 2),
user_index= rank.index(i) + 1,

View File

@ -57,17 +57,38 @@ async def persist_id2user_id(ids: List) -> List[str]:
return [i.id1 for i in records]
async def user_id2persist_id(id: str) -> int:
async def user_id2persist_id(ids: List[str]) -> List[int]:
whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id2 == id]))
whereclause.append(or_(*[SessionModel.id1 == id for id in ids]))
statement = (
select(SessionModel).where(*whereclause)
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
)
async with get_session() as db_session:
records = (await db_session.scalars(statement)).all()
return records[0].id
return [i.id for i in records]
async def group_id2persist_id(ids: List[str]) -> List[int]:
whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id2 == id for id in ids]))
statement = (
select(SessionModel).where(*whereclause)
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
)
async with get_session() as db_session:
records = (await db_session.scalars(statement)).all()
return [i.id for i in records]
async def persist_id2group_id(ids: List[str]) -> List[str]:
whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id == id for id in ids]))
statement = (
select(SessionModel).where(*whereclause)
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
)
async with get_session() as db_session:
records = (await db_session.scalars(statement)).all()
return [i.id2 for i in records]
def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]:
"""### 计算每个人的消息量