适配onebotV12+代码结构优化[没有测试

This commit is contained in:
unknown 2023-04-09 09:33:57 +08:00
parent 06a6f7ce0d
commit fa48ca798e
3 changed files with 355 additions and 218 deletions

View File

@ -1,73 +1,29 @@
import re import re
import time import time
import asyncio
from typing import Tuple, Union from typing import Tuple, Union
from datetime import datetime, timedelta from datetime import datetime, timedelta
try:
from zoneinfo import ZoneInfo
except ImportError:
from backports.zoneinfo import ZoneInfo # type: ignore
from nonebot import on_command, require from nonebot import on_command, require
from nonebot.log import logger from nonebot.log import logger
from nonebot.params import Command, CommandArg, Arg, Depends from nonebot.params import Command, CommandArg, Arg, Depends
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.adapters import Bot from nonebot.adapters.onebot import V11Bot, V12Bot, V11Event, V12Event, V11Message, V12Message # type: ignore
from nonebot.adapters.onebot.v11 import GroupMessageEvent, PrivateMessageEvent, Message
try:
from zoneinfo import ZoneInfo
except ImportError:
from backports.zoneinfo import ZoneInfo # type: ignore
require("nonebot_plugin_chatrecorder") require("nonebot_plugin_chatrecorder")
from nonebot_plugin_chatrecorder import get_message_records from nonebot_plugin_chatrecorder import get_message_records
require("nonebot_plugin_guild_patch")
from nonebot_plugin_guild_patch import GuildMessageEvent
from .function import msg_counter, msg_list2msg from .function import *
from .config import plugin_config from .config import plugin_config
def parse_datetime(key: str):
"""解析数字,并将结果存入 state 中"""
async def _key_parser( ranks = on_command(
matcher: Matcher, "群话痨排行榜",
state: T_State,
input: Union[datetime, Message] = Arg(key)
):
if isinstance(input, datetime):
return
plaintext = input.extract_plain_text()
try:
state[key] = get_datetime_fromisoformat_with_timezone(plaintext)
except ValueError:
await matcher.reject_arg(key, "请输入正确的日期,不然我没法理解呢!")
return _key_parser
def get_datetime_now_with_timezone() -> datetime:
"""获取当前时间,并包含时区信息"""
if plugin_config.timezone:
return datetime.now(ZoneInfo(plugin_config.timezone))
else:
return datetime.now().astimezone()
def get_datetime_fromisoformat_with_timezone(date_string: str) -> datetime:
"""从 iso8601 格式字符串中获取时间,并包含时区信息"""
if plugin_config.timezone:
return datetime.fromisoformat(date_string).astimezone(
ZoneInfo(plugin_config.timezone)
)
else:
return datetime.fromisoformat(date_string).astimezone()
rankings = on_command(
'群话痨排行榜',
aliases={ aliases={
"今日群话痨排行榜", "今日群话痨排行榜",
"昨日群话痨排行榜", "昨日群话痨排行榜",
@ -78,26 +34,34 @@ rankings = on_command(
"历史群话痨排行榜", "历史群话痨排行榜",
}, },
priority=6, priority=6,
block=True block=True,
) )
@rankings.handle()
async def _group_message(
event:Union[GroupMessageEvent, GuildMessageEvent],
state: T_State,commands: Tuple[str, ...] = Command(),
args: Message = CommandArg()
):
if isinstance(event, GroupMessageEvent): @ranks.handle()
logger.debug('handle command from qq') async def _group_message(
elif isinstance(event, GuildMessageEvent): matcher: Matcher,
logger.debug('handle command from qqguild') event: Union[
V11Event.GroupMessageEvent,
V12Event.GroupMessageEvent,
V12Event.ChannelMessageEvent,
],
state: T_State,
commands: Tuple[str, ...] = Command(),
args: Union[V11Message, V11Message] = CommandArg(),
):
if isinstance(event, V11Event.GroupMessageEvent):
logger.debug("handle command from onebotV11 adapter(qq)")
elif isinstance(event, V12Event.GroupMessageEvent):
logger.debug("handle command from onebotV12 adapter")
dt = get_datetime_now_with_timezone() dt = get_datetime_now_with_timezone()
command = commands[0] command = commands[0]
if command == "群话痨排行榜": if command == "群话痨排行榜":
state["start"] = dt.replace(year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0) state["start"] = dt.replace(
year=2000, month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
state["stop"] = dt state["stop"] = dt
elif command == "今日群话痨排行榜": elif command == "今日群话痨排行榜":
state["start"] = dt.replace(hour=0, minute=0, second=0, microsecond=0) state["start"] = dt.replace(hour=0, minute=0, second=0, microsecond=0)
@ -106,7 +70,9 @@ async def _group_message(
state["stop"] = dt.replace(hour=0, minute=0, second=0, microsecond=0) state["stop"] = dt.replace(hour=0, minute=0, second=0, microsecond=0)
state["start"] = state["stop"] - timedelta(days=1) state["start"] = state["stop"] - timedelta(days=1)
elif command == "前日群话痨排行榜": elif command == "前日群话痨排行榜":
state["stop"] = dt.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1) state["stop"] = dt.replace(
hour=0, minute=0, second=0, microsecond=0
) - timedelta(days=1)
state["start"] = state["stop"] - timedelta(days=1) state["start"] = state["stop"] - timedelta(days=1)
elif command == "本周群话痨排行榜": elif command == "本周群话痨排行榜":
state["start"] = dt.replace( state["start"] = dt.replace(
@ -145,74 +111,71 @@ async def _group_message(
) )
state["stop"] = state["start"] + timedelta(days=1) state["stop"] = state["start"] + timedelta(days=1)
except ValueError: except ValueError:
await rankings.finish("请输入正确的日期,不然我没法理解呢!") await matcher.finish("请输入正确的日期,不然我没法理解呢!")
else: else:
pass pass
@rankings.handle()
@ranks.handle()
async def _private_message( async def _private_message(
event:PrivateMessageEvent, matcher: Matcher,
state: T_State,commands: Tuple[str, ...] = Command(), event: Union[V11Event.GroupMessageEvent, V12Event.GroupMessageEvent],
args: Message = CommandArg() state: T_State,
commands: Tuple[str, ...] = Command(),
args: Union[V11Message, V12Message] = CommandArg(),
): ):
# TODO:支持私聊的查询 # TODO:支持私聊的查询
await rankings.finish('暂不支持私聊查询,今后可能会添加这一项功能') await matcher.finish("暂不支持私聊查询,今后可能会添加这一项功能")
@rankings.got(
@ranks.got(
"start", "start",
prompt="请输入你要查询的起始日期(如 2022-01-01", prompt="请输入你要查询的起始日期(如 2022-01-01",
parameterless=[Depends(parse_datetime("start"))] parameterless=[Depends(parse_datetime("start"))],
) )
@rankings.got( @ranks.got(
"stop", "stop",
prompt="请输入你要查询的结束日期(如 2022-02-22", prompt="请输入你要查询的结束日期(如 2022-02-22",
parameterless=[Depends(parse_datetime("stop"))] parameterless=[Depends(parse_datetime("stop"))],
) )
async def handle_message( async def handle_message(
bot: Bot, matcher: Matcher,
event: GroupMessageEvent, #Union[GroupMessageEvent,GuildMessageEvent], bot: Union[V11Bot, V12Bot],
event: Union[
V11Event.GroupMessageEvent,
V12Event.GroupMessageEvent,
V12Event.ChannelMessageEvent,
],
stop: datetime = Arg(), stop: datetime = Arg(),
start: datetime = Arg() start: datetime = Arg(),
): ):
st = time.time() st = time.time()
#if isinstance(event,GroupMessageEvent):
if plugin_config.dialectlist_excluded_self: if plugin_config.dialectlist_excluded_self:
bot_id:dict = await bot.call_api('get_login_info') bot_id = await bot.call_api("get_login_info")
plugin_config.dialectlist_excluded_people.append(bot_id['user_id']) plugin_config.dialectlist_excluded_people.append(bot_id["user_id"])
print(event.self_id)
msg_list = await get_message_records( msg_list = await get_message_records(
bot_ids=[str(bot.self_id)], bot_ids=[str(bot.self_id)],
platforms=[str('qq')], platforms=[str(bot.platform)],
group_ids=[str(event.group_id)], group_ids=[str(event.group_id)]
if isinstance(event, (V11Event.GroupMessageEvent, V12Event.GroupMessageEvent))
else None,
guild_ids=[str(event.guild_id)]
if isinstance(event, V12Event.ChannelMessageEvent)
else None,
exclude_user_ids=plugin_config.dialectlist_excluded_people, exclude_user_ids=plugin_config.dialectlist_excluded_people,
time_start=start.astimezone(ZoneInfo("UTC")), time_start=start.astimezone(ZoneInfo("UTC")),
time_stop=stop.astimezone(ZoneInfo("UTC")) time_stop=stop.astimezone(ZoneInfo("UTC")),
) )
# elif isinstance(event, GuildMessageEvent): if isinstance(event, V11Event.GroupMessageEvent):
# if plugin_config.dialectlist_excluded_self: processer = V11GroupMsgProcesser(bot=bot, gid=str(event.group_id), msg_list=msg_list) # type: ignore
# bot_id = await bot.call_api('get_guild_service_profile') elif isinstance(event, V12Event.GroupMessageEvent):
# guild_ids:List[str] = [str(event.guild_id)] processer = V12GroupMsgProcesser(bot=bot, gid=str(event.group_id), msg_list=msg_list) # type: ignore
# msg_list = await get_message_records( elif isinstance(event, V12Event.ChannelMessageEvent):
# bot_ids=[str(bot.self_id)], pass
# platforms=['qqguild'], else:
# guild_ids=guild_ids, raise NotImplementedError("没支持呢(())")
# exclude_user_ids=plugin_config.dialectlist_excluded_people,
# time_start=start.astimezone(ZoneInfo("UTC")),
# time_stop=stop.astimezone(ZoneInfo("UTC"))
# )
msg_dict = await msg_counter(msg_list=msg_list)
# if isinstance(event,GroupMessageEvent):
msg = await msg_list2msg(msg_list=msg_dict,gid=event.group_id,platform='qq',bot=bot,got_num=plugin_config.dialectlist_get_num)
# elif isinstance(event, GuildMessageEvent):
# msg = await msg_list2msg(msg_list=msg_dict,gid=event.guild_id,platform='guild',bot=bot,got_num=plugin_config.dialectlist_get_num)
await rankings.send(msg)
await asyncio.sleep(1) #让图片先发出来
if plugin_config.dialectlist_string_suffix_format:
await rankings.finish(plugin_config.dialectlist_string_suffix_format.format(timecost=time.time()-st-1))
msg = await processer.get_send_msg() # type: ignore
await matcher.send(msg)

View File

@ -4,16 +4,15 @@ from pydantic import BaseModel, Extra
class Config(BaseModel, extra=Extra.ignore): class Config(BaseModel, extra=Extra.ignore):
timezone: Optional[str] timezone: Optional[str]
dialectlist_string_format: str = '{index}名:\n{nickname},{chatdatanum}条消息\n' #消息格式 dialectlist_string_format: str = "{index}名:\n{nickname},{chatdatanum}条消息\n" # 消息格式
dialectlist_string_suffix_format: Optional[str] = '数你们聊天记录都要花{timecost}秒,你看看你们多能聊!' #消息后缀格式
dialectlist_get_num: int = 5 # 获取人数数量 dialectlist_get_num: int = 5 # 获取人数数量
dialectlist_visualization: bool = True # 是否可视化 dialectlist_visualization: bool = True # 是否可视化
dialectlist_visualization_type:Literal['饼图','圆环图','柱状图'] = '圆环图' #可视化方案 dialectlist_visualization_type: Literal["饼图", "圆环图", "柱状图"] = "圆环图" # 可视化方案
dialectlist_font:str = 'SimHei'#字体格式 dialectlist_font: str = "SimHei" # 字体格式
dialectlist_excluded_people:List[str] = []#排除的人的QQ号(或频道号?(未经测试)) dialectlist_excluded_people: List[str] = [] # 排除的人的QQ号
dialectlist_excluded_self: bool = True dialectlist_excluded_self: bool = True
global_config = get_driver().config global_config = get_driver().config
plugin_config = Config(**global_config.dict()) plugin_config = Config(**global_config.dict())

View File

@ -1,23 +1,47 @@
import abc
import pygal import pygal
import unicodedata import unicodedata
import requests
from datetime import datetime
from typing_extensions import Literal from typing import List, Dict, Union
from typing import List, Optional, Dict
from pygal.style import Style from pygal.style import Style
from nonebot import require
from nonebot.log import logger from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.params import Arg
from nonebot.adapters.onebot.v11 import Message,MessageSegment from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Message
from nonebot.adapters.onebot import V11Bot, V12Bot, V11Message, V12Message, V11MessageSegment, V12MessageSegment # type: ignore
from nonebot.exception import ActionFailed from nonebot.exception import ActionFailed
try:
from zoneinfo import ZoneInfo
except ImportError:
from backports.zoneinfo import ZoneInfo # type: ignore
require("nonebot_plugin_htmlrender")
from nonebot_plugin_htmlrender import (
md_to_pic,
html_to_pic,
text_to_pic,
capture_element,
template_to_pic,
template_to_html,
)
require("nonebot_plugin_chatrecorder")
from nonebot_plugin_chatrecorder import get_message_records
from nonebot_plugin_chatrecorder.model import MessageRecord from nonebot_plugin_chatrecorder.model import MessageRecord
from .config import plugin_config from .config import plugin_config
style = Style(font_family=plugin_config.dialectlist_font) style = Style(font_family=plugin_config.dialectlist_font)
def remove_control_characters(string: str) -> str: def remove_control_characters(string: str) -> str:
"""将字符串中的控制符去除 """### 将字符串中的控制符去除
Args: Args:
string (str): 需要去除的字符串 string (str): 需要去除的字符串
@ -28,14 +52,57 @@ def remove_control_characters(string:str) -> str:
return "".join(ch for ch in string if unicodedata.category(ch)[0] != "C") return "".join(ch for ch in string if unicodedata.category(ch)[0] != "C")
async def msg_counter(msg_list:List[MessageRecord])->Dict[str,int]: def parse_datetime(key: str):
''' """解析数字,并将结果存入 state 中"""
计算出话最多的几个人的id并返回
''' async def _key_parser(
matcher: Matcher,
state: T_State,
input: Union[datetime, Union[V11Message, V12Message]] = Arg(key),
):
if isinstance(input, datetime):
return
plaintext = input.extract_plain_text()
try:
state[key] = get_datetime_fromisoformat_with_timezone(plaintext)
except ValueError:
await matcher.reject_arg(key, "请输入正确的日期,不然我没法理解呢!")
return _key_parser
def get_datetime_now_with_timezone() -> datetime:
"""获取当前时间,并包含时区信息"""
if plugin_config.timezone:
return datetime.now(ZoneInfo(plugin_config.timezone))
else:
return datetime.now().astimezone()
def get_datetime_fromisoformat_with_timezone(date_string: str) -> datetime:
"""从 iso8601 格式字符串中获取时间,并包含时区信息"""
if plugin_config.timezone:
return datetime.fromisoformat(date_string).astimezone(
ZoneInfo(plugin_config.timezone)
)
else:
return datetime.fromisoformat(date_string).astimezone()
def msg_counter(msg_list: List[MessageRecord]) -> Dict[str, int]:
"""### 计算每个人的消息量
Args:
msg_list (list[MessageRecord]): 需要处理的消息列表
Returns:
(dict[str,int]): 处理后的消息数量字典,键为用户,值为消息数量
"""
lst: Dict[str, int] = {} lst: Dict[str, int] = {}
msg_len = len(msg_list) msg_len = len(msg_list)
logger.info('wow , there are {} msgs to count !!!'.format(msg_len)) logger.info("wow , there are {} msgs to count !!!".format(msg_len))
for i in msg_list: for i in msg_list:
try: try:
@ -47,78 +114,186 @@ async def msg_counter(msg_list:List[MessageRecord])->Dict[str,int]:
return lst return lst
async def msg_list2msg(
msg_list:Dict[str,int],
gid:int,
got_num:int,
platform:Optional[Literal['guild', 'qq']],
bot:Bot
)->Message:
ranking = [] def got_rank(msg_dict: Dict[str, int]) -> List[List[Union[str, int]]]:
while len(ranking) < got_num: """### 获得排行榜
Args:
msg_dict (Dict[str,int]): 要处理的字典
Returns:
List[Tuple[str,int]]: 排行榜列表(已排序)
"""
rank = []
while len(rank) < plugin_config.dialectlist_get_num:
try: try:
maxkey = max(msg_list, key=msg_list.get) # type: ignore max_key = max(msg_dict.items(), key=lambda x: x[1])
rank.append(list(max_key))
except ValueError: except ValueError:
ranking.append(["null",0]) rank.append(["null", 0])
continue continue
logger.debug('searching member {} from group {}'.format(str(maxkey),str(gid))) return rank
try:
if platform == 'qq': class MsgProcesser(abc.ABC):
member_info = await bot.call_api( def __init__(self, bot: Bot, gid: str, msg_list: List[MessageRecord]) -> None:
"get_group_member_info", if isinstance(bot, Bot):
group_id=int(gid), self.bot = bot
user_id=int(maxkey),
no_cache=True
)
nickname:str = member_info['nickname']if not member_info['card'] else member_info['card']
else: else:
member_info = await bot.call_api( self.bot = None
"get_guild_member_profile", self.gid = gid
guild_id=str(gid), self.rank = got_rank(msg_counter(msg_list))
user_id=str(maxkey)
)
nickname:str = member_info['nickname']
ranking.append([remove_control_characters(nickname).strip(),msg_list.pop(maxkey)])
except ActionFailed as e:
logger.warning(e)
logger.warning('member {} is not exit in group {}'.format(str(maxkey),str(gid)))
msg_list.pop(maxkey)
@abc.abstractmethod
async def get_nickname_list(self) -> List:
"""
### 获得昵称
#### 抽象原因
要对onebot协议不同版本进行适配
"""
raise NotImplementedError
logger.debug('loaded list:\n{}'.format(ranking)) @abc.abstractmethod
def get_head_portrait_urls(self) -> List:
raise NotImplementedError
@abc.abstractmethod
async def get_send_msg(self) -> Message:
raise NotImplementedError
async def get_msg(self) -> List[Union[str, bytes, None]]:
str_msg = await self.render_template_msg()
pic_msg = None
if plugin_config.dialectlist_visualization: if plugin_config.dialectlist_visualization:
if plugin_config.dialectlist_visualization_type == '圆环图': try:
pic_msg = self.render_template_pic()
except OSError:
plugin_config.dialectlist_visualization = False
str_msg += "\n\n无法发送可视化图片请检查是否安装GTK+详细安装教程可见github\nhttps://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer \n若不想安装这个软件,再次使用这个指令不会显示这个提示"
return [str_msg, pic_msg]
async def render_template_msg(self) -> str:
"""渲染文字"""
string: str = ""
rank: List = self.rank
nicknames: List = await self.get_nickname_list()
for i in range(len(rank)):
index = i + 1
nickname, chatdatanum = nicknames[i], rank[i]
str_example = plugin_config.dialectlist_string_format.format(
index=index, nickname=nickname, chatdatanum=chatdatanum
)
string += str_example
return string
def render_template_pic(self) -> bytes:
if plugin_config.dialectlist_visualization_type == "圆环图":
view = pygal.Pie(inner_radius=0.6, style=style) view = pygal.Pie(inner_radius=0.6, style=style)
elif plugin_config.dialectlist_visualization_type == '饼图': elif plugin_config.dialectlist_visualization_type == "饼图":
view = pygal.Pie(style=style) view = pygal.Pie(style=style)
else: else:
view = pygal.Bar(style=style) view = pygal.Bar(style=style)
view.title = '消息可视化' view.title = "消息可视化"
for i in ranking: for i, j in zip(self.rank, self.get_nickname_list()): # type: ignore
view.add(str(i[0]),int(i[1])) view.add(str(j), int(i[1]))
try:
png: bytes = view.render_to_png() # type: ignore png: bytes = view.render_to_png() # type: ignore
process_msg = Message(MessageSegment.image(png)) self.img = png
except OSError: return png
logger.error('GTK+(GIMP Toolkit) is not installed, the svg can not be transformed to png')
plugin_config.dialectlist_visualization = False
process_msg = Message('无法发送可视化图片请检查是否安装GTK+详细安装教程可见github\nhttps://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer \n若不想安装这个软件,再次使用这个指令不会显示这个提示')
else:
process_msg = ''
out:str = ''
for i in range(got_num):
index = i+1
nickname,chatdatanum = ranking[i]
str_example = plugin_config.dialectlist_string_format.format(index=index,nickname=nickname,chatdatanum=chatdatanum)
out = out + str_example
logger.debug(out) class V11GroupMsgProcesser(MsgProcesser):
def __init__(self, bot: V11Bot, gid: str, msg_list: List[MessageRecord]) -> None:
super().__init__(bot, gid, msg_list)
self.bot = bot
return Message(out)+process_msg async def get_nickname_list(self) -> List:
nicknames = []
for i in range(len(self.rank)):
try:
member_info = await self.bot.get_group_member_info(
group_id=int(self.gid), user_id=int(self.rank[i][0]), no_cache=True
)
nicknames.append(
member_info["nickname"]
if not member_info["card"]
else member_info["card"]
)
except ActionFailed as e:
nicknames.append("{}这家伙不在群里了".format(self.rank[i][0]))
return nicknames
def get_head_portrait_urls(self) -> List:
self.portrait_urls = [
"http://q2.qlogo.cn/headimg_dl?dst_uin={}&spec=640".format(i[0])
for i in self.rank
]
return self.portrait_urls
async def get_send_msg(self) -> V11Message:
msgs: List = await self.get_msg()
msg = V11Message()
msg += V11MessageSegment.text(msgs[0]) # type: ignore
msg += V12MessageSegment.image(msgs[1]) # type: ignore
return msg
class V12MsgProcesser(MsgProcesser):
def __init__(self, bot: V12Bot, gid: str, msg_list: List[MessageRecord]) -> None:
super().__init__(bot, gid, msg_list)
self.bot = bot
async def get_send_msg(self) -> V12Message:
msgs: List = await self.get_msg()
msg = V12Message()
msg += V12MessageSegment.text(msgs[0]) # type: ignore
msg += V12MessageSegment.image(msgs[1]) # type: ignore
return msg
def get_head_portrait_urls(self) -> List:
return super().get_head_portrait_urls()
class V12GroupMsgProcesser(V12MsgProcesser):
def __init__(self, bot: V12Bot, gid: str, msg_list: List[MessageRecord]) -> None:
super().__init__(bot, gid, msg_list)
async def get_nickname_list(self) -> List:
nicknames = []
for i in range(len(self.rank)):
try:
member_info = await self.bot.get_group_member_info(
group_id=str(self.gid), user_id=str(self.rank[i][0]), no_cache=True
)
nicknames.append(
member_info["user_displayname"]
if member_info["user_displayname"]
else member_info["user_name"]
)
except ActionFailed as e:
nicknames.append("{}这家伙不在群里了".format(self.rank[i][0]))
return nicknames
class V12GuildMsgProcesser(V12MsgProcesser):
def __init__(self, bot: V12Bot, gid: str, msg_list: List[MessageRecord]) -> None:
super().__init__(bot, gid, msg_list)
async def get_nickname_list(self) -> List:
nicknames = []
for i in range(len(self.rank)):
try:
member_info = await self.bot.get_guild_member_info(
guild_id=str(self.gid), user_id=str(self.rank[i][0]), no_cache=True
)
nicknames.append(
member_info["user_displayname"]
if member_info["user_displayname"]
else member_info["user_name"]
)
except ActionFailed as e:
nicknames.append("{}这家伙不在群里了".format(self.rank[i][0]))
return nicknames