🎨 使用black格式化代码

This commit is contained in:
Chen_Xu233 2024-07-28 09:33:07 +08:00
parent 065619f0a8
commit 9fb171ec1c
3 changed files with 79 additions and 48 deletions

View File

@ -50,14 +50,18 @@ from nonebot_plugin_session import Session, SessionIdType, extract_session
# from .function import * # from .function import *
from .config import Config, plugin_config from .config import Config, plugin_config
from .usage import __usage__ from .usage import __usage__
from .time import get_datetime_fromisoformat_with_timezone, get_datetime_now_with_timezone,parse_datetime from .time import (
get_datetime_fromisoformat_with_timezone,
get_datetime_now_with_timezone,
parse_datetime,
)
from .model import UserRankInfo from .model import UserRankInfo
from .utils import ( from .utils import (
got_rank, got_rank,
msg_counter, msg_counter,
persist_id2user_id, persist_id2user_id,
user_id2persist_id, user_id2persist_id,
get_rank_image get_rank_image,
) )
__plugin_meta__ = PluginMetadata( __plugin_meta__ = PluginMetadata(
@ -83,17 +87,20 @@ class SameTime(ArparmaBehavior):
interface.behave_fail() interface.behave_fail()
def wrapper(slot: Union[int, str], content: Optional[str],context) -> str: def wrapper(slot: Union[int, str], content: Optional[str], context) -> str:
if slot == "type" and content: if slot == "type" and content:
return content return content
return "" # pragma: no cover return "" # pragma: no cover
rank_cmd = on_alconna( rank_cmd = on_alconna(
Alconna( Alconna(
"B话榜", "B话榜",
Args["type?", ["今日", "昨日", "本周", "上周", "本月", "上月", "年度", "历史"]][ Args["type?", ["今日", "昨日", "本周", "上周", "本月", "上月", "年度", "历史"]][
"time?",str,], "time?",
Option("-g|--group_id",Args["group_id?", str]), str,
],
Option("-g|--group_id", Args["group_id?", str]),
behaviors=[SameTime()], behaviors=[SameTime()],
), ),
aliases={"废话榜"}, aliases={"废话榜"},
@ -206,10 +213,7 @@ async def _group_message(
prompt="请输入你要查询的结束日期(如 2022-02-22", prompt="请输入你要查询的结束日期(如 2022-02-22",
parameterless=[Depends(parse_datetime("stop"))], parameterless=[Depends(parse_datetime("stop"))],
) )
@rank_cmd.got( @rank_cmd.got("group_id", prompt="请输入你要查询的群号。")
"group_id",
prompt="请输入你要查询的群号。"
)
async def handle_rank( async def handle_rank(
state: T_State, state: T_State,
bot: Bot, bot: Bot,
@ -256,23 +260,33 @@ async def handle_rank(
for i in rank: for i in rank:
if user_info := await get_user_info(bot, event, user_id=str(i[0])): if user_info := await get_user_info(bot, event, user_id=str(i[0])):
logger.debug(user_info) logger.debug(user_info)
user_nickname = user_info.user_displayname\ user_nickname = (
if user_info.user_displayname\ user_info.user_displayname
else user_info.user_name\ if user_info.user_displayname
if user_info.user_name\ else user_info.user_name if user_info.user_name else user_info.user_id
else\ )
user_info.user_id user_avatar = (
user_avatar = await user_info.user_avatar.get_image()\ await user_info.user_avatar.get_image()
if user_info.user_avatar\ if user_info.user_avatar
else open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read() else open(
user = UserRankInfo(**model_dump(user_info), os.path.dirname(os.path.abspath(__file__))
user_bnum=i[1], + "/template/avatar/default.jpg",
user_proportion= round(i[1] / total * 100, 2), "rb",
user_index= cn2an.an2cn(index), ).read()
user_nickname= user_nickname, )
user_avatar_bytes= user_avatar, user = UserRankInfo(
**model_dump(user_info),
user_bnum=i[1],
user_proportion=round(i[1] / total * 100, 2),
user_index=cn2an.an2cn(index),
user_nickname=user_nickname,
user_avatar_bytes=user_avatar,
)
user.user_gender = (
""
if user_info.user_gender == "female"
else "" if user_info.user_gender == "male" else "ta"
) )
user.user_gender="" if user_info.user_gender == "female" else "" if user_info.user_gender == "male" else "ta"
rank2.append(user) rank2.append(user)
index += 1 index += 1
@ -283,7 +297,7 @@ async def handle_rank(
str_example = plugin_config.string_format.format( str_example = plugin_config.string_format.format(
index=rank2[i].user_index, index=rank2[i].user_index,
nickname=rank2[i].user_nickname, nickname=rank2[i].user_nickname,
chatdatanum=rank2[i].user_bnum chatdatanum=rank2[i].user_bnum,
) )
string += str_example string += str_example

View File

@ -7,5 +7,5 @@ class UserRankInfo(UserInfo):
user_bnum: int user_bnum: int
user_proportion: float user_proportion: float
user_nickname: str user_nickname: str
user_index: Union[int,str] user_index: Union[int, str]
user_avatar_bytes: bytes user_avatar_bytes: bytes

View File

@ -16,7 +16,7 @@ from nonebot_plugin_orm import get_session
from nonebot_plugin_session import Session, SessionLevel, extract_session from nonebot_plugin_session import Session, SessionLevel, extract_session
from nonebot_plugin_session_orm import SessionModel from nonebot_plugin_session_orm import SessionModel
from nonebot_plugin_userinfo import EventUserInfo, UserInfo from nonebot_plugin_userinfo import EventUserInfo, UserInfo
from nonebot_plugin_htmlrender import html_to_pic,template_to_pic from nonebot_plugin_htmlrender import html_to_pic, template_to_pic
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_chatrecorder import MessageRecord from nonebot_plugin_chatrecorder import MessageRecord
from nonebot_plugin_localstore import get_cache_dir from nonebot_plugin_localstore import get_cache_dir
@ -49,7 +49,7 @@ async def persist_id2user_id(ids: List) -> List[str]:
user_ids = [] user_ids = []
async with get_session() as db_session: async with get_session() as db_session:
for i in ids: for i in ids:
user_id = (await db_session.scalar(select(SessionModel).where(or_(*[SessionModel.id == i])))).id1 # type: ignore user_id = (await db_session.scalar(select(SessionModel).where(or_(*[SessionModel.id == i])))).id1 # type: ignore
user_ids.append(user_id) user_ids.append(user_id)
return user_ids return user_ids
@ -65,6 +65,7 @@ async def user_id2persist_id(ids: List[str]) -> List[int]:
records = (await db_session.scalars(statement)).all() records = (await db_session.scalars(statement)).all()
return [i.id for i in records] return [i.id for i in records]
async def group_id2persist_id(ids: List[str]) -> List[int]: async def group_id2persist_id(ids: List[str]) -> List[int]:
whereclause: List[ColumnElement[bool]] = [] whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id2 == id for id in ids])) whereclause.append(or_(*[SessionModel.id2 == id for id in ids]))
@ -76,6 +77,7 @@ async def group_id2persist_id(ids: List[str]) -> List[int]:
records = (await db_session.scalars(statement)).all() records = (await db_session.scalars(statement)).all()
return [i.id for i in records] return [i.id for i in records]
async def persist_id2group_id(ids: List[str]) -> List[str]: async def persist_id2group_id(ids: List[str]) -> List[str]:
whereclause: List[ColumnElement[bool]] = [] whereclause: List[ColumnElement[bool]] = []
whereclause.append(or_(*[SessionModel.id == id for id in ids])) whereclause.append(or_(*[SessionModel.id == id for id in ids]))
@ -87,6 +89,7 @@ async def persist_id2group_id(ids: List[str]) -> List[str]:
records = (await db_session.scalars(statement)).all() records = (await db_session.scalars(statement)).all()
return [i.id2 for i in records] return [i.id2 for i in records]
def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]: def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]:
"""### 计算每个人的消息量 """### 计算每个人的消息量
@ -137,6 +140,7 @@ def got_rank(msg_dict: Dict[str, int]) -> List:
return rank return rank
def remove_control_characters(string: str) -> str: def remove_control_characters(string: str) -> str:
"""### 将字符串中的控制符去除 """### 将字符串中的控制符去除
@ -148,27 +152,40 @@ def remove_control_characters(string: str) -> str:
""" """
return "".join(ch for ch in string if unicodedata.category(ch)[0] != "C") return "".join(ch for ch in string if unicodedata.category(ch)[0] != "C")
async def get_rank_image(rank: List[UserRankInfo]) -> bytes: async def get_rank_image(rank: List[UserRankInfo]) -> bytes:
for i in rank: for i in rank:
if i.user_avatar: if i.user_avatar:
try: try:
user_avatar = i.user_avatar_bytes user_avatar = i.user_avatar_bytes
except NotImplementedError: except NotImplementedError:
user_avatar = open(os.path.dirname(os.path.abspath(__file__))+"/template/avatar/default.jpg", "rb").read() user_avatar = open(
os.path.dirname(os.path.abspath(__file__))
+ "/template/avatar/default.jpg",
"rb",
).read()
# if not os.path.exists(cache_path / str(i.user_id)): # if not os.path.exists(cache_path / str(i.user_id)):
with open(cache_path / (str(i.user_id) + ".jpg"), "wb") as f: with open(cache_path / (str(i.user_id) + ".jpg"), "wb") as f:
f.write(user_avatar) f.write(user_avatar)
if plugin_config.template_path[:2] == './': if plugin_config.template_path[:2] == "./":
path = os.path.dirname(os.path.abspath(__file__)) + plugin_config.template_path[1:] path = (
os.path.dirname(os.path.abspath(__file__)) + plugin_config.template_path[1:]
)
else: else:
path = plugin_config.template_path path = plugin_config.template_path
path_dir, filename = os.path.split(path) path_dir, filename = os.path.split(path)
logger.debug(os.path.dirname(os.path.abspath(__file__)) + plugin_config.template_path[1:]) logger.debug(
return await template_to_pic(path_dir, os.path.dirname(os.path.abspath(__file__)) + plugin_config.template_path[1:]
filename, )
{'users': rank, return await template_to_pic(
'cache_path': cache_path, path_dir,
'file_path': os.path.dirname(os.path.abspath(__file__))}, filename,
pages={"viewport": {"width": 1366, "height": 10}}) {
"users": rank,
"cache_path": cache_path,
"file_path": os.path.dirname(os.path.abspath(__file__)),
},
pages={"viewport": {"width": 1200, "height": 10}},
)