代码结构优化

This commit is contained in:
ShiXui 2022-09-18 09:15:07 +08:00
parent d3043ca2f0
commit ffcd7992dd
4 changed files with 83 additions and 128 deletions

View File

@ -17,15 +17,14 @@ from nonebot.matcher import Matcher
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent, Message from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent, Message
require("nonebot_plugin_datastore") require("nonebot_plugin_chatrecorder_guild_patch")
from nonebot_plugin_chatrecorder_guild_patch import get_guild_all_channel
require("nonebot_plugin_chatrecorder") require("nonebot_plugin_chatrecorder")
require("nonebot_plugin_guild_patch") require("nonebot_plugin_guild_patch")
from nonebot_plugin_guild_patch import GuildMessageEvent from nonebot_plugin_guild_patch import GuildMessageEvent
from .qqDBRecorder import get_message_records,msg_counter from .function import get_message_records,msg_counter, msg_list2msg
from .config import plugin_config from .config import plugin_config
from .qqGuildJsonRecorder import get_guild_message_records
def parse_datetime(key: str): def parse_datetime(key: str):
"""解析数字,并将结果存入 state 中""" """解析数字,并将结果存入 state 中"""
@ -166,24 +165,37 @@ async def handle_message(
): ):
st = time.time() st = time.time()
bot_id = await bot.call_api('get_login_info')
bot_id = [str(bot_id['user_id'])]
if isinstance(event,GroupMessageEvent):
if isinstance(event,GroupMessageEvent):
bot_id = await bot.call_api('get_login_info')
bot_id = [str(bot_id['user_id'])]
gids:List[str] = [str(event.group_id)] gids:List[str] = [str(event.group_id)]
msg = await get_message_records( msg_list = await get_message_records(
group_ids=gids, group_ids=gids,
exclude_user_ids=bot_id, exclude_user_ids=bot_id,
message_type='group', message_type='group',
time_start=start.astimezone(ZoneInfo("UTC")), time_start=start.astimezone(ZoneInfo("UTC")),
time_stop=stop.astimezone(ZoneInfo("UTC")) time_stop=stop.astimezone(ZoneInfo("UTC"))
) )
msg = await msg_counter(gid=event.group_id, bot=bot, msg=msg,got_num=plugin_config.dialectlist_get_num)
elif isinstance(event, GuildMessageEvent): elif isinstance(event, GuildMessageEvent):
bot_id = await bot.call_api('get_guild_service_profile')
bot_id = [str(bot_id['tiny_id'])]
guild_ids:List[str] = await get_guild_all_channel(event.guild_id,bot=bot)
msg_list = await get_message_records(
group_ids=guild_ids,
exclude_user_ids=bot_id,
message_type='group',
time_start=start.astimezone(ZoneInfo("UTC")),
time_stop=stop.astimezone(ZoneInfo("UTC"))
)
guild_id = event.guild_id
msg = await get_guild_message_records(guild_id=str(guild_id),bot=bot) msg_dict = await msg_counter(msg_list=msg_list)
if isinstance(event,GroupMessageEvent):
msg = await msg_list2msg(msg_list=msg_dict,gid=event.group_id,platform='qq',bot=bot,got_num=plugin_config.dialectlist_get_num)
elif isinstance(event, GuildMessageEvent):
msg = await msg_list2msg(msg_list=msg_dict,gid=event.guild_id,platform='guild',bot=bot,got_num=plugin_config.dialectlist_get_num)
await rankings.send(msg) await rankings.send(msg)
await asyncio.sleep(1) #让图片先发出来 await asyncio.sleep(1) #让图片先发出来

View File

@ -1,21 +1,16 @@
from typing import Optional from typing import Optional, Union, Literal
from nonebot import get_driver from nonebot import get_driver
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from pathlib import Path
import os
class Config(BaseModel, extra=Extra.ignore): class Config(BaseModel, extra=Extra.ignore):
timezone: Optional[str] timezone: Optional[str]
dialectlist_string_format: str = '{index}名:\n{nickname},{chatdatanum}条消息\n' dialectlist_string_format: str = '{index}名:\n{nickname},{chatdatanum}条消息\n' #消息格式
dialectlist_string_suffix_format: str = '你们的职业是水群吗————MYX\n计算花费时间:{timecost}' dialectlist_string_suffix_format: str = '你们的职业是水群吗————MYX\n计算花费时间:{timecost}' #消息后缀格式
dialectlist_path:str = os.path.dirname(__file__) dialectlist_get_num:int = 10 #获取人数数量
dialectlist_image_path: Path = Path(dialectlist_path)/'image.png' dialectlist_visualization:bool = True #是否可视化
dialectlist_imageSvg_path: Path = Path(dialectlist_path)/'image.svg' dialectlist_visualization_type:Literal['饼图','圆环图','柱状图'] = '圆环图' #可视化方案
dialectlist_json_path:Path = Path(dialectlist_path)/'qqguild.json'
dialectlist_get_num:int = 10
dialectlist_visualization:bool = True
global_config = get_driver().config global_config = get_driver().config
plugin_config = Config.parse_obj(global_config) plugin_config = Config.parse_obj(global_config)

