nonebot2/none/command.py

345 lines
11 KiB
Python
Raw Normal View History

2018-06-15 06:58:24 +08:00
import re
2018-06-25 16:50:34 +08:00
import asyncio
from datetime import datetime
2018-06-25 21:14:27 +08:00
from typing import (
2018-06-27 22:05:12 +08:00
Tuple, Union, Callable, Iterable, Dict, Any, Optional, Sequence
2018-06-25 21:14:27 +08:00
)
2018-06-15 06:58:24 +08:00
2018-07-01 11:01:24 +08:00
from aiocqhttp import CQHttp
2018-06-15 15:00:58 +08:00
from aiocqhttp.message import Message
2018-06-15 06:58:24 +08:00
2018-07-01 11:01:24 +08:00
from . import permission as perm
from .helpers import context_source
from .expression import render
from .session import BaseSession
2018-06-15 06:58:24 +08:00
2018-06-15 10:40:53 +08:00
# Key: str (one segment of command name)
# Value: subtree or a leaf Command object
_registry = {}
2018-06-15 06:58:24 +08:00
# Key: str
# Value: tuple that identifies a command
2018-06-15 10:40:53 +08:00
_aliases = {}
2018-06-15 06:58:24 +08:00
# Key: context source
2018-06-25 16:50:34 +08:00
# Value: Session object
_sessions = {}
2018-06-25 12:41:12 +08:00
2018-06-15 06:58:24 +08:00
class Command:
__slots__ = ('name', 'func', 'permission', 'only_to_me', 'args_parser_func')
2018-06-15 06:58:24 +08:00
2018-07-01 17:51:01 +08:00
def __init__(self, *, name: Tuple[str], func: Callable, permission: int,
only_to_me: bool):
2018-06-15 06:58:24 +08:00
self.name = name
2018-06-15 10:40:53 +08:00
self.func = func
self.permission = permission
self.only_to_me = only_to_me
2018-06-25 16:50:34 +08:00
self.args_parser_func = None
2018-06-15 06:58:24 +08:00
2018-07-01 11:01:24 +08:00
async def run(self, session, check_perm: bool = True) -> bool:
2018-06-30 09:25:25 +08:00
"""
Run the command in a given session.
:param session: CommandSession object
2018-07-01 11:01:24 +08:00
:param check_perm: should check permission before running
2018-06-30 09:25:25 +08:00
:return: the command is finished
"""
2018-07-01 11:01:24 +08:00
if check_perm:
has_perm = await perm.check_permission(
session.bot, session.ctx, self.permission)
else:
has_perm = True
if self.func and has_perm:
2018-06-25 16:50:34 +08:00
if self.args_parser_func:
await self.args_parser_func(session)
await self.func(session)
2018-06-25 10:41:48 +08:00
return True
return False
2018-06-15 10:40:53 +08:00
2018-06-25 17:28:10 +08:00
def on_command(name: Union[str, Tuple[str]], *,
aliases: Iterable = (),
permission: int = perm.EVERYBODY,
only_to_me: bool = True) -> Callable:
2018-06-25 15:22:59 +08:00
def deco(func: Callable) -> Callable:
if not isinstance(name, (str, tuple)):
raise TypeError('the name of a command must be a str or tuple')
if not name:
raise ValueError('the name of a command must not be empty')
cmd_name = name if isinstance(name, tuple) else (name,)
current_parent = _registry
for parent_key in cmd_name[:-1]:
2018-06-25 17:28:10 +08:00
current_parent[parent_key] = current_parent.get(parent_key) or {}
2018-06-25 15:22:59 +08:00
current_parent = current_parent[parent_key]
cmd = Command(name=cmd_name, func=func, permission=permission,
only_to_me=only_to_me)
2018-06-25 15:22:59 +08:00
current_parent[cmd_name[-1]] = cmd
for alias in aliases:
_aliases[alias] = cmd_name
2018-07-01 17:51:01 +08:00
def args_parser_deco(parser_func: Callable):
2018-06-25 16:50:34 +08:00
cmd.args_parser_func = parser_func
2018-06-25 15:22:59 +08:00
return parser_func
2018-07-01 17:51:01 +08:00
func.args_parser = args_parser_deco
2018-06-25 15:22:59 +08:00
return func
return deco
2018-06-25 22:49:15 +08:00
class CommandGroup:
2018-06-30 09:25:25 +08:00
"""
Group a set of commands with same name prefix.
"""
__slots__ = ('basename', 'permission', 'only_to_me')
2018-06-25 22:49:15 +08:00
2018-06-30 09:25:25 +08:00
def __init__(self, name: Union[str, Tuple[str]],
permission: Optional[int] = None, *,
only_to_me: Optional[bool] = None):
2018-06-26 08:49:08 +08:00
self.basename = (name,) if isinstance(name, str) else name
2018-06-25 22:49:15 +08:00
self.permission = permission
self.only_to_me = only_to_me
2018-06-25 22:49:15 +08:00
def command(self, name: Union[str, Tuple[str]], *,
2018-06-30 09:25:25 +08:00
aliases: Optional[Iterable] = None,
permission: Optional[int] = None,
only_to_me: Optional[bool] = None) -> Callable:
2018-06-26 08:49:08 +08:00
sub_name = (name,) if isinstance(name, str) else name
name = self.basename + sub_name
2018-06-25 22:49:15 +08:00
kwargs = {}
if aliases is not None:
kwargs['aliases'] = aliases
if permission is not None:
kwargs['permission'] = permission
2018-06-30 09:25:25 +08:00
elif self.permission is not None:
kwargs['permission'] = self.permission
if only_to_me is not None:
kwargs['only_to_me'] = only_to_me
elif self.only_to_me is not None:
kwargs['only_to_me'] = self.only_to_me
2018-06-25 22:49:15 +08:00
return on_command(name, **kwargs)
2018-06-25 12:41:12 +08:00
def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]:
2018-06-26 08:49:08 +08:00
cmd_name = (name,) if isinstance(name, str) else name
2018-06-25 12:41:12 +08:00
if not cmd_name:
2018-06-15 10:40:53 +08:00
return None
cmd_tree = _registry
2018-06-25 12:41:12 +08:00
for part in cmd_name[:-1]:
2018-06-15 10:40:53 +08:00
if part not in cmd_tree:
2018-06-23 22:45:43 +08:00
return None
2018-06-15 10:40:53 +08:00
cmd_tree = cmd_tree[part]
2018-06-25 12:41:12 +08:00
return cmd_tree.get(cmd_name[-1])
2018-06-15 10:40:53 +08:00
2018-06-26 08:49:08 +08:00
class _FurtherInteractionNeeded(Exception):
2018-06-25 15:22:59 +08:00
"""
Raised by session.require_arg() indicating
that the command should enter interactive mode
to ask the user for some arguments.
"""
pass
2018-06-27 22:05:12 +08:00
class CommandSession(BaseSession):
__slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_text',
2018-06-30 21:00:41 +08:00
'current_arg_images', 'args', 'last_interaction')
2018-06-15 10:40:53 +08:00
2018-06-27 22:05:12 +08:00
def __init__(self, bot: CQHttp, ctx: Dict[str, Any], cmd: Command, *,
2018-06-30 09:25:25 +08:00
current_arg: str = '', args: Optional[Dict[str, Any]] = None):
2018-06-27 22:05:12 +08:00
super().__init__(bot, ctx)
2018-06-15 10:40:53 +08:00
self.cmd = cmd
2018-06-24 23:00:37 +08:00
self.current_key = None
2018-07-01 17:51:01 +08:00
self.current_arg = None
self.current_arg_text = None
self.current_arg_images = None
self.refresh(ctx, current_arg=current_arg)
2018-06-25 12:41:12 +08:00
self.args = args or {}
2018-06-15 10:40:53 +08:00
self.last_interaction = None
2018-06-15 06:58:24 +08:00
2018-06-26 08:49:08 +08:00
def refresh(self, ctx: Dict[str, Any], *, current_arg: str = '') -> None:
"""
Refill the session with a new message context.
:param ctx: new message context
:param current_arg: new command argument as a string
"""
2018-06-25 15:22:59 +08:00
self.ctx = ctx
self.current_arg = current_arg
2018-07-01 17:51:01 +08:00
current_arg_as_msg = Message(current_arg)
self.current_arg_text = current_arg_as_msg.extract_plain_text()
self.current_arg_images = [s.data['url'] for s in current_arg_as_msg
2018-06-30 21:00:41 +08:00
if s.type == 'image' and 'url' in s.data]
2018-06-24 23:00:37 +08:00
2018-06-25 15:22:59 +08:00
@property
2018-06-26 08:49:08 +08:00
def is_valid(self) -> bool:
2018-06-30 09:25:25 +08:00
"""Check if the session is expired or not."""
2018-06-25 16:50:34 +08:00
if self.last_interaction and \
datetime.now() - self.last_interaction > \
self.bot.config.SESSION_EXPIRE_TIMEOUT:
return False
2018-06-25 15:22:59 +08:00
return True
2018-06-15 06:58:24 +08:00
2018-07-01 17:51:01 +08:00
def get(self, key: str, *, prompt: str = None,
prompt_expr: Union[str, Sequence[str], Callable] = None) -> Any:
2018-06-25 15:22:59 +08:00
"""
Get an argument with a given key.
2018-07-01 17:51:01 +08:00
If the argument does not exist in the current session,
a FurtherInteractionNeeded exception will be raised,
and the caller of the command will know it should keep
the session for further interaction with the user.
2018-06-25 15:22:59 +08:00
:param key: argument key
2018-06-25 16:50:34 +08:00
:param prompt: prompt to ask the user
2018-06-25 21:14:27 +08:00
:param prompt_expr: prompt expression to ask the user
2018-06-25 15:22:59 +08:00
:return: the argument value
:raise FurtherInteractionNeeded: further interaction is needed
"""
2018-07-01 17:51:01 +08:00
value = self.get_optional(key)
if value is not None:
2018-06-25 15:22:59 +08:00
return value
self.current_key = key
2018-06-25 16:50:34 +08:00
# ask the user for more information
2018-06-25 21:14:27 +08:00
if prompt_expr is not None:
prompt = render(prompt_expr, key=key)
2018-07-01 17:51:01 +08:00
if prompt:
asyncio.ensure_future(self.send(prompt))
2018-06-26 08:49:08 +08:00
raise _FurtherInteractionNeeded
2018-06-25 15:22:59 +08:00
2018-07-01 17:51:01 +08:00
def get_optional(self, key: str,
default: Optional[Any] = None) -> Optional[Any]:
return self.args.get(key, default)
2018-06-25 15:22:59 +08:00
def _new_command_session(bot: CQHttp,
2018-06-27 22:05:12 +08:00
ctx: Dict[str, Any]) -> Optional[CommandSession]:
2018-06-26 08:49:08 +08:00
"""
Create a new session for a command.
This will firstly attempt to parse the current message as
a command, and if succeeded, it then create a session for
the command and return. If the message is not a valid command,
None will be returned.
:param bot: CQHttp instance
:param ctx: message context
2018-06-27 22:05:12 +08:00
:return: CommandSession object or None
2018-06-26 08:49:08 +08:00
"""
2018-06-15 15:00:58 +08:00
msg_text = str(ctx['message']).lstrip()
2018-06-15 06:58:24 +08:00
for start in bot.config.COMMAND_START:
if isinstance(start, type(re.compile(''))):
m = start.search(msg_text)
if m:
full_command = msg_text[len(m.group(0)):].lstrip()
break
elif isinstance(start, str):
if msg_text.startswith(start):
full_command = msg_text[len(start):].lstrip()
break
else:
# it's not a command
2018-06-25 15:22:59 +08:00
return None
2018-06-15 06:58:24 +08:00
if not full_command:
# command is empty
2018-06-25 15:22:59 +08:00
return None
2018-06-15 06:58:24 +08:00
cmd_name_text, *cmd_remained = full_command.split(maxsplit=1)
2018-06-15 10:40:53 +08:00
cmd_name = _aliases.get(cmd_name_text)
2018-06-15 06:58:24 +08:00
if not cmd_name:
for sep in bot.config.COMMAND_SEP:
if isinstance(sep, type(re.compile(''))):
cmd_name = tuple(sep.split(cmd_name_text))
break
elif isinstance(sep, str):
cmd_name = tuple(cmd_name_text.split(sep))
break
else:
cmd_name = (cmd_name_text,)
2018-06-15 10:40:53 +08:00
cmd = _find_command(cmd_name)
if not cmd:
2018-06-25 15:22:59 +08:00
return None
if cmd.only_to_me and not ctx['to_me']:
return None
2018-06-24 23:00:37 +08:00
2018-06-27 22:05:12 +08:00
return CommandSession(bot, ctx, cmd, current_arg=''.join(cmd_remained))
2018-06-24 23:00:37 +08:00
2018-06-15 06:58:24 +08:00
2018-06-25 15:22:59 +08:00
async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
2018-06-26 08:49:08 +08:00
"""
Handle a message as a command.
This function is typically called by "handle_message".
:param bot: CQHttp instance
:param ctx: message context
:return: the message is handled as a command
"""
2018-06-25 15:22:59 +08:00
src = context_source(ctx)
2018-06-25 16:50:34 +08:00
session = None
2018-07-01 11:01:24 +08:00
check_perm = True
2018-06-25 16:50:34 +08:00
if _sessions.get(src):
session = _sessions[src]
if session and session.is_valid:
session.refresh(ctx, current_arg=str(ctx['message']))
2018-07-01 11:01:24 +08:00
# there is no need to check permission for existing session
check_perm = False
2018-06-25 16:50:34 +08:00
else:
# the session is expired, remove it
del _sessions[src]
session = None
if not session:
2018-06-25 15:22:59 +08:00
session = _new_command_session(bot, ctx)
if not session:
return False
2018-07-01 17:51:01 +08:00
return await _real_run_command(session, src, check_perm=check_perm)
async def call_command(bot: CQHttp, ctx: Dict[str, Any],
name: Union[str, Tuple[str]],
args: Dict[str, Any]) -> bool:
"""
Call a command internally.
This function is typically called by some other commands
or "handle_natural_language" when handling NLPResult object.
:param bot: CQHttp instance
:param ctx: message context
:param name: command name
:param args: command args
:return: the command is successfully called
"""
cmd = _find_command(name)
if not cmd:
return False
session = CommandSession(bot, ctx, cmd, args=args)
return await _real_run_command(session, context_source(session.ctx),
check_perm=False)
async def _real_run_command(session: CommandSession,
ctx_src: str, **kwargs) -> bool:
_sessions[ctx_src] = session
2018-06-25 15:22:59 +08:00
try:
2018-07-01 17:51:01 +08:00
res = await session.cmd.run(session, **kwargs)
2018-06-25 16:50:34 +08:00
# the command is finished, remove the session
2018-07-01 17:51:01 +08:00
del _sessions[ctx_src]
2018-06-25 15:22:59 +08:00
return res
2018-06-26 08:49:08 +08:00
except _FurtherInteractionNeeded:
2018-06-25 16:50:34 +08:00
session.last_interaction = datetime.now()
2018-06-25 15:22:59 +08:00
# return True because this step of the session is successful
return True