Implement .get and .index methods for Message

This commit is contained in:
Mix 2022-01-16 17:13:26 +08:00
parent 39822378a7
commit 1221baaa94

View File

@ -89,7 +89,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]):
raise NotImplementedError raise NotImplementedError
class Message(List[TMS], Generic[TMS], abc.ABC): class Message(List[TMS], abc.ABC):
"""消息数组""" """消息数组"""
def __init__( def __init__(
@ -204,7 +204,7 @@ class Message(List[TMS], Generic[TMS], abc.ABC):
def __getitem__( def __getitem__(
self: TM, self: TM,
__args: Union[ args: Union[
str, str,
Tuple[str, int], Tuple[str, int],
Tuple[str, slice], Tuple[str, slice],
@ -212,7 +212,7 @@ class Message(List[TMS], Generic[TMS], abc.ABC):
slice, slice,
], ],
) -> Union[TMS, TM]: ) -> Union[TMS, TM]:
arg1, arg2 = __args if isinstance(__args, tuple) else (__args, None) arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None: if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1) return super().__getitem__(arg1)
elif isinstance(arg1, slice) and arg2 is None: elif isinstance(arg1, slice) and arg2 is None:
@ -224,7 +224,22 @@ class Message(List[TMS], Generic[TMS], abc.ABC):
elif isinstance(arg1, str) and isinstance(arg2, slice): elif isinstance(arg1, str) and isinstance(arg2, slice):
return self.__class__([seg for seg in self if seg.type == arg1][arg2]) return self.__class__([seg for seg in self if seg.type == arg1][arg2])
else: else:
raise ValueError("Invalid arguments to __getitem__") raise ValueError("Incorrect arguments to slice")
def index(self, value: Union[TMS, str], *args) -> int:
if isinstance(value, str):
first_segment = next((seg for seg in self if seg.type == value), None) # type: ignore
return super().index(first_segment, *args) # type: ignore
return super().index(value, *args)
def get(self: TM, type_: str, count: int) -> TM:
iterator = (seg for seg in self if seg.type == type_)
return self.__class__(
filter(
lambda seg: seg is not None,
(next(iterator) for _ in range(count)),
)
)
def append(self: TM, obj: Union[str, TMS]) -> TM: def append(self: TM, obj: Union[str, TMS]) -> TM:
""" """