From 3b4c4d30812031b889b2370a7bc868af454dab14 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Mon, 17 Jan 2022 00:28:36 +0800 Subject: [PATCH] :sparkles: :zap: Implement `.count` and optimize `.get` performance for message slice --- nonebot/adapters/_message.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index d3f050a1..a1a44ba9 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -12,6 +12,7 @@ from typing import ( Mapping, TypeVar, Iterable, + Optional, overload, ) @@ -232,14 +233,20 @@ class Message(List[TMS], abc.ABC): return super().index(first_segment, *args) # type: ignore return super().index(value, *args) - def get(self: TM, type_: str, count: int) -> TM: - iterator = (seg for seg in self if seg.type == type_) - return self.__class__( - filter( - lambda seg: seg is not None, - (next(iterator) for _ in range(count)), - ) - ) + def get(self: TM, type_: str, count: Optional[int] = None) -> TM: + if count is None: + return self[type_] + + iterator, filtered = (seg for seg in self if seg.type == type_), [] + for _ in range(count): + seg = next(iterator, None) + if seg is None: + break + filtered.append(seg) + return self.__class__(filtered) + + def count(self, value: Union[TMS, str]) -> int: + return len(self[value]) if isinstance(value, str) else super().count(value) def append(self: TM, obj: Union[str, TMS]) -> TM: """