mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 05:49:02 +08:00
Improve session
This commit is contained in:
parent
77db6bfc84
commit
7da2043f94
@ -17,7 +17,7 @@ default_handler.setFormatter(logging.Formatter(
|
|||||||
logger.addHandler(default_handler)
|
logger.addHandler(default_handler)
|
||||||
|
|
||||||
from .plugin import handle_message, handle_notice, handle_request
|
from .plugin import handle_message, handle_notice, handle_request
|
||||||
from .command import on_command, call_command
|
from .command import on_command
|
||||||
|
|
||||||
|
|
||||||
def create_bot(config_object: Any = None):
|
def create_bot(config_object: Any = None):
|
||||||
|
118
none/command.py
118
none/command.py
@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
import asyncio
|
||||||
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional
|
from datetime import datetime
|
||||||
|
from typing import Tuple, Union, Callable, Iterable, Dict, Any, Optional, List
|
||||||
|
|
||||||
from aiocqhttp import CQHttp, Error as CQHttpError
|
from aiocqhttp import CQHttp, Error as CQHttpError
|
||||||
from aiocqhttp.message import Message
|
from aiocqhttp.message import Message
|
||||||
@ -17,27 +18,28 @@ _registry = {}
|
|||||||
_aliases = {}
|
_aliases = {}
|
||||||
|
|
||||||
# Key: context source
|
# Key: context source
|
||||||
# Value: list (stack) of Session objects
|
# Value: Session object
|
||||||
_sessions = defaultdict(list)
|
_sessions = {}
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
class Command:
|
||||||
__slots__ = ('name', 'func', 'permission', 'args_parser')
|
__slots__ = ('name', 'func', 'permission', 'args_parser_func')
|
||||||
|
|
||||||
def __init__(self, name: Tuple[str], func: Callable, permission: int):
|
def __init__(self, name: Tuple[str],
|
||||||
|
func: Callable,
|
||||||
|
permission: int):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.func = func
|
self.func = func
|
||||||
self.permission = permission
|
self.permission = permission
|
||||||
self.args_parser = None
|
self.args_parser_func = None
|
||||||
|
|
||||||
async def run(self, bot, session, *,
|
async def run(self, session, *, permission: int = None) -> bool:
|
||||||
permission: int = None) -> bool:
|
|
||||||
if permission is None:
|
if permission is None:
|
||||||
permission = await calculate_permission(bot, session.ctx)
|
permission = await calculate_permission(session.bot, session.ctx)
|
||||||
if isinstance(self.func, Callable) and permission & self.permission:
|
if self.func and permission & self.permission:
|
||||||
if isinstance(self.args_parser, Callable):
|
if self.args_parser_func:
|
||||||
self.args_parser(session)
|
await self.args_parser_func(session)
|
||||||
await self.func(bot, session)
|
await self.func(session)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -91,7 +93,7 @@ def on_command(name: Union[str, Tuple[str]], aliases: Iterable = (),
|
|||||||
_aliases[alias] = cmd_name
|
_aliases[alias] = cmd_name
|
||||||
|
|
||||||
def args_parser(parser_func: Callable):
|
def args_parser(parser_func: Callable):
|
||||||
cmd.args_parser = parser_func
|
cmd.args_parser_func = parser_func
|
||||||
return parser_func
|
return parser_func
|
||||||
|
|
||||||
func.args_parser = args_parser
|
func.args_parser = args_parser
|
||||||
@ -125,17 +127,16 @@ class FurtherInteractionNeeded(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
__slots__ = ('cmd', 'ctx',
|
__slots__ = ('bot', 'cmd', 'ctx',
|
||||||
'current_key', 'current_prompt',
|
'current_key', 'current_arg', 'current_arg_text',
|
||||||
'current_arg', 'current_arg_text',
|
|
||||||
'images', 'args', 'last_interaction')
|
'images', 'args', 'last_interaction')
|
||||||
|
|
||||||
def __init__(self, cmd: Command, ctx: Dict[str, Any], *,
|
def __init__(self, bot: CQHttp, cmd: Command, ctx: Dict[str, Any], *,
|
||||||
current_arg: str = '', args: Dict[str, Any] = None):
|
current_arg: str = '', args: Dict[str, Any] = None):
|
||||||
|
self.bot = bot
|
||||||
self.cmd = cmd
|
self.cmd = cmd
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.current_key = None
|
self.current_key = None
|
||||||
self.current_prompt = None
|
|
||||||
self.current_arg = current_arg
|
self.current_arg = current_arg
|
||||||
self.current_arg_text = Message(current_arg).extract_plain_text()
|
self.current_arg_text = Message(current_arg).extract_plain_text()
|
||||||
self.images = [s.data['url'] for s in ctx['message']
|
self.images = [s.data['url'] for s in ctx['message']
|
||||||
@ -152,10 +153,13 @@ class Session:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_valid(self):
|
def is_valid(self):
|
||||||
# TODO: 检查 last_interaction
|
if self.last_interaction and \
|
||||||
|
datetime.now() - self.last_interaction > \
|
||||||
|
self.bot.config.SESSION_EXPIRE_TIMEOUT:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def require_arg(self, key: str, prompt: str = None, *,
|
def require_arg(self, key: str, prompt: str, *,
|
||||||
interactive: bool = True) -> Any:
|
interactive: bool = True) -> Any:
|
||||||
"""
|
"""
|
||||||
Get an argument with a given key.
|
Get an argument with a given key.
|
||||||
@ -168,7 +172,7 @@ class Session:
|
|||||||
If "interactive" is False, missed key will cause a result of None.
|
If "interactive" is False, missed key will cause a result of None.
|
||||||
|
|
||||||
:param key: argument key
|
:param key: argument key
|
||||||
:param prompt: prompt to ask the user with
|
:param prompt: prompt to ask the user
|
||||||
:param interactive: should enter interactive mode while key missing
|
:param interactive: should enter interactive mode while key missing
|
||||||
:return: the argument value
|
:return: the argument value
|
||||||
:raise FurtherInteractionNeeded: further interaction is needed
|
:raise FurtherInteractionNeeded: further interaction is needed
|
||||||
@ -178,9 +182,19 @@ class Session:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
self.current_key = key
|
self.current_key = key
|
||||||
self.current_prompt = prompt or f'请输入 {self.current_key}:'
|
# ask the user for more information
|
||||||
|
asyncio.ensure_future(self.send(prompt))
|
||||||
raise FurtherInteractionNeeded
|
raise FurtherInteractionNeeded
|
||||||
|
|
||||||
|
async def send(self,
|
||||||
|
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
|
||||||
|
*, ignore_failure: bool = True) -> None:
|
||||||
|
try:
|
||||||
|
await self.bot.send(self.ctx, message)
|
||||||
|
except CQHttpError:
|
||||||
|
if not ignore_failure:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _new_command_session(bot: CQHttp,
|
def _new_command_session(bot: CQHttp,
|
||||||
ctx: Dict[str, Any]) -> Optional[Session]:
|
ctx: Dict[str, Any]) -> Optional[Session]:
|
||||||
@ -222,56 +236,32 @@ def _new_command_session(bot: CQHttp,
|
|||||||
if not cmd:
|
if not cmd:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return Session(cmd, ctx, current_arg=''.join(cmd_remained))
|
return Session(bot, cmd, ctx, current_arg=''.join(cmd_remained))
|
||||||
|
|
||||||
|
|
||||||
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
|
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
|
||||||
src = context_source(ctx)
|
src = context_source(ctx)
|
||||||
if _sessions[src]:
|
session = None
|
||||||
session = _sessions[src][-1]
|
if _sessions.get(src):
|
||||||
session.refresh(ctx, current_arg=str(ctx['message']))
|
session = _sessions[src]
|
||||||
# TODO: 检查 is_valid
|
if session and session.is_valid:
|
||||||
else:
|
session.refresh(ctx, current_arg=str(ctx['message']))
|
||||||
|
else:
|
||||||
|
# the session is expired, remove it
|
||||||
|
del _sessions[src]
|
||||||
|
session = None
|
||||||
|
if not session:
|
||||||
session = _new_command_session(bot, ctx)
|
session = _new_command_session(bot, ctx)
|
||||||
if not session:
|
if not session:
|
||||||
return False
|
return False
|
||||||
_sessions[src].append(session)
|
_sessions[src] = session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = await session.cmd.run(bot, session)
|
res = await session.cmd.run(session)
|
||||||
# the command is finished, pop the session
|
# the command is finished, remove the session
|
||||||
_sessions[src].pop()
|
del _sessions[src]
|
||||||
if not _sessions[src]:
|
|
||||||
# session stack of the current user is empty
|
|
||||||
del _sessions[src]
|
|
||||||
return res
|
return res
|
||||||
except FurtherInteractionNeeded:
|
except FurtherInteractionNeeded:
|
||||||
# ask the user for more information
|
session.last_interaction = datetime.now()
|
||||||
await bot.send(ctx, session.current_prompt)
|
|
||||||
|
|
||||||
# return True because this step of the session is successful
|
# return True because this step of the session is successful
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def call_command(name: Union[str, Tuple[str]],
|
|
||||||
bot: CQHttp, ctx: Dict[str, Any], **kwargs) -> bool:
|
|
||||||
"""
|
|
||||||
Call a command internally.
|
|
||||||
|
|
||||||
There is no permission restriction on this function,
|
|
||||||
which means any command can be called from any other command.
|
|
||||||
Unexpected users should be handled by the caller command's permission
|
|
||||||
option.
|
|
||||||
|
|
||||||
:param name: command name (str or tuple of str)
|
|
||||||
:param bot: CQHttp instance
|
|
||||||
:param ctx: event context
|
|
||||||
:param kwargs: other keyword args that will be passed to Session()
|
|
||||||
:return: the command is successfully called
|
|
||||||
"""
|
|
||||||
cmd = _find_command(name)
|
|
||||||
if cmd:
|
|
||||||
session = Session(cmd, ctx, **kwargs)
|
|
||||||
# TODO: 插入 session
|
|
||||||
return await cmd.run(bot, session, permission=perm.IS_SUPERUSER)
|
|
||||||
return False
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
API_ROOT = ''
|
API_ROOT = ''
|
||||||
SECRET = ''
|
SECRET = ''
|
||||||
ACCESS_TOKEN = ''
|
ACCESS_TOKEN = ''
|
||||||
@ -8,3 +10,4 @@ DEBUG = True
|
|||||||
SUPERUSERS = set()
|
SUPERUSERS = set()
|
||||||
COMMAND_START = {'/', '!', '/', '!'}
|
COMMAND_START = {'/', '!', '/', '!'}
|
||||||
COMMAND_SEP = {'/', '.'}
|
COMMAND_SEP = {'/', '.'}
|
||||||
|
SESSION_EXPIRE_TIMEOUT = timedelta(minutes=5)
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
from typing import Dict, Any, Union, List
|
from typing import Dict, Any
|
||||||
|
|
||||||
from aiocqhttp import CQHttp, Error as CQHttpError
|
|
||||||
|
|
||||||
|
|
||||||
def context_source(ctx: Dict[str, Any]) -> str:
|
def context_source(ctx: Dict[str, Any]) -> str:
|
||||||
@ -12,13 +10,3 @@ def context_source(ctx: Dict[str, Any]) -> str:
|
|||||||
if ctx.get('user_id'):
|
if ctx.get('user_id'):
|
||||||
src += f'/user/{ctx["user_id"]}'
|
src += f'/user/{ctx["user_id"]}'
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
async def send(bot: CQHttp, ctx: Dict[str, Any],
|
|
||||||
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
|
|
||||||
*, ignore_failure: bool = True) -> None:
|
|
||||||
try:
|
|
||||||
await bot.send(ctx, message)
|
|
||||||
except CQHttpError:
|
|
||||||
if not ignore_failure:
|
|
||||||
raise
|
|
||||||
|
@ -3,14 +3,13 @@ from aiocqhttp.message import unescape
|
|||||||
import none
|
import none
|
||||||
from none import permissions as perm
|
from none import permissions as perm
|
||||||
from none.command import Session
|
from none.command import Session
|
||||||
from none.helpers import send
|
|
||||||
|
|
||||||
|
|
||||||
@none.on_command('echo')
|
@none.on_command('echo')
|
||||||
async def echo(bot, session: Session):
|
async def echo(session: Session):
|
||||||
await send(bot, session.ctx, session.current_arg)
|
await session.send(session.current_arg)
|
||||||
|
|
||||||
|
|
||||||
@none.on_command('say', permission=perm.SUPERUSER)
|
@none.on_command('say', permission=perm.SUPERUSER)
|
||||||
async def _(bot, session: Session):
|
async def _(session: Session):
|
||||||
await send(bot, session.ctx, unescape(session.current_arg))
|
await session.send(unescape(session.current_arg))
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
import none
|
import none
|
||||||
from none.command import Session
|
from none.command import Session
|
||||||
from none.helpers import send
|
|
||||||
|
|
||||||
|
|
||||||
@none.on_command('weather', aliases=('天气',))
|
@none.on_command('weather', aliases=('天气',))
|
||||||
async def weather(bot, session: Session):
|
async def weather(session: Session):
|
||||||
city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?')
|
city = session.require_arg('city', prompt='你想知道哪个城市的天气呢?')
|
||||||
other = session.require_arg('other')
|
other = session.require_arg('other', prompt='其他信息?')
|
||||||
await send(bot, session.ctx, f'你查询了{city}的天气,{other}')
|
await session.send(f'你查询了{city}的天气,{other}')
|
||||||
|
|
||||||
|
|
||||||
@weather.args_parser
|
@weather.args_parser
|
||||||
def _(session: Session):
|
async def _(session: Session):
|
||||||
if session.current_key:
|
if session.current_key:
|
||||||
session.args[session.current_key] = session.current_arg.strip()
|
session.args[session.current_key] = session.current_arg.strip()
|
||||||
|
Loading…
Reference in New Issue
Block a user