🔥 use Any for format type

This commit is contained in:
yanyongyu 2021-08-27 15:08:26 +08:00
parent 58d10abd32
commit 7cfdc2dd37
2 changed files with 10 additions and 20 deletions

View File

@ -10,7 +10,7 @@ import asyncio
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing_extensions import Protocol from typing_extensions import Protocol
from dataclasses import asdict, dataclass, field from dataclasses import dataclass, field, asdict
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping, from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
Generic, Optional, Iterable) Generic, Optional, Iterable)
@ -332,9 +332,7 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message)) self.extend(self._construct(message))
@classmethod @classmethod
def template( def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]:
cls: Type[TM],
format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore
return MessageFormatter(cls, format_string) return MessageFormatter(cls, format_string)
@classmethod @classmethod

View File

@ -5,26 +5,23 @@ from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
Generic, Sequence, TYPE_CHECKING) Generic, Sequence, TYPE_CHECKING)
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Message, MessageSegment from . import Message
TM = TypeVar("TM", bound="Message") TM = TypeVar("TM", bound="Message")
TMS = TypeVar("TMS", bound="MessageSegment")
TAddable = Union[str, TM, TMS]
class MessageFormatter(Formatter, Generic[TM, TMS]): class MessageFormatter(Formatter, Generic[TM]):
def __init__(self, factory: Type[TM], template: str) -> None: def __init__(self, factory: Type[TM], template: str) -> None:
self.template = template self.template = template
self.factory = factory self.factory = factory
def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM, def format(self, *args: Any, **kwargs: Any) -> TM:
TMS]) -> TM:
msg = self.vformat(self.template, args, kwargs) msg = self.vformat(self.template, args, kwargs)
return msg if isinstance(msg, self.factory) else self.factory(msg) return msg if isinstance(msg, self.factory) else self.factory(msg)
def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]], def vformat(self, format_string: str, args: Sequence[Any],
kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM: kwargs: Mapping[str, Any]) -> TM:
used_args = set() used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2) result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs) self.check_unused_args(list(used_args), args, kwargs)
@ -33,8 +30,8 @@ class MessageFormatter(Formatter, Generic[TM, TMS]):
def _vformat( def _vformat(
self, self,
format_string: str, format_string: str,
args: Sequence[TAddable[TM, TMS]], args: Sequence[Any],
kwargs: Mapping[str, TAddable[TM, TMS]], kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]], used_args: Set[Union[int, str]],
recursion_depth: int, recursion_depth: int,
auto_arg_index: int = 0, auto_arg_index: int = 0,
@ -43,7 +40,7 @@ class MessageFormatter(Formatter, Generic[TM, TMS]):
if recursion_depth < 0: if recursion_depth < 0:
raise ValueError("Max string recursion exceeded") raise ValueError("Max string recursion exceeded")
results: List[TAddable[TM, TMS]] = [] results: List[Any] = []
for (literal_text, field_name, format_spec, for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string): conversion) in self.parse(format_string):
@ -100,8 +97,3 @@ class MessageFormatter(Formatter, Generic[TM, TMS]):
return self.factory(functools.reduce(operator.add, results or return self.factory(functools.reduce(operator.add, results or
[""])), auto_arg_index [""])), auto_arg_index
def format_field(self, value: TAddable[TM, TMS],
format_spec: str) -> TAddable[TM, TMS]:
return super().format_field(value,
format_spec) if format_spec else value