🎨 change typing for formatter

This commit is contained in:
yanyongyu 2021-08-27 14:46:15 +08:00
parent f0bc47ec5e
commit 58d10abd32
2 changed files with 37 additions and 46 deletions

View File

@ -8,19 +8,19 @@
import abc
import asyncio
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from functools import partial
from typing import (Any, Dict, Generic, Iterable, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
from typing_extensions import Protocol
from dataclasses import asdict, dataclass, field
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
Generic, Optional, Iterable)
from pydantic import BaseModel
from typing_extensions import Protocol
from nonebot.config import Config
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from nonebot.log import logger
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
from nonebot.config import Config
from nonebot.utils import DataclassEncoder
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from ._formatter import MessageFormatter
@ -332,7 +332,9 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message))
@classmethod
def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]:
def template(
cls: Type[TM],
format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore
return MessageFormatter(cls, format_string)
@classmethod

View File

@ -1,59 +1,49 @@
import functools
import operator
from string import Formatter
from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple,
Type, TypeVar, Union, TYPE_CHECKING)
from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping,
Generic, Sequence, TYPE_CHECKING)
if TYPE_CHECKING:
from nonebot.adapters import Message
from . import Message, MessageSegment
TM = TypeVar("TM", bound="Message")
TMS = TypeVar("TMS", bound="MessageSegment")
TAddable = Union[str, TM, TMS]
class AddAble(Protocol):
class MessageFormatter(Formatter, Generic[TM, TMS]):
def __add__(self, __s: Any) -> "AddAble":
...
def __str__(self) -> str:
...
AddAble_T = TypeVar("AddAble_T", bound=AddAble)
MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True)
class MessageFormatter(Formatter, Generic[MessageResult_T]):
def __init__(self, factory: Type[MessageResult_T], template: str) -> None:
super().__init__()
def __init__(self, factory: Type[TM], template: str) -> None:
self.template = template
self.factory = factory
def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T:
msg: AddAble = super().format(self.template, *args, **kwargs)
return msg if isinstance(msg, self.factory) else self.factory(
msg) # type: ignore
def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM,
TMS]) -> TM:
msg = self.vformat(self.template, args, kwargs)
return msg if isinstance(msg, self.factory) else self.factory(msg)
def vformat(self, format_string: str, args: Sequence[AddAble],
kwargs: Mapping[str, AddAble]):
result, arg_index, used_args = self._vformat(format_string, args,
kwargs, set(), 2)
def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM:
used_args = set()
result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
def _vformat(
self,
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
args: Sequence[TAddable[TM, TMS]],
kwargs: Mapping[str, TAddable[TM, TMS]],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[AddAble, int, Set[Union[int, str]]]:
) -> Tuple[TM, int]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[AddAble] = []
results: List[TAddable[TM, TMS]] = []
for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string):
@ -95,24 +85,23 @@ class MessageFormatter(Formatter, Generic[MessageResult_T]):
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index, formatted_args = self._vformat(
format_control, auto_arg_index = self._vformat(
format_spec,
args,
kwargs,
used_args.copy(),
used_args,
recursion_depth - 1,
auto_arg_index,
)
used_args |= formatted_args
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text)
return functools.reduce(operator.add, results or
[""]), auto_arg_index, used_args
return self.factory(functools.reduce(operator.add, results or
[""])), auto_arg_index
def format_field(self, value: AddAble_T,
format_spec: str) -> Union[AddAble_T, str]:
def format_field(self, value: TAddable[TM, TMS],
format_spec: str) -> TAddable[TM, TMS]:
return super().format_field(value,
format_spec) if format_spec else value