🎨 fix message typing error

This commit is contained in:
yanyongyu 2021-06-17 01:07:19 +08:00
parent 6749afe75e
commit b2f21ab974
2 changed files with 30 additions and 46 deletions

View File

@ -233,12 +233,14 @@ class Bot(abc.ABC):
return func
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
T = TypeVar("T")
TMS = TypeVar("TMS")
TM = TypeVar("TM", bound="Message")
# TM = TypeVar("TM_co", bound="Message")
@dataclass
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
class MessageSegment(Mapping, abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""
@ -253,7 +255,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
@classmethod
@abc.abstractmethod
def get_message_class(cls) -> Type[T_Message]:
def get_message_class(cls) -> Type[TM]:
raise NotImplementedError
@abc.abstractmethod
@ -264,15 +266,13 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
def __len__(self) -> int:
return len(str(self))
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
def __ne__(self: T, other: T) -> bool:
return not self == other
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> T_Message:
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> T_Message:
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
return self.get_message_class()(other) + self
def __getitem__(self, key: str):
@ -299,7 +299,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
def items(self):
return self.data.items()
def copy(self: T_MessageSegment) -> T_MessageSegment:
def copy(self: T) -> T:
return deepcopy(self)
@abc.abstractmethod
@ -307,12 +307,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
raise NotImplementedError
class Message(List[T_MessageSegment], abc.ABC):
class Message(List[TMS], abc.ABC):
"""消息数组"""
def __init__(self: T_Message,
message: Union[str, None, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message, Any] = None,
def __init__(self: TM,
message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM,
Any] = None,
*args,
**kwargs):
"""
@ -332,7 +332,7 @@ class Message(List[T_MessageSegment], abc.ABC):
@classmethod
@abc.abstractmethod
def get_segment_class(cls) -> Type[T_MessageSegment]:
def get_segment_class(cls) -> Type[TMS]:
raise NotImplementedError
def __str__(self):
@ -349,29 +349,19 @@ class Message(List[T_MessageSegment], abc.ABC):
@staticmethod
@abc.abstractmethod
def _construct(
msg: Union[str, Mapping, Iterable[Mapping], Any]
) -> Iterable[T_MessageSegment]:
msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
raise NotImplementedError
def __add__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
result = self.copy()
result += other
return result
def __radd__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
result = self.__class__(other)
return result + self
def __iadd__(
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
if isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Message):
@ -380,7 +370,7 @@ class Message(List[T_MessageSegment], abc.ABC):
self.extend(self._construct(other))
return self
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
def append(self: TM, obj: Union[str, TMS]) -> TM:
"""
:说明:
@ -398,8 +388,7 @@ class Message(List[T_MessageSegment], abc.ABC):
raise ValueError(f"Unexpected type: {type(obj)} {obj}")
return self
def extend(self: T_Message,
obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message:
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
"""
:说明:
@ -413,10 +402,10 @@ class Message(List[T_MessageSegment], abc.ABC):
self.append(segment)
return self
def copy(self: T_Message) -> T_Message:
def copy(self: TM) -> TM:
return deepcopy(self)
def extract_plain_text(self) -> str:
def extract_plain_text(self: "Message[MessageSegment]") -> str:
"""
:说明:

View File

@ -2,8 +2,7 @@ import re
from io import BytesIO
from pathlib import Path
from base64 import b64encode
from functools import reduce
from typing import Any, List, Dict, Union, Tuple, Mapping, Iterable, Optional
from typing import Type, Union, Tuple, Mapping, Iterable, Optional
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
@ -17,7 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
"""
@classmethod
def get_message_class(cls):
def get_message_class(cls) -> Type["Message"]:
return Message
@overrides(BaseMessageSegment)
@ -236,22 +235,18 @@ class Message(BaseMessage[MessageSegment]):
"""
@classmethod
def get_segment_class(cls):
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@overrides(BaseMessage)
def __add__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage)
def __radd__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)