Implement argument filters

This commit is contained in:
Richard Chien 2019-01-25 00:14:30 +08:00
parent 6b6daf7235
commit f8ecc7bba1
5 changed files with 284 additions and 47 deletions

View File

@ -3,16 +3,18 @@ import re
import shlex
from datetime import datetime
from typing import (
Tuple, Union, Callable, Iterable, Any, Optional, List, Dict
Tuple, Union, Callable, Iterable, Any, Optional, List, Dict,
Awaitable
)
from . import NoneBot, permission as perm
from .helpers import context_id, send, render_expression
from .log import logger
from .message import Message
from .session import BaseSession
from .typing import (
Context_T, CommandName_T, CommandArgs_T, Message_T
from nonebot import NoneBot, permission as perm
from nonebot.command.argfilter import ArgFilter_T, ValidateError
from nonebot.helpers import context_id, send, render_expression
from nonebot.log import logger
from nonebot.message import Message
from nonebot.session import BaseSession
from nonebot.typing import (
Context_T, CommandName_T, CommandArgs_T, Message_T, State_T
)
# key: one segment of command name
@ -27,19 +29,25 @@ _aliases = {} # type: Dict[str, CommandName_T]
# value: CommandSession object
_sessions = {} # type: Dict[str, CommandSession]
CommandHandler_T = Callable[['CommandSession'], Any]
class Command:
__slots__ = ('name', 'func', 'permission',
'only_to_me', 'privileged', 'args_parser_func')
def __init__(self, *, name: CommandName_T, func: Callable,
permission: int, only_to_me: bool, privileged: bool):
def __init__(self, *,
name: CommandName_T,
func: CommandHandler_T,
permission: int,
only_to_me: bool,
privileged: bool):
self.name = name
self.func = func
self.permission = permission
self.only_to_me = only_to_me
self.privileged = privileged
self.args_parser_func = None
self.args_parser_func: Optional[CommandHandler_T] = None
async def run(self, session, *,
check_perm: bool = True,
@ -56,8 +64,28 @@ class Command:
if self.func and has_perm:
if dry:
return True
if self.args_parser_func:
await self.args_parser_func(session)
if session.current_arg_filters is not None and \
session.current_key is not None:
# argument-level filters are given, use them
arg = session.current_arg
for f in session.current_arg_filters:
try:
res = f(arg)
if isinstance(res, Awaitable):
res = await res
arg = res
except ValidateError as e:
# validation failed
session.pause(e.message)
# passed all filters
session.state[session.current_key] = arg
else:
# fallback to command-level args_parser_func
if self.args_parser_func:
await self.args_parser_func(session)
await self.func(session)
return True
return False
@ -77,14 +105,14 @@ class Command:
class CommandFunc:
__slots__ = ('cmd', 'func')
def __init__(self, cmd: Command, func: Callable):
def __init__(self, cmd: Command, func: CommandHandler_T):
self.cmd = cmd
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __call__(self, session: 'CommandSession') -> Any:
return self.func(session)
def args_parser(self, parser_func: Callable):
def args_parser(self, parser_func: CommandHandler_T) -> CommandHandler_T:
"""
Decorator to register a function as the arguments parser of
the corresponding command.
@ -110,7 +138,7 @@ def on_command(name: Union[str, CommandName_T], *,
:param shell_like: use shell-like syntax to split arguments
"""
def deco(func: Callable) -> Callable:
def deco(func: CommandHandler_T) -> CommandHandler_T:
if not isinstance(name, (str, tuple)):
raise TypeError('the name of a command must be a str or tuple')
if not name:
@ -153,7 +181,7 @@ class CommandGroup:
privileged: Optional[bool] = None,
shell_like: Optional[bool] = None):
self.basename = (name,) if isinstance(name, str) else name
self.permission = permission
self.permission = permission # TODO: use .pyi
self.only_to_me = only_to_me
self.privileged = privileged
self.shell_like = shell_like
@ -204,10 +232,10 @@ def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]:
return cmd if isinstance(cmd, Command) else None
class _FurtherInteractionNeeded(Exception):
class _PauseException(Exception):
"""
Raised by session.pause() indicating that the command should
enter interactive mode to ask the user for some arguments.
Raised by session.pause() indicating that the command session
should be paused to ask the user for some arguments.
"""
pass
@ -244,23 +272,48 @@ class SwitchException(Exception):
class CommandSession(BaseSession):
__slots__ = ('cmd', 'current_key', 'current_arg',
__slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_filters',
'current_arg_text', 'current_arg_images',
'args', '_last_interaction', '_running')
'_state', '_last_interaction', '_running')
def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *,
current_arg: str = '', args: Optional[CommandArgs_T] = None):
super().__init__(bot, ctx)
self.cmd = cmd # Command object
self.current_key = None # current key that the command handler needs
self.current_arg = None # current argument (with potential CQ codes)
self.current_arg_text = None # current argument without any CQ codes
self.current_arg_images = None # image urls in current argument
self.refresh(ctx, current_arg=current_arg)
self.args = args or {}
# unique key of the argument that is currently requesting (asking)
self.current_key: Optional[str] = None
# initialize current argument
self.current_arg: str = '' # with potential CQ codes
self.current_arg_text: str = '' # without any CQ codes TODO: property
self.current_arg_images: List[str] = [] # image urls
self.refresh(ctx, current_arg=current_arg) # fill the above
# initialize current argument filters
self.current_arg_filters: Optional[List[ArgFilter_T]] = None
self._state: State_T = {}
if args:
self._state.update(args)
self._last_interaction = None # last interaction time of this session
self._running = False
@property
def state(self) -> State_T:
"""
State of the session.
This contains all named arguments and
other session scope temporary values.
"""
return self._state
@property
def args(self) -> CommandArgs_T:
"""Deprecated. Use `session.state` instead."""
return self.state
@property
def running(self) -> bool:
return self._running
@ -292,7 +345,7 @@ class CommandSession(BaseSession):
Shell-like argument list, similar to sys.argv.
Only available while shell_like is True in on_command decorator.
"""
return self.get_optional('argv', [])
return self.state.get('argv', [])
def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None:
"""
@ -308,38 +361,46 @@ class CommandSession(BaseSession):
self.current_arg_images = [s.data['url'] for s in current_arg_as_msg
if s.type == 'image' and 'url' in s.data]
def get(self, key: Any, *,
prompt: Optional[Message_T] = None, **kwargs) -> Any:
def get(self, key: str, *,
prompt: Optional[Message_T] = None,
arg_filters: Optional[List[ArgFilter_T]] = None,
**kwargs) -> Any:
"""
Get an argument with a given key.
If the argument does not exist in the current session,
a FurtherInteractionNeeded exception will be raised,
and the caller of the command will know it should keep
the session for further interaction with the user.
a pause exception will be raised, and the caller of
the command will know it should keep the session for
further interaction with the user.
:param key: argument key
:param prompt: prompt to ask the user
:param arg_filters: argument filters for next user input
:return: the argument value
"""
value = self.get_optional(key)
if value is not None:
return value
if key in self.state:
return self.state[key]
self.current_key = key
self.current_arg_filters = arg_filters
# TODO: self.current_send_kwargs
# ask the user for more information
self.pause(prompt, **kwargs)
def get_optional(self, key: Any,
def get_optional(self, key: str,
default: Optional[Any] = None) -> Optional[Any]:
"""Simply get a argument with given key."""
return self.args.get(key, default)
"""
Simply get a argument with given key.
Deprecated. Use `session.state.get()` instead.
"""
return self.state.get(key, default)
def pause(self, message: Optional[Message_T] = None, **kwargs) -> None:
"""Pause the session for further interaction."""
if message:
asyncio.ensure_future(self.send(message, **kwargs))
raise _FurtherInteractionNeeded
raise _PauseException
def finish(self, message: Optional[Message_T] = None, **kwargs) -> None:
"""Finish the session."""
@ -564,9 +625,7 @@ async def _real_run_command(session: CommandSession,
handled = future.result()
except asyncio.TimeoutError:
handled = True
except (_FurtherInteractionNeeded,
_FinishException,
SwitchException) as e:
except (_PauseException, _FinishException, SwitchException) as e:
raise e
except Exception as e:
logger.error(f'An exception occurred while '
@ -574,7 +633,7 @@ async def _real_run_command(session: CommandSession,
logger.exception(e)
handled = True
raise _FinishException(handled)
except _FurtherInteractionNeeded:
except _PauseException:
session.running = False
if disable_interaction:
# if the command needs further interaction, we view it as failed

View File

@ -0,0 +1,8 @@
from typing import Callable, Any, Awaitable, Union
ArgFilter_T = Callable[[Any], Union[Any, Awaitable[Any]]]
class ValidateError(ValueError):
def __init__(self, message=None):
self.message = message

View File

@ -0,0 +1,40 @@
from typing import Optional, List
def _simple_chinese_to_bool(text: str) -> Optional[bool]:
"""
Convert a chinese text to boolean.
Examples:
是的 -> True
好的呀 -> True
不要 -> False
不用了 -> False
你好呀 -> None
"""
text = text.strip().lower().replace(' ', '') \
.rstrip(',.!?~,。!?~了的呢吧呀啊呗啦')
if text in {'', '', '', '', '', '', '',
'ok', 'okay', 'yeah', 'yep',
'当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'}:
return True
if text in {'', '不要', '不用', '不是', '', '不好', '不对', '不行', '',
'no', 'nono', 'nonono', 'nope', '不ok', '不可以', '不能',
'不可以'}:
return False
return None
def _split_nonempty_lines(text: str) -> List[str]:
return list(filter(lambda x: x, text.splitlines()))
def _split_nonempty_stripped_lines(text: str) -> List[str]:
return list(filter(lambda x: x,
map(lambda x: x.strip(), text.splitlines())))
simple_chinese_to_bool = _simple_chinese_to_bool
split_nonempty_lines = _split_nonempty_lines
split_nonempty_stripped_lines = _split_nonempty_stripped_lines

View File

@ -0,0 +1,29 @@
import re
from typing import List
from nonebot.message import Message
from nonebot.typing import Message_T
def _extract_text(arg: Message_T) -> str:
"""Extract all plain text segments from a message-like object."""
arg_as_msg = Message(arg)
return arg_as_msg.extract_plain_text()
def _extract_image_urls(arg: Message_T) -> List[str]:
"""Extract all image urls from a message-like object."""
arg_as_msg = Message(arg)
return [s.data['url'] for s in arg_as_msg
if s.type == 'image' and 'url' in s.data]
def _extract_numbers(arg: Message_T) -> List[float]:
"""Extract all numbers (integers and floats) from a message-like object."""
s = str(arg)
return list(map(float, re.findall(r'[+-]?(\d*\.?\d+|\d+\.?\d*)', s)))
extract_text = _extract_text
extract_image_urls = _extract_image_urls
extract_numbers = _extract_numbers

View File

@ -0,0 +1,101 @@
import re
from typing import Callable, Any
from nonebot.command.argfilter import ValidateError
class BaseValidator:
def __init__(self, message=None):
self.message = message
def raise_failure(self):
raise ValidateError(self.message)
class not_empty(BaseValidator):
"""
Validate any object to ensure it's not empty (is None or has no elements).
"""
def __call__(self, value):
if value is None:
self.raise_failure()
if hasattr(value, '__len__') and value.__len__() == 0:
self.raise_failure()
return value
class fit_size(BaseValidator):
"""
Validate any sized object to ensure the size/length
is in a given range [min_length, max_length].
"""
def __init__(self, min_length: int = 0, max_length: int = None,
message=None):
super().__init__(message)
self.min_length = min_length
self.max_length = max_length
def __call__(self, value):
length = len(value) if value is not None else 0
if length < self.min_length or \
(self.max_length is not None and length > self.max_length):
self.raise_failure()
return value
class match_regex(BaseValidator):
"""
Validate any string object to ensure it matches a given pattern.
"""
def __init__(self, pattern: str, message=None, *, flags=0,
fullmatch: bool = False):
super().__init__(message)
self.pattern = re.compile(pattern, flags)
self.fullmatch = fullmatch
def __call__(self, value):
if self.fullmatch:
if not re.fullmatch(self.pattern, value):
self.raise_failure()
else:
if not re.match(self.pattern, value):
self.raise_failure()
return value
class ensure_true(BaseValidator):
"""
Validate any object to ensure the result of applying
a boolean function to it is True.
"""
def __init__(self, bool_func: Callable[[Any], bool], message=None):
super().__init__(message)
self.bool_func = bool_func
def __call__(self, value):
if self.bool_func(value) is not True:
self.raise_failure()
return value
class between_inclusive(BaseValidator):
"""
Validate any comparable object to ensure it's between
`start` and `end` inclusively.
"""
def __init__(self, start=None, end=None, message=None):
super().__init__(message)
self.start = start
self.end = end
def __call__(self, value):
if self.start is not None and value < self.start:
self.raise_failure()
if self.end is not None and self.end < value:
self.raise_failure()
return value