From 7cfdc2dd37198103a912280a60275b84efb52e13 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Fri, 27 Aug 2021 15:08:26 +0800 Subject: [PATCH] :fire: use Any for format type --- nonebot/adapters/_base.py | 6 ++---- nonebot/adapters/_formatter.py | 24 ++++++++---------------- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 387cf316..30fa312a 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -10,7 +10,7 @@ import asyncio from copy import deepcopy from functools import partial from typing_extensions import Protocol -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field, asdict from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping, Generic, Optional, Iterable) @@ -332,9 +332,7 @@ class Message(List[TMS], abc.ABC): self.extend(self._construct(message)) @classmethod - def template( - cls: Type[TM], - format_string: str) -> MessageFormatter[TM, TMS]: # type: ignore + def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]: return MessageFormatter(cls, format_string) @classmethod diff --git a/nonebot/adapters/_formatter.py b/nonebot/adapters/_formatter.py index 9efc004c..9dad44d1 100644 --- a/nonebot/adapters/_formatter.py +++ b/nonebot/adapters/_formatter.py @@ -5,26 +5,23 @@ from typing import (Any, Set, List, Type, Tuple, Union, TypeVar, Mapping, Generic, Sequence, TYPE_CHECKING) if TYPE_CHECKING: - from . import Message, MessageSegment + from . import Message TM = TypeVar("TM", bound="Message") -TMS = TypeVar("TMS", bound="MessageSegment") -TAddable = Union[str, TM, TMS] -class MessageFormatter(Formatter, Generic[TM, TMS]): +class MessageFormatter(Formatter, Generic[TM]): def __init__(self, factory: Type[TM], template: str) -> None: self.template = template self.factory = factory - def format(self, *args: TAddable[TM, TMS], **kwargs: TAddable[TM, - TMS]) -> TM: + def format(self, *args: Any, **kwargs: Any) -> TM: msg = self.vformat(self.template, args, kwargs) return msg if isinstance(msg, self.factory) else self.factory(msg) - def vformat(self, format_string: str, args: Sequence[TAddable[TM, TMS]], - kwargs: Mapping[str, TAddable[TM, TMS]]) -> TM: + def vformat(self, format_string: str, args: Sequence[Any], + kwargs: Mapping[str, Any]) -> TM: used_args = set() result, _ = self._vformat(format_string, args, kwargs, used_args, 2) self.check_unused_args(list(used_args), args, kwargs) @@ -33,8 +30,8 @@ class MessageFormatter(Formatter, Generic[TM, TMS]): def _vformat( self, format_string: str, - args: Sequence[TAddable[TM, TMS]], - kwargs: Mapping[str, TAddable[TM, TMS]], + args: Sequence[Any], + kwargs: Mapping[str, Any], used_args: Set[Union[int, str]], recursion_depth: int, auto_arg_index: int = 0, @@ -43,7 +40,7 @@ class MessageFormatter(Formatter, Generic[TM, TMS]): if recursion_depth < 0: raise ValueError("Max string recursion exceeded") - results: List[TAddable[TM, TMS]] = [] + results: List[Any] = [] for (literal_text, field_name, format_spec, conversion) in self.parse(format_string): @@ -100,8 +97,3 @@ class MessageFormatter(Formatter, Generic[TM, TMS]): return self.factory(functools.reduce(operator.add, results or [""])), auto_arg_index - - def format_field(self, value: TAddable[TM, TMS], - format_spec: str) -> TAddable[TM, TMS]: - return super().format_field(value, - format_spec) if format_spec else value