improve radd support for messagesegment

This commit is contained in:
yanyongyu 2020-12-28 17:39:33 +08:00
parent e14d3d8d73
commit ab61be26a9
2 changed files with 43 additions and 18 deletions

View File

@ -9,7 +9,7 @@ import abc
from typing_extensions import Literal from typing_extensions import Literal
from functools import reduce, partial from functools import reduce, partial
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING from typing import Any, Dict, Union, TypeVar, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING
from pydantic import BaseModel from pydantic import BaseModel
@ -267,6 +267,10 @@ class Event(abc.ABC, BaseModel):
raise NotImplementedError raise NotImplementedError
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
@dataclass @dataclass
class MessageSegment(abc.ABC): class MessageSegment(abc.ABC):
"""消息段基类""" """消息段基类"""
@ -282,19 +286,34 @@ class MessageSegment(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
def __str__(self) -> str: def __str__(self: T_MessageSegment) -> str:
"""该消息段所代表的 str在命令匹配部分使用""" """该消息段所代表的 str在命令匹配部分使用"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def __add__(self, other) -> "Message": def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
T_Message]) -> "T_Message":
"""你需要在这里实现不同消息段的合并: """你需要在这里实现不同消息段的合并:
比如 比如
if isinstance(other, str): if isinstance(other, str):
... ...
elif isinstance(other, MessageSegment): elif isinstance(other, MessageSegment):
... ...
注意不能返回 self需要返回一个新生成的对象 注意需要返回一个新生成的对象
"""
raise NotImplementedError
@abc.abstractmethod
def __radd__(
self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment,
T_Message]) -> "T_Message":
"""你需要在这里实现不同消息段的合并:
比如
if isinstance(other, str):
...
elif isinstance(other, MessageSegment):
...
注意需要返回一个新生成的对象
""" """
raise NotImplementedError raise NotImplementedError
@ -316,17 +335,17 @@ class Message(list, abc.ABC):
"""消息数组""" """消息数组"""
def __init__(self, def __init__(self,
message: Union[str, dict, list, BaseModel, MessageSegment, message: Union[str, dict, list, T_MessageSegment,
"Message"] = None, T_Message] = None,
*args, *args,
**kwargs): **kwargs):
""" """
:参数: :参数:
* ``message: Union[str, dict, list, BaseModel, MessageSegment, Message]``: 消息内容 * ``message: Union[str, dict, list, MessageSegment, Message]``: 消息内容
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if isinstance(message, (str, dict, list, BaseModel)): if isinstance(message, (str, dict, list)):
self.extend(self._construct(message)) self.extend(self._construct(message))
elif isinstance(message, Message): elif isinstance(message, Message):
self.extend(message) self.extend(message)
@ -347,11 +366,12 @@ class Message(list, abc.ABC):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _construct( def _construct(
msg: Union[str, dict, list, BaseModel]) -> Iterable[MessageSegment]: msg: Union[str, dict, list,
BaseModel]) -> Iterable[T_MessageSegment]:
raise NotImplementedError raise NotImplementedError
def __add__(self, other: Union[str, MessageSegment, def __add__(self: T_Message, other: Union[str, T_MessageSegment,
"Message"]) -> "Message": T_Message]) -> T_Message:
result = self.__class__(self) result = self.__class__(self)
if isinstance(other, str): if isinstance(other, str):
result.extend(self._construct(other)) result.extend(self._construct(other))
@ -361,11 +381,12 @@ class Message(list, abc.ABC):
result.extend(other) result.extend(other)
return result return result
def __radd__(self, other: Union[str, MessageSegment, "Message"]): def __radd__(self: T_Message, other: Union[str, T_MessageSegment,
T_Message]):
result = self.__class__(other) result = self.__class__(other)
return result.__add__(self) return result.__add__(self)
def append(self, obj: Union[str, MessageSegment]) -> "Message": def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
""" """
:说明: :说明:
@ -383,8 +404,8 @@ class Message(list, abc.ABC):
raise ValueError(f"Unexpected type: {type(obj)} {obj}") raise ValueError(f"Unexpected type: {type(obj)} {obj}")
return self return self
def extend(self, obj: Union["Message", def extend(self: T_Message,
Iterable[MessageSegment]]) -> "Message": obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message:
""" """
:说明: :说明:
@ -398,7 +419,7 @@ class Message(list, abc.ABC):
self.append(segment) self.append(segment)
return self return self
def reduce(self) -> None: def reduce(self: T_Message) -> None:
""" """
:说明: :说明:
@ -413,14 +434,14 @@ class Message(list, abc.ABC):
else: else:
index += 1 index += 1
def extract_plain_text(self) -> str: def extract_plain_text(self: T_Message) -> str:
""" """
:说明: :说明:
提取消息内纯文本消息 提取消息内纯文本消息
""" """
def _concat(x: str, y: MessageSegment) -> str: def _concat(x: str, y: T_MessageSegment) -> str:
return f"{x} {y}" if y.is_text() else x return f"{x} {y}" if y.is_text() else x
plain_text = reduce(_concat, self, "") plain_text = reduce(_concat, self, "")

View File

@ -35,6 +35,10 @@ class MessageSegment(BaseMessageSegment):
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
return Message(self) + other return Message(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return Message(other) + self
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"