diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 1f886108..6f209343 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -103,11 +103,12 @@ class Message(List[TMS], abc.ABC): self.extend(self._construct(message)) @classmethod - def template(cls: Type[TM], format_string: str) -> MessageTemplate[TM]: + def template(cls: Type[TM], + format_string: Union[str, TM]) -> MessageTemplate[TM]: """ :说明: - 根据创建消息模板, 用法和 ``str.format`` 大致相同, 但是可以输出消息对象 + 根据创建消息模板, 用法和 ``str.format`` 大致相同, 但是可以输出消息对象, 并且支持以 ``Message`` 对象作为消息模板 :示例: @@ -117,6 +118,13 @@ class Message(List[TMS], abc.ABC): Message(MessageSegment(type='text', data={'text': 'hello world'})) >>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'})) + >>> Message.template( + ... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) + + ... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'} + ... ) + Message(MessageSegment(type='text', data={'text': 'test 123456'}), + MessageSegment(type='face', data={'face': 233}), + MessageSegment(type='text', data={'text': 'test hello world'})) :参数: diff --git a/nonebot/adapters/_template.py b/nonebot/adapters/_template.py index 8f63b6f5..38a2dbc1 100644 --- a/nonebot/adapters/_template.py +++ b/nonebot/adapters/_template.py @@ -1,11 +1,10 @@ -import operator import functools from string import Formatter from typing import (TYPE_CHECKING, Any, Set, List, Type, Tuple, Union, Generic, Mapping, TypeVar, Sequence) if TYPE_CHECKING: - from . import Message + from . import Message, MessageSegment TM = TypeVar("TM", bound="Message") @@ -13,7 +12,7 @@ TM = TypeVar("TM", bound="Message") class MessageTemplate(Formatter, Generic[TM]): """消息模板格式化实现类""" - def __init__(self, factory: Type[TM], template: str) -> None: + def __init__(self, factory: Type[TM], template: Union[str, TM]) -> None: self.template = template self.factory = factory @@ -23,8 +22,18 @@ class MessageTemplate(Formatter, Generic[TM]): 根据模板和参数生成消息对象 """ - msg = self.vformat(self.template, args, kwargs) - return msg if isinstance(msg, self.factory) else self.factory(msg) + msg = self.factory() + if isinstance(self.template, str): + msg += self.vformat(self.template, args, kwargs) + elif isinstance(self.template, self.factory): + for seg in self.template: + seg: "MessageSegment" + msg += self.vformat(str(seg), args, + kwargs) if seg.is_text() else seg + else: + raise TypeError('template must be a string or instance of Message!') + + return msg def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TM: @@ -101,9 +110,15 @@ class MessageTemplate(Formatter, Generic[TM]): formatted_text = self.format_field(obj, str(format_control)) results.append(formatted_text) - return self.factory(functools.reduce(operator.add, results or + return self.factory(functools.reduce(self._add, results or [""])), auto_arg_index def format_field(self, value: Any, format_spec: str) -> Any: return super().format_field(value, format_spec) if format_spec else value + + def _add(self, a: Any, b: Any) -> Any: + try: + return a + b + except TypeError: + return a + str(b)