diff --git a/docs/api/adapters/README.md b/docs/api/adapters/README.md index e53a751b..3c1bcc8b 100644 --- a/docs/api/adapters/README.md +++ b/docs/api/adapters/README.md @@ -305,7 +305,7 @@ await bot.send_msg(message="hello world") * **说明** - 根据创建消息模板, 用法和 `str.format` 大致相同, 但是可以输出消息对象 + 根据创建消息模板, 用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板 @@ -317,6 +317,13 @@ await bot.send_msg(message="hello world") Message(MessageSegment(type='text', data={'text': 'hello world'})) >>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'})) +>>> Message.template( +... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) + +... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'} +... ) +Message(MessageSegment(type='text', data={'text': 'test 123456'}), + MessageSegment(type='face', data={'face': 233}), + MessageSegment(type='text', data={'text': 'test hello world'})) ``` diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 1f886108..6f209343 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -103,11 +103,12 @@ class Message(List[TMS], abc.ABC): self.extend(self._construct(message)) @classmethod - def template(cls: Type[TM], format_string: str) -> MessageTemplate[TM]: + def template(cls: Type[TM], + format_string: Union[str, TM]) -> MessageTemplate[TM]: """ :说明: - 根据创建消息模板, 用法和 ``str.format`` 大致相同, 但是可以输出消息对象 + 根据创建消息模板, 用法和 ``str.format`` 大致相同, 但是可以输出消息对象, 并且支持以 ``Message`` 对象作为消息模板 :示例: @@ -117,6 +118,13 @@ class Message(List[TMS], abc.ABC): Message(MessageSegment(type='text', data={'text': 'hello world'})) >>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'})) + >>> Message.template( + ... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) + + ... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'} + ... ) + Message(MessageSegment(type='text', data={'text': 'test 123456'}), + MessageSegment(type='face', data={'face': 233}), + MessageSegment(type='text', data={'text': 'test hello world'})) :参数: diff --git a/nonebot/adapters/_template.py b/nonebot/adapters/_template.py index 8f63b6f5..38a2dbc1 100644 --- a/nonebot/adapters/_template.py +++ b/nonebot/adapters/_template.py @@ -1,11 +1,10 @@ -import operator import functools from string import Formatter from typing import (TYPE_CHECKING, Any, Set, List, Type, Tuple, Union, Generic, Mapping, TypeVar, Sequence) if TYPE_CHECKING: - from . import Message + from . import Message, MessageSegment TM = TypeVar("TM", bound="Message") @@ -13,7 +12,7 @@ TM = TypeVar("TM", bound="Message") class MessageTemplate(Formatter, Generic[TM]): """消息模板格式化实现类""" - def __init__(self, factory: Type[TM], template: str) -> None: + def __init__(self, factory: Type[TM], template: Union[str, TM]) -> None: self.template = template self.factory = factory @@ -23,8 +22,18 @@ class MessageTemplate(Formatter, Generic[TM]): 根据模板和参数生成消息对象 """ - msg = self.vformat(self.template, args, kwargs) - return msg if isinstance(msg, self.factory) else self.factory(msg) + msg = self.factory() + if isinstance(self.template, str): + msg += self.vformat(self.template, args, kwargs) + elif isinstance(self.template, self.factory): + for seg in self.template: + seg: "MessageSegment" + msg += self.vformat(str(seg), args, + kwargs) if seg.is_text() else seg + else: + raise TypeError('template must be a string or instance of Message!') + + return msg def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TM: @@ -101,9 +110,15 @@ class MessageTemplate(Formatter, Generic[TM]): formatted_text = self.format_field(obj, str(format_control)) results.append(formatted_text) - return self.factory(functools.reduce(operator.add, results or + return self.factory(functools.reduce(self._add, results or [""])), auto_arg_index def format_field(self, value: Any, format_spec: str) -> Any: return super().format_field(value, format_spec) if format_spec else value + + def _add(self, a: Any, b: Any) -> Any: + try: + return a + b + except TypeError: + return a + str(b) diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py index 34ea948f..14d6f9e8 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py @@ -274,22 +274,6 @@ class MessageChain(BaseMessage[MessageSegment]): def get_segment_class(cls) -> Type[MessageSegment]: return MessageSegment - @overrides(BaseMessage) - def __init__(self, message: Union[List[Dict[str, - Any]], Iterable[MessageSegment], - MessageSegment, str], **kwargs): - super().__init__(**kwargs) - if isinstance(message, MessageSegment): - self.append(message) - elif isinstance(message, str): - self.append(MessageSegment.plain(text=message)) - elif isinstance(message, Iterable): - self.extend(self._construct(message)) - else: - raise ValueError( - f'Type {type(message).__name__} is not supported in mirai adapter.' - ) - @overrides(BaseMessage) def _construct( self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment]]