Improve session

This commit is contained in:
Richard Chien 2018-06-25 16:50:34 +08:00
parent 77db6bfc84
commit 7da2043f94
6 changed files with 67 additions and 88 deletions

View File

@ -17,7 +17,7 @@ default_handler.setFormatter(logging.Formatter(
logger.addHandler(default_handler) logger.addHandler(default_handler)
from .plugin import handle_message, handle_notice, handle_request from .plugin import handle_message, handle_notice, handle_request
from .command import on_command, call_command from .command import on_command
def create_bot(config_object: Any = None): def create_bot(config_object: Any = None):

View File

@ -1,6 +1,7 @@
import re import re
from collections import defaultdict import asyncio
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional from datetime import datetime
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional, List
from aiocqhttp import CQHttp, Error as CQHttpError from aiocqhttp import CQHttp, Error as CQHttpError
from aiocqhttp.message import Message from aiocqhttp.message import Message
@ -17,27 +18,28 @@ _registry = {}
_aliases = {} _aliases = {}
# Key: context source # Key: context source
# Value: list (stack) of Session objects # Value: Session object
_sessions = defaultdict(list) _sessions = {}
class Command: class Command:
__slots__ = ('name', 'func', 'permission', 'args_parser') __slots__ = ('name', 'func', 'permission', 'args_parser_func')
def __init__(self, name: Tuple[str], func: Callable, permission: int): def __init__(self, name: Tuple[str],
func: Callable,
permission: int):
self.name = name self.name = name
self.func = func self.func = func
self.permission = permission self.permission = permission
self.args_parser = None self.args_parser_func = None
async def run(self, bot, session, *, async def run(self, session, *, permission: int = None) -> bool:
permission: int = None) -> bool:
if permission is None: if permission is None:
permission = await calculate_permission(bot, session.ctx) permission = await calculate_permission(session.bot, session.ctx)
if isinstance(self.func, Callable) and permission & self.permission: if self.func and permission & self.permission:
if isinstance(self.args_parser, Callable): if self.args_parser_func:
self.args_parser(session) await self.args_parser_func(session)
await self.func(bot, session) await self.func(session)
return True return True
return False return False
@ -91,7 +93,7 @@ def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (),
_aliases[alias] = cmd_name _aliases[alias] = cmd_name
def args_parser(parser_func: Callable): def args_parser(parser_func: Callable):
cmd.args_parser = parser_func cmd.args_parser_func = parser_func
return parser_func return parser_func
func.args_parser = args_parser func.args_parser = args_parser
@ -125,17 +127,16 @@ class FurtherInteractionNeeded(Exception):
class Session: class Session:
__slots__ = ('cmd', 'ctx', __slots__ = ('bot', 'cmd', 'ctx',
'current_key', 'current_prompt', 'current_key', 'current_arg', 'current_arg_text',
'current_arg', 'current_arg_text',
'images', 'args', 'last_interaction') 'images', 'args', 'last_interaction')
def __init__(self, cmd: Command, ctx: Dict[str, Any], *, def __init__(self, bot: CQHttp, cmd: Command, ctx: Dict[str, Any], *,
current_arg: str = '', args: Dict[str, Any] = None): current_arg: str = '', args: Dict[str, Any] = None):
self.bot = bot
self.cmd = cmd self.cmd = cmd
self.ctx = ctx self.ctx = ctx
self.current_key = None self.current_key = None
self.current_prompt = None
self.current_arg = current_arg self.current_arg = current_arg
self.current_arg_text = Message(current_arg).extract_plain_text() self.current_arg_text = Message(current_arg).extract_plain_text()
self.images = [s.data['url'] for s in ctx['message'] self.images = [s.data['url'] for s in ctx['message']
@ -152,10 +153,13 @@ class Session:
@property @property
def is_valid(self): def is_valid(self):
# TODO: 检查 last_interaction if self.last_interaction and \
datetime.now() - self.last_interaction > \
self.bot.config.SESSION_EXPIRE_TIMEOUT:
return False
return True return True
def require_arg(self, key: str, prompt: str = None, *, def require_arg(self, key: str, prompt: str, *,
interactive: bool = True) -> Any: interactive: bool = True) -> Any:
""" """
Get an argument with a given key. Get an argument with a given key.
@ -168,7 +172,7 @@ class Session:
If "interactive" is False, missed key will cause a result of None. If "interactive" is False, missed key will cause a result of None.
:param key: argument key :param key: argument key
:param prompt: prompt to ask the user with :param prompt: prompt to ask the user
:param interactive: should enter interactive mode while key missing :param interactive: should enter interactive mode while key missing
:return: the argument value :return: the argument value
:raise FurtherInteractionNeeded: further interaction is needed :raise FurtherInteractionNeeded: further interaction is needed
@ -178,9 +182,19 @@ class Session:
return value return value
self.current_key = key self.current_key = key
self.current_prompt = prompt or f'请输入 {self.current_key}' # ask the user for more information
asyncio.ensure_future(self.send(prompt))
raise FurtherInteractionNeeded raise FurtherInteractionNeeded
async def send(self,
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
*, ignore_failure: bool = True) -> None:
try:
await self.bot.send(self.ctx, message)
except CQHttpError:
if not ignore_failure:
raise
def _new_command_session(bot: CQHttp, def _new_command_session(bot: CQHttp,
ctx: Dict[str, Any]) -> Optional[Session]: ctx: Dict[str, Any]) -> Optional[Session]:
@ -222,56 +236,32 @@ def _new_command_session(bot: CQHttp,
if not cmd: if not cmd:
return None return None
return Session(cmd, ctx, current_arg=''.join(cmd_remained)) return Session(bot, cmd, ctx, current_arg=''.join(cmd_remained))
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool: async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
src = context_source(ctx) src = context_source(ctx)
if _sessions[src]: session = None
session = _sessions[src][-1] if _sessions.get(src):
session.refresh(ctx, current_arg=str(ctx['message'])) session = _sessions[src]
# TODO: 检查 is_valid if session and session.is_valid:
else: session.refresh(ctx, current_arg=str(ctx['message']))
else:
# the session is expired, remove it
del _sessions[src]
session = None
if not session:
session = _new_command_session(bot, ctx) session = _new_command_session(bot, ctx)
if not session: if not session:
return False return False
_sessions[src].append(session) _sessions[src] = session
try: try:
res = await session.cmd.run(bot, session) res = await session.cmd.run(session)
# the command is finished, pop the session # the command is finished, remove the session
_sessions[src].pop() del _sessions[src]
if not _sessions[src]:
# session stack of the current user is empty
del _sessions[src]
return res return res
except FurtherInteractionNeeded: except FurtherInteractionNeeded:
# ask the user for more information session.last_interaction = datetime.now()
await bot.send(ctx, session.current_prompt)
# return True because this step of the session is successful # return True because this step of the session is successful
return True return True
async def call_command(name: Union[str, Tuple[str]],
bot: CQHttp, ctx: Dict[str, Any], **kwargs) -> bool:
"""
Call a command internally.
There is no permission restriction on this function,
which means any command can be called from any other command.
Unexpected users should be handled by the caller command's permission
option.
:param name: command name (str or tuple of str)
:param bot: CQHttp instance
:param ctx: event context
:param kwargs: other keyword args that will be passed to Session()
:return: the command is successfully called
"""
cmd = _find_command(name)
if cmd:
session = Session(cmd, ctx, **kwargs)
# TODO: 插入 session
return await cmd.run(bot, session, permission=perm.IS_SUPERUSER)
return False

View File

@ -1,3 +1,5 @@
from datetime import timedelta
API_ROOT = '' API_ROOT = ''
SECRET = '' SECRET = ''
ACCESS_TOKEN = '' ACCESS_TOKEN = ''
@ -8,3 +10,4 @@ DEBUG = True
SUPERUSERS = set() SUPERUSERS = set()
COMMAND_START = {'/', '!', '', ''} COMMAND_START = {'/', '!', '', ''}
COMMAND_SEP = {'/', '.'} COMMAND_SEP = {'/', '.'}
SESSION_EXPIRE_TIMEOUT = timedelta(minutes=5)

View File

@ -1,6 +1,4 @@
from typing import Dict, Any, Union, List from typing import Dict, Any
from aiocqhttp import CQHttp, Error as CQHttpError
def context_source(ctx: Dict[str, Any]) -> str: def context_source(ctx: Dict[str, Any]) -> str:
@ -12,13 +10,3 @@ def context_source(ctx: Dict[str, Any]) -> str:
if ctx.get('user_id'): if ctx.get('user_id'):
src += f'/user/{ctx["user_id"]}' src += f'/user/{ctx["user_id"]}'
return src return src
async def send(bot: CQHttp, ctx: Dict[str, Any],
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
*, ignore_failure: bool = True) -> None:
try:
await bot.send(ctx, message)
except CQHttpError:
if not ignore_failure:
raise

View File

@ -3,14 +3,13 @@ from aiocqhttp.message import unescape
import none import none
from none import permissions as perm from none import permissions as perm
from none.command import Session from none.command import Session
from none.helpers import send
@none.on_command('echo') @none.on_command('echo')
async def echo(bot, session: Session): async def echo(session: Session):
await send(bot, session.ctx, session.current_arg) await session.send(session.current_arg)
@none.on_command('say', permission=perm.SUPERUSER) @none.on_command('say', permission=perm.SUPERUSER)
async def _(bot, session: Session): async def _(session: Session):
await send(bot, session.ctx, unescape(session.current_arg)) await session.send(unescape(session.current_arg))

View File

@ -1,16 +1,15 @@
import none import none
from none.command import Session from none.command import Session
from none.helpers import send
@none.on_command('weather', aliases=('天气',)) @none.on_command('weather', aliases=('天气',))
async def weather(bot, session: Session): async def weather(session: Session):
city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?') city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?')
other = session.require_arg('other') other = session.require_arg('other', prompt='其他信息?')
await send(bot, session.ctx, f'你查询了{city}的天气,{other}') await session.send(f'你查询了{city}的天气,{other}')
@weather.args_parser @weather.args_parser
def _(session: Session): async def _(session: Session):
if session.current_key: if session.current_key:
session.args[session.current_key] = session.current_arg.strip() session.args[session.current_key] = session.current_arg.strip()