Fix type hints and small bugs

This commit is contained in:
Richard Chien 2018-10-16 01:03:50 +08:00
parent 73e891d521
commit 0046ebacac
12 changed files with 98 additions and 86 deletions

View File

@ -5,11 +5,11 @@ import os
import re import re
from typing import Any, Optional from typing import Any, Optional
import aiocqhttp.message
from aiocqhttp import CQHttp from aiocqhttp import CQHttp
from aiocqhttp.message import Message
from .log import logger from .log import logger
from .scheduler import Scheduler from .sched import Scheduler
if Scheduler: if Scheduler:
scheduler = Scheduler() scheduler = Scheduler()
@ -18,14 +18,14 @@ else:
class NoneBot(CQHttp): class NoneBot(CQHttp):
def __init__(self, config_object: Any = None): def __init__(self, config_object: Optional[Any] = None):
if config_object is None: if config_object is None:
from . import default_config as config_object from . import default_config as config_object
config_dict = {k: v for k, v in config_object.__dict__.items() config_dict = {k: v for k, v in config_object.__dict__.items()
if k.isupper() and not k.startswith('_')} if k.isupper() and not k.startswith('_')}
logger.debug(f'Loaded configurations: {config_dict}') logger.debug(f'Loaded configurations: {config_dict}')
super().__init__(message_class=Message, super().__init__(message_class=aiocqhttp.message.Message,
**{k.lower(): v for k, v in config_dict.items()}) **{k.lower(): v for k, v in config_dict.items()})
self.config = config_object self.config = config_object
@ -46,7 +46,8 @@ class NoneBot(CQHttp):
async def _(ctx): async def _(ctx):
asyncio.ensure_future(handle_notice_or_request(self, ctx)) asyncio.ensure_future(handle_notice_or_request(self, ctx))
def run(self, host: str = None, port: int = None, *args, **kwargs): def run(self, host: Optional[str] = None, port: Optional[int] = None,
*args, **kwargs) -> None:
host = host or self.config.HOST host = host or self.config.HOST
port = port or self.config.PORT port = port or self.config.PORT
if 'debug' not in kwargs: if 'debug' not in kwargs:
@ -60,7 +61,7 @@ class NoneBot(CQHttp):
_bot: Optional[NoneBot] = None _bot: Optional[NoneBot] = None
def init(config_object: Any = None) -> None: def init(config_object: Optional[Any] = None) -> None:
""" """
Initialize NoneBot instance. Initialize NoneBot instance.
@ -97,7 +98,8 @@ def get_bot() -> NoneBot:
return _bot return _bot
def run(host: str = None, port: int = None, *args, **kwargs) -> None: def run(host: Optional[str] = None, port: Optional[int] = None,
*args, **kwargs) -> None:
"""Run the NoneBot instance.""" """Run the NoneBot instance."""
get_bot().run(host=host, port=port, *args, **kwargs) get_bot().run(host=host, port=port, *args, **kwargs)
@ -143,6 +145,7 @@ def load_builtin_plugins() -> None:
load_plugins(plugin_dir, 'none.plugins') load_plugins(plugin_dir, 'none.plugins')
from .exceptions import *
from .message import message_preprocessor, Message, MessageSegment from .message import message_preprocessor, Message, MessageSegment
from .command import on_command, CommandSession, CommandGroup from .command import on_command, CommandSession, CommandGroup
from .natural_language import on_natural_language, NLPSession, NLPResult from .natural_language import on_natural_language, NLPSession, NLPResult

View File

