diff --git a/docs/api/adapters/README.md b/docs/api/adapters/README.md index d1431b1f..2e8555c8 100644 --- a/docs/api/adapters/README.md +++ b/docs/api/adapters/README.md @@ -300,6 +300,40 @@ await bot.send_msg(message="hello world") +### _classmethod_ `template(format_string)` + + +* **说明** + + 根据创建消息模板, 用法和 `str.format` 大致相同, 但是可以输出消息对象 + + + +* **示例** + + +```python +>>> Message.template("{} {}").format("hello", "world") +Message(MessageSegment(type='text', data={'text': 'hello world'})) +>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") +Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'})) +``` + + +* **参数** + + + * `format_string: str`: 格式化字符串 + + + +* **返回** + + + * `MessageFormatter[TM]`: 消息格式化器 + + + ### `append(obj)` @@ -499,3 +533,19 @@ Event 基类。提供获取关键信息的方法,其余信息可直接获取 * `bool` + + + +## _class_ `MessageFormatter` + +基类:`string.Formatter`, `Generic`[`nonebot.adapters._formatter.TM`] + +消息模板格式化实现类 + + +### `format(*args, **kwargs)` + + +* **说明** + + 根据模板和参数生成消息对象 diff --git a/docs_build/adapters/README.rst b/docs_build/adapters/README.rst index 8e759794..14f91a77 100644 --- a/docs_build/adapters/README.rst +++ b/docs_build/adapters/README.rst @@ -11,3 +11,9 @@ NoneBot.adapters 模块 :private-members: :special-members: __init__ :show-inheritance: + +.. automodule:: nonebot.adapters._formatter + :members: + :private-members: + :special-members: __init__ + :show-inheritance: \ No newline at end of file diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 10932667..223dea6d 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -19,8 +19,10 @@ from pydantic import BaseModel from nonebot.log import logger from nonebot.config import Config from nonebot.utils import DataclassEncoder -from nonebot.drivers import Driver, HTTPConnection, HTTPResponse from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook +from nonebot.drivers import Driver, HTTPConnection, HTTPResponse + +from ._formatter import MessageFormatter class _ApiCall(Protocol): @@ -329,6 +331,32 @@ class Message(List[TMS], abc.ABC): else: self.extend(self._construct(message)) + @classmethod + def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]: + """ + :说明: + + 根据创建消息模板, 用法和 ``str.format`` 大致相同, 但是可以输出消息对象 + + :示例: + + .. code-block:: python + + >>> Message.template("{} {}").format("hello", "world") + Message(MessageSegment(type='text', data={'text': 'hello world'})) + >>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") + Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'})) + + :参数: + + * ``format_string: str``: 格式化字符串 + + :返回: + + - ``MessageFormatter[TM]``: 消息格式化器 + """ + return MessageFormatter(cls, format_string) + @classmethod @abc.abstractmethod def get_segment_class(cls) -> Type[TMS]: diff --git a/nonebot/adapters/_formatter.py b/nonebot/adapters/_formatter.py new file mode 100644 index 00000000..68994f42 --- /dev/null +++ b/nonebot/adapters/_formatter.py @@ -0,0 +1,105 @@ +import functools +import operator +from string import Formatter +from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping, + Generic, Sequence, TYPE_CHECKING) + +if TYPE_CHECKING: + from . import Message + +TM = TypeVar("TM", bound="Message") + + +class MessageFormatter(Formatter, Generic[TM]): + """消息模板格式化实现类""" + + def __init__(self, factory: Type[TM], template: str) -> None: + self.template = template + self.factory = factory + + def format(self, *args: Any, **kwargs: Any) -> TM: + """ + :说明: + + 根据模板和参数生成消息对象 + """ + msg = self.vformat(self.template, args, kwargs) + return msg if isinstance(msg, self.factory) else self.factory(msg) + + def vformat(self, format_string: str, args: Sequence[Any], + kwargs: Mapping[str, Any]) -> TM: + used_args = set() + result, _ = self._vformat(format_string, args, kwargs, used_args, 2) + self.check_unused_args(list(used_args), args, kwargs) + return result + + def _vformat( + self, + format_string: str, + args: Sequence[Any], + kwargs: Mapping[str, Any], + used_args: Set[Union[int, str]], + recursion_depth: int, + auto_arg_index: int = 0, + ) -> Tuple[TM, int]: + + if recursion_depth < 0: + raise ValueError("Max string recursion exceeded") + + results: List[Any] = [] + + for (literal_text, field_name, format_spec, + conversion) in self.parse(format_string): + + # output the literal text + if literal_text: + results.append(literal_text) + + # if there's a field, output it + if field_name is not None: + # this is some markup, find the object and do + # the formatting + + # handle arg indexing when empty field_names are given. + if field_name == "": + if auto_arg_index is False: + raise ValueError( + "cannot switch from manual field specification to " + "automatic field numbering") + field_name = str(auto_arg_index) + auto_arg_index += 1 + elif field_name.isdigit(): + if auto_arg_index: + raise ValueError( + "cannot switch from manual field specification to " + "automatic field numbering") + # disable auto arg incrementing, if it gets + # used later on, then an exception will be raised + auto_arg_index = False + + # given the field_name, find the object it references + # and the argument it came from + obj, arg_used = self.get_field(field_name, args, kwargs) + used_args.add(arg_used) + + assert format_spec is not None + + # do any conversion on the resulting object + obj = self.convert_field(obj, conversion) if conversion else obj + + # expand the format spec, if needed + format_control, auto_arg_index = self._vformat( + format_spec, + args, + kwargs, + used_args, + recursion_depth - 1, + auto_arg_index, + ) + + # format the object and append to the result + formatted_text = self.format_field(obj, str(format_control)) + results.append(formatted_text) + + return self.factory(functools.reduce(operator.add, results or + [""])), auto_arg_index