mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
move decorators to plugin module
This commit is contained in:
parent
9fbd09331c
commit
4f9a9136f9
@ -15,12 +15,16 @@ else:
|
||||
|
||||
|
||||
class NoneBot(CQHttp):
|
||||
|
||||
def __init__(self, config_object: Optional[Any] = None):
|
||||
if config_object is None:
|
||||
from . import default_config as config_object
|
||||
|
||||
config_dict = {k: v for k, v in config_object.__dict__.items()
|
||||
if k.isupper() and not k.startswith('_')}
|
||||
config_dict = {
|
||||
k: v
|
||||
for k, v in config_object.__dict__.items()
|
||||
if k.isupper() and not k.startswith('_')
|
||||
}
|
||||
logger.debug(f'Loaded configurations: {config_dict}')
|
||||
super().__init__(message_class=aiocqhttp.message.Message,
|
||||
**{k.lower(): v for k, v in config_dict.items()})
|
||||
@ -43,8 +47,11 @@ class NoneBot(CQHttp):
|
||||
async def _(event: aiocqhttp.Event):
|
||||
asyncio.create_task(handle_notice_or_request(self, event))
|
||||
|
||||
def run(self, host: Optional[str] = None, port: Optional[int] = None,
|
||||
*args, **kwargs) -> None:
|
||||
def run(self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
host = host or self.config.HOST
|
||||
port = port or self.config.PORT
|
||||
if 'debug' not in kwargs:
|
||||
@ -99,8 +106,8 @@ def get_bot() -> NoneBot:
|
||||
return _bot
|
||||
|
||||
|
||||
def run(host: Optional[str] = None, port: Optional[int] = None,
|
||||
*args, **kwargs) -> None:
|
||||
def run(host: Optional[str] = None, port: Optional[int] = None, *args,
|
||||
**kwargs) -> None:
|
||||
"""Run the NoneBot instance."""
|
||||
get_bot().run(host=host, port=port, *args, **kwargs)
|
||||
|
||||
@ -125,31 +132,40 @@ def on_websocket_connect(func: Callable[[aiocqhttp.Event], Awaitable[None]]) \
|
||||
|
||||
from .exceptions import *
|
||||
from .message import message_preprocessor, Message, MessageSegment
|
||||
from .plugin import (load_plugin, load_plugins, load_builtin_plugins,
|
||||
from .plugin import (on_command, on_natural_language, on_notice, on_request,
|
||||
load_plugin, load_plugins, load_builtin_plugins,
|
||||
get_loaded_plugins)
|
||||
from .command import on_command, CommandSession, CommandGroup
|
||||
from .natural_language import (on_natural_language, NLPSession, NLPResult,
|
||||
IntentCommand)
|
||||
from .notice_request import (on_notice, NoticeSession,
|
||||
on_request, RequestSession)
|
||||
from .command import CommandSession, CommandGroup
|
||||
from .natural_language import NLPSession, NLPResult, IntentCommand
|
||||
from .notice_request import NoticeSession, RequestSession
|
||||
from .helpers import context_id
|
||||
|
||||
__all__ = [
|
||||
'NoneBot', 'scheduler', 'init', 'get_bot', 'run',
|
||||
|
||||
'on_startup', 'on_websocket_connect',
|
||||
|
||||
'NoneBot',
|
||||
'scheduler',
|
||||
'init',
|
||||
'get_bot',
|
||||
'run',
|
||||
'on_startup',
|
||||
'on_websocket_connect',
|
||||
'CQHttpError',
|
||||
|
||||
'load_plugin', 'load_plugins', 'load_builtin_plugins',
|
||||
'load_plugin',
|
||||
'load_plugins',
|
||||
'load_builtin_plugins',
|
||||
'get_loaded_plugins',
|
||||
|
||||
'message_preprocessor', 'Message', 'MessageSegment',
|
||||
|
||||
'on_command', 'CommandSession', 'CommandGroup',
|
||||
|
||||
'on_natural_language', 'NLPSession', 'NLPResult', 'IntentCommand',
|
||||
'on_notice', 'NoticeSession', 'on_request', 'RequestSession',
|
||||
|
||||
'message_preprocessor',
|
||||
'Message',
|
||||
'MessageSegment',
|
||||
'on_command',
|
||||
'CommandSession',
|
||||
'CommandGroup',
|
||||
'on_natural_language',
|
||||
'NLPSession',
|
||||
'NLPResult',
|
||||
'IntentCommand',
|
||||
'on_notice',
|
||||
'NoticeSession',
|
||||
'on_request',
|
||||
'RequestSession',
|
||||
'context_id',
|
||||
]
|
||||
|
@ -4,6 +4,7 @@ from .command import CommandSession
|
||||
|
||||
|
||||
class ParserExit(RuntimeError):
|
||||
|
||||
def __init__(self, status=0, message=None):
|
||||
self.status = status
|
||||
self.message = message
|
||||
|
@ -15,20 +15,17 @@ from nonebot.helpers import context_id, send, render_expression
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import Message
|
||||
from nonebot.session import BaseSession
|
||||
from nonebot.typing import (CommandName_T, CommandArgs_T, Message_T, State_T,
|
||||
Filter_T)
|
||||
from nonebot.typing import (CommandName_T, CommandArgs_T, CommandHandler_T,
|
||||
Message_T, State_T, Filter_T)
|
||||
|
||||
# key: context id
|
||||
# value: CommandSession object
|
||||
_sessions = {} # type: Dict[str, "CommandSession"]
|
||||
|
||||
CommandHandler_T = Callable[['CommandSession'], Any]
|
||||
|
||||
|
||||
class Command:
|
||||
__slots__ = ('name', 'func', 'permission', 'only_to_me', 'privileged',
|
||||
'args_parser_func', '__name__', '__qualname__', '__doc__',
|
||||
'__annotations__', '__dict__')
|
||||
'args_parser_func')
|
||||
|
||||
def __init__(self, *, name: CommandName_T, func: CommandHandler_T,
|
||||
permission: int, only_to_me: bool, privileged: bool):
|
||||
@ -361,55 +358,6 @@ class CommandManager:
|
||||
cmd_name] if state is None else bool(state)
|
||||
|
||||
|
||||
def on_command(name: Union[str, CommandName_T],
|
||||
*,
|
||||
aliases: Union[Iterable[str], str] = (),
|
||||
permission: int = perm.EVERYBODY,
|
||||
only_to_me: bool = True,
|
||||
privileged: bool = False,
|
||||
shell_like: bool = False) -> Callable:
|
||||
"""
|
||||
Decorator to register a function as a command.
|
||||
|
||||
:param name: command name (e.g. 'echo' or ('random', 'number'))
|
||||
:param aliases: aliases of command name, for convenient access
|
||||
:param permission: permission required by the command
|
||||
:param only_to_me: only handle messages to me
|
||||
:param privileged: can be run even when there is already a session
|
||||
:param shell_like: use shell-like syntax to split arguments
|
||||
"""
|
||||
|
||||
def deco(func: CommandHandler_T) -> Command:
|
||||
if not isinstance(name, (str, tuple)):
|
||||
raise TypeError('the name of a command must be a str or tuple')
|
||||
if not name:
|
||||
raise ValueError('the name of a command must not be empty')
|
||||
|
||||
cmd_name = (name,) if isinstance(name, str) else name
|
||||
|
||||
cmd = Command(name=cmd_name,
|
||||
func=func,
|
||||
permission=permission,
|
||||
only_to_me=only_to_me,
|
||||
privileged=privileged)
|
||||
|
||||
if shell_like:
|
||||
|
||||
async def shell_like_args_parser(session):
|
||||
session.args['argv'] = shlex.split(session.current_arg)
|
||||
|
||||
cmd.args_parser_func = shell_like_args_parser
|
||||
|
||||
CommandManager.add_command(cmd_name, cmd)
|
||||
CommandManager.add_aliases(aliases, cmd)
|
||||
|
||||
update_wrapper(wrapper=cmd, wrapped=func) # type: ignore
|
||||
|
||||
return cmd
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
class _PauseException(Exception):
|
||||
"""
|
||||
Raised by session.pause() indicating that the command session
|
||||
|
@ -4,5 +4,6 @@ from nonebot.typing import Message_T
|
||||
|
||||
|
||||
class ValidateError(ValueError):
|
||||
|
||||
def __init__(self, message: Optional[Message_T] = None):
|
||||
self.message = message
|
||||
|
@ -11,8 +11,8 @@ def handle_cancellation(session: CommandSession):
|
||||
|
||||
def control(value):
|
||||
if _is_cancellation(value) is True:
|
||||
session.finish(render_expression(
|
||||
session.bot.config.SESSION_CANCEL_EXPRESSION))
|
||||
session.finish(
|
||||
render_expression(session.bot.config.SESSION_CANCEL_EXPRESSION))
|
||||
return value
|
||||
|
||||
return control
|
||||
|
@ -15,13 +15,15 @@ def _simple_chinese_to_bool(text: str) -> Optional[bool]:
|
||||
"""
|
||||
text = text.strip().lower().replace(' ', '') \
|
||||
.rstrip(',.!?~,。!?~了的呢吧呀啊呗啦')
|
||||
if text in {'要', '用', '是', '好', '对', '嗯', '行',
|
||||
'ok', 'okay', 'yeah', 'yep',
|
||||
'当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'}:
|
||||
if text in {
|
||||
'要', '用', '是', '好', '对', '嗯', '行', 'ok', 'okay', 'yeah', 'yep',
|
||||
'当真', '当然', '必须', '可以', '肯定', '没错', '确定', '确认'
|
||||
}:
|
||||
return True
|
||||
if text in {'不', '不要', '不用', '不是', '否', '不好', '不对', '不行', '别',
|
||||
'no', 'nono', 'nonono', 'nope', '不ok', '不可以', '不能',
|
||||
'不可以'}:
|
||||
if text in {
|
||||
'不', '不要', '不用', '不是', '否', '不好', '不对', '不行', '别', 'no', 'nono',
|
||||
'nonono', 'nope', '不ok', '不可以', '不能', '不可以'
|
||||
}:
|
||||
return False
|
||||
return None
|
||||
|
||||
@ -31,8 +33,8 @@ def _split_nonempty_lines(text: str) -> List[str]:
|
||||
|
||||
|
||||
def _split_nonempty_stripped_lines(text: str) -> List[str]:
|
||||
return list(filter(lambda x: x,
|
||||
map(lambda x: x.strip(), text.splitlines())))
|
||||
return list(filter(lambda x: x, map(lambda x: x.strip(),
|
||||
text.splitlines())))
|
||||
|
||||
|
||||
simple_chinese_to_bool = _simple_chinese_to_bool
|
||||
|
@ -14,8 +14,11 @@ def _extract_text(arg: Message_T) -> str:
|
||||
def _extract_image_urls(arg: Message_T) -> List[str]:
|
||||
"""Extract all image urls from a message-like object."""
|
||||
arg_as_msg = Message(arg)
|
||||
return [s.data['url'] for s in arg_as_msg
|
||||
if s.type == 'image' and 'url' in s.data]
|
||||
return [
|
||||
s.data['url']
|
||||
for s in arg_as_msg
|
||||
if s.type == 'image' and 'url' in s.data
|
||||
]
|
||||
|
||||
|
||||
def _extract_numbers(arg: Message_T) -> List[float]:
|
||||
|
@ -6,6 +6,7 @@ from nonebot.typing import Filter_T
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
|
||||
def __init__(self, message=None):
|
||||
self.message = message
|
||||
|
||||
@ -69,8 +70,7 @@ def match_regex(pattern: str, message=None, *, flags=0,
|
||||
return validate
|
||||
|
||||
|
||||
def ensure_true(bool_func: Callable[[Any], bool],
|
||||
message=None) -> Filter_T:
|
||||
def ensure_true(bool_func: Callable[[Any], bool], message=None) -> Filter_T:
|
||||
"""
|
||||
Validate any object to ensure the result of applying
|
||||
a boolean function to it is True.
|
||||
|
@ -45,6 +45,4 @@ TOO_MANY_VALIDATION_FAILURES_EXPRESSION: Expression_T = \
|
||||
|
||||
SESSION_CANCEL_EXPRESSION: Expression_T = '好的'
|
||||
|
||||
APSCHEDULER_CONFIG: Dict[str, Any] = {
|
||||
'apscheduler.timezone': 'Asia/Shanghai'
|
||||
}
|
||||
APSCHEDULER_CONFIG: Dict[str, Any] = {'apscheduler.timezone': 'Asia/Shanghai'}
|
||||
|
@ -10,8 +10,8 @@ from .message import escape
|
||||
from .typing import Message_T, Expression_T
|
||||
|
||||
|
||||
def context_id(event: CQEvent, *,
|
||||
mode: str = 'default', use_hash: bool = False) -> str:
|
||||
def context_id(event: CQEvent, *, mode: str = 'default',
|
||||
use_hash: bool = False) -> str:
|
||||
"""
|
||||
Calculate a unique id representing the context of the given event.
|
||||
|
||||
@ -48,8 +48,10 @@ def context_id(event: CQEvent, *,
|
||||
return ctx_id
|
||||
|
||||
|
||||
async def send(bot: NoneBot, event: CQEvent,
|
||||
message: Message_T, *,
|
||||
async def send(bot: NoneBot,
|
||||
event: CQEvent,
|
||||
message: Message_T,
|
||||
*,
|
||||
ensure_private: bool = False,
|
||||
ignore_failure: bool = True,
|
||||
**kwargs) -> Any:
|
||||
@ -65,8 +67,10 @@ async def send(bot: NoneBot, event: CQEvent,
|
||||
return None
|
||||
|
||||
|
||||
def render_expression(expr: Expression_T, *args,
|
||||
escape_args: bool = True, **kwargs) -> str:
|
||||
def render_expression(expr: Expression_T,
|
||||
*args,
|
||||
escape_args: bool = True,
|
||||
**kwargs) -> str:
|
||||
"""
|
||||
Render an expression to message string.
|
||||
|
||||
@ -82,8 +86,8 @@ def render_expression(expr: Expression_T, *args,
|
||||
expr = random.choice(expr)
|
||||
if escape_args:
|
||||
return expr.format(
|
||||
*[escape(s) if isinstance(s, str) else s for s in args],
|
||||
**{k: escape(v) if isinstance(v, str) else v
|
||||
for k, v in kwargs.items()}
|
||||
)
|
||||
*[escape(s) if isinstance(s, str) else s for s in args], **{
|
||||
k: escape(v) if isinstance(v, str) else v
|
||||
for k, v in kwargs.items()
|
||||
})
|
||||
return expr.format(*args, **kwargs)
|
||||
|
@ -10,7 +10,6 @@ import sys
|
||||
|
||||
logger = logging.getLogger('nonebot')
|
||||
default_handler = logging.StreamHandler(sys.stdout)
|
||||
default_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s %(name)s] %(levelname)s: %(message)s'
|
||||
))
|
||||
default_handler.setFormatter(
|
||||
logging.Formatter('[%(asctime)s %(name)s] %(levelname)s: %(message)s'))
|
||||
logger.addHandler(default_handler)
|
||||
|
@ -57,7 +57,8 @@ async def handle_message(bot: NoneBot, event: CQEvent) -> None:
|
||||
|
||||
while True:
|
||||
try:
|
||||
handled = await handle_command(bot, event, plugin_manager.cmd_manager)
|
||||
handled = await handle_command(bot, event,
|
||||
plugin_manager.cmd_manager)
|
||||
break
|
||||
except SwitchException as e:
|
||||
# we are sure that there is no session existing now
|
||||
@ -67,7 +68,8 @@ async def handle_message(bot: NoneBot, event: CQEvent) -> None:
|
||||
logger.info(f'Message {event.message_id} is handled as a command')
|
||||
return
|
||||
|
||||
handled = await handle_natural_language(bot, event, plugin_manager.nlp_manager)
|
||||
handled = await handle_natural_language(bot, event,
|
||||
plugin_manager.nlp_manager)
|
||||
if handled:
|
||||
logger.info(f'Message {event.message_id} is handled '
|
||||
f'as natural language')
|
||||
@ -121,8 +123,8 @@ def _check_calling_me_nickname(bot: NoneBot, event: CQEvent) -> None:
|
||||
else:
|
||||
nicknames = filter(lambda n: n, bot.config.NICKNAME)
|
||||
nickname_regex = '|'.join(nicknames)
|
||||
m = re.search(rf'^({nickname_regex})([\s,,]*|$)',
|
||||
first_text, re.IGNORECASE)
|
||||
m = re.search(rf'^({nickname_regex})([\s,,]*|$)', first_text,
|
||||
re.IGNORECASE)
|
||||
if m:
|
||||
nickname = m.group(1)
|
||||
logger.debug(f'User is calling me {nickname}')
|
||||
|
@ -15,8 +15,7 @@ from .typing import CommandName_T, CommandArgs_T
|
||||
|
||||
class NLProcessor:
|
||||
__slots__ = ('func', 'keywords', 'permission', 'only_to_me',
|
||||
'only_short_message', 'allow_empty_message', '__name__', '__qualname__', '__doc__',
|
||||
'__annotations__', '__dict__')
|
||||
'only_short_message', 'allow_empty_message')
|
||||
|
||||
def __init__(self, *, func: Callable, keywords: Optional[Iterable],
|
||||
permission: int, only_to_me: bool, only_short_message: bool,
|
||||
@ -101,44 +100,6 @@ class NLPManager:
|
||||
return False
|
||||
|
||||
|
||||
def on_natural_language(
|
||||
keywords: Union[Optional[Iterable], str, Callable] = None,
|
||||
*,
|
||||
permission: int = perm.EVERYBODY,
|
||||
only_to_me: bool = True,
|
||||
only_short_message: bool = True,
|
||||
allow_empty_message: bool = False) -> Callable:
|
||||
"""
|
||||
Decorator to register a function as a natural language processor.
|
||||
|
||||
:param keywords: keywords to respond to, if None, respond to all messages
|
||||
:param permission: permission required by the processor
|
||||
:param only_to_me: only handle messages to me
|
||||
:param only_short_message: only handle short messages
|
||||
:param allow_empty_message: handle empty messages
|
||||
"""
|
||||
|
||||
def deco(func: Callable) -> NLProcessor:
|
||||
nl_processor = NLProcessor(
|
||||
func=func,
|
||||
keywords=keywords, # type: ignore
|
||||
permission=permission,
|
||||
only_to_me=only_to_me,
|
||||
only_short_message=only_short_message,
|
||||
allow_empty_message=allow_empty_message)
|
||||
NLPManager.add_nl_processor(nl_processor)
|
||||
update_wrapper(wrapper=nl_processor, wrapped=func) # type: ignore
|
||||
return nl_processor
|
||||
|
||||
if isinstance(keywords, Callable):
|
||||
# here "keywords" is the function to be decorated
|
||||
return on_natural_language()(keywords)
|
||||
else:
|
||||
if isinstance(keywords, str):
|
||||
keywords = (keywords,)
|
||||
return deco
|
||||
|
||||
|
||||
class NLPSession(BaseSession):
|
||||
__slots__ = ('msg', 'msg_text', 'msg_images')
|
||||
|
||||
|
@ -11,41 +11,15 @@ from .session import BaseSession
|
||||
|
||||
_bus = EventBus()
|
||||
|
||||
|
||||
class EventHandler:
|
||||
__slots__ = ('events', 'func', '__name__', '__qualname__', '__doc__',
|
||||
'__annotations__', '__dict__')
|
||||
|
||||
__slots__ = ('events', 'func')
|
||||
|
||||
def __init__(self, events: List[str], func: Callable):
|
||||
self.events = events
|
||||
self.func = func
|
||||
|
||||
|
||||
def _make_event_deco(post_type: str) -> Callable:
|
||||
def deco_deco(arg: Optional[Union[str, Callable]] = None,
|
||||
*events: str) -> Callable:
|
||||
def deco(func: Callable) -> EventHandler:
|
||||
if isinstance(arg, str):
|
||||
events_tmp = list(map(lambda x: f"{post_type}.{x}", [arg] + list(events)))
|
||||
for e in events_tmp:
|
||||
_bus.subscribe(e, func)
|
||||
handler = EventHandler(events_tmp, func)
|
||||
return update_wrapper(handler, func) # type: ignore
|
||||
else:
|
||||
_bus.subscribe(post_type, func)
|
||||
handler = EventHandler([post_type], func)
|
||||
return update_wrapper(handler, func) # type: ignore
|
||||
|
||||
if isinstance(arg, Callable):
|
||||
return deco(arg) # type: ignore
|
||||
return deco
|
||||
|
||||
return deco_deco
|
||||
|
||||
|
||||
on_notice = _make_event_deco('notice')
|
||||
on_request = _make_event_deco('request')
|
||||
|
||||
|
||||
class NoticeSession(BaseSession):
|
||||
__slots__ = ()
|
||||
|
||||
@ -66,12 +40,13 @@ class RequestSession(BaseSession):
|
||||
:param remark: remark of friend (only works in friend request)
|
||||
"""
|
||||
try:
|
||||
await self.bot.call_action(
|
||||
action='.handle_quick_operation_async',
|
||||
self_id=self.event.self_id,
|
||||
context=self.event,
|
||||
operation={'approve': True, 'remark': remark}
|
||||
)
|
||||
await self.bot.call_action(action='.handle_quick_operation_async',
|
||||
self_id=self.event.self_id,
|
||||
context=self.event,
|
||||
operation={
|
||||
'approve': True,
|
||||
'remark': remark
|
||||
})
|
||||
except CQHttpError:
|
||||
pass
|
||||
|
||||
@ -82,12 +57,13 @@ class RequestSession(BaseSession):
|
||||
:param reason: reason to reject (only works in group request)
|
||||
"""
|
||||
try:
|
||||
await self.bot.call_action(
|
||||
action='.handle_quick_operation_async',
|
||||
self_id=self.event.self_id,
|
||||
context=self.event,
|
||||
operation={'approve': False, 'reason': reason}
|
||||
)
|
||||
await self.bot.call_action(action='.handle_quick_operation_async',
|
||||
self_id=self.event.self_id,
|
||||
context=self.event,
|
||||
operation={
|
||||
'approve': False,
|
||||
'reason': reason
|
||||
})
|
||||
except CQHttpError:
|
||||
pass
|
||||
|
||||
|
@ -88,8 +88,7 @@ async def _check(bot: NoneBot, min_event: _MinEvent,
|
||||
self_id=min_event.self_id,
|
||||
group_id=min_event.group_id,
|
||||
user_id=min_event.user_id,
|
||||
no_cache=True
|
||||
)
|
||||
no_cache=True)
|
||||
if member_info:
|
||||
if member_info['role'] == 'owner':
|
||||
permission |= IS_GROUP_OWNER
|
||||
|
@ -1,20 +1,29 @@
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import warnings
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
from typing import Any, Set, Dict, Optional
|
||||
from typing import Any, Set, Dict, Union, Optional, Iterable, Callable
|
||||
|
||||
from .log import logger
|
||||
from nonebot import permission as perm
|
||||
from .command import Command, CommandManager
|
||||
from .natural_language import NLProcessor, NLPManager
|
||||
from .notice_request import _bus, EventHandler
|
||||
from .typing import CommandName_T, CommandHandler_T
|
||||
|
||||
_tmp_command: Set[Command] = set()
|
||||
_tmp_nl_processor: Set[NLProcessor] = set()
|
||||
_tmp_event_handler: Set[EventHandler] = set()
|
||||
|
||||
|
||||
class Plugin:
|
||||
__slots__ = ('module', 'name', 'usage', 'commands', 'nl_processors', 'event_handlers')
|
||||
__slots__ = ('module', 'name', 'usage', 'commands', 'nl_processors',
|
||||
'event_handlers')
|
||||
|
||||
def __init__(self, module: ModuleType,
|
||||
def __init__(self,
|
||||
module: ModuleType,
|
||||
name: Optional[str] = None,
|
||||
usage: Optional[Any] = None,
|
||||
commands: Set[Command] = set(),
|
||||
@ -27,40 +36,58 @@ class Plugin:
|
||||
self.nl_processors = nl_processors
|
||||
self.event_handlers = event_handlers
|
||||
|
||||
|
||||
class PluginManager:
|
||||
_plugins: Dict[str, Plugin] = {}
|
||||
_anonymous_plugins: Set[Plugin] = set()
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.cmd_manager = CommandManager()
|
||||
self.nlp_manager = NLPManager()
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_plugin(cls, plugin: Plugin) -> None:
|
||||
def add_plugin(cls, module_path: str, plugin: Plugin) -> None:
|
||||
"""Register a plugin
|
||||
|
||||
Args:
|
||||
name (str): module path
|
||||
plugin (Plugin): Plugin object
|
||||
"""
|
||||
if plugin.name:
|
||||
if plugin.name in cls._plugins:
|
||||
warnings.warn(f"Plugin {plugin.name} already exists")
|
||||
return
|
||||
cls._plugins[plugin.name] = plugin
|
||||
else:
|
||||
cls._anonymous_plugins.add(plugin)
|
||||
|
||||
@classmethod
|
||||
def get_plugin(cls, name: str) -> Optional[Plugin]:
|
||||
return cls._plugins.get(name)
|
||||
|
||||
# TODO: plugin重加载
|
||||
@classmethod
|
||||
def reload_plugin(cls, plugin: Plugin) -> None:
|
||||
pass
|
||||
if module_path in cls._plugins:
|
||||
warnings.warn(f"Plugin {module_path} already exists")
|
||||
return
|
||||
cls._plugins[module_path] = plugin
|
||||
|
||||
@classmethod
|
||||
def switch_plugin_global(cls, name: str, state: Optional[bool] = None) -> None:
|
||||
def get_plugin(cls, module_path: str) -> Optional[Plugin]:
|
||||
"""Get plugin object by plugin path
|
||||
|
||||
Args:
|
||||
name (str): plugin path
|
||||
|
||||
Returns:
|
||||
Optional[Plugin]: Plugin object
|
||||
"""
|
||||
return cls._plugins.get(module_path, None)
|
||||
|
||||
@classmethod
|
||||
def remove_plugin(cls, module_path: str) -> bool:
|
||||
plugin = cls.get_plugin(module_path)
|
||||
if not plugin:
|
||||
warnings.warn(f"Plugin {module_path} not exists")
|
||||
return False
|
||||
for command in plugin.commands:
|
||||
CommandManager.remove_command(command.name)
|
||||
for nl_processor in plugin.nl_processors:
|
||||
NLPManager.remove_nl_processor(nl_processor)
|
||||
for event_handler in plugin.event_handlers:
|
||||
for event in event_handler.events:
|
||||
_bus.unsubscribe(event, event_handler.func)
|
||||
del cls._plugins[module_path]
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def switch_plugin_global(cls, name: str,
|
||||
state: Optional[bool] = None) -> None:
|
||||
"""Change plugin state globally or simply switch it if `state` is None
|
||||
|
||||
Args:
|
||||
@ -79,11 +106,13 @@ class PluginManager:
|
||||
for event in event_handler.events:
|
||||
if event_handler.func in _bus._subscribers[event] and not state:
|
||||
_bus.unsubscribe(event, event_handler.func)
|
||||
elif event_handler.func not in _bus._subscribers[event] and state != False:
|
||||
elif event_handler.func not in _bus._subscribers[
|
||||
event] and state != False:
|
||||
_bus.subscribe(event, event_handler.func)
|
||||
|
||||
@classmethod
|
||||
def switch_command_global(cls, name: str, state: Optional[bool] = None) -> None:
|
||||
def switch_command_global(cls, name: str,
|
||||
state: Optional[bool] = None) -> None:
|
||||
"""Change plugin command state globally or simply switch it if `state` is None
|
||||
|
||||
Args:
|
||||
@ -96,9 +125,10 @@ class PluginManager:
|
||||
return
|
||||
for command in plugin.commands:
|
||||
CommandManager.switch_command_global(command.name, state)
|
||||
|
||||
|
||||
@classmethod
|
||||
def switch_nlprocessor_global(cls, name: str, state: Optional[bool] = None) -> None:
|
||||
def switch_nlprocessor_global(cls, name: str,
|
||||
state: Optional[bool] = None) -> None:
|
||||
"""Change plugin nlprocessor state globally or simply switch it if `state` is None
|
||||
|
||||
Args:
|
||||
@ -113,7 +143,8 @@ class PluginManager:
|
||||
NLPManager.switch_nlprocessor_global(processor, state)
|
||||
|
||||
@classmethod
|
||||
def switch_eventhandler_global(cls, name: str, state: Optional[bool] = None) -> None:
|
||||
def switch_eventhandler_global(cls, name: str,
|
||||
state: Optional[bool] = None) -> None:
|
||||
"""Change plugin event handler state globally or simply switch it if `state` is None
|
||||
|
||||
Args:
|
||||
@ -128,7 +159,8 @@ class PluginManager:
|
||||
for event in event_handler.events:
|
||||
if event_handler.func in _bus._subscribers[event] and not state:
|
||||
_bus.unsubscribe(event, event_handler.func)
|
||||
elif event_handler.func not in _bus._subscribers[event] and state != False:
|
||||
elif event_handler.func not in _bus._subscribers[
|
||||
event] and state != False:
|
||||
_bus.subscribe(event, event_handler.func)
|
||||
|
||||
def switch_plugin(self, name: str, state: Optional[bool] = None) -> None:
|
||||
@ -151,7 +183,7 @@ class PluginManager:
|
||||
self.cmd_manager.switch_command(command.name, state)
|
||||
for nl_processor in plugin.nl_processors:
|
||||
self.nlp_manager.switch_nlprocessor(nl_processor, state)
|
||||
|
||||
|
||||
def switch_command(self, name: str, state: Optional[bool] = None) -> None:
|
||||
"""Change plugin command state or simply switch it if `state` is None
|
||||
|
||||
@ -166,7 +198,8 @@ class PluginManager:
|
||||
for command in plugin.commands:
|
||||
self.cmd_manager.switch_command(command.name, state)
|
||||
|
||||
def switch_nlprocessor(self, name: str, state: Optional[bool] = None) -> None:
|
||||
def switch_nlprocessor(self, name: str,
|
||||
state: Optional[bool] = None) -> None:
|
||||
"""Change plugin nlprocessor state or simply switch it if `state` is None
|
||||
|
||||
Args:
|
||||
@ -181,41 +214,42 @@ class PluginManager:
|
||||
self.nlp_manager.switch_nlprocessor(processor, state)
|
||||
|
||||
|
||||
def load_plugin(module_name: str) -> Optional[Plugin]:
|
||||
"""
|
||||
Load a module as a plugin.
|
||||
|
||||
:param module_name: name of module to import
|
||||
:return: successful or not
|
||||
def load_plugin(module_path: str) -> Optional[Plugin]:
|
||||
"""Load a module as a plugin
|
||||
|
||||
Args:
|
||||
module_path (str): path of module to import
|
||||
|
||||
Returns:
|
||||
Optional[Plugin]: Plugin object loaded
|
||||
"""
|
||||
# Make sure tmp is clean
|
||||
_tmp_command.clear()
|
||||
_tmp_nl_processor.clear()
|
||||
_tmp_event_handler.clear()
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
module = importlib.import_module(module_path)
|
||||
name = getattr(module, '__plugin_name__', None)
|
||||
usage = getattr(module, '__plugin_usage__', None)
|
||||
commands = set()
|
||||
nl_processors = set()
|
||||
event_handlers = set()
|
||||
for attr in dir(module):
|
||||
func = getattr(module, attr)
|
||||
if isinstance(func, Command):
|
||||
commands.add(func)
|
||||
elif isinstance(func, NLProcessor):
|
||||
nl_processors.add(func)
|
||||
elif isinstance(func, EventHandler):
|
||||
event_handlers.add(func)
|
||||
plugin = Plugin(module, name, usage, commands, nl_processors, event_handlers)
|
||||
PluginManager.add_plugin(plugin)
|
||||
logger.info(f'Succeeded to import "{module_name}"')
|
||||
commands = _tmp_command.copy()
|
||||
nl_processors = _tmp_nl_processor.copy()
|
||||
event_handlers = _tmp_event_handler.copy()
|
||||
plugin = Plugin(module, name, usage, commands, nl_processors,
|
||||
event_handlers)
|
||||
PluginManager.add_plugin(module_path, plugin)
|
||||
logger.info(f'Succeeded to import "{module_path}"')
|
||||
return plugin
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to import "{module_name}", error: {e}')
|
||||
logger.error(f'Failed to import "{module_path}", error: {e}')
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
|
||||
# TODO: plugin重加载
|
||||
def reload_plugin(module_name: str) -> Optional[Plugin]:
|
||||
pass
|
||||
def reload_plugin(module_path: str) -> Optional[Plugin]:
|
||||
result = PluginManager.remove_plugin(module_path)
|
||||
if not result:
|
||||
return None
|
||||
return load_plugin(module_path)
|
||||
|
||||
|
||||
def load_plugins(plugin_dir: str, module_prefix: str) -> Set[Plugin]:
|
||||
@ -262,4 +296,121 @@ def get_loaded_plugins() -> Set[Plugin]:
|
||||
|
||||
:return: a set of Plugin objects
|
||||
"""
|
||||
return set(PluginManager._plugins.values()) | PluginManager._anonymous_plugins
|
||||
return set(PluginManager._plugins.values())
|
||||
|
||||
|
||||
def on_command(name: Union[str, CommandName_T],
|
||||
*,
|
||||
aliases: Union[Iterable[str], str] = (),
|
||||
permission: int = perm.EVERYBODY,
|
||||
only_to_me: bool = True,
|
||||
privileged: bool = False,
|
||||
shell_like: bool = False) -> Callable:
|
||||
"""
|
||||
Decorator to register a function as a command.
|
||||
|
||||
:param name: command name (e.g. 'echo' or ('random', 'number'))
|
||||
:param aliases: aliases of command name, for convenient access
|
||||
:param permission: permission required by the command
|
||||
:param only_to_me: only handle messages to me
|
||||
:param privileged: can be run even when there is already a session
|
||||
:param shell_like: use shell-like syntax to split arguments
|
||||
"""
|
||||
|
||||
def deco(func: CommandHandler_T) -> CommandHandler_T:
|
||||
if not isinstance(name, (str, tuple)):
|
||||
raise TypeError('the name of a command must be a str or tuple')
|
||||
if not name:
|
||||
raise ValueError('the name of a command must not be empty')
|
||||
|
||||
cmd_name = (name,) if isinstance(name, str) else name
|
||||
|
||||
cmd = Command(name=cmd_name,
|
||||
func=func,
|
||||
permission=permission,
|
||||
only_to_me=only_to_me,
|
||||
privileged=privileged)
|
||||
|
||||
if shell_like:
|
||||
|
||||
async def shell_like_args_parser(session):
|
||||
session.args['argv'] = shlex.split(session.current_arg)
|
||||
|
||||
cmd.args_parser_func = shell_like_args_parser
|
||||
|
||||
CommandManager.add_command(cmd_name, cmd)
|
||||
CommandManager.add_aliases(aliases, cmd)
|
||||
|
||||
_tmp_command.add(cmd)
|
||||
func.args_parser = cmd.args_parser
|
||||
|
||||
return func
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def on_natural_language(
|
||||
keywords: Union[Optional[Iterable], str, Callable] = None,
|
||||
*,
|
||||
permission: int = perm.EVERYBODY,
|
||||
only_to_me: bool = True,
|
||||
only_short_message: bool = True,
|
||||
allow_empty_message: bool = False) -> Callable:
|
||||
"""
|
||||
Decorator to register a function as a natural language processor.
|
||||
|
||||
:param keywords: keywords to respond to, if None, respond to all messages
|
||||
:param permission: permission required by the processor
|
||||
:param only_to_me: only handle messages to me
|
||||
:param only_short_message: only handle short messages
|
||||
:param allow_empty_message: handle empty messages
|
||||
"""
|
||||
|
||||
def deco(func: Callable) -> Callable:
|
||||
nl_processor = NLProcessor(
|
||||
func=func,
|
||||
keywords=keywords, # type: ignore
|
||||
permission=permission,
|
||||
only_to_me=only_to_me,
|
||||
only_short_message=only_short_message,
|
||||
allow_empty_message=allow_empty_message)
|
||||
NLPManager.add_nl_processor(nl_processor)
|
||||
_tmp_nl_processor.add(nl_processor)
|
||||
return func
|
||||
|
||||
if isinstance(keywords, Callable):
|
||||
# here "keywords" is the function to be decorated
|
||||
return on_natural_language()(keywords)
|
||||
else:
|
||||
if isinstance(keywords, str):
|
||||
keywords = (keywords,)
|
||||
return deco
|
||||
|
||||
|
||||
def _make_event_deco(post_type: str) -> Callable:
|
||||
|
||||
def deco_deco(arg: Optional[Union[str, Callable]] = None,
|
||||
*events: str) -> Callable:
|
||||
|
||||
def deco(func: Callable) -> Callable:
|
||||
if isinstance(arg, str):
|
||||
events_tmp = list(
|
||||
map(lambda x: f"{post_type}.{x}", [arg] + list(events)))
|
||||
for e in events_tmp:
|
||||
_bus.subscribe(e, func)
|
||||
handler = EventHandler(events_tmp, func)
|
||||
else:
|
||||
_bus.subscribe(post_type, func)
|
||||
handler = EventHandler([post_type], func)
|
||||
_tmp_event_handler.add(handler)
|
||||
return func
|
||||
|
||||
if isinstance(arg, Callable):
|
||||
return deco(arg) # type: ignore
|
||||
return deco
|
||||
|
||||
return deco_deco
|
||||
|
||||
|
||||
on_notice = _make_event_deco('notice')
|
||||
on_request = _make_event_deco('request')
|
||||
|
@ -5,6 +5,7 @@ except ImportError:
|
||||
AsyncIOScheduler = None
|
||||
|
||||
if AsyncIOScheduler:
|
||||
|
||||
class Scheduler(AsyncIOScheduler):
|
||||
pass
|
||||
else:
|
||||
|
@ -24,7 +24,9 @@ class BaseSession:
|
||||
def self_id(self) -> int:
|
||||
return self.event.self_id
|
||||
|
||||
async def send(self, message: Message_T, *,
|
||||
async def send(self,
|
||||
message: Message_T,
|
||||
*,
|
||||
at_sender: bool = False,
|
||||
ensure_private: bool = False,
|
||||
ignore_failure: bool = True,
|
||||
@ -38,7 +40,10 @@ class BaseSession:
|
||||
:param ignore_failure: if any CQHttpError raised, ignore it
|
||||
:return: the result returned by CQHTTP
|
||||
"""
|
||||
return await send(self.bot, self.event, message,
|
||||
return await send(self.bot,
|
||||
self.event,
|
||||
message,
|
||||
at_sender=at_sender,
|
||||
ensure_private=ensure_private,
|
||||
ignore_failure=ignore_failure, **kwargs)
|
||||
ignore_failure=ignore_failure,
|
||||
**kwargs)
|
||||
|
@ -5,5 +5,6 @@ Message_T = Union[str, Dict[str, Any], List[Dict[str, Any]]]
|
||||
Expression_T = Union[str, Sequence[str], Callable]
|
||||
CommandName_T = Tuple[str, ...]
|
||||
CommandArgs_T = Dict[str, Any]
|
||||
CommandHandler_T = Callable[["CommandSession"], Any]
|
||||
State_T = Dict[str, Any]
|
||||
Filter_T = Callable[[Any], Union[Any, Awaitable[Any]]]
|
||||
|
Loading…
Reference in New Issue
Block a user