@ -2,15 +2,18 @@ import asyncio
import re import re
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
Tuple, Union, Callable, Iterable, Dict, Any, Optional, Sequence Tuple, Union, Callable, Iterable, Any, Optional
) )
from . import NoneBot, permission as perm from . import NoneBot, permission as perm
from .log import logger
from .message import Message
from .expression import render from .expression import render
from .helpers import context_id, send_expr from .helpers import context_id, send_expr
from .log import logger
from .message import Message
from .session import BaseSession from .session import BaseSession
from .typing import (
Context_T, CommandName_T, CommandArgs_T, Expression_T, Message_T
)
# Key: str (one segment of command name) # Key: str (one segment of command name)
# Value: subtree or a leaf Command object # Value: subtree or a leaf Command object
@ -29,8 +32,8 @@ class Command:
__slots__ = ('name', 'func', 'permission', __slots__ = ('name', 'func', 'permission',
'only_to_me', 'privileged', 'args_parser_func') 'only_to_me', 'privileged', 'args_parser_func')
def __init__(self, *, name: Tuple[str], func: Callable, permission: int, def __init__(self, *, name: CommandName_T, func: Callable,
only_to_me: bool, privileged: bool): permission: int, only_to_me: bool, privileged: bool):
self.name = name self.name = name
self.func = func self.func = func
self.permission = permission self.permission = permission
@ -71,8 +74,8 @@ class Command:
self.permission) self.permission)
def on_command(name: Union[str, Tuple[str]], *, def on_command(name: Union[str, CommandName_T], *,
aliases: Iterable = (), aliases: Iterable[str] = (),
permission: int = perm.EVERYBODY, permission: int = perm.EVERYBODY,
only_to_me: bool = True, only_to_me: bool = True,
privileged: bool = False) -> Callable: privileged: bool = False) -> Callable:
@ -117,19 +120,22 @@ class CommandGroup:
""" """
Group a set of commands with same name prefix. Group a set of commands with same name prefix.
""" """
__slots__ = ('basename', 'permission', 'only_to_me') __slots__ = ('basename', 'permission', 'only_to_me', 'privileged')
def __init__(self, name: Union[str, Tuple[str]], def __init__(self, name: Union[str, CommandName_T],
permission: Optional[int] = None, *, permission: Optional[int] = None, *,
only_to_me: Optional[bool] = None): only_to_me: Optional[bool] = None,
privileged: Optional[bool] = None):
self.basename = (name,) if isinstance(name, str) else name self.basename = (name,) if isinstance(name, str) else name
self.permission = permission self.permission = permission
self.only_to_me = only_to_me self.only_to_me = only_to_me
self.privileged = privileged
def command(self, name: Union[str, Tuple[str]], *, def command(self, name: Union[str, CommandName_T], *,
aliases: Optional[Iterable] = None, aliases: Optional[Iterable[str]] = None,
permission: Optional[int] = None, permission: Optional[int] = None,
only_to_me: Optional[bool] = None) -> Callable: only_to_me: Optional[bool] = None,
privileged: Optional[bool] = None) -> Callable:
sub_name = (name,) if isinstance(name, str) else name sub_name = (name,) if isinstance(name, str) else name
name = self.basename + sub_name name = self.basename + sub_name
@ -144,10 +150,14 @@ class CommandGroup:
kwargs['only_to_me'] = only_to_me kwargs['only_to_me'] = only_to_me
elif self.only_to_me is not None: elif self.only_to_me is not None:
kwargs['only_to_me'] = self.only_to_me kwargs['only_to_me'] = self.only_to_me
if privileged is not None:
kwargs['privileged'] = privileged
elif self.privileged is not None:
kwargs['privileged'] = self.privileged
return on_command(name, **kwargs) return on_command(name, **kwargs)
def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]: def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]:
cmd_name = (name,) if isinstance(name, str) else name cmd_name = (name,) if isinstance(name, str) else name
if not cmd_name: if not cmd_name:
return None return None
@ -204,8 +214,8 @@ class CommandSession(BaseSession):
__slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_text', __slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_text',
'current_arg_images', 'args', '_last_interaction', '_running') 'current_arg_images', 'args', '_last_interaction', '_running')
def __init__(self, bot: NoneBot, ctx: Dict[str, Any], cmd: Command, *, def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *,
current_arg: str = '', args: Optional[Dict[str, Any]] = None): current_arg: str = '', args: Optional[CommandArgs_T] = None):
super().__init__(bot, ctx) super().__init__(bot, ctx)
self.cmd = cmd # Command object self.cmd = cmd # Command object
self.current_key = None # current key that the command handler needs self.current_key = None # current key that the command handler needs
@ -218,21 +228,21 @@ class CommandSession(BaseSession):
self._running = False self._running = False
@property @property
def running(self): def running(self) -> bool:
return self._running return self._running
@running.setter @running.setter
def running(self, value): def running(self, value) -> None:
if self._running is True and value is False: if self._running is True and value is False:
# change status from running to not running, record the time # change status from running to not running, record the time
self._last_interaction = datetime.now() self._last_interaction = datetime.now()
self._running = value self._running = value
@property @property
def is_first_run(self): def is_first_run(self) -> bool:
return self._last_interaction is None return self._last_interaction is None
def refresh(self, ctx: Dict[str, Any], *, current_arg: str = '') -> None: def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None:
""" """
Refill the session with a new message context. Refill the session with a new message context.
@ -256,8 +266,9 @@ class CommandSession(BaseSession):
return False return False
return True return True
def get(self, key: str, *, prompt: str = None, def get(self, key: Any, *,
prompt_expr: Union[str, Sequence[str], Callable] = None) -> Any: prompt: Optional[Message_T] = None,
prompt_expr: Optional[Expression_T] = None) -> Any:
""" """
Get an argument with a given key. Get an argument with a given key.
@ -270,7 +281,6 @@ class CommandSession(BaseSession):
:param prompt: prompt to ask the user :param prompt: prompt to ask the user
:param prompt_expr: prompt expression to ask the user :param prompt_expr: prompt expression to ask the user
:return: the argument value :return: the argument value
:raise FurtherInteractionNeeded: further interaction is needed
""" """
value = self.get_optional(key) value = self.get_optional(key)
if value is not None: if value is not None:
@ -282,25 +292,24 @@ class CommandSession(BaseSession):
prompt = render(prompt_expr, key=key) prompt = render(prompt_expr, key=key)
self.pause(prompt) self.pause(prompt)
def get_optional(self, key: str, def get_optional(self, key: Any,
default: Optional[Any] = None) -> Optional[Any]: default: Optional[Any] = None) -> Optional[Any]:
"""Simply get a argument with given key.""" """Simply get a argument with given key."""
return self.args.get(key, default) return self.args.get(key, default)
def pause(self, message=None) -> None: def pause(self, message: Optional[Message_T] = None) -> None:
"""Pause the session for further interaction.""" """Pause the session for further interaction."""
if message: if message:
asyncio.ensure_future(self.send(message)) asyncio.ensure_future(self.send(message))
raise _FurtherInteractionNeeded raise _FurtherInteractionNeeded
def finish(self, message=None) -> None: def finish(self, message: Optional[Message_T] = None) -> None:
"""Finish the session.""" """Finish the session."""
if message: if message:
asyncio.ensure_future(self.send(message)) asyncio.ensure_future(self.send(message))
raise _FinishException raise _FinishException
# noinspection PyMethodMayBeStatic def switch(self, new_ctx_message: Message_T) -> None:
def switch(self, new_ctx_message: Any) -> None:
""" """
Finish the session and switch to a new (fake) message context. Finish the session and switch to a new (fake) message context.
@ -392,7 +401,7 @@ def parse_command(bot: NoneBot,
return cmd, ''.join(cmd_remained) return cmd, ''.join(cmd_remained)
async def handle_command(bot: NoneBot, ctx: Dict[str, Any]) -> bool: async def handle_command(bot: NoneBot, ctx: Context_T) -> bool:
""" """
Handle a message as a command. Handle a message as a command.
@ -456,10 +465,10 @@ async def handle_command(bot: NoneBot, ctx: Dict[str, Any]) -> bool:
disable_interaction=disable_interaction) disable_interaction=disable_interaction)
async def call_command(bot: NoneBot, ctx: Dict[str, Any], async def call_command(bot: NoneBot, ctx: Context_T,
name: Union[str, Tuple[str]], *, name: Union[str, CommandName_T], *,
current_arg: str = '', current_arg: str = '',
args: Optional[Dict[str, Any]] = None, args: Optional[CommandArgs_T] = None,
check_perm: bool = True, check_perm: bool = True,
disable_interaction: bool = False) -> bool: disable_interaction: bool = False) -> bool:
""" """
@ -543,7 +552,7 @@ async def _real_run_command(session: CommandSession,
raise e # this is intended to be propagated to handle_message() raise e # this is intended to be propagated to handle_message()
def kill_current_session(bot: NoneBot, ctx: Dict[str, Any]) -> None: def kill_current_session(bot: NoneBot, ctx: Context_T) -> None:
""" """
Force kill current session of the given context, Force kill current session of the given context,
despite whether it is running or not. despite whether it is running or not.

1
none/exceptions.py Normal file
View File

@ -0,0 +1 @@
from aiocqhttp import Error as CQHttpError

View File

@ -1,13 +1,11 @@
import random import random
from typing import Union, Sequence, Callable from typing import Sequence, Callable
from .message import escape from .message import escape
from .typing import Expression_T
Expression_T = Union[str, Sequence[str], Callable]
def render(expr: Expression_T, *, escape_args=True, def render(expr: Expression_T, *, escape_args: bool = True, **kwargs) -> str:
**kwargs) -> str:
""" """
Render an expression to message string. Render an expression to message string.

View File

@ -1,12 +1,11 @@
import hashlib import hashlib
from typing import Dict, Any, Union, List, Sequence, Callable
from aiocqhttp import Error as CQHttpError
from . import NoneBot, expression from . import NoneBot, expression
from .exceptions import CQHttpError
from .typing import Context_T, Message_T, Expression_T
def context_id(ctx: Dict[str, Any], *, def context_id(ctx: Context_T, *,
mode: str = 'default', use_hash: bool = False) -> str: mode: str = 'default', use_hash: bool = False) -> str:
""" """
Calculate a unique id representing the current context. Calculate a unique id representing the current context.
@ -42,9 +41,8 @@ def context_id(ctx: Dict[str, Any], *,
return ctx_id return ctx_id
async def send(bot: NoneBot, ctx: Dict[str, Any], async def send(bot: NoneBot, ctx: Context_T, message: Message_T, *,
message: Union[str, Dict[str, Any], List[Dict[str, Any]]], ignore_failure: bool = True) -> None:
*, ignore_failure: bool = True) -> None:
"""Send a message ignoring failure by default.""" """Send a message ignoring failure by default."""
try: try:
if ctx.get('post_type') == 'message': if ctx.get('post_type') == 'message':
@ -64,8 +62,7 @@ async def send(bot: NoneBot, ctx: Dict[str, Any],
raise raise
async def send_expr(bot: NoneBot, ctx: Dict[str, Any], async def send_expr(bot: NoneBot, ctx: Context_T,
expr: Union[str, Sequence[str], Callable], expr: Expression_T, **kwargs):
**kwargs):
"""Sending a expression message ignoring failure by default.""" """Sending a expression message ignoring failure by default."""
return await send(bot, ctx, expression.render(expr, **kwargs)) return await send(bot, ctx, expression.render(expr, **kwargs))

View File

@ -7,6 +7,7 @@ from . import NoneBot
from .command import handle_command, SwitchException from .command import handle_command, SwitchException
from .log import logger from .log import logger
from .natural_language import handle_natural_language from .natural_language import handle_natural_language
from .typing import Context_T
_message_preprocessors = set() _message_preprocessors = set()
@ -16,12 +17,12 @@ def message_preprocessor(func: Callable) -> Callable:
return func return func
async def handle_message(bot: NoneBot, ctx: Dict[str, Any]) -> None: async def handle_message(bot: NoneBot, ctx: Context_T) -> None:
_log_message(ctx) _log_message(ctx)
coros = [] coros = []
for processor in _message_preprocessors: for processor in _message_preprocessors:
coros.append(processor(ctx)) coros.append(processor(bot, ctx))
if coros: if coros:
await asyncio.wait(coros) await asyncio.wait(coros)
@ -56,7 +57,7 @@ async def handle_message(bot: NoneBot, ctx: Dict[str, Any]) -> None:
return return
def _log_message(ctx: Dict[str, Any]) -> None: def _log_message(ctx: Context_T) -> None:
msg_from = f'{ctx["user_id"]}' msg_from = f'{ctx["user_id"]}'
if ctx['message_type'] == 'group': if ctx['message_type'] == 'group':
msg_from += f'@[群:{ctx["group_id"]}]' msg_from += f'@[群:{ctx["group_id"]}]'

View File

@ -1,13 +1,13 @@
import asyncio import asyncio
import re import re
from collections import namedtuple from typing import Iterable, Optional, Callable, Union, NamedTuple
from typing import Dict, Any, Iterable, Optional, Callable, Union
from . import NoneBot, permission as perm from . import NoneBot, permission as perm
from .command import call_command from .command import call_command
from .log import logger from .log import logger
from .message import Message from .message import Message
from .session import BaseSession from .session import BaseSession
from .typing import Context_T, CommandName_T, CommandArgs_T
_nl_processors = set() _nl_processors = set()
@ -56,7 +56,7 @@ def on_natural_language(keywords: Union[Optional[Iterable], Callable] = None,
class NLPSession(BaseSession): class NLPSession(BaseSession):
__slots__ = ('msg', 'msg_text', 'msg_images') __slots__ = ('msg', 'msg_text', 'msg_images')
def __init__(self, bot: NoneBot, ctx: Dict[str, Any], msg: str): def __init__(self, bot: NoneBot, ctx: Context_T, msg: str):
super().__init__(bot, ctx) super().__init__(bot, ctx)
self.msg = msg self.msg = msg
tmp_msg = Message(msg) tmp_msg = Message(msg)
@ -65,14 +65,13 @@ class NLPSession(BaseSession):
if s.type == 'image' and 'url' in s.data] if s.type == 'image' and 'url' in s.data]
NLPResult = namedtuple('NLPResult', ( class NLPResult(NamedTuple):
'confidence', confidence: float
'cmd_name', cmd_name: Union[str, CommandName_T]
'cmd_args', cmd_args: Optional[CommandArgs_T] = None
))
async def handle_natural_language(bot: NoneBot, ctx: Dict[str, Any]) -> bool: async def handle_natural_language(bot: NoneBot, ctx: Context_T) -> bool:
""" """
Handle a message as natural language. Handle a message as natural language.

View File

@ -1,11 +1,12 @@
from typing import Dict, Any, Optional, Callable, Union from typing import Optional, Callable, Union
from aiocqhttp import Error as CQHttpError
from aiocqhttp.bus import EventBus from aiocqhttp.bus import EventBus
from . import NoneBot from . import NoneBot
from .exceptions import CQHttpError
from .log import logger from .log import logger
from .session import BaseSession from .session import BaseSession
from .typing import Context_T
_bus = EventBus() _bus = EventBus()
@ -35,14 +36,14 @@ on_request = _make_event_deco('request')
class NoticeSession(BaseSession): class NoticeSession(BaseSession):
__slots__ = () __slots__ = ()
def __init__(self, bot: NoneBot, ctx: Dict[str, Any]): def __init__(self, bot: NoneBot, ctx: Context_T):
super().__init__(bot, ctx) super().__init__(bot, ctx)
class RequestSession(BaseSession): class RequestSession(BaseSession):
__slots__ = () __slots__ = ()
def __init__(self, bot: NoneBot, ctx: Dict[str, Any]): def __init__(self, bot: NoneBot, ctx: Context_T):
super().__init__(bot, ctx) super().__init__(bot, ctx)
async def approve(self, remark: str = '') -> None: async def approve(self, remark: str = '') -> None:
@ -78,7 +79,7 @@ class RequestSession(BaseSession):
pass pass
async def handle_notice_or_request(bot: NoneBot, ctx: Dict[str, Any]) -> None: async def handle_notice_or_request(bot: NoneBot, ctx: Context_T) -> None:
post_type = ctx['post_type'] # "notice" or "request" post_type = ctx['post_type'] # "notice" or "request"
detail_type = ctx[f'{post_type}_type'] detail_type = ctx[f'{post_type}_type']
event = f'{post_type}.{detail_type}' event = f'{post_type}.{detail_type}'
@ -96,9 +97,9 @@ async def handle_notice_or_request(bot: NoneBot, ctx: Dict[str, Any]) -> None:
await _bus.emit(event, session) await _bus.emit(event, session)
def _log_notice(ctx: Dict[str, Any]) -> None: def _log_notice(ctx: Context_T) -> None:
logger.info(f'Notice: {ctx}') logger.info(f'Notice: {ctx}')
def _log_request(ctx: Dict[str, Any]) -> None: def _log_request(ctx: Context_T) -> None:
logger.info(f'Request: {ctx}') logger.info(f'Request: {ctx}')

View File

@ -1,10 +1,10 @@
from collections import namedtuple from collections import namedtuple
from typing import Dict, Any
from aiocache import cached from aiocache import cached
from aiocqhttp import Error as CQHttpError
from . import NoneBot from . import NoneBot
from .exceptions import CQHttpError
from .typing import Context_T
PRIVATE_FRIEND = 0x0001 PRIVATE_FRIEND = 0x0001
PRIVATE_GROUP = 0x0002 PRIVATE_GROUP = 0x0002
@ -45,7 +45,7 @@ _min_context_fields = (
_MinContext = namedtuple('MinContext', _min_context_fields) _MinContext = namedtuple('MinContext', _min_context_fields)
async def check_permission(bot: NoneBot, ctx: Dict[str, Any], async def check_permission(bot: NoneBot, ctx: Context_T,
permission_required: int) -> bool: permission_required: int) -> bool:
""" """
Check if the context has the permission required. Check if the context has the permission required.

View File

@ -1,25 +1,21 @@
from typing import Union, Callable, Dict, Any, List, Sequence
from . import NoneBot from . import NoneBot
from .helpers import send, send_expr from .helpers import send, send_expr
from .typing import Context_T, Message_T, Expression_T
class BaseSession: class BaseSession:
__slots__ = ('bot', 'ctx') __slots__ = ('bot', 'ctx')
def __init__(self, bot: NoneBot, ctx: Dict[str, Any]): def __init__(self, bot: NoneBot, ctx: Context_T):
self.bot = bot self.bot = bot
self.ctx = ctx self.ctx = ctx
async def send(self, async def send(self, message: Message_T, *,
message: Union[str, Dict[str, Any], List[Dict[str, Any]]], ignore_failure: bool = True) -> None:
*, ignore_failure: bool = True) -> None:
"""Send a message ignoring failure by default.""" """Send a message ignoring failure by default."""
return await send(self.bot, self.ctx, message, return await send(self.bot, self.ctx, message,
ignore_failure=ignore_failure) ignore_failure=ignore_failure)
async def send_expr(self, async def send_expr(self, expr: Expression_T, **kwargs):
expr: Union[str, Sequence[str], Callable],
**kwargs):
"""Sending a expression message ignoring failure by default.""" """Sending a expression message ignoring failure by default."""
return await send_expr(self.bot, self.ctx, expr, **kwargs) return await send_expr(self.bot, self.ctx, expr, **kwargs)

7
none/typing.py Normal file
View File

@ -0,0 +1,7 @@
from typing import Union, List, Dict, Any, Sequence, Callable, Tuple
Context_T = Dict[str, Any]
Message_T = Union[str, Dict[str, Any], List[Dict[str, Any]]]
Expression_T = Union[str, Sequence[str], Callable]
CommandName_T = Tuple[str]
CommandArgs_T = Dict[str, Any]