diff --git a/command.py b/command.py index 7d4b269f..335744c6 100644 --- a/command.py +++ b/command.py @@ -292,14 +292,33 @@ class CommandHub: hub = CommandHub() -def split_args(maxsplit=0): +class CommandArgumentError(Exception): + pass + + +def split_arguments(maxsplit=0): + """ + To use this decorator, you should add a parameter exactly named 'argv' to the function of the command, + which will be set to the split argument list when called. + + However, the first parameter, typically 'args_text', will remain to be the whole argument string, like before. + + :param maxsplit: max split time + """ + def decorator(func): + @functools.wraps(func) def wrapper(argument, *args, **kwargs): - if isinstance(argument, (list, tuple)): - args_list = list(argument) + if argument is None: + raise CommandArgumentError + if kwargs.get('argv') is not None: + argv = kwargs['argv'] + del kwargs['argv'] + elif isinstance(argument, (list, tuple)): + argv = list(argument) else: - args_list = list(filter(lambda arg: arg, re.split('|'.join(_command_args_seps), argument, maxsplit))) - return func(args_list, *args, **kwargs) + argv = list(filter(lambda arg: arg, re.split('|'.join(_command_args_seps), argument, maxsplit))) + return func(argument, argv=argv, *args, **kwargs) return wrapper diff --git a/commands/note.py b/commands/note.py index 07f18417..81ff87db 100644 --- a/commands/note.py +++ b/commands/note.py @@ -1,5 +1,4 @@ import sqlite3 -from functools import wraps from datetime import datetime import pytz @@ -7,7 +6,7 @@ import pytz from command import CommandRegistry from commands import core from interactive import * -from little_shit import get_default_db_path, get_source, get_target +from little_shit import get_default_db_path, get_source, get_target, check_target __registry__ = cr = CommandRegistry() @@ -26,19 +25,6 @@ def _open_db_conn(): return conn -def _check_target(func): - @wraps(func) - def wrapper(args_text, ctx_msg, *args, **kwargs): - target = get_target(ctx_msg) - if not target: - core.echo('似乎出错了,请稍后再试吧~', ctx_msg) - return - else: - return func(args_text, ctx_msg, *args, **kwargs) - - return wrapper - - _cmd_take = 'note.take' _cmd_remove = 'note.remove' @@ -46,7 +32,7 @@ _cmd_remove = 'note.remove' @cr.register('记笔记', '添加笔记') @cr.register('take', 'add', hidden=True) @cr.restrict(group_admin_only=True) -@_check_target +@check_target def take(args_text, ctx_msg, allow_interactive=True): source = get_source(ctx_msg) if allow_interactive and (not args_text or has_session(source, _cmd_take)): @@ -67,7 +53,7 @@ def take(args_text, ctx_msg, allow_interactive=True): @cr.register('列出所有笔记', '查看所有笔记', '所有笔记') @cr.register('list', hidden=True) -@_check_target +@check_target def list_all(_, ctx_msg): conn = _open_db_conn() target = get_target(ctx_msg) @@ -90,7 +76,7 @@ def list_all(_, ctx_msg): @cr.register('删除笔记') @cr.register('remove', 'delete', hidden=True) @cr.restrict(group_admin_only=True) -@_check_target +@check_target def remove(args_text, ctx_msg, allow_interactive=True): source = get_source(ctx_msg) if allow_interactive and (not args_text or has_session(source, _cmd_remove)): @@ -118,7 +104,7 @@ def remove(args_text, ctx_msg, allow_interactive=True): @cr.register('清空笔记', '清空所有笔记', '删除所有笔记') @cr.register('clear', hidden=True) @cr.restrict(group_admin_only=True) -@_check_target +@check_target def clear(_, ctx_msg): conn = _open_db_conn() target = get_target(ctx_msg) diff --git a/commands/scheduler.py b/commands/scheduler.py index d21510df..db82b126 100644 --- a/commands/scheduler.py +++ b/commands/scheduler.py @@ -29,6 +29,7 @@ _scheduler = BackgroundScheduler( _command_args_start_flags = get_command_args_start_flags() _args_split_sep = '[ \n\t]' +_job_id_suffix_start = '@' def _init(): @@ -148,7 +149,7 @@ def add_job(args_text, ctx_msg, internal=False): if len(tmp) < 2: raise _IncompleteArgsError job_id_without_suffix, command_raw = tmp - job_id = job_id_without_suffix + '_' + get_target(ctx_msg) + job_id = job_id_without_suffix + _job_id_suffix_start + get_target(ctx_msg) command_list = [] if multi: command_raw_list = command_raw.split('\n') @@ -190,7 +191,7 @@ def remove_job(args_text, ctx_msg, internal=False): if not job_id_without_suffix: _send_text('请指定计划任务的 ID', ctx_msg, internal) return False - job_id = job_id_without_suffix + '_' + get_target(ctx_msg) + job_id = job_id_without_suffix + _job_id_suffix_start + get_target(ctx_msg) try: _scheduler.remove_job(job_id, 'default') _send_text('成功删除计划任务 ' + job_id_without_suffix, ctx_msg, internal) @@ -208,9 +209,11 @@ def get_job(args_text, ctx_msg, internal=False): if not job_id_without_suffix: _send_text('请指定计划任务的 ID', ctx_msg, internal) return None - job_id = job_id_without_suffix + '_' + get_target(ctx_msg) + job_id = job_id_without_suffix + _job_id_suffix_start + get_target(ctx_msg) job = _scheduler.get_job(job_id, 'default') if internal: + if job: + job.id = job_id_without_suffix return job if not job: core.echo('没有找到该计划任务,请指定正确的计划任务 ID', ctx_msg, internal) @@ -220,7 +223,7 @@ def get_job(args_text, ctx_msg, internal=False): reply += '下次触发时间:\n%s\n' % job.next_run_time.strftime('%Y-%m-%d %H:%M') reply += '命令:\n' command_list = job.kwargs['command_list'] - reply += _convert_command_list_to_str(command_list) + reply += convert_command_list_to_str(command_list) _send_text(reply, ctx_msg, internal) @@ -229,9 +232,11 @@ def get_job(args_text, ctx_msg, internal=False): @_check_target def list_jobs(_, ctx_msg, internal=False): target = get_target(ctx_msg) - job_id_suffix = '_' + target + job_id_suffix = _job_id_suffix_start + target jobs = list(filter(lambda j: j.id.endswith(job_id_suffix), _scheduler.get_jobs('default'))) if internal: + for job in jobs: + job.id = job.id[:-len(job_id_suffix)] return jobs for job in jobs: @@ -240,7 +245,7 @@ def list_jobs(_, ctx_msg, internal=False): reply = 'ID:' + job_id + '\n' reply += '下次触发时间:\n%s\n' % job.next_run_time.strftime('%Y-%m-%d %H:%M') reply += '命令:\n' - reply += _convert_command_list_to_str(command_list) + reply += convert_command_list_to_str(command_list) _send_text(reply, ctx_msg, internal) if len(jobs): _send_text('以上', ctx_msg, internal) @@ -256,12 +261,12 @@ def execute_job(args_text, ctx_msg, internal=False): if not job: core.echo('没有找到该计划任务,请指定正确的计划任务 ID', ctx_msg, internal) return - job_id_suffix = '_' + get_target(ctx_msg) + job_id_suffix = _job_id_suffix_start + get_target(ctx_msg) job_id = job.id[:-len(job_id_suffix)] _call_commands(job_id, job.kwargs['command_list'], job.kwargs['ctx_msg'], internal) -def _convert_command_list_to_str(command_list): +def convert_command_list_to_str(command_list): s = '' if len(command_list) > 1: for c in command_list: diff --git a/commands/subscribe.py b/commands/subscribe.py new file mode 100644 index 00000000..f39114c5 --- /dev/null +++ b/commands/subscribe.py @@ -0,0 +1,164 @@ +import re +from datetime import datetime + +from command import CommandRegistry, split_arguments +from commands import core, scheduler +from interactive import * +from little_shit import get_source, check_target + +__registry__ = cr = CommandRegistry() + +_cmd_subscribe = 'subscribe.subscribe' +_scheduler_job_id_prefix = _cmd_subscribe + '_' + + +@cr.register('subscribe', '订阅') +@cr.restrict(group_admin_only=True) +@split_arguments(maxsplit=1) +@check_target +def subscribe(args_text, ctx_msg, argv=None, allow_interactive=True): + source = get_source(ctx_msg) + if allow_interactive and has_session(source, _cmd_subscribe): + # Already in a session, no need to pass in data, + # because the interactive version of this command will take care of it + return _subscribe_interactively(args_text, ctx_msg, source, None) + + data = {} + if argv: + m = re.match('([0-1]\d|[2][0-3])(?::|:)?([0-5]\d)', argv[0]) + if not m: + # Got command but no time + data['command'] = args_text + else: + # Got time + data['hour'], data['minute'] = m.group(1), m.group(2) + if len(argv) == 2: + # Got command + data['command'] = argv[1] + + if allow_interactive: + if data.keys() != {'command', 'hour', 'minute'}: + # First visit and data is not enough + return _subscribe_interactively(args_text, ctx_msg, source, data) + + # Got both time and command, do the job! + hour, minute = data['hour'], data['minute'] + command = data['command'] + job = scheduler.add_job( + '-H %s -M %s %s %s' % (hour, minute, _scheduler_job_id_prefix + str(datetime.now().timestamp()), command), + ctx_msg, internal=True + ) + if job: + # Succeeded to add a job + print('成功订阅:', hour, minute, command) + reply = '订阅成功,我会在每天 %s 推送哦~' % ':'.join((hour, minute)) + else: + reply = '订阅失败,可能后台出了点小问题~' + + core.echo(reply, ctx_msg) + + +@cr.register('subscribe_list', 'subscribe-list', '订阅列表', '查看订阅', '查看所有订阅', '所有订阅') +@cr.restrict(group_admin_only=True) +@check_target +def subscribe_list(_, ctx_msg, internal=False): + jobs = sorted(filter( + lambda j: j.id.startswith(_scheduler_job_id_prefix), + scheduler.list_jobs('', ctx_msg, internal=True) + ), key=lambda j: j.id) + + if internal: + return jobs + + if not jobs: + core.echo('暂时还没有订阅哦~', ctx_msg) + return + + for index, job in enumerate(jobs): + command_list = job.kwargs['command_list'] + reply = 'ID:' + str(index + 1) + '\n' + reply += '下次推送时间:\n%s\n' % job.next_run_time.strftime('%Y-%m-%d %H:%M') + reply += '命令:\n' + reply += scheduler.convert_command_list_to_str(command_list) + core.echo(reply, ctx_msg) + core.echo('以上~', ctx_msg) + + +@cr.register('unsubscribe', '取消订阅') +@cr.restrict(group_admin_only=True) +@split_arguments() +@check_target +def unsubscribe(_, ctx_msg, argv=None, internal=False): + if not argv: + core.echo('请在命令名后指定要取消订阅的 ID(多个 ID、ID 和命令名之间用空格隔开)哦~\n\n' + '你可以通过「查看所有订阅」命令来查看所有订阅项目的 ID', ctx_msg, internal) + return + + jobs = subscribe_list('', ctx_msg, internal=True) + min_id = 1 + max_id = len(jobs) + if not all(map(lambda x: x.isdigit() and int(x) in range(min_id, max_id + 1), argv)): + core.echo('请输入正确的 ID 哦~\n\n' + '你可以通过「查看所有订阅」命令来查看所有订阅项目的 ID', ctx_msg, internal) + + result = [] + for i in argv: + result.append(scheduler.remove_job(jobs[int(i) - 1].id, ctx_msg, internal=True)) + if all(result): + core.echo('取消订阅成功~', ctx_msg, internal) + else: + core.echo('出了点小问题,可能有一些订阅项目没有成功取消订阅,请使用「查看所有订阅」命令来检查哦~', + ctx_msg, internal) + + +def _subscribe_interactively(args_text, ctx_msg, source, data): + sess = get_session(source, _cmd_subscribe) + if data: + sess.data.update(data) + + state_command = 1 + state_time = 2 + state_finish = -1 + if sess.state == state_command: + if not args_text.strip(): + core.echo('你输入的命令不正确,请重新发送订阅命令哦~', ctx_msg) + sess.state = state_finish + else: + sess.data['command'] = args_text + elif sess.state == state_time: + m = re.match('([0-1]\d|[2][0-3])(?::|:)?([0-5]\d)', args_text.strip()) + if not m: + core.echo('你输入的时间格式不正确,请重新发送订阅命令哦~', ctx_msg) + sess.state = state_finish + else: + sess.data['hour'], sess.data['minute'] = m.group(1), m.group(2) + + if sess.state == state_finish: + remove_session(source, _cmd_subscribe) + return + + if 'command' not in sess.data: + # Ask for command + core.echo( + '请输入你需要订阅的命令(包括所需的参数),不需要加开头的斜杠哦~\n\n' + '例如(序号后的):\n' + '(1) 天气 南京\n' + '(2) 知乎日报\n' + '(3) 历史上的今天', + ctx_msg + ) + sess.state = state_command + return + + if 'hour' not in sess.data or 'minute' not in sess.data: + # Ask for time + core.echo('请输入你需要推送的时间,格式如 22:00', ctx_msg) + sess.state = state_time + return + + subscribe( + '', ctx_msg, + argv=[':'.join((sess.data['hour'], sess.data['minute'])), sess.data['command']], + allow_interactive=False + ) + remove_session(source, _cmd_subscribe) diff --git a/commands/sudo.py b/commands/sudo.py index 9324ecc4..c5da62d8 100644 --- a/commands/sudo.py +++ b/commands/sudo.py @@ -1,6 +1,6 @@ import sqlite3 -from command import CommandRegistry, split_args +from command import CommandRegistry, split_arguments from commands import core from little_shit import get_default_db_path, get_target @@ -26,16 +26,16 @@ def test(_, ctx_msg): @cr.register('block') @cr.restrict(full_command_only=True, superuser_only=True) -@split_args(maxsplit=2) -def block(args, ctx_msg): +@split_arguments(maxsplit=2) +def block(_, ctx_msg, argv=None): def _send_error_msg(): core.echo('参数不正确。\n\n正确使用方法:\nsudo.block wx|qq ', ctx_msg) - if len(args) != 2: + if len(argv) != 2: _send_error_msg() return - via, account = args + via, account = argv # Get a target using a fake context message target = get_target({ 'via': via, @@ -74,16 +74,16 @@ def block_list(_, ctx_msg, internal=False): @cr.register('unblock') @cr.restrict(full_command_only=True, superuser_only=True) -@split_args(maxsplit=2) -def unblock(args, ctx_msg): +@split_arguments(maxsplit=2) +def unblock(_, ctx_msg, argv=None): def _send_error_msg(): core.echo('参数不正确。\n\n正确使用方法:\nsudo.unblock wx|qq ', ctx_msg) - if len(args) != 2: + if len(argv) != 2: _send_error_msg() return - via, account = args + via, account = argv # Get a target using a fake context message target = get_target({ 'via': via, diff --git a/commands/weather.py b/commands/weather.py index 883e770f..c0b97926 100644 --- a/commands/weather.py +++ b/commands/weather.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta import requests -from command import CommandRegistry, split_args +from command import CommandRegistry, split_arguments from commands import core from little_shit import get_source, get_db_dir, get_tmp_dir from interactive import * @@ -25,14 +25,14 @@ _weekday_string = ['周一', '周二', '周三', '周四', '周五', '周六', ' @cr.register('weather') @cr.register('天气', '查天气', '天气预报', '查天气预报') -@split_args() -def weather(args, ctx_msg, allow_interactive=True): +@split_arguments() +def weather(args_text, ctx_msg, argv: list = None, allow_interactive=True): source = get_source(ctx_msg) - if allow_interactive and (len(args) < 1 or not args[0].startswith('CN') or has_session(source, _cmd_weather)): + if allow_interactive and (not argv or not argv[0].startswith('CN') or has_session(source, _cmd_weather)): # Be interactive - return _do_interactively(_cmd_weather, weather, args, ctx_msg, source) + return _do_interactively(_cmd_weather, weather, args_text.strip(), ctx_msg, source) - city_id = args[0] + city_id = argv[0] text = '' data = _get_weather(city_id) @@ -41,12 +41,13 @@ def weather(args, ctx_msg, allow_interactive=True): now = data['now'] aqi = data['aqi']['city'] - text += '\n\n实时:\n%s,气温%s˚C,体感温度%s˚C,%s%s级,能见度%skm,空气质量指数:%s,%s,PM2.5:%s,PM10:%s' \ + text += '\n\n实时:\n\n%s,气温%s˚C,体感温度%s˚C,%s%s级,' \ + '能见度%skm,空气质量指数:%s,%s,PM2.5:%s,PM10:%s' \ % (now['cond']['txt'], now['tmp'], now['fl'], now['wind']['dir'], now['wind']['sc'], now['vis'], aqi['aqi'], aqi['qlty'], aqi['pm25'], aqi['pm10']) daily_forecast = data['daily_forecast'] - text += '\n\n预报:\n' + text += '\n\n预报:\n\n' for forecast in daily_forecast: d = datetime.strptime(forecast['date'], '%Y-%m-%d') @@ -59,24 +60,25 @@ def weather(args, ctx_msg, allow_interactive=True): text += forecast['tmp']['min'] + '~' + forecast['tmp']['max'] + '°C,' text += forecast['wind']['dir'] + forecast['wind']['sc'] + '级,' text += '降雨概率%s%%' % forecast['pop'] - text += '\n' + text += '\n\n' + text = text.rstrip() if text: core.echo(text, ctx_msg) else: core.echo('查询失败了,请稍后再试哦~', ctx_msg) -@cr.register('suggestion') +@cr.register('suggestion', hidden=True) @cr.register('生活指数', '生活建议', '天气建议') -@split_args() -def suggestion(args, ctx_msg, allow_interactive=True): +@split_arguments() +def suggestion(args_text, ctx_msg, argv: list = None, allow_interactive=True): source = get_source(ctx_msg) - if allow_interactive and (len(args) < 1 or not args[0].startswith('CN') or has_session(source, _cmd_suggestion)): + if allow_interactive and (len(argv) < 1 or not argv[0].startswith('CN') or has_session(source, _cmd_suggestion)): # Be interactive - return _do_interactively(_cmd_suggestion, suggestion, args, ctx_msg, source) + return _do_interactively(_cmd_suggestion, suggestion, args_text.strip(), ctx_msg, source) - city_id = args[0] + city_id = argv[0] text = '' data = _get_weather(city_id) @@ -101,9 +103,9 @@ def suggestion(args, ctx_msg, allow_interactive=True): _state_machines = {} -def _do_interactively(command_name, func, args, ctx_msg, source): +def _do_interactively(command_name, func, args_text, ctx_msg, source): def ask_for_city(s, a, c): - if len(a) > 0: + if a: if search_city(s, a, c): return True else: @@ -111,11 +113,11 @@ def _do_interactively(command_name, func, args, ctx_msg, source): s.state += 1 def search_city(s, a, c): - if len(a) < 1: + if not a: core.echo('你输入的城市不正确哦,请重新发送命令~', c) return True - city_list = _get_city_list(a[0]) + city_list = _get_city_list(a) if not city_list: core.echo('没有找到你输入的城市哦,请重新发送命令~', c) @@ -125,7 +127,7 @@ def _do_interactively(command_name, func, args, ctx_msg, source): if len(city_list) == 1: # Directly choose the first one - choose_city(s, ['1'], c) + choose_city(s, '1', c) return True # Here comes more than one city with the same name @@ -140,11 +142,11 @@ def _do_interactively(command_name, func, args, ctx_msg, source): s.state += 1 def choose_city(s, a, c): - if len(a) != 1 or not a[0].isdigit(): + if not a or not a.isdigit(): core.echo('你输入的序号不正确哦,请重新发送命令~', c) return True - choice = int(a[0]) - 1 # Should be from 0 to len(city_list) - 1 + choice = int(a) - 1 # Should be from 0 to len(city_list) - 1 city_list = s.data['city_list'] if choice < 0 or choice >= len(city_list): core.echo('你输入的序号超出范围了,请重新发送命令~', c) @@ -164,7 +166,7 @@ def _do_interactively(command_name, func, args, ctx_msg, source): sess = get_session(source, command_name) sess.data['func'] = func - if _state_machines[command_name][sess.state](sess, args, ctx_msg): + if _state_machines[command_name][sess.state](sess, args_text, ctx_msg): # Done remove_session(source, command_name) diff --git a/docs/Write_Command.md b/docs/Write_Command.md index 71f658ef..7ee5638d 100644 --- a/docs/Write_Command.md +++ b/docs/Write_Command.md @@ -104,13 +104,13 @@ Source 表示命令的来源(由谁发出),Target 表示命令将对谁产 ## 命令参数 -命令的函数的第一个参数为命令参数,默认情况下,是一个字符串,即用户发送的消息中命令后面的内容,可以自行切割、分析。如果需要使用默认的命令参数分隔符,可以使用 `command.py` 中的 `split_args` 装饰器,使用之后,命令的函数接收到的第一个参数将变为命令参数列表,而不再是字符串。例如: +命令的函数的第一个参数为命令参数,默认情况下,是一个字符串,即用户发送的消息中命令后面的内容,可以自行切割、分析。如果需要使用默认的命令参数分隔符,可以使用 `command.py` 中的 `split_arguments` 装饰器,使用之后,命令的函数将接受到一个名为 `argv` 的参数,为分割后的参数列表,而原先第一个参数还保留为原字符串。例如: ```python @__registry__.register('test') @__registry__.restrict(group_admin_only=True) -@split_args() -def test(args, ctx_msg): - if len(args) > 0: +@split_arguments() +def test(args_text, ctx_msg, argv=None): + if argv: print(args[0]) ``` diff --git a/little_shit.py b/little_shit.py index ff3c7a69..102f53ab 100644 --- a/little_shit.py +++ b/little_shit.py @@ -1,9 +1,11 @@ import os import hashlib import random +import functools from datetime import datetime from config import config +from apiclient import client as api class SkipException(Exception): @@ -87,6 +89,23 @@ def get_target(ctx_msg): return None +def check_target(func): + """ + This decorator checks whether there is a target value, and prevent calling the function if not. + """ + + @functools.wraps(func) + def wrapper(args_text, ctx_msg, *args, **kwargs): + target = get_target(ctx_msg) + if not target: + api.send_message('当前语境无法使用这个命令,请尝试发送私聊消息或稍后再试吧~', ctx_msg) + return + else: + return func(args_text, ctx_msg, *args, **kwargs) + + return wrapper + + def get_command_start_flags(): return tuple(sorted(config.get('command_start_flags', ('',)), reverse=True)) @@ -96,11 +115,11 @@ def get_command_name_separators(): def get_command_args_start_flags(): - return tuple(sorted(('[ \t\n]',) + config.get('command_args_start_flags', ()), reverse=True)) + return tuple(sorted(('[ \t\n]+',) + config.get('command_args_start_flags', ()), reverse=True)) def get_command_args_separators(): - return tuple(sorted(('[ \t\n]',) + config.get('command_args_separators', ()), reverse=True)) + return tuple(sorted(('[ \t\n]+',) + config.get('command_args_separators', ()), reverse=True)) def get_fallback_command():