mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
🎨 fix message typing error
This commit is contained in:
parent
6749afe75e
commit
b2f21ab974
@ -233,12 +233,14 @@ class Bot(abc.ABC):
|
||||
return func
|
||||
|
||||
|
||||
T_Message = TypeVar("T_Message", bound="Message")
|
||||
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
|
||||
T = TypeVar("T")
|
||||
TMS = TypeVar("TMS")
|
||||
TM = TypeVar("TM", bound="Message")
|
||||
# TM = TypeVar("TM_co", bound="Message")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||
class MessageSegment(Mapping, abc.ABC, Generic[TM]):
|
||||
"""消息段基类"""
|
||||
type: str
|
||||
"""
|
||||
@ -253,7 +255,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_message_class(cls) -> Type[T_Message]:
|
||||
def get_message_class(cls) -> Type[TM]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -264,15 +266,13 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||
def __len__(self) -> int:
|
||||
return len(str(self))
|
||||
|
||||
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
|
||||
def __ne__(self: T, other: T) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __add__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> T_Message:
|
||||
def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
|
||||
return self.get_message_class()(self) + other
|
||||
|
||||
def __radd__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> T_Message:
|
||||
def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
|
||||
return self.get_message_class()(other) + self
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
@ -299,7 +299,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def copy(self: T_MessageSegment) -> T_MessageSegment:
|
||||
def copy(self: T) -> T:
|
||||
return deepcopy(self)
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -307,12 +307,12 @@ class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Message(List[T_MessageSegment], abc.ABC):
|
||||
class Message(List[TMS], abc.ABC):
|
||||
"""消息数组"""
|
||||
|
||||
def __init__(self: T_Message,
|
||||
message: Union[str, None, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message, Any] = None,
|
||||
def __init__(self: TM,
|
||||
message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM,
|
||||
Any] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
@ -332,7 +332,7 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_segment_class(cls) -> Type[T_MessageSegment]:
|
||||
def get_segment_class(cls) -> Type[TMS]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
@ -349,29 +349,19 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def _construct(
|
||||
msg: Union[str, Mapping, Iterable[Mapping], Any]
|
||||
) -> Iterable[T_MessageSegment]:
|
||||
msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(
|
||||
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message]
|
||||
) -> T_Message:
|
||||
def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
|
||||
result = self.copy()
|
||||
result += other
|
||||
return result
|
||||
|
||||
def __radd__(
|
||||
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message]
|
||||
) -> T_Message:
|
||||
def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
|
||||
result = self.__class__(other)
|
||||
return result + self
|
||||
|
||||
def __iadd__(
|
||||
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message]
|
||||
) -> T_Message:
|
||||
def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM:
|
||||
if isinstance(other, MessageSegment):
|
||||
self.append(other)
|
||||
elif isinstance(other, Message):
|
||||
@ -380,7 +370,7 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
self.extend(self._construct(other))
|
||||
return self
|
||||
|
||||
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
|
||||
def append(self: TM, obj: Union[str, TMS]) -> TM:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -398,8 +388,7 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
raise ValueError(f"Unexpected type: {type(obj)} {obj}")
|
||||
return self
|
||||
|
||||
def extend(self: T_Message,
|
||||
obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message:
|
||||
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -413,10 +402,10 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
self.append(segment)
|
||||
return self
|
||||
|
||||
def copy(self: T_Message) -> T_Message:
|
||||
def copy(self: TM) -> TM:
|
||||
return deepcopy(self)
|
||||
|
||||
def extract_plain_text(self) -> str:
|
||||
def extract_plain_text(self: "Message[MessageSegment]") -> str:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
|
@ -2,8 +2,7 @@ import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from base64 import b64encode
|
||||
from functools import reduce
|
||||
from typing import Any, List, Dict, Union, Tuple, Mapping, Iterable, Optional
|
||||
from typing import Type, Union, Tuple, Mapping, Iterable, Optional
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
||||
@ -17,7 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_message_class(cls):
|
||||
def get_message_class(cls) -> Type["Message"]:
|
||||
return Message
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
@ -236,22 +235,18 @@ class Message(BaseMessage[MessageSegment]):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_segment_class(cls):
|
||||
def get_segment_class(cls) -> Type[MessageSegment]:
|
||||
return MessageSegment
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __add__(
|
||||
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
|
||||
"Message"]
|
||||
) -> "Message":
|
||||
def __add__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> "Message":
|
||||
return super(Message, self).__add__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __radd__(
|
||||
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
|
||||
"Message"]
|
||||
) -> "Message":
|
||||
def __radd__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> "Message":
|
||||
return super(Message, self).__radd__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user