mirror of
https://github.com/ChenXu233/nonebot_plugin_dialectlist.git
synced 2024-11-24 08:05:26 +08:00
🚸 兼容pydantic
This commit is contained in:
parent
86fd356f46
commit
e290a94cf3
@ -19,6 +19,7 @@ from arclet.alconna.arparma import Arparma
|
|||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.params import Arg, Depends
|
from nonebot.params import Arg, Depends
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
from nonebot.compat import model_dump
|
||||||
from nonebot.adapters import Bot, Event, Message
|
from nonebot.adapters import Bot, Event, Message
|
||||||
from nonebot.params import Arg, Depends
|
from nonebot.params import Arg, Depends
|
||||||
from nonebot.permission import SUPERUSER
|
from nonebot.permission import SUPERUSER
|
||||||
@ -54,6 +55,7 @@ from .utils import (
|
|||||||
got_rank,
|
got_rank,
|
||||||
msg_counter,
|
msg_counter,
|
||||||
persist_id2user_id,
|
persist_id2user_id,
|
||||||
|
user_id2persist_id,
|
||||||
get_rank_image
|
get_rank_image
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -208,18 +210,20 @@ async def _group_message(
|
|||||||
prompt="请输入你要查询的群号。"
|
prompt="请输入你要查询的群号。"
|
||||||
)
|
)
|
||||||
async def handle_rank(
|
async def handle_rank(
|
||||||
|
state: T_State,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
event: Event,
|
event: Event,
|
||||||
session: Session = Depends(extract_session),
|
session: Session = Depends(extract_session),
|
||||||
start: datetime = Arg(),
|
start: datetime = Arg(),
|
||||||
stop: datetime = Arg(),
|
stop: datetime = Arg(),
|
||||||
group_id: str = Arg(),
|
# group_id: str = Arg(),
|
||||||
):
|
):
|
||||||
if group_id:
|
if id := state["group_id"]:
|
||||||
id = group_id
|
# id = await user_id2persist_id(id)
|
||||||
logger.debug(f"group_id: {id}")
|
logger.debug(f"group_id: {id}")
|
||||||
else:
|
else:
|
||||||
id = session.id2
|
id = session.id2
|
||||||
|
logger.debug(f"group_id: {id}")
|
||||||
|
|
||||||
if not id:
|
if not id:
|
||||||
await saa.Text("没有指定群哦").finish()
|
await saa.Text("没有指定群哦").finish()
|
||||||
@ -259,7 +263,7 @@ async def handle_rank(
|
|||||||
user_avatar = await user_info.user_avatar.get_image()\
|
user_avatar = await user_info.user_avatar.get_image()\
|
||||||
if user_info.user_avatar\
|
if user_info.user_avatar\
|
||||||
else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read()
|
else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read()
|
||||||
user = UserRankInfo(**user_info.model_dump(),
|
user = UserRankInfo(**model_dump(user_info),
|
||||||
user_bnum=i[1],
|
user_bnum=i[1],
|
||||||
user_proportion= round(i[1] / total * 100, 2),
|
user_proportion= round(i[1] / total * 100, 2),
|
||||||
user_index= rank.index(i) + 1,
|
user_index= rank.index(i) + 1,
|
||||||
|
@ -57,17 +57,38 @@ async def persist_id2user_id(ids: List) -> List[str]:
|
|||||||
return [i.id1 for i in records]
|
return [i.id1 for i in records]
|
||||||
|
|
||||||
|
|
||||||
async def user_id2persist_id(id: str) -> int:
|
async def user_id2persist_id(ids: List[str]) -> List[int]:
|
||||||
whereclause: List[ColumnElement[bool]] = []
|
whereclause: List[ColumnElement[bool]] = []
|
||||||
whereclause.append(or_(*[SessionModel.id2 == id]))
|
whereclause.append(or_(*[SessionModel.id1 == id for id in ids]))
|
||||||
statement = (
|
statement = (
|
||||||
select(SessionModel).where(*whereclause)
|
select(SessionModel).where(*whereclause)
|
||||||
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
|
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
|
||||||
)
|
)
|
||||||
async with get_session() as db_session:
|
async with get_session() as db_session:
|
||||||
records = (await db_session.scalars(statement)).all()
|
records = (await db_session.scalars(statement)).all()
|
||||||
return records[0].id
|
return [i.id for i in records]
|
||||||
|
|
||||||
|
async def group_id2persist_id(ids: List[str]) -> List[int]:
|
||||||
|
whereclause: List[ColumnElement[bool]] = []
|
||||||
|
whereclause.append(or_(*[SessionModel.id2 == id for id in ids]))
|
||||||
|
statement = (
|
||||||
|
select(SessionModel).where(*whereclause)
|
||||||
|
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
|
||||||
|
)
|
||||||
|
async with get_session() as db_session:
|
||||||
|
records = (await db_session.scalars(statement)).all()
|
||||||
|
return [i.id for i in records]
|
||||||
|
|
||||||
|
async def persist_id2group_id(ids: List[str]) -> List[str]:
|
||||||
|
whereclause: List[ColumnElement[bool]] = []
|
||||||
|
whereclause.append(or_(*[SessionModel.id == id for id in ids]))
|
||||||
|
statement = (
|
||||||
|
select(SessionModel).where(*whereclause)
|
||||||
|
# .join(SessionModel, SessionModel.id == MessageRecord.session_persist_id)
|
||||||
|
)
|
||||||
|
async with get_session() as db_session:
|
||||||
|
records = (await db_session.scalars(statement)).all()
|
||||||
|
return [i.id2 for i in records]
|
||||||
|
|
||||||
def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]:
|
def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]:
|
||||||
"""### 计算每个人的消息量
|
"""### 计算每个人的消息量
|
||||||
|
Loading…
Reference in New Issue
Block a user