Improve code docs

This commit is contained in:
Richard Chien 2018-07-01 20:01:05 +08:00
parent 6ec3ac66f7
commit 026d83e507
9 changed files with 100 additions and 39 deletions

View File

@ -13,7 +13,7 @@ from .notice_request import handle_notice_or_request
from .log import logger from .log import logger
def create_bot(config_object: Any = None): def create_bot(config_object: Any = None) -> CQHttp:
if config_object is None: if config_object is None:
from . import default_config as config_object from . import default_config as config_object
@ -46,7 +46,7 @@ def create_bot(config_object: Any = None):
_plugins = set() _plugins = set()
def load_plugins(plugin_dir: str, module_prefix: str): def load_plugins(plugin_dir: str, module_prefix: str) -> None:
for name in os.listdir(plugin_dir): for name in os.listdir(plugin_dir):
path = os.path.join(plugin_dir, name) path = os.path.join(plugin_dir, name)
if os.path.isfile(path) and \ if os.path.isfile(path) and \
@ -69,7 +69,7 @@ def load_plugins(plugin_dir: str, module_prefix: str):
logger.warning('Failed to import "{}"'.format(mod_name)) logger.warning('Failed to import "{}"'.format(mod_name))
def load_builtin_plugins(): def load_builtin_plugins() -> None:
plugin_dir = os.path.join(os.path.dirname(__file__), 'plugins') plugin_dir = os.path.join(os.path.dirname(__file__), 'plugins')
load_plugins(plugin_dir, 'none.plugins') load_plugins(plugin_dir, 'none.plugins')

View File