View File

@ -9,18 +9,18 @@ from typing import Iterable, List, Optional, Dict
from pygal.style import Style from pygal.style import Style
style=Style(font_family="SimHei",) style=Style(font_family="SimHei",)
from nonebot.log import logger from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Message,MessageSegment from nonebot.adapters.onebot.v11 import Message,MessageSegment
from nonebot.adapters.onebot.v11.exception import ActionFailed from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot_plugin_datastore import create_session from nonebot_plugin_datastore import create_session
from nonebot_plugin_chatrecorder.model import MessageRecord from nonebot_plugin_chatrecorder.model import MessageRecord
from .config import plugin_config from .config import plugin_config
def remove_control_characters(string:str) -> str: def remove_control_characters(string:str) -> str:
"""将字符串中的控制符去除 """将字符串中的控制符去除
@ -32,6 +32,8 @@ def remove_control_characters(string:str) -> str:
""" """
return "".join(ch for ch in string if unicodedata.category(ch)[0]!="C") return "".join(ch for ch in string if unicodedata.category(ch)[0]!="C")
async def get_message_records( async def get_message_records(
user_ids: Optional[Iterable[str]] = None, user_ids: Optional[Iterable[str]] = None,
group_ids: Optional[Iterable[str]] = None, group_ids: Optional[Iterable[str]] = None,
@ -95,78 +97,90 @@ async def get_message_records(
async def msg_counter( async def msg_counter(msg_list:List[MessageRecord])->Dict[str,int]:
gid:int,
bot:Bot,
msg:List[MessageRecord],
got_num:int=10,
)->Message:
''' '''
计算出结果并返回可以直接发送的字符串和图片 计算出话最多的几个人的id并返回
''' '''
st = time.time()
logger.debug('loading msg from group {}'.format(gid))
gnl = await bot.call_api('get_group_member_list',group_id=int(gid))
logger.debug('group {} have number {}'.format(gid,len(gnl)))
lst:Dict[str,int] = {} lst:Dict[str,int] = {}
msg_len = len(msg) msg_len = len(msg_list)
for i in msg: logger.info('wow , there are {} msgs to count !!!'.format(msg_len))
for i in msg_list:
try: try:
lst[i.user_id] +=1 lst[i.user_id] +=1
except KeyError: except KeyError:
lst[i.user_id] =1 lst[i.user_id] =1
logger.debug(lst) logger.debug(lst)
logger.debug('group number num is '+str(len(lst)))
return lst
async def msg_list2msg(
msg_list:Dict[str,int],
gid:int,
got_num:int,
platform:Optional[Literal['guild', 'qq']],
bot:Bot
)->Message:
ranking = [] ranking = []
while len(ranking) < got_num: while len(ranking) < got_num:
try: try:
maxkey = max(lst, key=lst.get) # type: ignore maxkey = max(msg_list, key=msg_list.get) # type: ignore
except ValueError: except ValueError:
ranking.append(["null",0]) ranking.append(["null",0])
continue continue
logger.debug('searching number {} from group {}'.format(str(maxkey),str(gid))) logger.debug('searching member {} from group {}'.format(str(maxkey),str(gid)))
try: try:
if platform == 'qq':
member_info = await bot.call_api( member_info = await bot.call_api(
"get_group_member_info", "get_group_member_info",
group_id=int(gid), group_id=int(gid),
user_id=int(maxkey), user_id=int(maxkey),
no_cache=True no_cache=True
) )
nickname:str = member_info['nickname']if not member_info['card'] else member_info['card'] nickname:str = member_info['nickname']if not member_info['card'] else member_info['card']
ranking.append([remove_control_characters(nickname).strip(),lst.pop(maxkey)]) else:
member_info = await bot.call_api(
"get_guild_member_profile",
guild_id=str(gid),
user_id=str(maxkey)
)
nickname:str = member_info['nickname']
ranking.append([remove_control_characters(nickname).strip(),msg_list.pop(maxkey)])
except ActionFailed as e: except ActionFailed as e:
logger.warning(e) logger.warning(e)
logger.warning('number {} is not exit in group {}'.format(str(maxkey),str(gid))) logger.warning('member {} is not exit in group {}'.format(str(maxkey),str(gid)))
lst.pop(maxkey) msg_list.pop(maxkey)
logger.debug('loaded list:\n{}'.format(ranking)) logger.debug('loaded list:\n{}'.format(ranking))
if plugin_config.dialectlist_visualization: if plugin_config.dialectlist_visualization:
if plugin_config.dialectlist_visualization_type == '圆环图':
view = pygal.Pie(inner_radius=0.6,style=style)
elif plugin_config.dialectlist_visualization_type == '饼图':
view = pygal.Pie(style=style)
else:
view = pygal.Bar(style=style)
view = pygal.Pie(inner_radius=0.6,style=style) view.title = '消息可视化'
view.title = '消息圆环图'
for i in ranking: for i in ranking:
view.add(str(i[0]),int(i[1])) view.add(str(i[0]),int(i[1]))
try: try:
png: bytes = view.render_to_png()# type: ignore png: bytes = view.render_to_png() # type: ignore
process_msg = Message(MessageSegment.image(png)) process_msg = Message(MessageSegment.image(png))
except OSError: except OSError:
logger.error('GTK+(GIMP Toolkit) is not installed, the svg can not be transformed to png') logger.error('GTK+(GIMP Toolkit) is not installed, the svg can not be transformed to png')
plugin_config.dialectlist_visualization = False plugin_config.dialectlist_visualization = False
process_msg = Message('无法发送可视化图片请检查是否安装GTK+详细安装教程可见github\nhttps://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer \n若不想安装这个软件,再次使用指令会转换为发送字符串而不是发送图片') process_msg = Message('无法发送可视化图片请检查是否安装GTK+详细安装教程可见github\nhttps://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer \n若不想安装这个软件,再次使用这个指令不会显示这个提示')
else: else:
process_msg = '' process_msg = ''
out:str = '' out:str = ''
for i in range(got_num): for i in range(got_num):
index = i+1 index = i+1
@ -175,6 +189,5 @@ async def msg_counter(
out = out + str_example out = out + str_example
logger.debug(out) logger.debug(out)
logger.info('spent {} seconds to count from {} msg'.format(time.time()-st,msg_len))
return Message(out)+process_msg return Message(out)+process_msg

View File

@ -1,65 +0,0 @@
import json
from typing import Dict
from nonebot.log import logger
from nonebot.message import event_postprocessor
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Message
from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot_plugin_guild_patch import GuildMessageEvent
from .config import plugin_config
def update_json(updatedata:Dict):
with open(plugin_config.dialectlist_json_path, 'w', encoding='utf-8') as f:
json.dump(updatedata, f, ensure_ascii=False, indent=4)
def get_json()-> Dict[str,Dict]:
if not plugin_config.dialectlist_json_path.exists():
return {}
with open(plugin_config.dialectlist_json_path, 'r', encoding='utf-8') as f:
data:Dict = json.load(f)
return data
@event_postprocessor
async def _pocesser(event:GuildMessageEvent):
data = get_json()
try:
data[str(event.guild_id)][str(event.sender.nickname)] += 1
except KeyError:
data[str(event.guild_id)] = {str(event.sender.nickname):1}
update_json(data)
async def get_guild_message_records(
guild_id:str,
bot:Bot,
got_num:int=10,
)->Message:
data = get_json()
ranking = []
while len(ranking) < got_num:
try:
maxkey = max(data[guild_id], key=data[guild_id].get) # type: ignore
except ValueError:
ranking.append(("null",0))
continue
ranking.append((maxkey,data[guild_id].pop(maxkey)))
logger.debug('loaded list:\n{}'.format(ranking))
out:str = ''
for i in range(got_num):
index = i+1
nickname,chatdatanum = ranking[i]
str_example = plugin_config.dialectlist_string_format.format(index=index,nickname=nickname,chatdatanum=chatdatanum)
out = out + str_example
return Message(out)