2018-07-01 17:51:01 +08:00
|
|
|
|
import asyncio
|
2018-07-02 16:54:29 +08:00
|
|
|
|
import re
|
2018-10-16 01:03:50 +08:00
|
|
|
|
from typing import Iterable, Optional, Callable, Union, NamedTuple
|
2018-07-01 11:01:24 +08:00
|
|
|
|
|
2018-07-04 09:28:31 +08:00
|
|
|
|
from . import NoneBot, permission as perm
|
2018-07-01 17:51:01 +08:00
|
|
|
|
from .command import call_command
|
|
|
|
|
from .log import logger
|
2018-10-14 22:52:37 +08:00
|
|
|
|
from .message import Message
|
2018-07-02 16:54:29 +08:00
|
|
|
|
from .session import BaseSession
|
2018-10-16 01:03:50 +08:00
|
|
|
|
from .typing import Context_T, CommandName_T, CommandArgs_T
|
2018-07-01 11:01:24 +08:00
|
|
|
|
|
|
|
|
|
_nl_processors = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NLProcessor:
|
2018-07-27 22:53:38 +08:00
|
|
|
|
__slots__ = ('func', 'keywords', 'permission',
|
|
|
|
|
'only_to_me', 'only_short_message')
|
2018-07-01 17:51:01 +08:00
|
|
|
|
|
|
|
|
|
def __init__(self, *, func: Callable, keywords: Optional[Iterable],
|
2018-07-27 22:53:38 +08:00
|
|
|
|
permission: int, only_to_me: bool, only_short_message: bool):
|
2018-07-01 17:51:01 +08:00
|
|
|
|
self.func = func
|
|
|
|
|
self.keywords = keywords
|
|
|
|
|
self.permission = permission
|
|
|
|
|
self.only_to_me = only_to_me
|
2018-07-27 22:53:38 +08:00
|
|
|
|
self.only_short_message = only_short_message
|
2018-07-01 17:51:01 +08:00
|
|
|
|
|
|
|
|
|
|
2018-07-03 10:36:05 +08:00
|
|
|
|
def on_natural_language(keywords: Union[Optional[Iterable], Callable] = None,
|
|
|
|
|
*, permission: int = perm.EVERYBODY,
|
2018-07-27 22:53:38 +08:00
|
|
|
|
only_to_me: bool = True,
|
|
|
|
|
only_short_message: bool = True) -> Callable:
|
2018-07-01 20:01:05 +08:00
|
|
|
|
"""
|
|
|
|
|
Decorator to register a function as a natural language processor.
|
|
|
|
|
|
|
|
|
|
:param keywords: keywords to respond, if None, respond to all messages
|
|
|
|
|
:param permission: permission required by the processor
|
|
|
|
|
:param only_to_me: only handle messages to me
|
2018-07-27 22:53:38 +08:00
|
|
|
|
:param only_short_message: only handle short message
|
2018-07-01 20:01:05 +08:00
|
|
|
|
"""
|
|
|
|
|
|
2018-07-01 17:51:01 +08:00
|
|
|
|
def deco(func: Callable) -> Callable:
|
|
|
|
|
nl_processor = NLProcessor(func=func, keywords=keywords,
|
2018-07-03 10:36:05 +08:00
|
|
|
|
permission=permission,
|
2018-07-27 22:53:38 +08:00
|
|
|
|
only_to_me=only_to_me,
|
|
|
|
|
only_short_message=only_short_message)
|
2018-07-01 17:51:01 +08:00
|
|
|
|
_nl_processors.add(nl_processor)
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
if isinstance(keywords, Callable):
|
|
|
|
|
# here "keywords" is the function to be decorated
|
|
|
|
|
return on_natural_language()(keywords)
|
|
|
|
|
else:
|
|
|
|
|
return deco
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NLPSession(BaseSession):
|
|
|
|
|
__slots__ = ('msg', 'msg_text', 'msg_images')
|
|
|
|
|
|
2018-10-16 01:03:50 +08:00
|
|
|
|
def __init__(self, bot: NoneBot, ctx: Context_T, msg: str):
|
2018-07-01 17:51:01 +08:00
|
|
|
|
super().__init__(bot, ctx)
|
|
|
|
|
self.msg = msg
|
|
|
|
|
tmp_msg = Message(msg)
|
|
|
|
|
self.msg_text = tmp_msg.extract_plain_text()
|
|
|
|
|
self.msg_images = [s.data['url'] for s in tmp_msg
|
|
|
|
|
if s.type == 'image' and 'url' in s.data]
|
2018-07-01 11:01:24 +08:00
|
|
|
|
|
|
|
|
|
|
2018-10-16 01:03:50 +08:00
|
|
|
|
class NLPResult(NamedTuple):
|
|
|
|
|
confidence: float
|
|
|
|
|
cmd_name: Union[str, CommandName_T]
|
|
|
|
|
cmd_args: Optional[CommandArgs_T] = None
|
2018-07-01 11:01:24 +08:00
|
|
|
|
|
|
|
|
|
|
2018-10-16 01:03:50 +08:00
|
|
|
|
async def handle_natural_language(bot: NoneBot, ctx: Context_T) -> bool:
|
2018-07-01 20:01:05 +08:00
|
|
|
|
"""
|
|
|
|
|
Handle a message as natural language.
|
|
|
|
|
|
|
|
|
|
This function is typically called by "handle_message".
|
|
|
|
|
|
2018-07-04 09:28:31 +08:00
|
|
|
|
:param bot: NoneBot instance
|
2018-07-01 20:01:05 +08:00
|
|
|
|
:param ctx: message context
|
|
|
|
|
:return: the message is handled as natural language
|
|
|
|
|
"""
|
2018-07-01 17:51:01 +08:00
|
|
|
|
msg = str(ctx['message'])
|
|
|
|
|
if bot.config.NICKNAME:
|
2018-07-04 09:39:50 +08:00
|
|
|
|
# check if the user is calling me with my nickname
|
2018-10-14 20:32:00 +08:00
|
|
|
|
if isinstance(bot.config.NICKNAME, str) or \
|
|
|
|
|
not isinstance(bot.config.NICKNAME, Iterable):
|
2018-07-04 09:39:50 +08:00
|
|
|
|
nicknames = (bot.config.NICKNAME,)
|
|
|
|
|
else:
|
|
|
|
|
nicknames = filter(lambda n: n, bot.config.NICKNAME)
|
2018-07-21 22:01:14 +08:00
|
|
|
|
m = re.search(rf'^({"|".join(nicknames)})[\s,,]+', msg, re.IGNORECASE)
|
2018-07-01 17:51:01 +08:00
|
|
|
|
if m:
|
2018-07-21 00:46:34 +08:00
|
|
|
|
logger.debug(f'User is calling me {m.group(1)}')
|
2018-07-01 17:51:01 +08:00
|
|
|
|
ctx['to_me'] = True
|
|
|
|
|
msg = msg[m.end():]
|
2018-07-01 20:01:05 +08:00
|
|
|
|
|
2018-07-01 17:51:01 +08:00
|
|
|
|
session = NLPSession(bot, ctx, msg)
|
|
|
|
|
|
2018-07-27 22:53:38 +08:00
|
|
|
|
# use msg_text here because CQ code "share" may be very long,
|
|
|
|
|
# at the same time some plugins may want to handle it
|
|
|
|
|
msg_text_length = len(session.msg_text)
|
|
|
|
|
|
2018-07-01 17:51:01 +08:00
|
|
|
|
coros = []
|
|
|
|
|
for p in _nl_processors:
|
2018-07-27 22:53:38 +08:00
|
|
|
|
if p.only_short_message and \
|
|
|
|
|
msg_text_length > bot.config.SHORT_MESSAGE_MAX_LENGTH:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if p.only_to_me and not ctx['to_me']:
|
|
|
|
|
continue
|
|
|
|
|
|
2018-07-01 17:51:01 +08:00
|
|
|
|
should_run = await perm.check_permission(bot, ctx, p.permission)
|
|
|
|
|
if should_run and p.keywords:
|
|
|
|
|
for kw in p.keywords:
|
|
|
|
|
if kw in session.msg_text:
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
# no keyword matches
|
|
|
|
|
should_run = False
|
|
|
|
|
|
|
|
|
|
if should_run:
|
|
|
|
|
coros.append(p.func(session))
|
|
|
|
|
|
|
|
|
|
if coros:
|
2018-07-01 20:01:05 +08:00
|
|
|
|
# wait for possible results, and sort them by confidence
|
2018-07-01 17:51:01 +08:00
|
|
|
|
results = sorted(filter(lambda r: r, await asyncio.gather(*coros)),
|
|
|
|
|
key=lambda r: r.confidence, reverse=True)
|
2018-07-21 00:46:34 +08:00
|
|
|
|
logger.debug(f'NLP results: {results}')
|
2018-07-01 17:51:01 +08:00
|
|
|
|
if results and results[0].confidence >= 60.0:
|
2018-07-01 20:01:05 +08:00
|
|
|
|
# choose the result with highest confidence
|
2018-07-21 00:46:34 +08:00
|
|
|
|
logger.debug(f'NLP result with highest confidence: {results[0]}')
|
2018-07-05 23:11:00 +08:00
|
|
|
|
return await call_command(bot, ctx, results[0].cmd_name,
|
|
|
|
|
args=results[0].cmd_args,
|
|
|
|
|
check_perm=False)
|
2018-07-21 00:46:34 +08:00
|
|
|
|
else:
|
|
|
|
|
logger.debug('No NLP result having enough confidence')
|
2018-07-01 17:51:01 +08:00
|
|
|
|
return False
|