🔀 Merge pull request #510

Feature: str.format like message formatting support
This commit is contained in:
Ju4tCode 2021-08-28 13:06:02 +08:00 committed by GitHub
commit 9518e3c568
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 190 additions and 1 deletions

View File

@ -300,6 +300,40 @@ await bot.send_msg(message="hello world")
### _classmethod_ `template(format_string)`
* **说明**
根据创建消息模板, 用法和 `str.format` 大致相同, 但是可以输出消息对象
* **示例**
```python
>>> Message.template("{} {}").format("hello", "world")
Message(MessageSegment(type='text', data={'text': 'hello world'}))
>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world")
Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'}))
```
* **参数**
* `format_string: str`: 格式化字符串
* **返回**
* `MessageFormatter[TM]`: 消息格式化器
### `append(obj)` ### `append(obj)`
@ -499,3 +533,19 @@ Event 基类。提供获取关键信息的方法,其余信息可直接获取
* `bool` * `bool`
## _class_ `MessageFormatter`
基类:`string.Formatter`, `Generic`[`nonebot.adapters._formatter.TM`]
消息模板格式化实现类
### `format(*args, **kwargs)`
* **说明**
根据模板和参数生成消息对象

View File

@ -11,3 +11,9 @@ NoneBot.adapters 模块
:private-members: :private-members:
:special-members: __init__ :special-members: __init__
:show-inheritance: :show-inheritance:
.. automodule:: nonebot.adapters._formatter
:members:
:private-members:
:special-members: __init__
:show-inheritance:

View File

@ -19,8 +19,10 @@ from pydantic import BaseModel
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from ._formatter import MessageFormatter
class _ApiCall(Protocol): class _ApiCall(Protocol):
@ -329,6 +331,32 @@ class Message(List[TMS], abc.ABC):
else: else:
self.extend(self._construct(message)) self.extend(self._construct(message))
@classmethod
def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]:
"""
:说明:
根据创建消息模板, 用法和 ``str.format`` 大致相同, 但是可以输出消息对象
:示例:
.. code-block:: python
>>> Message.template("{} {}").format("hello", "world")
Message(MessageSegment(type='text', data={'text': 'hello world'}))
>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world")
Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'}))
:参数:
* ``format_string: str``: 格式化字符串
:返回:
- ``MessageFormatter[TM]``: 消息格式化器
"""
return MessageFormatter(cls, format_string)
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def get_segment_class(cls) -> Type[TMS]: def get_segment_class(cls) -> Type[TMS]:

View File

@ -0,0 +1,105 @@
import functools
import operator
from string import Formatter
from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
Generic, Sequence, TYPE_CHECKING)
if TYPE_CHECKING:
from . import Message
TM = TypeVar("TM", bound="Message")
class MessageFormatter(Formatter, Generic[TM]):
"""消息模板格式化实现类"""
def __init__(self, factory: Type[TM], template: str) -> None:
self.template = template
self.factory = factory
def format(self, *args: Any, **kwargs: Any) -> TM:
"""
:说明:
根据模板和参数生成消息对象
"""
msg = self.vformat(self.template, args, kwargs)
return msg if isinstance(msg, self.factory) else self.factory(msg)
def vformat(self, format_string: str, args: Sequence[Any],
kwargs: Mapping[str, Any]) -> TM:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
def _vformat(
self,
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[TM, int]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[Any] = []
for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string):
# output the literal text
if literal_text:
results.append(literal_text)
# if there's a field, output it
if field_name is not None:
# this is some markup, find the object and do
# the formatting
# handle arg indexing when empty field_names are given.
if field_name == "":
if auto_arg_index is False:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering")
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering")
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
# given the field_name, find the object it references
# and the argument it came from
obj, arg_used = self.get_field(field_name, args, kwargs)
used_args.add(arg_used)
assert format_spec is not None
# do any conversion on the resulting object
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index,
)
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text)
return self.factory(functools.reduce(operator.add, results or
[""])), auto_arg_index