mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-18 09:25:46 +08:00
Fix type hints and small bugs
This commit is contained in:
parent
73e891d521
commit
0046ebacac
@ -5,11 +5,11 @@ import os
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
import aiocqhttp.message
|
||||
from aiocqhttp import CQHttp
|
||||
from aiocqhttp.message import Message
|
||||
|
||||
from .log import logger
|
||||
from .scheduler import Scheduler
|
||||
from .sched import Scheduler
|
||||
|
||||
if Scheduler:
|
||||
scheduler = Scheduler()
|
||||
@ -18,14 +18,14 @@ else:
|
||||
|
||||
|
||||
class NoneBot(CQHttp):
|
||||
def __init__(self, config_object: Any = None):
|
||||
def __init__(self, config_object: Optional[Any] = None):
|
||||
if config_object is None:
|
||||
from . import default_config as config_object
|
||||
|
||||
config_dict = {k: v for k, v in config_object.__dict__.items()
|
||||
if k.isupper() and not k.startswith('_')}
|
||||
logger.debug(f'Loaded configurations: {config_dict}')
|
||||
super().__init__(message_class=Message,
|
||||
super().__init__(message_class=aiocqhttp.message.Message,
|
||||
**{k.lower(): v for k, v in config_dict.items()})
|
||||
|
||||
self.config = config_object
|
||||
@ -46,7 +46,8 @@ class NoneBot(CQHttp):
|
||||
async def _(ctx):
|
||||
asyncio.ensure_future(handle_notice_or_request(self, ctx))
|
||||
|
||||
def run(self, host: str = None, port: int = None, *args, **kwargs):
|
||||
def run(self, host: Optional[str] = None, port: Optional[int] = None,
|
||||
*args, **kwargs) -> None:
|
||||
host = host or self.config.HOST
|
||||
port = port or self.config.PORT
|
||||
if 'debug' not in kwargs:
|
||||
@ -60,7 +61,7 @@ class NoneBot(CQHttp):
|
||||
_bot: Optional[NoneBot] = None
|
||||
|
||||
|
||||
def init(config_object: Any = None) -> None:
|
||||
def init(config_object: Optional[Any] = None) -> None:
|
||||
"""
|
||||
Initialize NoneBot instance.
|
||||
|
||||
@ -97,7 +98,8 @@ def get_bot() -> NoneBot:
|
||||
return _bot
|
||||
|
||||
|
||||
def run(host: str = None, port: int = None, *args, **kwargs) -> None:
|
||||
def run(host: Optional[str] = None, port: Optional[int] = None,
|
||||
*args, **kwargs) -> None:
|
||||
"""Run the NoneBot instance."""
|
||||
get_bot().run(host=host, port=port, *args, **kwargs)
|
||||
|
||||
@ -143,6 +145,7 @@ def load_builtin_plugins() -> None:
|
||||
load_plugins(plugin_dir, 'none.plugins')
|
||||
|
||||
|
||||
from .exceptions import *
|
||||
from .message import message_preprocessor, Message, MessageSegment
|
||||
from .command import on_command, CommandSession, CommandGroup
|
||||
from .natural_language import on_natural_language, NLPSession, NLPResult
|
||||
|
@ -2,15 +2,18 @@ import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Tuple, Union, Callable, Iterable, Dict, Any, Optional, Sequence
|
||||
Tuple, Union, Callable, Iterable, Any, Optional
|
||||
)
|
||||
|
||||
from . import NoneBot, permission as perm
|
||||
from .log import logger
|
||||
from .message import Message
|
||||
from .expression import render
|
||||
from .helpers import context_id, send_expr
|
||||
from .log import logger
|
||||
from .message import Message
|
||||
from .session import BaseSession
|
||||
from .typing import (
|
||||
Context_T, CommandName_T, CommandArgs_T, Expression_T, Message_T
|
||||
)
|
||||
|
||||
# Key: str (one segment of command name)
|
||||
# Value: subtree or a leaf Command object
|
||||
@ -29,8 +32,8 @@ class Command:
|
||||
__slots__ = ('name', 'func', 'permission',
|
||||
'only_to_me', 'privileged', 'args_parser_func')
|
||||
|
||||
def __init__(self, *, name: Tuple[str], func: Callable, permission: int,
|
||||
only_to_me: bool, privileged: bool):
|
||||
def __init__(self, *, name: CommandName_T, func: Callable,
|
||||
permission: int, only_to_me: bool, privileged: bool):
|
||||
self.name = name
|
||||
self.func = func
|
||||
self.permission = permission
|
||||
@ -71,8 +74,8 @@ class Command:
|
||||
self.permission)
|
||||
|
||||
|
||||
def on_command(name: Union[str, Tuple[str]], *,
|
||||
aliases: Iterable = (),
|
||||
def on_command(name: Union[str, CommandName_T], *,
|
||||
aliases: Iterable[str] = (),
|
||||
permission: int = perm.EVERYBODY,
|
||||
only_to_me: bool = True,
|
||||
privileged: bool = False) -> Callable:
|
||||
@ -117,19 +120,22 @@ class CommandGroup:
|
||||
"""
|
||||
Group a set of commands with same name prefix.
|
||||
"""
|
||||
__slots__ = ('basename', 'permission', 'only_to_me')
|
||||
__slots__ = ('basename', 'permission', 'only_to_me', 'privileged')
|
||||
|
||||
def __init__(self, name: Union[str, Tuple[str]],
|
||||
def __init__(self, name: Union[str, CommandName_T],
|
||||
permission: Optional[int] = None, *,
|
||||
only_to_me: Optional[bool] = None):
|
||||
only_to_me: Optional[bool] = None,
|
||||
privileged: Optional[bool] = None):
|
||||
self.basename = (name,) if isinstance(name, str) else name
|
||||
self.permission = permission
|
||||
self.only_to_me = only_to_me
|
||||
self.privileged = privileged
|
||||
|
||||
def command(self, name: Union[str, Tuple[str]], *,
|
||||
aliases: Optional[Iterable] = None,
|
||||
def command(self, name: Union[str, CommandName_T], *,
|
||||
aliases: Optional[Iterable[str]] = None,
|
||||
permission: Optional[int] = None,
|
||||
only_to_me: Optional[bool] = None) -> Callable:
|
||||
only_to_me: Optional[bool] = None,
|
||||
privileged: Optional[bool] = None) -> Callable:
|
||||
sub_name = (name,) if isinstance(name, str) else name
|
||||
name = self.basename + sub_name
|
||||
|
||||
@ -144,10 +150,14 @@ class CommandGroup:
|
||||
kwargs['only_to_me'] = only_to_me
|
||||
elif self.only_to_me is not None:
|
||||
kwargs['only_to_me'] = self.only_to_me
|
||||
if privileged is not None:
|
||||
kwargs['privileged'] = privileged
|
||||
elif self.privileged is not None:
|
||||
kwargs['privileged'] = self.privileged
|
||||
return on_command(name, **kwargs)
|
||||
|
||||
|
||||
def _find_command(name: Union[str, Tuple[str]]) -> Optional[Command]:
|
||||
def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]:
|
||||
cmd_name = (name,) if isinstance(name, str) else name
|
||||
if not cmd_name:
|
||||
return None
|
||||
@ -204,8 +214,8 @@ class CommandSession(BaseSession):
|
||||
__slots__ = ('cmd', 'current_key', 'current_arg', 'current_arg_text',
|
||||
'current_arg_images', 'args', '_last_interaction', '_running')
|
||||
|
||||
def __init__(self, bot: NoneBot, ctx: Dict[str, Any], cmd: Command, *,
|
||||
current_arg: str = '', args: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, bot: NoneBot, ctx: Context_T, cmd: Command, *,
|
||||
current_arg: str = '', args: Optional[CommandArgs_T] = None):
|
||||
super().__init__(bot, ctx)
|
||||
self.cmd = cmd # Command object
|
||||
self.current_key = None # current key that the command handler needs
|
||||
@ -218,21 +228,21 @@ class CommandSession(BaseSession):
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
def running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
@running.setter
|
||||
def running(self, value):
|
||||
def running(self, value) -> None:
|
||||
if self._running is True and value is False:
|
||||
# change status from running to not running, record the time
|
||||
self._last_interaction = datetime.now()
|
||||
self._running = value
|
||||
|
||||
@property
|
||||
def is_first_run(self):
|
||||
def is_first_run(self) -> bool:
|
||||
return self._last_interaction is None
|
||||
|
||||
def refresh(self, ctx: Dict[str, Any], *, current_arg: str = '') -> None:
|
||||
def refresh(self, ctx: Context_T, *, current_arg: str = '') -> None:
|
||||
"""
|
||||
Refill the session with a new message context.
|
||||
|
||||
@ -256,8 +266,9 @@ class CommandSession(BaseSession):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(self, key: str, *, prompt: str = None,
|
||||
prompt_expr: Union[str, Sequence[str], Callable] = None) -> Any:
|
||||
def get(self, key: Any, *,
|
||||
prompt: Optional[Message_T] = None,
|
||||
prompt_expr: Optional[Expression_T] = None) -> Any:
|
||||
"""
|
||||
Get an argument with a given key.
|
||||
|
||||
@ -270,7 +281,6 @@ class CommandSession(BaseSession):
|
||||
:param prompt: prompt to ask the user
|
||||
:param prompt_expr: prompt expression to ask the user
|
||||
:return: the argument value
|
||||
:raise FurtherInteractionNeeded: further interaction is needed
|
||||
"""
|
||||
value = self.get_optional(key)
|
||||
if value is not None:
|
||||
@ -282,25 +292,24 @@ class CommandSession(BaseSession):
|
||||
prompt = render(prompt_expr, key=key)
|
||||
self.pause(prompt)
|
||||
|
||||
def get_optional(self, key: str,
|
||||
def get_optional(self, key: Any,
|
||||
default: Optional[Any] = None) -> Optional[Any]:
|
||||
"""Simply get a argument with given key."""
|
||||
return self.args.get(key, default)
|
||||
|
||||
def pause(self, message=None) -> None:
|
||||
def pause(self, message: Optional[Message_T] = None) -> None:
|
||||
"""Pause the session for further interaction."""
|
||||
if message:
|
||||
asyncio.ensure_future(self.send(message))
|
||||
raise _FurtherInteractionNeeded
|
||||
|
||||
def finish(self, message=None) -> None:
|
||||
def finish(self, message: Optional[Message_T] = None) -> None:
|
||||
"""Finish the session."""
|
||||
if message:
|
||||
asyncio.ensure_future(self.send(message))
|
||||
raise _FinishException
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def switch(self, new_ctx_message: Any) -> None:
|
||||
def switch(self, new_ctx_message: Message_T) -> None:
|
||||
"""
|
||||
Finish the session and switch to a new (fake) message context.
|
||||
|
||||
@ -392,7 +401,7 @@ def parse_command(bot: NoneBot,
|
||||
return cmd, ''.join(cmd_remained)
|
||||
|
||||
|
||||
async def handle_command(bot: NoneBot, ctx: Dict[str, Any]) -> bool:
|
||||
async def handle_command(bot: NoneBot, ctx: Context_T) -> bool:
|
||||
"""
|
||||
Handle a message as a command.
|
||||
|
||||
@ -456,10 +465,10 @@ async def handle_command(bot: NoneBot, ctx: Dict[str, Any]) -> bool:
|
||||
disable_interaction=disable_interaction)
|
||||
|
||||
|
||||
async def call_command(bot: NoneBot, ctx: Dict[str, Any],
|
||||
name: Union[str, Tuple[str]], *,
|
||||
async def call_command(bot: NoneBot, ctx: Context_T,
|
||||
name: Union[str, CommandName_T], *,
|
||||
current_arg: str = '',
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
args: Optional[CommandArgs_T] = None,
|
||||
check_perm: bool = True,
|
||||
disable_interaction: bool = False) -> bool:
|
||||
"""
|
||||
@ -543,7 +552,7 @@ async def _real_run_command(session: CommandSession,
|
||||
raise e # this is intended to be propagated to handle_message()
|
||||
|
||||
|
||||
def kill_current_session(bot: NoneBot, ctx: Dict[str, Any]) -> None:
|
||||
def kill_current_session(bot: NoneBot, ctx: Context_T) -> None:
|
||||
"""
|
||||
Force kill current session of the given context,
|
||||
despite whether it is running or not.
|
||||
|
1
none/exceptions.py
Normal file
1
none/exceptions.py
Normal file
@ -0,0 +1 @@
|
||||
from aiocqhttp import Error as CQHttpError
|
@ -1,13 +1,11 @@
|
||||
import random
|
||||
from typing import Union, Sequence, Callable
|
||||
from typing import Sequence, Callable
|
||||
|
||||
from .message import escape
|
||||
|
||||
Expression_T = Union[str, Sequence[str], Callable]
|
||||
from .typing import Expression_T
|
||||
|
||||
|
||||
def render(expr: Expression_T, *, escape_args=True,
|
||||
**kwargs) -> str:
|
||||
def render(expr: Expression_T, *, escape_args: bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Render an expression to message string.
|
||||
|
||||
|
@ -1,12 +1,11 @@
|
||||
import hashlib
|
||||
from typing import Dict, Any, Union, List, Sequence, Callable
|
||||
|
||||
from aiocqhttp import Error as CQHttpError
|
||||
|
||||
from . import NoneBot, expression
|
||||
from .exceptions import CQHttpError
|
||||
from .typing import Context_T, Message_T, Expression_T
|
||||
|
||||
|
||||
def context_id(ctx: Dict[str, Any], *,
|
||||
def context_id(ctx: Context_T, *,
|
||||
mode: str = 'default', use_hash: bool = False) -> str:
|
||||
"""
|
||||
Calculate a unique id representing the current context.
|
||||
@ -42,9 +41,8 @@ def context_id(ctx: Dict[str, Any], *,
|
||||
return ctx_id
|
||||
|
||||
|
||||
async def send(bot: NoneBot, ctx: Dict[str, Any],
|
||||
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
|
||||
*, ignore_failure: bool = True) -> None:
|
||||
async def send(bot: NoneBot, ctx: Context_T, message: Message_T, *,
|
||||
ignore_failure: bool = True) -> None:
|
||||
"""Send a message ignoring failure by default."""
|
||||
try:
|
||||
if ctx.get('post_type') == 'message':
|
||||
@ -64,8 +62,7 @@ async def send(bot: NoneBot, ctx: Dict[str, Any],
|
||||
raise
|
||||
|
||||
|
||||
async def send_expr(bot: NoneBot, ctx: Dict[str, Any],
|
||||
expr: Union[str, Sequence[str], Callable],
|
||||
**kwargs):
|
||||
async def send_expr(bot: NoneBot, ctx: Context_T,
|
||||
expr: Expression_T, **kwargs):
|
||||
"""Sending a expression message ignoring failure by default."""
|
||||
return await send(bot, ctx, expression.render(expr, **kwargs))
|
||||
|
@ -7,6 +7,7 @@ from . import NoneBot
|
||||
from .command import handle_command, SwitchException
|
||||
from .log import logger
|
||||
from .natural_language import handle_natural_language
|
||||
from .typing import Context_T
|
||||
|
||||
_message_preprocessors = set()
|
||||
|
||||
@ -16,12 +17,12 @@ def message_preprocessor(func: Callable) -> Callable:
|
||||
return func
|
||||
|
||||
|
||||
async def handle_message(bot: NoneBot, ctx: Dict[str, Any]) -> None:
|
||||
async def handle_message(bot: NoneBot, ctx: Context_T) -> None:
|
||||
_log_message(ctx)
|
||||
|
||||
coros = []
|
||||
for processor in _message_preprocessors:
|
||||
coros.append(processor(ctx))
|
||||
coros.append(processor(bot, ctx))
|
||||
if coros:
|
||||
await asyncio.wait(coros)
|
||||
|
||||
@ -56,7 +57,7 @@ async def handle_message(bot: NoneBot, ctx: Dict[str, Any]) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _log_message(ctx: Dict[str, Any]) -> None:
|
||||
def _log_message(ctx: Context_T) -> None:
|
||||
msg_from = f'{ctx["user_id"]}'
|
||||
if ctx['message_type'] == 'group':
|
||||
msg_from += f'@[群:{ctx["group_id"]}]'
|
||||
|
@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Any, Iterable, Optional, Callable, Union
|
||||
from typing import Iterable, Optional, Callable, Union, NamedTuple
|
||||
|
||||
from . import NoneBot, permission as perm
|
||||
from .command import call_command
|
||||
from .log import logger
|
||||
from .message import Message
|
||||
from .session import BaseSession
|
||||
from .typing import Context_T, CommandName_T, CommandArgs_T
|
||||
|
||||
_nl_processors = set()
|
||||
|
||||
@ -56,7 +56,7 @@ def on_natural_language(keywords: Union[Optional[Iterable], Callable] = None,
|
||||
class NLPSession(BaseSession):
|
||||
__slots__ = ('msg', 'msg_text', 'msg_images')
|
||||
|
||||
def __init__(self, bot: NoneBot, ctx: Dict[str, Any], msg: str):
|
||||
def __init__(self, bot: NoneBot, ctx: Context_T, msg: str):
|
||||
super().__init__(bot, ctx)
|
||||
self.msg = msg
|
||||
tmp_msg = Message(msg)
|
||||
@ -65,14 +65,13 @@ class NLPSession(BaseSession):
|
||||
if s.type == 'image' and 'url' in s.data]
|
||||
|
||||
|
||||
NLPResult = namedtuple('NLPResult', (
|
||||
'confidence',
|
||||
'cmd_name',
|
||||
'cmd_args',
|
||||
))
|
||||
class NLPResult(NamedTuple):
|
||||
confidence: float
|
||||
cmd_name: Union[str, CommandName_T]
|
||||
cmd_args: Optional[CommandArgs_T] = None
|
||||
|
||||
|
||||
async def handle_natural_language(bot: NoneBot, ctx: Dict[str, Any]) -> bool:
|
||||
async def handle_natural_language(bot: NoneBot, ctx: Context_T) -> bool:
|
||||
"""
|
||||
Handle a message as natural language.
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Dict, Any, Optional, Callable, Union
|
||||
from typing import Optional, Callable, Union
|
||||
|
||||
from aiocqhttp import Error as CQHttpError
|
||||
from aiocqhttp.bus import EventBus
|
||||
|
||||
from . import NoneBot
|
||||
from .exceptions import CQHttpError
|
||||
from .log import logger
|
||||
from .session import BaseSession
|
||||
from .typing import Context_T
|
||||
|
||||
_bus = EventBus()
|
||||
|
||||
@ -35,14 +36,14 @@ on_request = _make_event_deco('request')
|
||||
class NoticeSession(BaseSession):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, bot: NoneBot, ctx: Dict[str, Any]):
|
||||
def __init__(self, bot: NoneBot, ctx: Context_T):
|
||||
super().__init__(bot, ctx)
|
||||
|
||||
|
||||
class RequestSession(BaseSession):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, bot: NoneBot, ctx: Dict[str, Any]):
|
||||
def __init__(self, bot: NoneBot, ctx: Context_T):
|
||||
super().__init__(bot, ctx)
|
||||
|
||||
async def approve(self, remark: str = '') -> None:
|
||||
@ -78,7 +79,7 @@ class RequestSession(BaseSession):
|
||||
pass
|
||||
|
||||
|
||||
async def handle_notice_or_request(bot: NoneBot, ctx: Dict[str, Any]) -> None:
|
||||
async def handle_notice_or_request(bot: NoneBot, ctx: Context_T) -> None:
|
||||
post_type = ctx['post_type'] # "notice" or "request"
|
||||
detail_type = ctx[f'{post_type}_type']
|
||||
event = f'{post_type}.{detail_type}'
|
||||
@ -96,9 +97,9 @@ async def handle_notice_or_request(bot: NoneBot, ctx: Dict[str, Any]) -> None:
|
||||
await _bus.emit(event, session)
|
||||
|
||||
|
||||
def _log_notice(ctx: Dict[str, Any]) -> None:
|
||||
def _log_notice(ctx: Context_T) -> None:
|
||||
logger.info(f'Notice: {ctx}')
|
||||
|
||||
|
||||
def _log_request(ctx: Dict[str, Any]) -> None:
|
||||
def _log_request(ctx: Context_T) -> None:
|
||||
logger.info(f'Request: {ctx}')
|
||||
|
@ -1,10 +1,10 @@
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Any
|
||||
|
||||
from aiocache import cached
|
||||
from aiocqhttp import Error as CQHttpError
|
||||
|
||||
from . import NoneBot
|
||||
from .exceptions import CQHttpError
|
||||
from .typing import Context_T
|
||||
|
||||
PRIVATE_FRIEND = 0x0001
|
||||
PRIVATE_GROUP = 0x0002
|
||||
@ -45,7 +45,7 @@ _min_context_fields = (
|
||||
_MinContext = namedtuple('MinContext', _min_context_fields)
|
||||
|
||||
|
||||
async def check_permission(bot: NoneBot, ctx: Dict[str, Any],
|
||||
async def check_permission(bot: NoneBot, ctx: Context_T,
|
||||
permission_required: int) -> bool:
|
||||
"""
|
||||
Check if the context has the permission required.
|
||||
|
@ -1,25 +1,21 @@
|
||||
from typing import Union, Callable, Dict, Any, List, Sequence
|
||||
|
||||
from . import NoneBot
|
||||
from .helpers import send, send_expr
|
||||
from .typing import Context_T, Message_T, Expression_T
|
||||
|
||||
|
||||
class BaseSession:
|
||||
__slots__ = ('bot', 'ctx')
|
||||
|
||||
def __init__(self, bot: NoneBot, ctx: Dict[str, Any]):
|
||||
def __init__(self, bot: NoneBot, ctx: Context_T):
|
||||
self.bot = bot
|
||||
self.ctx = ctx
|
||||
|
||||
async def send(self,
|
||||
message: Union[str, Dict[str, Any], List[Dict[str, Any]]],
|
||||
*, ignore_failure: bool = True) -> None:
|
||||
async def send(self, message: Message_T, *,
|
||||
ignore_failure: bool = True) -> None:
|
||||
"""Send a message ignoring failure by default."""
|
||||
return await send(self.bot, self.ctx, message,
|
||||
ignore_failure=ignore_failure)
|
||||
|
||||
async def send_expr(self,
|
||||
expr: Union[str, Sequence[str], Callable],
|
||||
**kwargs):
|
||||
async def send_expr(self, expr: Expression_T, **kwargs):
|
||||
"""Sending a expression message ignoring failure by default."""
|
||||
return await send_expr(self.bot, self.ctx, expr, **kwargs)
|
||||
|
7
none/typing.py
Normal file
7
none/typing.py
Normal file
@ -0,0 +1,7 @@
|
||||
from typing import Union, List, Dict, Any, Sequence, Callable, Tuple
|
||||
|
||||
Context_T = Dict[str, Any]
|
||||
Message_T = Union[str, Dict[str, Any], List[Dict[str, Any]]]
|
||||
Expression_T = Union[str, Sequence[str], Callable]
|
||||
CommandName_T = Tuple[str]
|
||||
CommandArgs_T = Dict[str, Any]
|
Loading…
Reference in New Issue
Block a user