From 1221baaa942faf8163e0443a2f2541fa4504b135 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 16 Jan 2022 17:13:26 +0800 Subject: [PATCH] :sparkles: Implement `.get` and `.index` methods for `Message` --- nonebot/adapters/_message.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 9b1996d9..d3f050a1 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -89,7 +89,7 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]): raise NotImplementedError -class Message(List[TMS], Generic[TMS], abc.ABC): +class Message(List[TMS], abc.ABC): """消息数组""" def __init__( @@ -204,7 +204,7 @@ class Message(List[TMS], Generic[TMS], abc.ABC): def __getitem__( self: TM, - __args: Union[ + args: Union[ str, Tuple[str, int], Tuple[str, slice], @@ -212,7 +212,7 @@ class Message(List[TMS], Generic[TMS], abc.ABC): slice, ], ) -> Union[TMS, TM]: - arg1, arg2 = __args if isinstance(__args, tuple) else (__args, None) + arg1, arg2 = args if isinstance(args, tuple) else (args, None) if isinstance(arg1, int) and arg2 is None: return super().__getitem__(arg1) elif isinstance(arg1, slice) and arg2 is None: @@ -224,7 +224,22 @@ class Message(List[TMS], Generic[TMS], abc.ABC): elif isinstance(arg1, str) and isinstance(arg2, slice): return self.__class__([seg for seg in self if seg.type == arg1][arg2]) else: - raise ValueError("Invalid arguments to __getitem__") + raise ValueError("Incorrect arguments to slice") + + def index(self, value: Union[TMS, str], *args) -> int: + if isinstance(value, str): + first_segment = next((seg for seg in self if seg.type == value), None) # type: ignore + return super().index(first_segment, *args) # type: ignore + return super().index(value, *args) + + def get(self: TM, type_: str, count: int) -> TM: + iterator = (seg for seg in self if seg.type == type_) + return self.__class__( + filter( + lambda seg: seg is not None, + (next(iterator) for _ in range(count)), + ) + ) def append(self: TM, obj: Union[str, TMS]) -> TM: """