Improve session

This commit is contained in:
Richard Chien 2018-06-25 16:50:34 +08:00
parent 77db6bfc84
commit 7da2043f94
6 changed files with 67 additions and 88 deletions

View File

@ -17,7 +17,7 @@ default_handler.setFormatter(logging.Formatter(
logger.addHandler(default_handler)
from .plugin import handle_message, handle_notice, handle_request
from .command import on_command, call_command
from .command import on_command
def create_bot(config_object: Any = None):

View File

@ -1,6 +1,7 @@
import re
from collections import defaultdict
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional
import asyncio
from datetime import datetime
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional, List
from aiocqhttp import CQHttp, Error as CQHttpError
from aiocqhttp.message import Message
@ -17,27 +18,28 @@ _registry = {}
_aliases = {}
# Key: context source
# Value: list (stack) of Session objects
_sessions = defaultdict(list)
# Value: Session object
_sessions = {}
class Command:
__slots__ = ('name', 'func', 'permission', 'args_parser')
__slots__ = ('name', 'func', 'permission', 'args_parser_func')
def __init__(self, name: Tuple[str], func: Callable, permission: int):
def __init__(self, name: Tuple[str],
func: Callable,
permission: int):
self.name = name
self.func = func
self.permission = permission
self.args_parser = None
self.args_parser_func = None
async def run(self, bot, session, *,
permission: int = None) -> bool:
async def run(self, session, *, permission: int = None) -> bool:
if permission is None:
permission = await calculate_permission(bot, session.ctx)
if isinstance(self.func, Callable) and permission & self.permission:
if isinstance(self.args_parser, Callable):
self.args_parser(session)
await self.func(bot, session)
permission = await calculate_permission(session.bot, session.ctx)
if self.func and permission & self.permission:
if self.args_parser_func:
await self.args_parser_func(session)
await self.func(session)
return True
return False
@ -91,7 +93,7 @@ def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (),
_aliases[alias] = cmd_name
def args_parser(parser_func: Callable):
cmd.args_parser = parser_func
cmd.args_parser_func = parser_func
return parser_func
func.args_parser = args_parser
@ -125,17 +127,16 @@ class FurtherInteractionNeeded(Exception):
class Session:
__slots__ = ('cmd', 'ctx',
'current_key', 'current_prompt',
'current_arg', 'current_arg_text',
__slots__ = ('bot', 'cmd', 'ctx',
'current_key', 'current_arg', 'current_arg_text',
'images', 'args', 'last_interaction')
def __init__(self, cmd: Command, ctx: Dict[str, Any], *,
def __init__(self, bot: CQHttp, cmd: Command, ctx: Dict[str, Any], *,
current_arg: str = '', args: Dict[str, Any] = None):
self.bot = bot
self.cmd = cmd
self.ctx = ctx
self.current_key = None
self.current_prompt = None
self.current_arg = current_arg
self.current_arg_text = Message(current_arg).extract_plain_text()
self.images = [s.data['url'] for s in ctx['message']
@ -152,10 +153,13 @@ class Session:
@property
def is_valid(self):
# TODO: 检查 last_interaction
if self.last_interaction and \
datetime.now() - self.last_interaction > \
self.bot.config.SESSION_EXPIRE_TIMEOUT:
return False
return True
def require_arg(self, key: str, prompt: str = None, *,
def require_arg(self, key: str, prompt: str, *,
interactive: bool = True) -> Any:
"""
Get an argument with a given key.
@ -168,7 +172,7 @@ class Session:
If "interactive" is False, missed key will cause a result of None.
:param key: argument key
:param prompt: prompt to ask the user with
:param prompt: prompt to ask the user
:param interactive: should enter interactive mode while key missing
:return: the argument value
:raise FurtherInteractionNeeded: further interaction is needed
@ -178,9 +182,19 @@ class Session:
return value
self.current_key = key
self.current_prompt = prompt or f'请输入 {self.current_key}'
# ask the user for more information
asyncio.ensure_future(self.send(prompt))
raise FurtherInteractionNeeded
async def send(self,
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
*, ignore_failure: bool = True) -> None:
try:
await self.bot.send(self.ctx, message)
except CQHttpError:
if not ignore_failure:
raise
def _new_command_session(bot: CQHttp,
ctx: Dict[str, Any]) -> Optional[Session]:
@ -222,56 +236,32 @@ def _new_command_session(bot: CQHttp,
if not cmd:
return None
return Session(cmd, ctx, current_arg=''.join(cmd_remained))
return Session(bot, cmd, ctx, current_arg=''.join(cmd_remained))
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
src = context_source(ctx)
if _sessions[src]:
session = _sessions[src][-1]
session.refresh(ctx, current_arg=str(ctx['message']))
# TODO: 检查 is_valid
else:
session = None
if _sessions.get(src):
session = _sessions[src]
if session and session.is_valid:
session.refresh(ctx, current_arg=str(ctx['message']))
else:
# the session is expired, remove it
del _sessions[src]
session = None
if not session:
session = _new_command_session(bot, ctx)
if not session:
return False
_sessions[src].append(session)
_sessions[src] = session
try:
res = await session.cmd.run(bot, session)
# the command is finished, pop the session
_sessions[src].pop()
if not _sessions[src]:
# session stack of the current user is empty
del _sessions[src]
res = await session.cmd.run(session)
# the command is finished, remove the session
del _sessions[src]
return res
except FurtherInteractionNeeded:
# ask the user for more information
await bot.send(ctx, session.current_prompt)
session.last_interaction = datetime.now()
# return True because this step of the session is successful
return True
async def call_command(name: Union[str, Tuple[str]],
bot: CQHttp, ctx: Dict[str, Any], **kwargs) -> bool:
"""
Call a command internally.
There is no permission restriction on this function,
which means any command can be called from any other command.
Unexpected users should be handled by the caller command's permission
option.
:param name: command name (str or tuple of str)
:param bot: CQHttp instance
:param ctx: event context
:param kwargs: other keyword args that will be passed to Session()
:return: the command is successfully called
"""
cmd = _find_command(name)
if cmd:
session = Session(cmd, ctx, **kwargs)
# TODO: 插入 session
return await cmd.run(bot, session, permission=perm.IS_SUPERUSER)
return False

View File

@ -1,3 +1,5 @@
from datetime import timedelta
API_ROOT = ''
SECRET = ''
ACCESS_TOKEN = ''
@ -8,3 +10,4 @@ DEBUG = True
SUPERUSERS = set()
COMMAND_START = {'/', '!', '', ''}
COMMAND_SEP = {'/', '.'}
SESSION_EXPIRE_TIMEOUT = timedelta(minutes=5)

View File

@ -1,6 +1,4 @@
from typing import Dict, Any, Union, List
from aiocqhttp import CQHttp, Error as CQHttpError
from typing import Dict, Any
def context_source(ctx: Dict[str, Any]) -> str:
@ -12,13 +10,3 @@ def context_source(ctx: Dict[str, Any]) -> str:
if ctx.get('user_id'):
src += f'/user/{ctx["user_id"]}'
return src
async def send(bot: CQHttp, ctx: Dict[str, Any],
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
*, ignore_failure: bool = True) -> None:
try:
await bot.send(ctx, message)
except CQHttpError:
if not ignore_failure:
raise

View File

@ -3,14 +3,13 @@ from aiocqhttp.message import unescape
import none
from none import permissions as perm
from none.command import Session
from none.helpers import send
@none.on_command('echo')
async def echo(bot, session: Session):
await send(bot, session.ctx, session.current_arg)
async def echo(session: Session):
await session.send(session.current_arg)
@none.on_command('say', permission=perm.SUPERUSER)
async def _(bot, session: Session):
await send(bot, session.ctx, unescape(session.current_arg))
async def _(session: Session):
await session.send(unescape(session.current_arg))

View File

@ -1,16 +1,15 @@
import none
from none.command import Session
from none.helpers import send
@none.on_command('weather', aliases=('天气',))
async def weather(bot, session: Session):
async def weather(session: Session):
city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?')
other = session.require_arg('other')
await send(bot, session.ctx, f'你查询了{city}的天气,{other}')
other = session.require_arg('other', prompt='其他信息?')
await session.send(f'你查询了{city}的天气,{other}')
@weather.args_parser
def _(session: Session):
async def _(session: Session):
if session.current_key:
session.args[session.current_key] = session.current_arg.strip()