diff --git a/none/command.py b/none/command.py index 9d648168..c8994fde 100644 --- a/none/command.py +++ b/none/command.py @@ -1,10 +1,12 @@ import re +from collections import defaultdict from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional from aiocqhttp import CQHttp, Error as CQHttpError from aiocqhttp.message import Message from . import permissions as perm +from .helpers import context_source # Key: str (one segment of command name) # Value: subtree or a leaf Command object @@ -15,12 +17,8 @@ _registry = {} _aliases = {} # Key: context source -# Value: Command object -_sessions = {} - - -# TODO: session 保存为一个栈,命令可以调用命令,进入新的 session,命令执行完毕, -# 中间没有抛出异常(标志进入交互模式的异常),则从栈中 pop +# Value: list (stack) of Session objects +_sessions = defaultdict(list) class Command: @@ -74,89 +72,6 @@ async def calculate_permission(bot: CQHttp, ctx: Dict[str, Any]) -> int: return permission -def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]: - cmd_name = name if isinstance(name, tuple) else (name,) - - if not cmd_name: - return None - - cmd_tree = _registry - for part in cmd_name[:-1]: - if part not in cmd_tree: - return None - cmd_tree = cmd_tree[part] - - return cmd_tree.get(cmd_name[-1]) - - -class Session: - __slots__ = ('cmd', 'ctx', - 'current_key', 'current_arg', 'current_arg_text', - 'images', 'args', 'last_interaction') - - def __init__(self, cmd: Command, ctx: Dict[str, Any], *, - current_arg: str = '', args: Dict[str, Any] = None): - self.cmd = cmd - self.ctx = ctx - self.current_key = None - self.current_arg = current_arg - self.current_arg_text = Message(current_arg).extract_plain_text() - self.images = [s.data['url'] for s in ctx['message'] - if s.type == 'image' and 'url' in s.data] - self.args = args or {} - self.last_interaction = None - - def require_arg(self, key: str, prompt: str = '', *, - interactive: bool = True): - # TODO: 检查 key 是否在 args 中,如果不在,抛出异常,保存 session,等待用户填充 - pass - - -async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool: - # TODO: check if there is a session - msg_text = str(ctx['message']).lstrip() - - for start in bot.config.COMMAND_START: - if isinstance(start, type(re.compile(''))): - m = start.search(msg_text) - if m: - full_command = msg_text[len(m.group(0)):].lstrip() - break - elif isinstance(start, str): - if msg_text.startswith(start): - full_command = msg_text[len(start):].lstrip() - break - else: - # it's not a command - return False - - if not full_command: - # command is empty - return False - - cmd_name_text, *cmd_remained = full_command.split(maxsplit=1) - cmd_name = _aliases.get(cmd_name_text) - - if not cmd_name: - for sep in bot.config.COMMAND_SEP: - if isinstance(sep, type(re.compile(''))): - cmd_name = tuple(sep.split(cmd_name_text)) - break - elif isinstance(sep, str): - cmd_name = tuple(cmd_name_text.split(sep)) - break - else: - cmd_name = (cmd_name_text,) - - cmd = _find_command(cmd_name) - if not cmd: - return False - - session = Session(cmd, ctx, current_arg=''.join(cmd_remained)) - # TODO: 插入 session - return await cmd.run(bot, session) - - def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (), permission: int = perm.EVERYONE) -> Callable: def deco(func: Callable) -> Callable: @@ -185,6 +100,159 @@ def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (), return deco +def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]: + cmd_name = name if isinstance(name, tuple) else (name,) + + if not cmd_name: + return None + + cmd_tree = _registry + for part in cmd_name[:-1]: + if part not in cmd_tree: + return None + cmd_tree = cmd_tree[part] + + return cmd_tree.get(cmd_name[-1]) + + +class FurtherInteractionNeeded(Exception): + """ + Raised by session.require_arg() indicating + that the command should enter interactive mode + to ask the user for some arguments. + """ + pass + + +class Session: + __slots__ = ('cmd', 'ctx', + 'current_key', 'current_prompt', + 'current_arg', 'current_arg_text', + 'images', 'args', 'last_interaction') + + def __init__(self, cmd: Command, ctx: Dict[str, Any], *, + current_arg: str = '', args: Dict[str, Any] = None): + self.cmd = cmd + self.ctx = ctx + self.current_key = None + self.current_prompt = None + self.current_arg = current_arg + self.current_arg_text = Message(current_arg).extract_plain_text() + self.images = [s.data['url'] for s in ctx['message'] + if s.type == 'image' and 'url' in s.data] + self.args = args or {} + self.last_interaction = None + + def refresh(self, ctx: Dict[str, Any], *, current_arg: str = ''): + self.ctx = ctx + self.current_arg = current_arg + self.current_arg_text = Message(current_arg).extract_plain_text() + self.images = [s.data['url'] for s in ctx['message'] + if s.type == 'image' and 'url' in s.data] + + @property + def is_valid(self): + # TODO: 检查 last_interaction + return True + + def require_arg(self, key: str, prompt: str = None, *, + interactive: bool = True) -> Any: + """ + Get an argument with a given key. + + If "interactive" is True, and the argument does not exist + in the current session, a FurtherInteractionNeeded exception + will be raised, and the caller of the command will know + it should keep the session for further interaction with the user. + + If "interactive" is False, missed key will cause a result of None. + + :param key: argument key + :param prompt: prompt to ask the user with + :param interactive: should enter interactive mode while key missing + :return: the argument value + :raise FurtherInteractionNeeded: further interaction is needed + """ + value = self.args.get(key) + if value is not None or not interactive: + return value + + self.current_key = key + self.current_prompt = prompt or f'请输入 {self.current_key}:' + raise FurtherInteractionNeeded + + +def _new_command_session(bot: CQHttp, + ctx: Dict[str, Any]) -> Optional[Session]: + msg_text = str(ctx['message']).lstrip() + + for start in bot.config.COMMAND_START: + if isinstance(start, type(re.compile(''))): + m = start.search(msg_text) + if m: + full_command = msg_text[len(m.group(0)):].lstrip() + break + elif isinstance(start, str): + if msg_text.startswith(start): + full_command = msg_text[len(start):].lstrip() + break + else: + # it's not a command + return None + + if not full_command: + # command is empty + return None + + cmd_name_text, *cmd_remained = full_command.split(maxsplit=1) + cmd_name = _aliases.get(cmd_name_text) + + if not cmd_name: + for sep in bot.config.COMMAND_SEP: + if isinstance(sep, type(re.compile(''))): + cmd_name = tuple(sep.split(cmd_name_text)) + break + elif isinstance(sep, str): + cmd_name = tuple(cmd_name_text.split(sep)) + break + else: + cmd_name = (cmd_name_text,) + + cmd = _find_command(cmd_name) + if not cmd: + return None + + return Session(cmd, ctx, current_arg=''.join(cmd_remained)) + + +async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool: + src = context_source(ctx) + if _sessions[src]: + session = _sessions[src][-1] + session.refresh(ctx, current_arg=str(ctx['message'])) + # TODO: 检查 is_valid + else: + session = _new_command_session(bot, ctx) + if not session: + return False + _sessions[src].append(session) + + try: + res = await session.cmd.run(bot, session) + # the command is finished, pop the session + _sessions[src].pop() + if not _sessions[src]: + # session stack of the current user is empty + del _sessions[src] + return res + except FurtherInteractionNeeded: + # ask the user for more information + await bot.send(ctx, session.current_prompt) + + # return True because this step of the session is successful + return True + + async def call_command(name: Union[str, Tuple[str]], bot: CQHttp, ctx: Dict[str, Any], **kwargs) -> bool: """ diff --git a/none/helpers.py b/none/helpers.py index ab668673..50b7b69c 100644 --- a/none/helpers.py +++ b/none/helpers.py @@ -6,11 +6,11 @@ from aiocqhttp import CQHttp, Error as CQHttpError def context_source(ctx: Dict[str, Any]) -> str: src = '' if ctx.get('group_id'): - src += 'g%s' % ctx['group_id'] + src += f'/group/{ctx["group_id"]}' elif ctx.get('discuss_id'): - src += 'd%s' % ctx['discuss_id'] + src += f'/discuss/{ctx["discuss_id"]}' if ctx.get('user_id'): - src += 'p%s' % ctx['user_id'] + src += f'/user/{ctx["user_id"]}' return src diff --git a/plugins/weather.py b/plugins/weather.py index 366aceab..e3673904 100644 --- a/plugins/weather.py +++ b/plugins/weather.py @@ -6,7 +6,8 @@ from none.helpers import send @none.on_command('weather', aliases=('天气',)) async def weather(bot, session: Session): city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?') - await send(bot, session.ctx, f'你查询了{city}的天气') + other = session.require_arg('other') + await send(bot, session.ctx, f'你查询了{city}的天气,{other}') @weather.args_parser