🚸 兼容pydantic

This commit is contained in:
Chenric 2024-07-16 13:56:50 +08:00
parent 86fd356f46
commit e290a94cf3
2 changed files with 32 additions and 7 deletions

View File

@ -19,6 +19,7 @@ from arclet.alconna.arparma import Arparma
from nonebot.log import logger from nonebot.log import logger
from nonebot.params import Arg, Depends from nonebot.params import Arg, Depends
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.compat import model_dump
from nonebot.adapters import Bot, Event, Message from nonebot.adapters import Bot, Event, Message
from nonebot.params import Arg, Depends from nonebot.params import Arg, Depends
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
@ -54,6 +55,7 @@ from .utils import (
got_rank, got_rank,
msg_counter, msg_counter,
persist_id2user_id, persist_id2user_id,
user_id2persist_id,
get_rank_image get_rank_image
) )
@ -208,18 +210,20 @@ async def _group_message(
prompt="请输入你要查询的群号。" prompt="请输入你要查询的群号。"
) )
async def handle_rank( async def handle_rank(
state: T_State,
bot: Bot, bot: Bot,
event: Event, event: Event,
session: Session = Depends(extract_session), session: Session = Depends(extract_session),
start: datetime = Arg(), start: datetime = Arg(),
stop: datetime = Arg(), stop: datetime = Arg(),
group_id: str = Arg(), # group_id: str = Arg(),
): ):
if group_id: if id := state["group_id"]:
id = group_id # id = await user_id2persist_id(id)
logger.debug(f"group_id: {id}") logger.debug(f"group_id: {id}")
else: else:
id = session.id2 id = session.id2
logger.debug(f"group_id: {id}")
if not id: if not id:
await saa.Text("没有指定群哦").finish() await saa.Text("没有指定群哦").finish()
@ -259,7 +263,7 @@ async def handle_rank(
user_avatar = await user_info.user_avatar.get_image()\ user_avatar = await user_info.user_avatar.get_image()\
if user_info.user_avatar\ if user_info.user_avatar\
else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read() else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read()
user = UserRankInfo(**user_info.model_dump(), user = UserRankInfo(**model_dump(user_info),
user_bnum=i[1], user_bnum=i[1],
user_proportion= round(i[1] / total * 100, 2), user_proportion= round(i[1] / total * 100, 2),
user_index= rank.index(i) + 1, user_index= rank.index(i) + 1,

View File

@ -57,17 +57,38 @@ async def persist_id2user_id(ids: List) -> List[str]:
return [i.id1 for i in records] return [i.id1 for i in records]
async def user_id2persist_id(id: str) -> int: async def user_id2persist_id(ids: List[str]) -> List[int]:
whereclause: List[ColumnElement[bool]] = [] whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id2 == id])) whereclause.append(or_(*[SessionModel.id1 == id for id in ids]))
statement = ( statement = (
select(SessionModel).where(*whereclause) select(SessionModel).where(*whereclause)
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id) # .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
) )
async with get_session() as db_session: async with get_session() as db_session:
records = (await db_session.scalars(statement)).all() records = (await db_session.scalars(statement)).all()
return records[0].id return [i.id for i in records]
async def group_id2persist_id(ids: List[str]) -> List[int]:
whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id2 == id for id in ids]))
statement = (
select(SessionModel).where(*whereclause)
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
)
async with get_session() as db_session:
records = (await db_session.scalars(statement)).all()
return [i.id for i in records]
async def persist_id2group_id(ids: List[str]) -> List[str]:
whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id == id for id in ids]))
statement = (
select(SessionModel).where(*whereclause)
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
)
async with get_session() as db_session:
records = (await db_session.scalars(statement)).all()
return [i.id2 for i in records]
def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]: def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]:
"""### 计算每个人的消息量 """### 计算每个人的消息量