diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index c2bf5062..e1986bbf 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -233,12 +233,14 @@ class Bot(abc.ABC): return func -T_Message = TypeVar("T_Message", bound="Message") -T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]") +T = TypeVar("T") +TMS = TypeVar("TMS") +TM = TypeVar("TM", bound="Message") +# TM = TypeVar("TM_co", bound="Message") @dataclass -class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): +class MessageSegment(Mapping, abc.ABC, Generic[TM]): """消息段基类""" type: str """ @@ -253,7 +255,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): @classmethod @abc.abstractmethod - def get_message_class(cls) -> Type[T_Message]: + def get_message_class(cls) -> Type[TM]: raise NotImplementedError @abc.abstractmethod @@ -264,15 +266,13 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): def __len__(self) -> int: return len(str(self)) - def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool: + def __ne__(self: T, other: T) -> bool: return not self == other - def __add__(self, other: Union[str, Mapping, - Iterable[Mapping]]) -> T_Message: + def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: return self.get_message_class()(self) + other - def __radd__(self, other: Union[str, Mapping, - Iterable[Mapping]]) -> T_Message: + def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: return self.get_message_class()(other) + self def __getitem__(self, key: str): @@ -299,7 +299,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): def items(self): return self.data.items() - def copy(self: T_MessageSegment) -> T_MessageSegment: + def copy(self: T) -> T: return deepcopy(self) @abc.abstractmethod @@ -307,12 +307,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): raise NotImplementedError -class Message(List[T_MessageSegment], abc.ABC): +class Message(List[TMS], abc.ABC): """消息数组""" - def __init__(self: T_Message, - message: Union[str, None, Mapping, Iterable[Mapping], - T_MessageSegment, T_Message, Any] = None, + def __init__(self: TM, + message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM, + Any] = None, *args, **kwargs): """ @@ -332,7 +332,7 @@ class Message(List[T_MessageSegment], abc.ABC): @classmethod @abc.abstractmethod - def get_segment_class(cls) -> Type[T_MessageSegment]: + def get_segment_class(cls) -> Type[TMS]: raise NotImplementedError def __str__(self): @@ -349,29 +349,19 @@ class Message(List[T_MessageSegment], abc.ABC): @staticmethod @abc.abstractmethod def _construct( - msg: Union[str, Mapping, Iterable[Mapping], Any] - ) -> Iterable[T_MessageSegment]: + msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]: raise NotImplementedError - def __add__( - self: T_Message, other: Union[str, Mapping, Iterable[Mapping], - T_MessageSegment, T_Message] - ) -> T_Message: + def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: result = self.copy() result += other return result - def __radd__( - self: T_Message, other: Union[str, Mapping, Iterable[Mapping], - T_MessageSegment, T_Message] - ) -> T_Message: + def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: result = self.__class__(other) return result + self - def __iadd__( - self: T_Message, other: Union[str, Mapping, Iterable[Mapping], - T_MessageSegment, T_Message] - ) -> T_Message: + def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: if isinstance(other, MessageSegment): self.append(other) elif isinstance(other, Message): @@ -380,7 +370,7 @@ class Message(List[T_MessageSegment], abc.ABC): self.extend(self._construct(other)) return self - def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message: + def append(self: TM, obj: Union[str, TMS]) -> TM: """ :说明: @@ -398,8 +388,7 @@ class Message(List[T_MessageSegment], abc.ABC): raise ValueError(f"Unexpected type: {type(obj)} {obj}") return self - def extend(self: T_Message, - obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message: + def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM: """ :说明: @@ -413,10 +402,10 @@ class Message(List[T_MessageSegment], abc.ABC): self.append(segment) return self - def copy(self: T_Message) -> T_Message: + def copy(self: TM) -> TM: return deepcopy(self) - def extract_plain_text(self) -> str: + def extract_plain_text(self: "Message[MessageSegment]") -> str: """ :说明: diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py index 48affc8a..1b0184b6 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py @@ -2,8 +2,7 @@ import re from io import BytesIO from pathlib import Path from base64 import b64encode -from functools import reduce -from typing import Any, List, Dict, Union, Tuple, Mapping, Iterable, Optional +from typing import Type, Union, Tuple, Mapping, Iterable, Optional from nonebot.typing import overrides from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment @@ -17,7 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]): """ @classmethod - def get_message_class(cls): + def get_message_class(cls) -> Type["Message"]: return Message @overrides(BaseMessageSegment) @@ -236,22 +235,18 @@ class Message(BaseMessage[MessageSegment]): """ @classmethod - def get_segment_class(cls): + def get_segment_class(cls) -> Type[MessageSegment]: return MessageSegment @overrides(BaseMessage) - def __add__( - self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment, - "Message"] - ) -> "Message": + def __add__(self, other: Union[str, Mapping, + Iterable[Mapping]]) -> "Message": return super(Message, self).__add__( MessageSegment.text(other) if isinstance(other, str) else other) @overrides(BaseMessage) - def __radd__( - self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment, - "Message"] - ) -> "Message": + def __radd__(self, other: Union[str, Mapping, + Iterable[Mapping]]) -> "Message": return super(Message, self).__radd__( MessageSegment.text(other) if isinstance(other, str) else other)