From f6289ff1b36c15f57ecda8dfae331503bc62e515 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Tue, 2 Mar 2021 14:35:02 +0800 Subject: [PATCH] :bug: fix prompt error --- nonebot/adapters/_base.py | 27 ++++++++++++++++++++++++--- nonebot/matcher.py | 21 ++++++++++++++++++--- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 328883c8..5ef603d3 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -152,7 +152,7 @@ T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") @dataclass -class MessageSegment(abc.ABC): +class MessageSegment(abc.ABC, Mapping): """消息段基类""" type: str """ @@ -166,10 +166,16 @@ class MessageSegment(abc.ABC): """ @abc.abstractmethod - def __str__(self: T_MessageSegment) -> str: + def __str__(self) -> str: """该消息段所代表的 str,在命令匹配部分使用""" raise NotImplementedError + def __len__(self) -> int: + return len(str(self)) + + def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool: + return not self == other + @abc.abstractmethod def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, T_Message]) -> T_Message: @@ -203,9 +209,24 @@ class MessageSegment(abc.ABC): def __setitem__(self, key, value): return setattr(self, key, value) - def get(self, key, default=None): + def __iter__(self): + yield from self.data.__iter__() + + def __contains__(self, key: object) -> bool: + return key in self.data + + def get(self, key: str, default=None): return getattr(self, key, default) + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + def copy(self: T_MessageSegment) -> T_MessageSegment: return copy(self) diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 2b7775f6..8a76eb52 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -10,7 +10,7 @@ from functools import wraps from datetime import datetime from contextvars import ContextVar from collections import defaultdict -from typing import Type, List, Dict, Union, Callable, Optional, NoReturn, TYPE_CHECKING +from typing import Type, List, Dict, Union, Mapping, Iterable, Callable, Optional, NoReturn, TYPE_CHECKING from nonebot.rule import Rule from nonebot.log import logger @@ -345,8 +345,23 @@ class Matcher(metaclass=MatcherMeta): state["_current_key"] = key if key not in state: if prompt: - await bot.send(event=event, - message=str(prompt).format(**state)) + if isinstance(prompt, str): + await bot.send(event=event, + message=prompt.format(**state)) + elif isinstance(prompt, Mapping): + if prompt.is_text(): + await bot.send(event=event, + message=str(prompt).format(**state)) + else: + await bot.send(event=event, message=prompt) + elif isinstance(prompt, Iterable): + await bot.send( + event=event, + message=prompt.__class__( + str(prompt).format(**state)) # type: ignore + ) + else: + logger.warning("Unknown prompt type, ignored.") raise PausedException else: state["_skip_key"] = True