🔥 use Any for format type

This commit is contained in:
yanyongyu 2021-08-27 15:08:26 +08:00
parent 58d10abd32
commit 7cfdc2dd37
2 changed files with 10 additions and 20 deletions

View File

@ -10,7 +10,7 @@ import asyncio
from copy import deepcopy
from functools import partial
from typing_extensions import Protocol
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field, asdict
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
Generic, Optional, Iterable)
@ -332,9 +332,7 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message))
@classmethod
def template(
cls: Type[TM],
format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore
def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]:
return MessageFormatter(cls, format_string)
@classmethod

View File

@ -5,26 +5,23 @@ from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
Generic, Sequence, TYPE_CHECKING)
if TYPE_CHECKING:
from . import Message, MessageSegment
from . import Message
TM = TypeVar("TM", bound="Message")
TMS = TypeVar("TMS", bound="MessageSegment")
TAddable = Union[str, TM, TMS]
class MessageFormatter(Formatter, Generic[TM, TMS]):
class MessageFormatter(Formatter, Generic[TM]):
def __init__(self, factory: Type[TM], template: str) -> None:
self.template = template
self.factory = factory
def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM,
TMS]) -> TM:
def format(self, *args: Any, **kwargs: Any) -> TM:
msg = self.vformat(self.template, args, kwargs)
return msg if isinstance(msg, self.factory) else self.factory(msg)
def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM:
def vformat(self, format_string: str, args: Sequence[Any],
kwargs: Mapping[str, Any]) -> TM:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
@ -33,8 +30,8 @@ class MessageFormatter(Formatter, Generic[TM, TMS]):
def _vformat(
self,
format_string: str,
args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, TAddable[TM, TMS]],
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
@ -43,7 +40,7 @@ class MessageFormatter(Formatter, Generic[TM, TMS]):
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[TAddable[TM, TMS]] = []
results: List[Any] = []
for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string):
@ -100,8 +97,3 @@ class MessageFormatter(Formatter, Generic[TM, TMS]):
return self.factory(functools.reduce(operator.add, results or
[""])), auto_arg_index
def format_field(self, value: TAddable[TM, TMS],
format_spec: str) -> TAddable[TM, TMS]:
return super().format_field(value,
format_spec) if format_spec else value