@ -9,7 +9,7 @@ from aiocqhttp import CQHttp
from aiocqhttp.message import Message from aiocqhttp.message import Message
from . import permission as perm from . import permission as perm
from .helpers import context_source from .helpers import context_id
from .expression import render from .expression import render
from .session import BaseSession from .session import BaseSession
@ -21,8 +21,8 @@ _registry = {}
# Value: tuple that identifies a command # Value: tuple that identifies a command
_aliases = {} _aliases = {}
# Key: context source # Key: context id
# Value: Session object # Value: CommandSession object
_sessions = {} _sessions = {}
@ -62,6 +62,15 @@ def on_command(name: Union[str, Tuple[str]], *,
aliases: Iterable = (), aliases: Iterable = (),
permission: int = perm.EVERYBODY, permission: int = perm.EVERYBODY,
only_to_me: bool = True) -> Callable: only_to_me: bool = True) -> Callable:
"""
Decorator to register a function as a command.
:param name: command name (e.g. 'echo' or ('random', 'number'))
:param aliases: aliases of command name, for convenient access
:param permission: permission required by the command
:param only_to_me: only handle messages to me
"""
def deco(func: Callable) -> Callable: def deco(func: Callable) -> Callable:
if not isinstance(name, (str, tuple)): if not isinstance(name, (str, tuple)):
raise TypeError('the name of a command must be a str or tuple') raise TypeError('the name of a command must be a str or tuple')
@ -93,7 +102,6 @@ class CommandGroup:
""" """
Group a set of commands with same name prefix. Group a set of commands with same name prefix.
""" """
__slots__ = ('basename', 'permission', 'only_to_me') __slots__ = ('basename', 'permission', 'only_to_me')
def __init__(self, name: Union[str, Tuple[str]], def __init__(self, name: Union[str, Tuple[str]],
@ -140,9 +148,8 @@ def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]:
class _FurtherInteractionNeeded(Exception): class _FurtherInteractionNeeded(Exception):
""" """
Raised by session.require_arg() indicating Raised by session.get() indicating that the command should
that the command should enter interactive mode enter interactive mode to ask the user for some arguments.
to ask the user for some arguments.
""" """
pass pass
@ -154,14 +161,14 @@ class CommandSession(BaseSession):
def __init__(self, bot: CQHttp, ctx: Dict[str, Any], cmd: Command, *, def __init__(self, bot: CQHttp, ctx: Dict[str, Any], cmd: Command, *,
current_arg: str = '', args: Optional[Dict[str, Any]] = None): current_arg: str = '', args: Optional[Dict[str, Any]] = None):
super().__init__(bot, ctx) super().__init__(bot, ctx)
self.cmd = cmd self.cmd = cmd # Command object
self.current_key = None self.current_key = None # current key that the command handler needs
self.current_arg = None self.current_arg = None # current argument (with potential CQ codes)
self.current_arg_text = None self.current_arg_text = None # current argument without any CQ codes
self.current_arg_images = None self.current_arg_images = None # image urls in current argument
self.refresh(ctx, current_arg=current_arg) self.refresh(ctx, current_arg=current_arg)
self.args = args or {} self.args = args or {}
self.last_interaction = None self.last_interaction = None # last interaction time of this session
def refresh(self, ctx: Dict[str, Any], *, current_arg: str = '') -> None: def refresh(self, ctx: Dict[str, Any], *, current_arg: str = '') -> None:
""" """
@ -224,26 +231,25 @@ def _new_command_session(bot: CQHttp,
""" """
Create a new session for a command. Create a new session for a command.
This will firstly attempt to parse the current message as This will attempt to parse the current message as a command,
a command, and if succeeded, it then create a session for and if succeeded, it then create a session for the command and return.
the command and return. If the message is not a valid command, If the message is not a valid command, None will be returned.
None will be returned.
:param bot: CQHttp instance :param bot: CQHttp instance
:param ctx: message context :param ctx: message context
:return: CommandSession object or None :return: CommandSession object or None
""" """
msg_text = str(ctx['message']).lstrip() msg = str(ctx['message']).lstrip()
for start in bot.config.COMMAND_START: for start in bot.config.COMMAND_START:
if isinstance(start, type(re.compile(''))): if isinstance(start, type(re.compile(''))):
m = start.search(msg_text) m = start.search(msg)
if m: if m:
full_command = msg_text[len(m.group(0)):].lstrip() full_command = msg[len(m.group(0)):].lstrip()
break break
elif isinstance(start, str): elif isinstance(start, str):
if msg_text.startswith(start): if msg.startswith(start):
full_command = msg_text[len(start):].lstrip() full_command = msg[len(start):].lstrip()
break break
else: else:
# it's not a command # it's not a command
@ -286,25 +292,24 @@ async def handle_command(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
:param ctx: message context :param ctx: message context
:return: the message is handled as a command :return: the message is handled as a command
""" """
src = context_source(ctx) ctx_id = context_id(ctx)
session = None session = None
check_perm = True check_perm = True
if _sessions.get(src): if _sessions.get(ctx_id):
session = _sessions[src] session = _sessions[ctx_id]
if session and session.is_valid: if session and session.is_valid:
session.refresh(ctx, current_arg=str(ctx['message'])) session.refresh(ctx, current_arg=str(ctx['message']))
# there is no need to check permission for existing session # there is no need to check permission for existing session
check_perm = False check_perm = False
else: else:
# the session is expired, remove it # the session is expired, remove it
del _sessions[src] del _sessions[ctx_id]
session = None session = None
if not session: if not session:
session = _new_command_session(bot, ctx) session = _new_command_session(bot, ctx)
if not session: if not session:
return False return False
return await _real_run_command(session, ctx_id, check_perm=check_perm)
return await _real_run_command(session, src, check_perm=check_perm)
async def call_command(bot: CQHttp, ctx: Dict[str, Any], async def call_command(bot: CQHttp, ctx: Dict[str, Any],
@ -316,6 +321,10 @@ async def call_command(bot: CQHttp, ctx: Dict[str, Any],
This function is typically called by some other commands This function is typically called by some other commands
or "handle_natural_language" when handling NLPResult object. or "handle_natural_language" when handling NLPResult object.
Note: After calling this function, any previous command session
will be overridden, even if the command being called here does
not need further interaction (a.k.a asking the user for more info).
:param bot: CQHttp instance :param bot: CQHttp instance
:param ctx: message context :param ctx: message context
:param name: command name :param name: command name
@ -326,17 +335,17 @@ async def call_command(bot: CQHttp, ctx: Dict[str, Any],
if not cmd: if not cmd:
return False return False
session = CommandSession(bot, ctx, cmd, args=args) session = CommandSession(bot, ctx, cmd, args=args)
return await _real_run_command(session, context_source(session.ctx), return await _real_run_command(session, context_id(session.ctx),
check_perm=False) check_perm=False)
async def _real_run_command(session: CommandSession, async def _real_run_command(session: CommandSession,
ctx_src: str, **kwargs) -> bool: ctx_id: str, **kwargs) -> bool:
_sessions[ctx_src] = session _sessions[ctx_id] = session
try: try:
res = await session.cmd.run(session, **kwargs) res = await session.cmd.run(session, **kwargs)
# the command is finished, remove the session # the command is finished, remove the session
del _sessions[ctx_src] del _sessions[ctx_id]
return res return res
except _FurtherInteractionNeeded: except _FurtherInteractionNeeded:
session.last_interaction = datetime.now() session.last_interaction = datetime.now()

View File

@ -6,6 +6,14 @@ from aiocqhttp import message
def render(expr: Union[str, Sequence[str], Callable], *, escape_args=True, def render(expr: Union[str, Sequence[str], Callable], *, escape_args=True,
**kwargs) -> str: **kwargs) -> str:
"""
Render an expression to message string.
:param expr: expression to render
:param escape_args: should escape arguments or not
:param kwargs: keyword arguments used in str.format()
:return: the rendered message
"""
if isinstance(expr, Callable): if isinstance(expr, Callable):
expr = expr() expr = expr()
elif isinstance(expr, Sequence): elif isinstance(expr, Sequence):

View File

@ -5,7 +5,10 @@ from aiocqhttp import CQHttp, Error as CQHttpError
from . import expression from . import expression
def context_source(ctx: Dict[str, Any]) -> str: def context_id(ctx: Dict[str, Any]) -> str:
"""
Calculate a unique id representing the current user.
"""
src = '' src = ''
if ctx.get('group_id'): if ctx.get('group_id'):
src += f'/group/{ctx["group_id"]}' src += f'/group/{ctx["group_id"]}'
@ -19,6 +22,9 @@ def context_source(ctx: Dict[str, Any]) -> str:
async def send(bot: CQHttp, ctx: Dict[str, Any], async def send(bot: CQHttp, ctx: Dict[str, Any],
message: Union[str, Dict[str, Any], List[Dict[str, Any]]], message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
*, ignore_failure: bool = True) -> None: *, ignore_failure: bool = True) -> None:
"""
Send a message ignoring failure by default.
"""
try: try:
if ctx.get('post_type') == 'message': if ctx.get('post_type') == 'message':
await bot.send(ctx, message) await bot.send(ctx, message)
@ -40,4 +46,7 @@ async def send(bot: CQHttp, ctx: Dict[str, Any],
async def send_expr(bot: CQHttp, ctx: Dict[str, Any], async def send_expr(bot: CQHttp, ctx: Dict[str, Any],
expr: Union[str, Sequence[str], Callable], expr: Union[str, Sequence[str], Callable],
**kwargs): **kwargs):
"""
Sending a expression message ignoring failure by default.
"""
return await send(bot, ctx, expression.render(expr, **kwargs)) return await send(bot, ctx, expression.render(expr, **kwargs))

View File

@ -23,7 +23,7 @@ async def handle_message(bot: CQHttp, ctx: Dict[str, Any]) -> None:
handled = await handle_command(bot, ctx) handled = await handle_command(bot, ctx)
if handled: if handled:
logger.debug('Message is handled as command') logger.debug('Message is handled as a command')
return return
handled = await handle_natural_language(bot, ctx) handled = await handle_natural_language(bot, ctx)

View File

@ -28,6 +28,14 @@ class NLProcessor:
def on_natural_language(keywords: Union[Optional[Iterable], Callable] = None, *, def on_natural_language(keywords: Union[Optional[Iterable], Callable] = None, *,
permission: int = perm.EVERYBODY, permission: int = perm.EVERYBODY,
only_to_me: bool = True) -> Callable: only_to_me: bool = True) -> Callable:
"""
Decorator to register a function as a natural language processor.
:param keywords: keywords to respond, if None, respond to all messages
:param permission: permission required by the processor
:param only_to_me: only handle messages to me
"""
def deco(func: Callable) -> Callable: def deco(func: Callable) -> Callable:
nl_processor = NLProcessor(func=func, keywords=keywords, nl_processor = NLProcessor(func=func, keywords=keywords,
permission=permission, only_to_me=only_to_me) permission=permission, only_to_me=only_to_me)
@ -61,12 +69,23 @@ NLPResult = namedtuple('NLPResult', (
async def handle_natural_language(bot: CQHttp, ctx: Dict[str, Any]) -> bool: async def handle_natural_language(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
"""
Handle a message as natural language.
This function is typically called by "handle_message".
:param bot: CQHttp instance
:param ctx: message context
:return: the message is handled as natural language
"""
msg = str(ctx['message']) msg = str(ctx['message'])
if bot.config.NICKNAME: if bot.config.NICKNAME:
# check if the user is calling to me with my nickname
m = re.search(rf'^{bot.config.NICKNAME}[\s,]+', msg) m = re.search(rf'^{bot.config.NICKNAME}[\s,]+', msg)
if m: if m:
ctx['to_me'] = True ctx['to_me'] = True
msg = msg[m.end():] msg = msg[m.end():]
session = NLPSession(bot, ctx, msg) session = NLPSession(bot, ctx, msg)
coros = [] coros = []
@ -86,10 +105,12 @@ async def handle_natural_language(bot: CQHttp, ctx: Dict[str, Any]) -> bool:
coros.append(p.func(session)) coros.append(p.func(session))
if coros: if coros:
# wait for possible results, and sort them by confidence
results = sorted(filter(lambda r: r, await asyncio.gather(*coros)), results = sorted(filter(lambda r: r, await asyncio.gather(*coros)),
key=lambda r: r.confidence, reverse=True) key=lambda r: r.confidence, reverse=True)
logger.debug(results) logger.debug(results)
if results and results[0].confidence >= 60.0: if results and results[0].confidence >= 60.0:
# choose the result with highest confidence
return await call_command(bot, ctx, return await call_command(bot, ctx,
results[0].cmd_name, results[0].cmd_args) results[0].cmd_name, results[0].cmd_args)
return False return False

View File

@ -82,7 +82,7 @@ async def handle_notice_or_request(bot: CQHttp, ctx: Dict[str, Any]) -> None:
if post_type == 'notice': if post_type == 'notice':
session = NoticeSession(bot, ctx) session = NoticeSession(bot, ctx)
else: else: # must be 'request'
session = RequestSession(bot, ctx) session = RequestSession(bot, ctx)
logger.debug(f'Emitting event: {event}') logger.debug(f'Emitting event: {event}')

View File

@ -44,6 +44,14 @@ _MinContext = namedtuple('MinContext', _min_context_fields)
async def check_permission(bot: CQHttp, ctx: Dict[str, Any], async def check_permission(bot: CQHttp, ctx: Dict[str, Any],
permission_required: int) -> bool: permission_required: int) -> bool:
"""
Check if the context has the permission required.
:param bot: CQHttp instance
:param ctx: message context
:param permission_required: permission required
:return: the context has the permission
"""
min_ctx_kwargs = {} min_ctx_kwargs = {}
for field in _min_context_fields: for field in _min_context_fields:
if field in ctx: if field in ctx:
@ -54,7 +62,7 @@ async def check_permission(bot: CQHttp, ctx: Dict[str, Any],
return await _check(bot, min_ctx, permission_required) return await _check(bot, min_ctx, permission_required)
@cached(ttl=2 * 60) # cache the result for 1 minute @cached(ttl=2 * 60) # cache the result for 2 minute
async def _check(bot: CQHttp, min_ctx: _MinContext, async def _check(bot: CQHttp, min_ctx: _MinContext,
permission_required: int) -> bool: permission_required: int) -> bool:
permission = 0 permission = 0

View File

@ -15,10 +15,16 @@ class BaseSession:
async def send(self, async def send(self,
message: Union[str, Dict[str, Any], List[Dict[str, Any]]], message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
*, ignore_failure: bool = True) -> None: *, ignore_failure: bool = True) -> None:
"""
Send a message ignoring failure by default.
"""
return await send(self.bot, self.ctx, message, return await send(self.bot, self.ctx, message,
ignore_failure=ignore_failure) ignore_failure=ignore_failure)
async def send_expr(self, async def send_expr(self,
expr: Union[str, Sequence[str], Callable], expr: Union[str, Sequence[str], Callable],
**kwargs): **kwargs):
"""
Sending a expression message ignoring failure by default.
"""
return await send_expr(self.bot, self.ctx, expr, **kwargs) return await send_expr(self.bot, self.ctx, expr, **kwargs)