diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 14b43dd1..0c579600 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -124,9 +124,9 @@ def on_websocket_connect(func: Callable[[aiocqhttp.Event], Awaitable[None]]) \ from .exceptions import * +from .message import message_preprocessor, Message, MessageSegment from .plugin import (load_plugin, load_plugins, load_builtin_plugins, get_loaded_plugins) -from .message import message_preprocessor, Message, MessageSegment from .command import on_command, CommandSession, CommandGroup from .natural_language import (on_natural_language, NLPSession, NLPResult, IntentCommand) diff --git a/nonebot/command/__init__.py b/nonebot/command/__init__.py index 29f042b2..013919a3 100644 --- a/nonebot/command/__init__.py +++ b/nonebot/command/__init__.py @@ -1,13 +1,11 @@ -import asyncio import re import shlex +import asyncio import warnings from datetime import datetime from functools import partial -from typing import ( - Tuple, Union, Callable, Iterable, Any, Optional, List, Dict, - Awaitable -) +from typing import (Tuple, Union, Callable, Iterable, Any, Optional, List, Dict, + Awaitable) from aiocqhttp import Event as CQEvent @@ -17,38 +15,22 @@ from nonebot.helpers import context_id, send, render_expression from nonebot.log import logger from nonebot.message import Message from nonebot.session import BaseSession -from nonebot.typing import ( - CommandName_T, CommandArgs_T, Message_T, State_T, Filter_T -) - -# key: one segment of command name -# value: subtree or a leaf Command object -_registry = {} # type: Dict[str, Union[Dict, Command]] - -# key: alias -# value: real command name -_aliases = {} # type: Dict[str, CommandName_T] +from nonebot.typing import (CommandName_T, CommandArgs_T, Message_T, State_T, + Filter_T) # key: context id # value: CommandSession object -_sessions = {} # type: Dict[str, CommandSession] +_sessions = {} # type: Dict[str, "CommandSession"] CommandHandler_T = Callable[['CommandSession'], Any] class Command: - __slots__ = ('name', 'func', - 'permission', - 'only_to_me', - 'privileged', + __slots__ = ('name', 'func', 'permission', 'only_to_me', 'privileged', 'args_parser_func') - def __init__(self, *, - name: CommandName_T, - func: CommandHandler_T, - permission: int, - only_to_me: bool, - privileged: bool): + def __init__(self, *, name: CommandName_T, func: CommandHandler_T, + permission: int, only_to_me: bool, privileged: bool): self.name = name self.func = func self.permission = permission @@ -56,8 +38,7 @@ class Command: self.privileged = privileged self.args_parser_func: Optional[CommandHandler_T] = None - async def run(self, session, *, - check_perm: bool = True, + async def run(self, session, *, check_perm: bool = True, dry: bool = False) -> bool: """ Run the command in a given session. @@ -94,15 +75,16 @@ class Command: if session.state['__validation_failure_num'] >= \ config.MAX_VALIDATION_FAILURES: # noinspection PyProtectedMember - session.finish(render_expression( - config.TOO_MANY_VALIDATION_FAILURES_EXPRESSION - ), **session._current_send_kwargs) + session.finish( + render_expression( + config. + TOO_MANY_VALIDATION_FAILURES_EXPRESSION + ), **session._current_send_kwargs) failure_message = e.message if failure_message is None: failure_message = render_expression( - config.DEFAULT_VALIDATION_FAILURE_EXPRESSION - ) + config.DEFAULT_VALIDATION_FAILURE_EXPRESSION) # noinspection PyProtectedMember session.pause(failure_message, **session._current_send_kwargs) @@ -133,6 +115,14 @@ class Command: return await perm.check_permission(session.bot, session.event, self.permission) + def args_parser(self, parser_func: CommandHandler_T) -> CommandHandler_T: + """ + Decorator to register a function as the arguments parser of + the corresponding command. + """ + self.args_parser_func = parser_func + return parser_func + def __repr__(self): return f'' @@ -140,7 +130,238 @@ class Command: return self.__repr__() -def on_command(name: Union[str, CommandName_T], *, +class CommandManager: + """Global Command Manager""" + _commands = {} # type: Dict[CommandName_T, Command] + _aliases = {} # type: Dict[str, Command] + _switches = {} # type: Dict[CommandName_T, bool] + + def __init__(self): + self.commands = CommandManager._commands.copy() + self.aliases = CommandManager._aliases.copy() + self.switches = CommandManager._switches.copy() + + @classmethod + def add_command(cls, cmd_name: CommandName_T, cmd: Command) -> None: + """Register a command + + Args: + cmd_name (CommandName_T): Command name + cmd (Command): Command object + """ + if cmd_name in cls._commands: + warnings.warn(f"Command {cmd_name} already exists") + return + cls._switches[cmd_name] = True + cls._commands[cmd_name] = cmd + + @classmethod + def reload_command(cls, cmd_name: CommandName_T, cmd: Command) -> None: + """Reload a command + + **Warning! Dangerous function** + + Args: + cmd_name (CommandName_T): Command name + cmd (Command): Command object + """ + if cmd_name not in cls._commands: + warnings.warn( + f"Command {cmd_name} does not exist. Please use add_command instead" + ) + return + cls._commands[cmd_name] = cmd + + @classmethod + def remove_command(cls, cmd_name: CommandName_T) -> bool: + """Remove a command + + **Warning! Dangerous function** + + Args: + cmd_name (CommandName_T): Command name to remove + + Returns: + bool: Success or not + """ + if cmd_name in cls._commands: + del cls._commands[cmd_name] + if cmd_name in cls._switches: + del cls._switches[cmd_name] + return True + return False + + @classmethod + def switch_command_global(cls, + cmd_name: CommandName_T, + state: Optional[bool] = None): + """Change command state globally or simply switch it if `state` is None + + Args: + cmd_name (CommandName_T): Command name + state (Optional[bool]): State to change to. Defaults to None. + """ + cls._switches[cmd_name] = not cls._switches[ + cmd_name] if state is None else bool(state) + + @classmethod + def add_aliases(cls, aliases: Union[Iterable[str], str], cmd: Command): + """Register command alias(es) + + Args: + aliases (Union[Iterable[str], str]): Command aliases + cmd_name (Command): Command + """ + if isinstance(aliases, str): + aliases = (aliases,) + for alias in aliases: + if not isinstance(alias, str): + warnings.warn(f"Alias {alias} is not a string! Ignored") + return + elif alias in cls._aliases: + warnings.warn(f"Alias {alias} already exists") + return + cls._aliases[alias] = cmd + + def _add_command_to_tree(self, cmd_name: CommandName_T, cmd: Command, + tree: Dict[str, Union[Dict, Command]]) -> None: + """Add command to the target command tree. + + Args: + cmd_name (CommandName_T): Name of the command + cmd (Command): Command object + tree (Dict[str, Union[Dict, Command]): Target command tree + """ + current_parent = tree + for parent_key in cmd_name[:-1]: + current_parent[parent_key] = current_parent.get(parent_key) or {} + current_parent = current_parent[parent_key] + # TODO: 支持test test.sub子命令 + if not isinstance(current_parent, dict): + warnings.warn(f"{current_parent} is not a registry dict") + return + if cmd_name[-1] in current_parent: + warnings.warn(f"There is already a command named {cmd_name}") + return + current_parent[cmd_name[-1]] = cmd + + def _generate_command_tree(self, commands: Dict[CommandName_T, Command] + ) -> Dict[str, Union[Dict, Command]]: + """Generate command tree from commands dictionary. + + Args: + commands (Dict[CommandName_T, Command]): Dictionary of commands + + Returns: + Dict[str, Union[Dict, "Command"]]: Command tree + """ + cmd_tree = {} #type: Dict[str, Union[Dict, "Command"]] + for cmd_name, cmd in commands.items(): + self._add_command_to_tree(cmd_name, cmd, cmd_tree) + return cmd_tree + + def _find_command(self, + name: Union[str, CommandName_T]) -> Optional[Command]: + cmd_name = (name,) if isinstance(name, str) else name + if not cmd_name: + return None + + cmd_tree = self._generate_command_tree({ + name: cmd + for name, cmd in self.commands.items() + if self.switches.get(name, True) + }) + for part in cmd_name[:-1]: + if part not in cmd_tree or not isinstance( + cmd_tree[part], #type: ignore + dict): + return None + cmd_tree = cmd_tree[part] # type: ignore + + cmd = cmd_tree.get(cmd_name[-1]) # type: ignore + return cmd if isinstance(cmd, Command) else None + + def parse_command(self, bot: NoneBot, cmd_string: str + ) -> Tuple[Optional[Command], Optional[str]]: + logger.debug(f'Parsing command: {repr(cmd_string)}') + + matched_start = None + for start in bot.config.COMMAND_START: + # loop through COMMAND_START to find the longest matched start + curr_matched_start = None + if isinstance(start, type(re.compile(''))): + m = start.search(cmd_string) + if m and m.start(0) == 0: + curr_matched_start = m.group(0) + elif isinstance(start, str): + if cmd_string.startswith(start): + curr_matched_start = start + + if curr_matched_start is not None and \ + (matched_start is None or + len(curr_matched_start) > len(matched_start)): + # a longer start, use it + matched_start = curr_matched_start + + if matched_start is None: + # it's not a command + logger.debug('It\'s not a command') + return None, None + + logger.debug(f'Matched command start: ' + f'{matched_start}{"(empty)" if not matched_start else ""}') + full_command = cmd_string[len(matched_start):].lstrip() + + if not full_command: + # command is empty + return None, None + + cmd_name_text, *cmd_remained = full_command.split(maxsplit=1) + + cmd_name = None + for sep in bot.config.COMMAND_SEP: + # loop through COMMAND_SEP to find the most optimized split + curr_cmd_name = None + if isinstance(sep, type(re.compile(''))): + curr_cmd_name = tuple(sep.split(cmd_name_text)) + elif isinstance(sep, str): + curr_cmd_name = tuple(cmd_name_text.split(sep)) + + if curr_cmd_name is not None and \ + (not cmd_name or len(curr_cmd_name) > len(cmd_name)): + # a more optimized split, use it + cmd_name = curr_cmd_name + + if not cmd_name: + cmd_name = (cmd_name_text,) + logger.debug(f'Split command name: {cmd_name}') + + cmd = self._find_command(cmd_name) # type: ignore + if not cmd: + logger.debug(f'Command {cmd_name} not found. Try to match aliases') + cmd = self.aliases.get(cmd_name_text) + + if not cmd: + return None, None + + logger.debug(f'Command {cmd.name} found, function: {cmd.func}') + return cmd, ''.join(cmd_remained) + + def switch_command(self, + cmd_name: CommandName_T, + state: Optional[bool] = None): + """Change command state or simply switch it if `state` is None + + Args: + cmd_name (CommandName_T): Command name + state (Optional[bool]): State to change to. Defaults to None. + """ + self.switches[cmd_name] = not self.switches[ + cmd_name] if state is None else bool(state) + + +def on_command(name: Union[str, CommandName_T], + *, aliases: Union[Iterable[str], str] = (), permission: int = perm.EVERYBODY, only_to_me: bool = True, @@ -157,7 +378,7 @@ def on_command(name: Union[str, CommandName_T], *, :param shell_like: use shell-like syntax to split arguments """ - def deco(func: CommandHandler_T) -> CommandHandler_T: + def deco(func: CommandHandler_T) -> Command: if not isinstance(name, (str, tuple)): raise TypeError('the name of a command must be a str or tuple') if not name: @@ -165,63 +386,27 @@ def on_command(name: Union[str, CommandName_T], *, cmd_name = (name,) if isinstance(name, str) else name - cmd = Command(name=cmd_name, func=func, permission=permission, - only_to_me=only_to_me, privileged=privileged) - - def args_parser(parser_func: CommandHandler_T) -> CommandHandler_T: - """ - Decorator to register a function as the arguments parser of - the corresponding command. - """ - cmd.args_parser_func = parser_func - return parser_func - - func.args_parser = args_parser + cmd = Command(name=cmd_name, + func=func, + permission=permission, + only_to_me=only_to_me, + privileged=privileged) if shell_like: + async def shell_like_args_parser(session): session.args['argv'] = shlex.split(session.current_arg) cmd.args_parser_func = shell_like_args_parser - current_parent = _registry - for parent_key in cmd_name[:-1]: - current_parent[parent_key] = current_parent.get(parent_key) or {} - current_parent = current_parent[parent_key] - if not isinstance(current_parent, dict): - warnings.warn(f'{current_parent} is not a registry dict') - return func - if cmd_name[-1] in current_parent: - warnings.warn(f'There is already a command named {cmd_name}') - return func - current_parent[cmd_name[-1]] = cmd + CommandManager.add_command(cmd_name, cmd) + CommandManager.add_aliases(aliases, cmd) - nonlocal aliases - if isinstance(aliases, str): - aliases = (aliases,) - for alias in aliases: - _aliases[alias] = cmd_name - - return func + return cmd return deco -def _find_command(name: Union[str, CommandName_T]) -> Optional[Command]: - cmd_name = (name,) if isinstance(name, str) else name - if not cmd_name: - return None - - cmd_tree = _registry - for part in cmd_name[:-1]: - if part not in cmd_tree or not isinstance(cmd_tree[part], dict): - return None - cmd_tree = cmd_tree[part] - - cmd = cmd_tree.get(cmd_name[-1]) - return cmd if isinstance(cmd, Command) else None - - class _PauseException(Exception): """ Raised by session.pause() indicating that the command session @@ -262,13 +447,18 @@ class SwitchException(Exception): class CommandSession(BaseSession): - __slots__ = ('cmd', - 'current_key', 'current_arg_filters', '_current_send_kwargs', - 'current_arg', '_current_arg_text', '_current_arg_images', - '_state', '_last_interaction', '_running', '_run_future') + __slots__ = ('cmd', 'current_key', 'current_arg_filters', + '_current_send_kwargs', 'current_arg', '_current_arg_text', + '_current_arg_images', '_state', '_last_interaction', + '_running', '_run_future') - def __init__(self, bot: NoneBot, event: CQEvent, cmd: Command, *, - current_arg: str = '', args: Optional[CommandArgs_T] = None): + def __init__(self, + bot: NoneBot, + event: CQEvent, + cmd: Command, + *, + current_arg: Optional[str] = '', + args: Optional[CommandArgs_T] = None): super().__init__(bot, event) self.cmd = cmd # Command object @@ -281,7 +471,7 @@ class CommandSession(BaseSession): self._current_send_kwargs: Dict[str, Any] = {} # initialize current argument - self.current_arg: str = '' # with potential CQ codes + self.current_arg: Optional[str] = '' # with potential CQ codes self._current_arg_text = None self._current_arg_images = None self.refresh(event, current_arg=current_arg) # fill the above @@ -353,7 +543,8 @@ class CommandSession(BaseSession): """ if self._current_arg_images is None: self._current_arg_images = [ - s.data['url'] for s in Message(self.current_arg) + s.data['url'] + for s in Message(self.current_arg) if s.type == 'image' and 'url' in s.data ] return self._current_arg_images @@ -366,7 +557,8 @@ class CommandSession(BaseSession): """ return self.state.get('argv', []) - def refresh(self, event: CQEvent, *, current_arg: str = '') -> None: + def refresh(self, event: CQEvent, *, + current_arg: Optional[str] = '') -> None: """ Refill the session with a new message event. @@ -378,7 +570,9 @@ class CommandSession(BaseSession): self._current_arg_text = None self._current_arg_images = None - def get(self, key: str, *, + def get(self, + key: str, + *, prompt: Optional[Message_T] = None, arg_filters: Optional[List[Filter_T]] = None, **kwargs) -> Any: @@ -444,79 +638,8 @@ class CommandSession(BaseSession): raise SwitchException(new_message) -def parse_command(bot: NoneBot, - cmd_string: str) -> Tuple[Optional[Command], Optional[str]]: - """ - Parse a command string (typically from a message). - - :param bot: NoneBot instance - :param cmd_string: command string - :return: (Command object, current arg string) - """ - logger.debug(f'Parsing command: {repr(cmd_string)}') - - matched_start = None - for start in bot.config.COMMAND_START: - # loop through COMMAND_START to find the longest matched start - curr_matched_start = None - if isinstance(start, type(re.compile(''))): - m = start.search(cmd_string) - if m and m.start(0) == 0: - curr_matched_start = m.group(0) - elif isinstance(start, str): - if cmd_string.startswith(start): - curr_matched_start = start - - if curr_matched_start is not None and \ - (matched_start is None or - len(curr_matched_start) > len(matched_start)): - # a longer start, use it - matched_start = curr_matched_start - - if matched_start is None: - # it's not a command - logger.debug('It\'s not a command') - return None, None - - logger.debug(f'Matched command start: ' - f'{matched_start}{"(empty)" if not matched_start else ""}') - full_command = cmd_string[len(matched_start):].lstrip() - - if not full_command: - # command is empty - return None, None - - cmd_name_text, *cmd_remained = full_command.split(maxsplit=1) - cmd_name = _aliases.get(cmd_name_text) - - if not cmd_name: - for sep in bot.config.COMMAND_SEP: - # loop through COMMAND_SEP to find the most optimized split - curr_cmd_name = None - if isinstance(sep, type(re.compile(''))): - curr_cmd_name = tuple(sep.split(cmd_name_text)) - elif isinstance(sep, str): - curr_cmd_name = tuple(cmd_name_text.split(sep)) - - if curr_cmd_name is not None and \ - (not cmd_name or len(curr_cmd_name) > len(cmd_name)): - # a more optimized split, use it - cmd_name = curr_cmd_name - - if not cmd_name: - cmd_name = (cmd_name_text,) - - logger.debug(f'Split command name: {cmd_name}') - cmd = _find_command(cmd_name) - if not cmd: - logger.debug(f'Command {cmd_name} not found') - return None, None - - logger.debug(f'Command {cmd.name} found, function: {cmd.func}') - return cmd, ''.join(cmd_remained) - - -async def handle_command(bot: NoneBot, event: CQEvent) -> bool: +async def handle_command(bot: NoneBot, event: CQEvent, + manager: CommandManager) -> Optional[bool]: """ Handle a message as a command. @@ -524,13 +647,14 @@ async def handle_command(bot: NoneBot, event: CQEvent) -> bool: :param bot: NoneBot instance :param event: message event + :param manager: command manager :return: the message is handled as a command """ - cmd, current_arg = parse_command(bot, str(event.message).lstrip()) + cmd, current_arg = manager.parse_command(bot, str(event.message).lstrip()) is_privileged_cmd = cmd and cmd.privileged if is_privileged_cmd and cmd.only_to_me and not event['to_me']: is_privileged_cmd = False - disable_interaction = is_privileged_cmd + disable_interaction = bool(is_privileged_cmd) if is_privileged_cmd: logger.debug(f'Command {cmd.name} is a privileged command') @@ -551,10 +675,9 @@ async def handle_command(bot: NoneBot, event: CQEvent) -> bool: if session.running: logger.warning(f'There is a session of command ' f'{session.cmd.name} running, notify the user') - asyncio.ensure_future(send( - bot, event, - render_expression(bot.config.SESSION_RUNNING_EXPRESSION) - )) + asyncio.ensure_future( + send(bot, event, + render_expression(bot.config.SESSION_RUNNING_EXPRESSION))) # pretend we are successful, so that NLP won't handle it return True @@ -582,16 +705,20 @@ async def handle_command(bot: NoneBot, event: CQEvent) -> bool: session = CommandSession(bot, event, cmd, current_arg=current_arg) logger.debug(f'New session of command {session.cmd.name} created') - return await _real_run_command(session, ctx_id, check_perm=check_perm, + return await _real_run_command(session, + ctx_id, + check_perm=check_perm, disable_interaction=disable_interaction) -async def call_command(bot: NoneBot, event: CQEvent, - name: Union[str, CommandName_T], *, +async def call_command(bot: NoneBot, + event: CQEvent, + name: Union[str, CommandName_T], + *, current_arg: str = '', args: Optional[CommandArgs_T] = None, check_perm: bool = True, - disable_interaction: bool = False) -> bool: + disable_interaction: bool = False) -> Optional[bool]: """ Call a command internally. @@ -612,22 +739,24 @@ async def call_command(bot: NoneBot, event: CQEvent, :param disable_interaction: disable the command's further interaction :return: the command is successfully called """ - cmd = _find_command(name) + cmd = CommandManager()._find_command(name) if not cmd: return False - session = CommandSession(bot, event, cmd, - current_arg=current_arg, args=args) - return await _real_run_command( - session, context_id(session.event), - check_perm=check_perm, - disable_interaction=disable_interaction - ) + session = CommandSession(bot, + event, + cmd, + current_arg=current_arg, + args=args) + return await _real_run_command(session, + context_id(session.event), + check_perm=check_perm, + disable_interaction=disable_interaction) async def _real_run_command(session: CommandSession, ctx_id: str, disable_interaction: bool = False, - **kwargs) -> bool: + **kwargs) -> Optional[bool]: if not disable_interaction: # override session only when interaction is not disabled _sessions[ctx_id] = session diff --git a/nonebot/message.py b/nonebot/message.py index 94c43ee3..a5898925 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -1,13 +1,15 @@ +import re import asyncio -from typing import Callable +from typing import Callable, Iterable from aiocqhttp import Event as CQEvent -from aiocqhttp.message import * +from aiocqhttp.message import escape, unescape, Message, MessageSegment from . import NoneBot -from .command import handle_command, SwitchException from .log import logger from .natural_language import handle_natural_language +from .command import handle_command, SwitchException +from .plugin import PluginManager _message_preprocessors = set() @@ -22,11 +24,12 @@ async def handle_message(bot: NoneBot, event: CQEvent) -> None: assert isinstance(event.message, Message) if not event.message: - event.message.append(MessageSegment.text('')) + event.message.append(MessageSegment.text('')) # type: ignore coros = [] + plugin_manager = PluginManager() for preprocessor in _message_preprocessors: - coros.append(preprocessor(bot, event)) + coros.append(preprocessor(bot, event, plugin_manager)) if coros: await asyncio.wait(coros) @@ -37,7 +40,7 @@ async def handle_message(bot: NoneBot, event: CQEvent) -> None: while True: try: - handled = await handle_command(bot, event) + handled = await handle_command(bot, event, plugin_manager.cmd_manager) break except SwitchException as e: # we are sure that there is no session existing now @@ -47,7 +50,7 @@ async def handle_message(bot: NoneBot, event: CQEvent) -> None: logger.info(f'Message {event.message_id} is handled as a command') return - handled = await handle_natural_language(bot, event) + handled = await handle_natural_language(bot, event, plugin_manager.nlp_manager) if handled: logger.info(f'Message {event.message_id} is handled ' f'as natural language') diff --git a/nonebot/natural_language.py b/nonebot/natural_language.py index 63a7e111..5b168bd4 100644 --- a/nonebot/natural_language.py +++ b/nonebot/natural_language.py @@ -1,5 +1,6 @@ import asyncio -from typing import Iterable, Optional, Callable, Union, NamedTuple +import warnings +from typing import Set, Iterable, Optional, Callable, Union, NamedTuple from aiocqhttp import Event as CQEvent @@ -10,13 +11,10 @@ from .message import Message from .session import BaseSession from .typing import CommandName_T, CommandArgs_T -_nl_processors = set() - class NLProcessor: - __slots__ = ('func', 'keywords', 'permission', - 'only_to_me', 'only_short_message', - 'allow_empty_message') + __slots__ = ('func', 'keywords', 'permission', 'only_to_me', + 'only_short_message', 'allow_empty_message') def __init__(self, *, func: Callable, keywords: Optional[Iterable], permission: int, only_to_me: bool, only_short_message: bool, @@ -29,8 +27,80 @@ class NLProcessor: self.allow_empty_message = allow_empty_message +class NLPManager: + _nl_processors: Set[NLProcessor] = set() + + def __init__(self): + self.nl_processors = NLPManager._nl_processors.copy() + + @classmethod + def add_nl_processor(cls, processor: NLProcessor) -> None: + """Register a natural language processor + + Args: + processor (NLProcessor): Processor object + """ + if processor in cls._nl_processors: + warnings.warn(f"NLProcessor {processor} already exists") + return + cls._nl_processors.add(processor) + + @classmethod + def remove_nl_processor(cls, processor: NLProcessor) -> bool: + """Remove a natural language processor globally + + Args: + processor (NLProcessor): Processor to remove + + Returns: + bool: Success or not + """ + if processor in cls._nl_processors: + cls._nl_processors.remove(processor) + return True + return False + + @classmethod + def switch_processor_global(cls, + processor: NLProcessor, + state: Optional[bool] = None) -> Optional[bool]: + """Remove or add a processor + + Args: + processor (NLProcessor): Processor object + + Returns: + bool: True if removed, False if added + """ + if processor in cls._nl_processors and not state: + cls._nl_processors.remove(processor) + return True + elif processor not in cls._nl_processors and state != False: + cls._nl_processors.add(processor) + return False + + def switch_processor(self, + processor: NLProcessor, + state: Optional[bool] = None) -> Optional[bool]: + """Remove or add processor + + Args: + processor (NLProcessor): Processor to remove + + Returns: + bool: True if removed, False if added + """ + if processor in self.nl_processors and not state: + self.nl_processors.remove(processor) + return True + elif processor not in self.nl_processors and state != False: + self.nl_processors.add(processor) + return False + + def on_natural_language( - keywords: Union[Optional[Iterable], str, Callable] = None, *, + keywords: Union[Optional[Iterable], str, Callable] = None, + *, permission: int = perm.EVERYBODY, only_to_me: bool = True, only_short_message: bool = True, @@ -45,14 +115,16 @@ def on_natural_language( :param allow_empty_message: handle empty messages """ - def deco(func: Callable) -> Callable: - nl_processor = NLProcessor(func=func, keywords=keywords, - permission=permission, - only_to_me=only_to_me, - only_short_message=only_short_message, - allow_empty_message=allow_empty_message) - _nl_processors.add(nl_processor) - return func + def deco(func: Callable) -> NLProcessor: + nl_processor = NLProcessor( + func=func, + keywords=keywords, # type: ignore + permission=permission, + only_to_me=only_to_me, + only_short_message=only_short_message, + allow_empty_message=allow_empty_message) + NLPManager.add_nl_processor(nl_processor) + return nl_processor if isinstance(keywords, Callable): # here "keywords" is the function to be decorated @@ -71,8 +143,11 @@ class NLPSession(BaseSession): self.msg = msg tmp_msg = Message(msg) self.msg_text = tmp_msg.extract_plain_text() - self.msg_images = [s.data['url'] for s in tmp_msg - if s.type == 'image' and 'url' in s.data] + self.msg_images = [ + s.data['url'] + for s in tmp_msg + if s.type == 'image' and 'url' in s.data + ] class NLPResult(NamedTuple): @@ -100,7 +175,8 @@ class IntentCommand(NamedTuple): current_arg: str = '' -async def handle_natural_language(bot: NoneBot, event: CQEvent) -> bool: +async def handle_natural_language(bot: NoneBot, event: CQEvent, + manager: NLPManager) -> bool: """ Handle a message as natural language. @@ -108,6 +184,7 @@ async def handle_natural_language(bot: NoneBot, event: CQEvent) -> bool: :param bot: NoneBot instance :param event: message event + :param manager: natural language processor manager :return: the message is handled as natural language """ session = NLPSession(bot, event, str(event.message)) @@ -117,7 +194,7 @@ async def handle_natural_language(bot: NoneBot, event: CQEvent) -> bool: msg_text_length = len(session.msg_text) futures = [] - for p in _nl_processors: + for p in manager.nl_processors: if not p.allow_empty_message and not session.msg: # don't allow empty msg, but it is one, so skip to next continue @@ -164,12 +241,12 @@ async def handle_natural_language(bot: NoneBot, event: CQEvent) -> bool: chosen_cmd = intent_commands[0] logger.debug( f'Intent command with highest confidence: {chosen_cmd}') - return await call_command( - bot, event, chosen_cmd.name, - args=chosen_cmd.args, - current_arg=chosen_cmd.current_arg, - check_perm=False - ) + return await call_command(bot, + event, + chosen_cmd.name, + args=chosen_cmd.args, + current_arg=chosen_cmd.current_arg, + check_perm=False) # type: ignore else: logger.debug('No intent command has enough confidence') return False diff --git a/nonebot/notice_request.py b/nonebot/notice_request.py index a9b72e41..bc333ae4 100644 --- a/nonebot/notice_request.py +++ b/nonebot/notice_request.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, Union +from typing import List, Optional, Callable, Union from aiocqhttp import Event as CQEvent from aiocqhttp.bus import EventBus @@ -10,20 +10,29 @@ from .session import BaseSession _bus = EventBus() +class EventHandler: + __slots__ = ('events', 'func') + + def __init__(self, events: List[str], func: Callable): + self.events = events + self.func = func + def _make_event_deco(post_type: str) -> Callable: def deco_deco(arg: Optional[Union[str, Callable]] = None, *events: str) -> Callable: - def deco(func: Callable) -> Callable: + def deco(func: Callable) -> EventHandler: if isinstance(arg, str): - for e in [arg] + list(events): - _bus.subscribe(f'{post_type}.{e}', func) + events_tmp = list(map(lambda x: f"{post_type}.{x}", [arg] + list(events))) + for e in events_tmp: + _bus.subscribe(e, func) + return EventHandler(events_tmp, func) else: _bus.subscribe(post_type, func) - return func + return EventHandler([post_type], func) if isinstance(arg, Callable): - return deco(arg) + return deco(arg) # type: ignore return deco return deco_deco diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 94db455f..09ce4046 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -1,26 +1,111 @@ -import importlib import os import re -from typing import Any, Set, Optional +import warnings +import importlib +from types import ModuleType +from typing import Any, Set, Dict, Optional from .log import logger +from .command import Command, CommandManager +from .natural_language import NLProcessor, NLPManager +from .notice_request import _bus, EventHandler class Plugin: - __slots__ = ('module', 'name', 'usage') + __slots__ = ('module', 'name', 'usage', 'commands', 'nl_processors', 'event_handlers') - def __init__(self, module: Any, + def __init__(self, module: ModuleType, name: Optional[str] = None, - usage: Optional[Any] = None): + usage: Optional[Any] = None, + commands: Set[Command] = set(), + nl_processors: Set[NLProcessor] = set(), + event_handlers: Set[EventHandler] = set()): self.module = module self.name = name self.usage = usage + self.commands = commands + self.nl_processors = nl_processors + self.event_handlers = event_handlers + +class PluginManager: + _plugins: Dict[str, Plugin] = {} + _anonymous_plugins: Set[Plugin] = set() + + def __init__(self): + self.cmd_manager = CommandManager() + self.nlp_manager = NLPManager() + + @classmethod + def add_plugin(cls, plugin: Plugin) -> None: + """Register a plugin + + Args: + plugin (Plugin): Plugin object + """ + if plugin.name: + if plugin.name in cls._plugins: + warnings.warn(f"Plugin {plugin.name} already exists") + return + cls._plugins[plugin.name] = plugin + else: + cls._anonymous_plugins.add(plugin) + + @classmethod + def get_plugin(cls, name: str) -> Optional[Plugin]: + return cls._plugins.get(name) + + # TODO: plugin重加载 + @classmethod + def reload_plugin(cls, plugin: Plugin) -> None: + pass + + @classmethod + def switch_plugin_global(cls, name: str, state: Optional[bool] = None) -> None: + """Change plugin state globally or simply switch it if `state` is None + + Args: + name (str): Plugin name + state (Optional[bool]): State to change to. Defaults to None. + """ + plugin = cls.get_plugin(name) + if not plugin: + warnings.warn(f"Plugin {name} not found") + return + for command in plugin.commands: + CommandManager.switch_command_global(command.name, state) + for nl_processor in plugin.nl_processors: + NLPManager.switch_processor_global(nl_processor, state) + for event_handler in plugin.event_handlers: + for event in event_handler.events: + if event_handler.func in _bus._subscribers[event] and not state: + _bus.unsubscribe(event, event_handler.func) + elif event_handler.func not in _bus._subscribers[event] and state != False: + _bus.subscribe(event, event_handler.func) + + def switch_plugin(self, name: str, state: Optional[bool] = None) -> None: + """Change plugin state or simply switch it if `state` is None + + Args: + name (str): Plugin name + state (Optional[bool]): State to change to. Defaults to None. + """ + plugin = self.get_plugin(name) + if not plugin: + warnings.warn(f"Plugin {name} not found") + return + for command in plugin.commands: + self.cmd_manager.switch_command(command.name, state) + for nl_processor in plugin.nl_processors: + self.nlp_manager.switch_processor(nl_processor, state) + # for event_handler in plugin.event_handlers: + # for event in event_handler.events: + # if event_handler.func in _bus._subscribers[event] and not state: + # _bus.unsubscribe(event, event_handler.func) + # elif event_handler.func not in _bus._subscribers[event] and state != False: + # _bus.subscribe(event, event_handler.func) -_plugins: Set[Plugin] = set() - - -def load_plugin(module_name: str) -> bool: +def load_plugin(module_name: str) -> Optional[Plugin]: """ Load a module as a plugin. @@ -31,16 +116,33 @@ def load_plugin(module_name: str) -> bool: module = importlib.import_module(module_name) name = getattr(module, '__plugin_name__', None) usage = getattr(module, '__plugin_usage__', None) - _plugins.add(Plugin(module, name, usage)) + commands = set() + nl_processors = set() + event_handlers = set() + for attr in dir(module): + func = getattr(module, attr) + if isinstance(func, Command): + commands.add(func) + elif isinstance(func, NLProcessor): + nl_processors.add(func) + elif isinstance(func, EventHandler): + event_handlers.add(func) + plugin = Plugin(module, name, usage, commands, nl_processors, event_handlers) + PluginManager.add_plugin(plugin) logger.info(f'Succeeded to import "{module_name}"') - return True + return plugin except Exception as e: logger.error(f'Failed to import "{module_name}", error: {e}') logger.exception(e) - return False + return None -def load_plugins(plugin_dir: str, module_prefix: str) -> int: +# TODO: plugin重加载 +def reload_plugin(module_name: str) -> Optional[Plugin]: + pass + + +def load_plugins(plugin_dir: str, module_prefix: str) -> Set[Plugin]: """ Find all non-hidden modules or packages in a given directory, and import them with the given module prefix. @@ -49,7 +151,7 @@ def load_plugins(plugin_dir: str, module_prefix: str) -> int: :param module_prefix: module prefix used while importing :return: number of plugins successfully loaded """ - count = 0 + count = set() for name in os.listdir(plugin_dir): path = os.path.join(plugin_dir, name) if os.path.isfile(path) and \ @@ -64,12 +166,13 @@ def load_plugins(plugin_dir: str, module_prefix: str) -> int: if not m: continue - if load_plugin(f'{module_prefix}.{m.group(1)}'): - count += 1 + result = load_plugin(f'{module_prefix}.{m.group(1)}') + if result: + count.add(result) return count -def load_builtin_plugins() -> int: +def load_builtin_plugins() -> Set[Plugin]: """ Load built-in plugins distributed along with "nonebot" package. """ @@ -83,4 +186,4 @@ def get_loaded_plugins() -> Set[Plugin]: :return: a set of Plugin objects """ - return _plugins + return set(PluginManager._plugins.values()) | PluginManager._anonymous_plugins