diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 10932667..123914e6 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -8,19 +8,21 @@ import abc import asyncio from copy import deepcopy +from dataclasses import asdict, dataclass, field from functools import partial -from typing_extensions import Protocol -from dataclasses import dataclass, field, asdict -from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping, - Generic, Optional, Iterable) +from typing import (Any, Dict, Generic, Iterable, List, Mapping, Optional, Set, + Tuple, Type, TypeVar, Union) from pydantic import BaseModel +from typing_extensions import Protocol -from nonebot.log import logger from nonebot.config import Config -from nonebot.utils import DataclassEncoder from nonebot.drivers import Driver, HTTPConnection, HTTPResponse -from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook +from nonebot.log import logger +from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook +from nonebot.utils import DataclassEncoder + +from ._formatter import MessageFormatter class _ApiCall(Protocol): @@ -329,6 +331,10 @@ class Message(List[TMS], abc.ABC): else: self.extend(self._construct(message)) + @classmethod + def template(cls: Type[TM], format_string: str) -> MessageFormatter[TM]: + return MessageFormatter(cls, format_string) + @classmethod @abc.abstractmethod def get_segment_class(cls) -> Type[TMS]: diff --git a/nonebot/adapters/_formatter.py b/nonebot/adapters/_formatter.py new file mode 100644 index 00000000..1c11afa3 --- /dev/null +++ b/nonebot/adapters/_formatter.py @@ -0,0 +1,118 @@ +import functools +import operator +from string import Formatter +from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple, + Type, TypeVar, Union, TYPE_CHECKING) + +if TYPE_CHECKING: + from nonebot.adapters import Message + + +class AddAble(Protocol): + + def __add__(self, __s: Any) -> "AddAble": + ... + + def __str__(self) -> str: + ... + + +AddAble_T = TypeVar("AddAble_T", bound=AddAble) +MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True) + + +class MessageFormatter(Formatter, Generic[MessageResult_T]): + + def __init__(self, factory: Type[MessageResult_T], template: str) -> None: + super().__init__() + self.template = template + self.factory = factory + + def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T: + msg: AddAble = super().format(self.template, *args, **kwargs) + return msg if isinstance(msg, self.factory) else self.factory( + msg) # type: ignore + + def vformat(self, format_string: str, args: Sequence[AddAble], + kwargs: Mapping[str, AddAble]): + result, arg_index, used_args = self._vformat(format_string, args, + kwargs, set(), 2) + self.check_unused_args(list(used_args), args, kwargs) + return result + + def _vformat( + self, + format_string: str, + args: Sequence[Any], + kwargs: Mapping[str, Any], + used_args: Set[Union[int, str]], + recursion_depth: int, + auto_arg_index: int = 0, + ) -> Tuple[AddAble, int, Set[Union[int, str]]]: + + if recursion_depth < 0: + raise ValueError("Max string recursion exceeded") + + results: List[AddAble] = [] + + for (literal_text, field_name, format_spec, + conversion) in self.parse(format_string): + + # output the literal text + if literal_text: + results.append(literal_text) + + # if there's a field, output it + if field_name is not None: + # this is some markup, find the object and do + # the formatting + + # handle arg indexing when empty field_names are given. + if field_name == "": + if auto_arg_index is False: + raise ValueError( + "cannot switch from manual field specification to " + "automatic field numbering") + field_name = str(auto_arg_index) + auto_arg_index += 1 + elif field_name.isdigit(): + if auto_arg_index: + raise ValueError( + "cannot switch from manual field specification to " + "automatic field numbering") + # disable auto arg incrementing, if it gets + # used later on, then an exception will be raised + auto_arg_index = False + + # given the field_name, find the object it references + # and the argument it came from + obj, arg_used = self.get_field(field_name, args, kwargs) + used_args.add(arg_used) + + assert format_spec is not None + + # do any conversion on the resulting object + obj = self.convert_field(obj, conversion) if conversion else obj + + # expand the format spec, if needed + format_control, auto_arg_index, formatted_args = self._vformat( + format_spec, + args, + kwargs, + used_args.copy(), + recursion_depth - 1, + auto_arg_index, + ) + used_args |= formatted_args + + # format the object and append to the result + formatted_text = self.format_field(obj, str(format_control)) + results.append(formatted_text) + + return functools.reduce(operator.add, results or + [""]), auto_arg_index, used_args + + def format_field(self, value: AddAble_T, + format_spec: str) -> Union[AddAble_T, str]: + return super().format_field(value, + format_spec) if format_spec else value