Interactive mode done

This commit is contained in:
Richard Chien 2018-06-25 15:22:59 +08:00
parent 492fed9b50
commit 77db6bfc84
3 changed files with 162 additions and 93 deletions

View File

@ -1,10 +1,12 @@
import re import re
from collections import defaultdict
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional
from aiocqhttp import CQHttp, Error as CQHttpError from aiocqhttp import CQHttp, Error as CQHttpError
from aiocqhttp.message import Message from aiocqhttp.message import Message
from . import permissions as perm from . import permissions as perm
from .helpers import context_source
# Key: str (one segment of command name) # Key: str (one segment of command name)
# Value: subtree or a leaf Command object # Value: subtree or a leaf Command object
@ -15,12 +17,8 @@ _registry = {}
_aliases = {} _aliases = {}
# Key: context source # Key: context source
# Value: Command object # Value: list (stack) of Session objects
_sessions = {} _sessions = defaultdict(list)
# TODO: session 保存为一个栈,命令可以调用命令,进入新的 session命令执行完毕
# 中间没有抛出异常(标志进入交互模式的异常),则从栈中 pop
class Command: class Command:
@ -74,89 +72,6 @@ async def calculate_permission(bot: CQHttp, ctx: Dict[str, Any]) -> int:
return permission return permission
def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]:
cmd_name = name if isinstance(name, tuple) else (name,)
if not cmd_name:
return None
cmd_tree = _registry
for part in cmd_name[:-1]:
if part not in cmd_tree:
return None
cmd_tree = cmd_tree[part]
return cmd_tree.get(cmd_name[-1])
class Session:
__slots__ = ('cmd', 'ctx',
'current_key', 'current_arg', 'current_arg_text',
'images', 'args', 'last_interaction')
def __init__(self, cmd: Command, ctx: Dict[str, Any], *,
current_arg: str = '', args: Dict[str, Any] = None):
self.cmd = cmd
self.ctx = ctx
self.current_key = None
self.current_arg = current_arg
self.current_arg_text = Message(current_arg).extract_plain_text()
self.images = [s.data['url'] for s in ctx['message']
if s.type == 'image' and 'url' in s.data]
self.args = args or {}
self.last_interaction = None
def require_arg(self, key: str, prompt: str = '', *,
interactive: bool = True):
# TODO: 检查 key 是否在 args 中,如果不在,抛出异常,保存 session等待用户填充
pass
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
# TODO: check if there is a session
msg_text = str(ctx['message']).lstrip()
for start in bot.config.COMMAND_START:
if isinstance(start, type(re.compile(''))):
m = start.search(msg_text)
if m:
full_command = msg_text[len(m.group(0)):].lstrip()
break
elif isinstance(start, str):
if msg_text.startswith(start):
full_command = msg_text[len(start):].lstrip()
break
else:
# it's not a command
return False
if not full_command:
# command is empty
return False
cmd_name_text, *cmd_remained = full_command.split(maxsplit=1)
cmd_name = _aliases.get(cmd_name_text)
if not cmd_name:
for sep in bot.config.COMMAND_SEP:
if isinstance(sep, type(re.compile(''))):
cmd_name = tuple(sep.split(cmd_name_text))
break
elif isinstance(sep, str):
cmd_name = tuple(cmd_name_text.split(sep))
break
else:
cmd_name = (cmd_name_text,)
cmd = _find_command(cmd_name)
if not cmd:
return False
session = Session(cmd, ctx, current_arg=''.join(cmd_remained))
# TODO: 插入 session
return await cmd.run(bot, session)
def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (), def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (),
permission: int = perm.EVERYONE) -> Callable: permission: int = perm.EVERYONE) -> Callable:
def deco(func: Callable) -> Callable: def deco(func: Callable) -> Callable:
@ -185,6 +100,159 @@ def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (),
return deco return deco
def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]:
cmd_name = name if isinstance(name, tuple) else (name,)
if not cmd_name:
return None
cmd_tree = _registry
for part in cmd_name[:-1]:
if part not in cmd_tree:
return None
cmd_tree = cmd_tree[part]
return cmd_tree.get(cmd_name[-1])
class FurtherInteractionNeeded(Exception):
"""
Raised by session.require_arg() indicating
that the command should enter interactive mode
to ask the user for some arguments.
"""
pass
class Session:
__slots__ = ('cmd', 'ctx',
'current_key', 'current_prompt',
'current_arg', 'current_arg_text',
'images', 'args', 'last_interaction')
def __init__(self, cmd: Command, ctx: Dict[str, Any], *,
current_arg: str = '', args: Dict[str, Any] = None):
self.cmd = cmd
self.ctx = ctx
self.current_key = None
self.current_prompt = None
self.current_arg = current_arg
self.current_arg_text = Message(current_arg).extract_plain_text()
self.images = [s.data['url'] for s in ctx['message']
if s.type == 'image' and 'url' in s.data]
self.args = args or {}
self.last_interaction = None
def refresh(self, ctx: Dict[str, Any], *, current_arg: str = ''):
self.ctx = ctx
self.current_arg = current_arg
self.current_arg_text = Message(current_arg).extract_plain_text()
self.images = [s.data['url'] for s in ctx['message']
if s.type == 'image' and 'url' in s.data]
@property
def is_valid(self):
# TODO: 检查 last_interaction
return True
def require_arg(self, key: str, prompt: str = None, *,
interactive: bool = True) -> Any:
"""
Get an argument with a given key.
If "interactive" is True, and the argument does not exist
in the current session, a FurtherInteractionNeeded exception
will be raised, and the caller of the command will know
it should keep the session for further interaction with the user.
If "interactive" is False, missed key will cause a result of None.
:param key: argument key
:param prompt: prompt to ask the user with
:param interactive: should enter interactive mode while key missing
:return: the argument value
:raise FurtherInteractionNeeded: further interaction is needed
"""
value = self.args.get(key)
if value is not None or not interactive:
return value
self.current_key = key
self.current_prompt = prompt or f'请输入 {self.current_key}'
raise FurtherInteractionNeeded
def _new_command_session(bot: CQHttp,
ctx: Dict[str, Any]) -> Optional[Session]:
msg_text = str(ctx['message']).lstrip()
for start in bot.config.COMMAND_START:
if isinstance(start, type(re.compile(''))):
m = start.search(msg_text)
if m:
full_command = msg_text[len(m.group(0)):].lstrip()
break
elif isinstance(start, str):
if msg_text.startswith(start):
full_command = msg_text[len(start):].lstrip()
break
else:
# it's not a command
return None
if not full_command:
# command is empty
return None
cmd_name_text, *cmd_remained = full_command.split(maxsplit=1)
cmd_name = _aliases.get(cmd_name_text)
if not cmd_name:
for sep in bot.config.COMMAND_SEP:
if isinstance(sep, type(re.compile(''))):
cmd_name = tuple(sep.split(cmd_name_text))
break
elif isinstance(sep, str):
cmd_name = tuple(cmd_name_text.split(sep))
break
else:
cmd_name = (cmd_name_text,)
cmd = _find_command(cmd_name)
if not cmd:
return None
return Session(cmd, ctx, current_arg=''.join(cmd_remained))
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
src = context_source(ctx)
if _sessions[src]:
session = _sessions[src][-1]
session.refresh(ctx, current_arg=str(ctx['message']))
# TODO: 检查 is_valid
else:
session = _new_command_session(bot, ctx)
if not session:
return False
_sessions[src].append(session)
try:
res = await session.cmd.run(bot, session)
# the command is finished, pop the session
_sessions[src].pop()
if not _sessions[src]:
# session stack of the current user is empty
del _sessions[src]
return res
except FurtherInteractionNeeded:
# ask the user for more information
await bot.send(ctx, session.current_prompt)
# return True because this step of the session is successful
return True
async def call_command(name: Union[str, Tuple[str]], async def call_command(name: Union[str, Tuple[str]],
bot: CQHttp, ctx: Dict[str, Any], **kwargs) -> bool: bot: CQHttp, ctx: Dict[str, Any], **kwargs) -> bool:
""" """

View File

@ -6,11 +6,11 @@ from aiocqhttp import CQHttp, Error as CQHttpError
def context_source(ctx: Dict[str, Any]) -> str: def context_source(ctx: Dict[str, Any]) -> str:
src = '' src = ''
if ctx.get('group_id'): if ctx.get('group_id'):
src += 'g%s' % ctx['group_id'] src += f'/group/{ctx["group_id"]}'
elif ctx.get('discuss_id'): elif ctx.get('discuss_id'):
src += 'd%s' % ctx['discuss_id'] src += f'/discuss/{ctx["discuss_id"]}'
if ctx.get('user_id'): if ctx.get('user_id'):
src += 'p%s' % ctx['user_id'] src += f'/user/{ctx["user_id"]}'
return src return src

View File

@ -6,7 +6,8 @@ from none.helpers import send
@none.on_command('weather', aliases=('天气',)) @none.on_command('weather', aliases=('天气',))
async def weather(bot, session: Session): async def weather(bot, session: Session):
city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?') city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?')
await send(bot, session.ctx, f'你查询了{city}的天气') other = session.require_arg('other')
await send(bot, session.ctx, f'你查询了{city}的天气,{other}')
@weather.args_parser @weather.args_parser