🐛 fix prompt error

This commit is contained in:
yanyongyu 2021-03-02 14:35:02 +08:00
parent 3c8dca67fa
commit f6289ff1b3
2 changed files with 42 additions and 6 deletions

View File

@ -152,7 +152,7 @@ T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
@dataclass @dataclass
class MessageSegment(abc.ABC): class MessageSegment(abc.ABC, Mapping):
"""消息段基类""" """消息段基类"""
type: str type: str
""" """
@ -166,10 +166,16 @@ class MessageSegment(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def __str__(self: T_MessageSegment) -> str: def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用""" """该消息段所代表的 str在命令匹配部分使用"""
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int:
return len(str(self))
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
return not self == other
@abc.abstractmethod @abc.abstractmethod
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
T_Message]) -> T_Message: T_Message]) -> T_Message:
@ -203,9 +209,24 @@ class MessageSegment(abc.ABC):
def __setitem__(self, key, value): def __setitem__(self, key, value):
return setattr(self, key, value) return setattr(self, key, value)
def get(self, key, default=None): def __iter__(self):
yield from self.data.__iter__()
def __contains__(self, key: object) -> bool:
return key in self.data
def get(self, key: str, default=None):
return getattr(self, key, default) return getattr(self, key, default)
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def copy(self: T_MessageSegment) -> T_MessageSegment: def copy(self: T_MessageSegment) -> T_MessageSegment:
return copy(self) return copy(self)

View File

@ -10,7 +10,7 @@ from functools import wraps
from datetime import datetime from datetime import datetime
from contextvars import ContextVar from contextvars import ContextVar
from collections import defaultdict from collections import defaultdict
from typing import Type, List, Dict, Union, Callable, Optional, NoReturn, TYPE_CHECKING from typing import Type, List, Dict, Union, Mapping, Iterable, Callable, Optional, NoReturn, TYPE_CHECKING
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.log import logger from nonebot.log import logger
@ -345,8 +345,23 @@ class Matcher(metaclass=MatcherMeta):
state["_current_key"] = key state["_current_key"] = key
if key not in state: if key not in state:
if prompt: if prompt:
await bot.send(event=event, if isinstance(prompt, str):
message=str(prompt).format(**state)) await bot.send(event=event,
message=prompt.format(**state))
elif isinstance(prompt, Mapping):
if prompt.is_text():
await bot.send(event=event,
message=str(prompt).format(**state))
else:
await bot.send(event=event, message=prompt)
elif isinstance(prompt, Iterable):
await bot.send(
event=event,
message=prompt.__class__(
str(prompt).format(**state)) # type: ignore
)
else:
logger.warning("Unknown prompt type, ignored.")
raise PausedException raise PausedException
else: else:
state["_skip_key"] = True state["_skip_key"] = True