mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
✨ Add message template formatter
ref: https://github.com/nonebot/discussions/discussions/27
This commit is contained in:
parent
b5f2b1a76d
commit
f0bc47ec5e
@ -8,19 +8,21 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from functools import partial
|
||||
from typing_extensions import Protocol
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
|
||||
Generic, Optional, Iterable)
|
||||
from typing import (Any, Dict, Generic, Iterable, List, Mapping, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
|
||||
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
||||
from nonebot.utils import DataclassEncoder
|
||||
|
||||
from ._formatter import MessageFormatter
|
||||
|
||||
|
||||
class _ApiCall(Protocol):
|
||||
@ -329,6 +331,10 @@ class Message(List[TMS], abc.ABC):
|
||||
else:
|
||||
self.extend(self._construct(message))
|
||||
|
||||
@classmethod
|
||||
def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]:
|
||||
return MessageFormatter(cls, format_string)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_segment_class(cls) -> Type[TMS]:
|
||||
|
118
nonebot/adapters/_formatter.py
Normal file
118
nonebot/adapters/_formatter.py
Normal file
@ -0,0 +1,118 @@
|
||||
import functools
|
||||
import operator
|
||||
from string import Formatter
|
||||
from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple,
|
||||
Type, TypeVar, Union, TYPE_CHECKING)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Message
|
||||
|
||||
|
||||
class AddAble(Protocol):
|
||||
|
||||
def __add__(self, __s: Any) -> "AddAble":
|
||||
...
|
||||
|
||||
def __str__(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
AddAble_T = TypeVar("AddAble_T", bound=AddAble)
|
||||
MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True)
|
||||
|
||||
|
||||
class MessageFormatter(Formatter, Generic[MessageResult_T]):
|
||||
|
||||
def __init__(self, factory: Type[MessageResult_T], template: str) -> None:
|
||||
super().__init__()
|
||||
self.template = template
|
||||
self.factory = factory
|
||||
|
||||
def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T:
|
||||
msg: AddAble = super().format(self.template, *args, **kwargs)
|
||||
return msg if isinstance(msg, self.factory) else self.factory(
|
||||
msg) # type: ignore
|
||||
|
||||
def vformat(self, format_string: str, args: Sequence[AddAble],
|
||||
kwargs: Mapping[str, AddAble]):
|
||||
result, arg_index, used_args = self._vformat(format_string, args,
|
||||
kwargs, set(), 2)
|
||||
self.check_unused_args(list(used_args), args, kwargs)
|
||||
return result
|
||||
|
||||
def _vformat(
|
||||
self,
|
||||
format_string: str,
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
used_args: Set[Union[int, str]],
|
||||
recursion_depth: int,
|
||||
auto_arg_index: int = 0,
|
||||
) -> Tuple[AddAble, int, Set[Union[int, str]]]:
|
||||
|
||||
if recursion_depth < 0:
|
||||
raise ValueError("Max string recursion exceeded")
|
||||
|
||||
results: List[AddAble] = []
|
||||
|
||||
for (literal_text, field_name, format_spec,
|
||||
conversion) in self.parse(format_string):
|
||||
|
||||
# output the literal text
|
||||
if literal_text:
|
||||
results.append(literal_text)
|
||||
|
||||
# if there's a field, output it
|
||||
if field_name is not None:
|
||||
# this is some markup, find the object and do
|
||||
# the formatting
|
||||
|
||||
# handle arg indexing when empty field_names are given.
|
||||
if field_name == "":
|
||||
if auto_arg_index is False:
|
||||
raise ValueError(
|
||||
"cannot switch from manual field specification to "
|
||||
"automatic field numbering")
|
||||
field_name = str(auto_arg_index)
|
||||
auto_arg_index += 1
|
||||
elif field_name.isdigit():
|
||||
if auto_arg_index:
|
||||
raise ValueError(
|
||||
"cannot switch from manual field specification to "
|
||||
"automatic field numbering")
|
||||
# disable auto arg incrementing, if it gets
|
||||
# used later on, then an exception will be raised
|
||||
auto_arg_index = False
|
||||
|
||||
# given the field_name, find the object it references
|
||||
# and the argument it came from
|
||||
obj, arg_used = self.get_field(field_name, args, kwargs)
|
||||
used_args.add(arg_used)
|
||||
|
||||
assert format_spec is not None
|
||||
|
||||
# do any conversion on the resulting object
|
||||
obj = self.convert_field(obj, conversion) if conversion else obj
|
||||
|
||||
# expand the format spec, if needed
|
||||
format_control, auto_arg_index, formatted_args = self._vformat(
|
||||
format_spec,
|
||||
args,
|
||||
kwargs,
|
||||
used_args.copy(),
|
||||
recursion_depth - 1,
|
||||
auto_arg_index,
|
||||
)
|
||||
used_args |= formatted_args
|
||||
|
||||
# format the object and append to the result
|
||||
formatted_text = self.format_field(obj, str(format_control))
|
||||
results.append(formatted_text)
|
||||
|
||||
return functools.reduce(operator.add, results or
|
||||
[""]), auto_arg_index, used_args
|
||||
|
||||
def format_field(self, value: AddAble_T,
|
||||
format_spec: str) -> Union[AddAble_T, str]:
|
||||
return super().format_field(value,
|
||||
format_spec) if format_spec else value
|
Loading…
Reference in New Issue
Block a user