修复了一些过时的bug,暂时删除对频道的支持

This commit is contained in:
unknown 2023-03-19 09:58:50 +08:00
parent be3c14ef28
commit d0801c4319
3 changed files with 42 additions and 105 deletions

View File

@ -1,7 +1,7 @@
import re import re
import time import time
import asyncio import asyncio
from typing import List, Tuple, Union from typing import Tuple, Union
from datetime import datetime, timedelta from datetime import datetime, timedelta
try: try:
@ -17,13 +17,13 @@ from nonebot.matcher import Matcher
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent, Message from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent, Message
require("nonebot_plugin_chatrecorder_guild_patch")
from nonebot_plugin_chatrecorder_guild_patch import get_guild_all_channel
require("nonebot_plugin_chatrecorder") require("nonebot_plugin_chatrecorder")
from nonebot_plugin_chatrecorder import get_message_records
require("nonebot_plugin_guild_patch") require("nonebot_plugin_guild_patch")
from nonebot_plugin_guild_patch import GuildMessageEvent from nonebot_plugin_guild_patch import GuildMessageEvent
from .function import get_message_records,msg_counter, msg_list2msg from .function import msg_counter, msg_list2msg
from .config import plugin_config from .config import plugin_config
def parse_datetime(key: str): def parse_datetime(key: str):
@ -104,6 +104,9 @@ async def _group_message(
elif command == "昨日群话痨排行榜": elif command == "昨日群话痨排行榜":
state["stop"] = dt.replace(hour=0, minute=0, second=0, microsecond=0) state["stop"] = dt.replace(hour=0, minute=0, second=0, microsecond=0)
state["start"] = state["stop"] - timedelta(days=1) state["start"] = state["stop"] - timedelta(days=1)
elif command == "前日群话痨排行榜":
state["stop"] = dt.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
state["start"] = state["stop"] - timedelta(days=1)
elif command == "本周群话痨排行榜": elif command == "本周群话痨排行榜":
state["start"] = dt.replace( state["start"] = dt.replace(
hour=0, minute=0, second=0, microsecond=0 hour=0, minute=0, second=0, microsecond=0
@ -159,45 +162,46 @@ async def _private_message(
) )
async def handle_message( async def handle_message(
bot: Bot, bot: Bot,
event: Union[GroupMessageEvent, GuildMessageEvent], event: GroupMessageEvent, #Union[GroupMessageEvent,GuildMessageEvent],
stop: datetime = Arg(), stop: datetime = Arg(),
start: datetime = Arg() start: datetime = Arg()
): ):
st = time.time() st = time.time()
if isinstance(event,GroupMessageEvent): #if isinstance(event,GroupMessageEvent):
if plugin_config.dialectlist_excluded_self: if plugin_config.dialectlist_excluded_self:
bot_id = await bot.call_api('get_login_info') bot_id:dict = await bot.call_api('get_login_info')
plugin_config.dialectlist_excluded_people.append(str(bot_id["user_id"])) plugin_config.dialectlist_excluded_people.append(bot_id['user_id'])
gids:List[str] = [str(event.group_id)] print(event.self_id)
msg_list = await get_message_records( msg_list = await get_message_records(
group_ids=gids, bot_ids=[str(bot.self_id)],
platforms=[str('qq')],
group_ids=[str(event.group_id)],
exclude_user_ids=plugin_config.dialectlist_excluded_people, exclude_user_ids=plugin_config.dialectlist_excluded_people,
message_type='group',
time_start=start.astimezone(ZoneInfo("UTC")), time_start=start.astimezone(ZoneInfo("UTC")),
time_stop=stop.astimezone(ZoneInfo("UTC")) time_stop=stop.astimezone(ZoneInfo("UTC"))
) )
elif isinstance(event, GuildMessageEvent): # elif isinstance(event, GuildMessageEvent):
if plugin_config.dialectlist_excluded_self: # if plugin_config.dialectlist_excluded_self:
bot_id = await bot.call_api('get_guild_service_profile') # bot_id = await bot.call_api('get_guild_service_profile')
plugin_config.dialectlist_excluded_people.append(str(bot_id["user_id"])) # guild_ids:List[str] = [str(event.guild_id)]
guild_ids:List[str] = await get_guild_all_channel(event.guild_id,bot=bot) # msg_list = await get_message_records(
msg_list = await get_message_records( # bot_ids=[str(bot.self_id)],
group_ids=guild_ids, # platforms=['qqguild'],
exclude_user_ids=plugin_config.dialectlist_excluded_people, # guild_ids=guild_ids,
message_type='group', # exclude_user_ids=plugin_config.dialectlist_excluded_people,
time_start=start.astimezone(ZoneInfo("UTC")), # time_start=start.astimezone(ZoneInfo("UTC")),
time_stop=stop.astimezone(ZoneInfo("UTC")) # time_stop=stop.astimezone(ZoneInfo("UTC"))
) # )
msg_dict = await msg_counter(msg_list=msg_list) msg_dict = await msg_counter(msg_list=msg_list)
if isinstance(event,GroupMessageEvent): # if isinstance(event,GroupMessageEvent):
msg = await msg_list2msg(msg_list=msg_dict,gid=event.group_id,platform='qq',bot=bot,got_num=plugin_config.dialectlist_get_num) msg = await msg_list2msg(msg_list=msg_dict,gid=event.group_id,platform='qq',bot=bot,got_num=plugin_config.dialectlist_get_num)
elif isinstance(event, GuildMessageEvent): # elif isinstance(event, GuildMessageEvent):
msg = await msg_list2msg(msg_list=msg_dict,gid=event.guild_id,platform='guild',bot=bot,got_num=plugin_config.dialectlist_get_num) # msg = await msg_list2msg(msg_list=msg_dict,gid=event.guild_id,platform='guild',bot=bot,got_num=plugin_config.dialectlist_get_num)
await rankings.send(msg) await rankings.send(msg)
await asyncio.sleep(1) #让图片先发出来 await asyncio.sleep(1) #让图片先发出来

View File

@ -7,8 +7,8 @@ class Config(BaseModel, extra=Extra.ignore):
timezone: Optional[str] timezone: Optional[str]
dialectlist_string_format: str = '{index}名:\n{nickname},{chatdatanum}条消息\n' #消息格式 dialectlist_string_format: str = '{index}名:\n{nickname},{chatdatanum}条消息\n' #消息格式
dialectlist_string_suffix_format: Optional[str] = '你们的职业是水群吗————MYX\n计算花费时间:{timecost}' #消息后缀格式 dialectlist_string_suffix_format: Optional[str] = '数你们聊天记录都要花{timecost}秒,你看看你们多能聊!' #消息后缀格式
dialectlist_get_num:int = 10 #获取人数数量 dialectlist_get_num:int = 5 #获取人数数量
dialectlist_visualization:bool = True #是否可视化 dialectlist_visualization:bool = True #是否可视化
dialectlist_visualization_type:Literal['饼图','圆环图','柱状图'] = '圆环图' #可视化方案 dialectlist_visualization_type:Literal['饼图','圆环图','柱状图'] = '圆环图' #可视化方案
dialectlist_font:str = 'SimHei'#字体格式 dialectlist_font:str = 'SimHei'#字体格式

View File

@ -1,10 +1,8 @@
import pygal import pygal
import unicodedata import unicodedata
from datetime import datetime
from sqlmodel import select, or_
from typing_extensions import Literal from typing_extensions import Literal
from typing import Iterable, List, Optional, Dict from typing import List, Optional, Dict
from pygal.style import Style from pygal.style import Style
from nonebot.log import logger from nonebot.log import logger
@ -12,7 +10,6 @@ from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Message,MessageSegment from nonebot.adapters.onebot.v11 import Message,MessageSegment
from nonebot.exception import ActionFailed from nonebot.exception import ActionFailed
from nonebot_plugin_datastore import create_session
from nonebot_plugin_chatrecorder.model import MessageRecord from nonebot_plugin_chatrecorder.model import MessageRecord
from .config import plugin_config from .config import plugin_config
@ -31,70 +28,6 @@ def remove_control_characters(string:str) -> str:
return "".join(ch for ch in string if unicodedata.category(ch)[0]!="C") return "".join(ch for ch in string if unicodedata.category(ch)[0]!="C")
async def get_message_records(
user_ids: Optional[Iterable[str]] = None,
group_ids: Optional[Iterable[str]] = None,
platforms: Optional[Iterable[str]] = None,
exclude_user_ids: Optional[Iterable[str]] = None,
exclude_group_ids: Optional[Iterable[str]] = None,
message_type: Optional[Literal['private', 'group']] = None,
time_start: Optional[datetime] = None,
time_stop: Optional[datetime] = None,
)->List[MessageRecord]:
"""
:说明:
获取消息记录
:参数:
* ``user_ids: Optional[Iterable[str]]``: 用户列表为空表示所有用户
* ``group_ids: Optional[Iterable[str]]``: 群组列表为空表示所有群组
* ``platform: OPtional[Iterable[str]]``: 消息来源列表为空表示所有来源
* ``exclude_user_ids: Optional[Iterable[str]]``: 不包含的用户列表为空表示不限制
* ``exclude_group_ids: Optional[Iterable[str]]``: 不包含的群组列表为空表示不限制
* ``message_type: Optional[Literal['private', 'group']]``: 消息类型可选值'private' 'group'为空表示所有类型
* ``time_start: Optional[datetime]``: 起始时间UTC 时间为空表示不限制起始时间
* ``time_stop: Optional[datetime]``: 结束时间UTC 时间为空表示不限制结束时间
:返回值:
* ``List[MessageRecord]``:返回信息
"""
whereclause = []
if user_ids:
whereclause.append(
or_(*[MessageRecord.user_id == user_id for user_id in user_ids]) # type: ignore
)
if group_ids:
whereclause.append(
or_(*[MessageRecord.group_id == group_id for group_id in group_ids]) # type: ignore
)
if platforms:
whereclause.append(
or_(*[MessageRecord.platform == platform for platform in platforms]) # type: ignore
)
if exclude_user_ids:
for user_id in exclude_user_ids:
whereclause.append(MessageRecord.user_id != user_id)
if exclude_group_ids:
for group_id in exclude_group_ids:
whereclause.append(MessageRecord.group_id != group_id)
if message_type:
whereclause.append(MessageRecord.detail_type == message_type)
if time_start:
whereclause.append(MessageRecord.time >= time_start)
if time_stop:
whereclause.append(MessageRecord.time <= time_stop)
statement = select(MessageRecord).where(*whereclause)
async with create_session() as session:
records: List[MessageRecord] = (await session.exec(statement)).all() # type: ignore
return records
async def msg_counter(msg_list:List[MessageRecord])->Dict[str,int]: async def msg_counter(msg_list:List[MessageRecord])->Dict[str,int]:
''' '''
计算出话最多的几个人的id并返回 计算出话最多的几个人的id并返回
@ -106,9 +39,9 @@ async def msg_counter(msg_list:List[MessageRecord])->Dict[str,int]:
for i in msg_list: for i in msg_list:
try: try:
lst[i.user_id] +=1 lst[i.user_id] += 1
except KeyError: except KeyError:
lst[i.user_id] =1 lst[i.user_id] = 1
logger.debug(lst) logger.debug(lst)