From ab61be26a9643b779aca54585273585e72dbf9ec Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 28 Dec 2020 17:39:33 +0800 Subject: [PATCH] :zap: improve radd support for messagesegment --- nonebot/adapters/__init__.py | 57 ++++++++++++++++++++---------- nonebot/adapters/cqhttp/message.py | 4 +++ 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 1d4d971a..22c6f587 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -9,7 +9,7 @@ import abc from typing_extensions import Literal from functools import reduce, partial from dataclasses import dataclass, field -from typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING +from typing import Any, Dict, Union, TypeVar, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING from pydantic import BaseModel @@ -267,6 +267,10 @@ class Event(abc.ABC, BaseModel): raise NotImplementedError +T_Message = TypeVar("T_Message", bound="Message") +T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") + + @dataclass class MessageSegment(abc.ABC): """消息段基类""" @@ -282,19 +286,34 @@ class MessageSegment(abc.ABC): """ @abc.abstractmethod - def __str__(self) -> str: + def __str__(self: T_MessageSegment) -> str: """该消息段所代表的 str,在命令匹配部分使用""" raise NotImplementedError @abc.abstractmethod - def __add__(self, other) -> "Message": + def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, + T_Message]) -> "T_Message": """你需要在这里实现不同消息段的合并: 比如: if isinstance(other, str): ... elif isinstance(other, MessageSegment): ... - 注意:不能返回 self,需要返回一个新生成的对象 + 注意:需要返回一个新生成的对象 + """ + raise NotImplementedError + + @abc.abstractmethod + def __radd__( + self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment, + T_Message]) -> "T_Message": + """你需要在这里实现不同消息段的合并: + 比如: + if isinstance(other, str): + ... + elif isinstance(other, MessageSegment): + ... + 注意:需要返回一个新生成的对象 """ raise NotImplementedError @@ -316,17 +335,17 @@ class Message(list, abc.ABC): """消息数组""" def __init__(self, - message: Union[str, dict, list, BaseModel, MessageSegment, - "Message"] = None, + message: Union[str, dict, list, T_MessageSegment, + T_Message] = None, *args, **kwargs): """ :参数: - * ``message: Union[str, dict, list, BaseModel, MessageSegment, Message]``: 消息内容 + * ``message: Union[str, dict, list, MessageSegment, Message]``: 消息内容 """ super().__init__(*args, **kwargs) - if isinstance(message, (str, dict, list, BaseModel)): + if isinstance(message, (str, dict, list)): self.extend(self._construct(message)) elif isinstance(message, Message): self.extend(message) @@ -347,11 +366,12 @@ class Message(list, abc.ABC): @staticmethod @abc.abstractmethod def _construct( - msg: Union[str, dict, list, BaseModel]) -> Iterable[MessageSegment]: + msg: Union[str, dict, list, + BaseModel]) -> Iterable[T_MessageSegment]: raise NotImplementedError - def __add__(self, other: Union[str, MessageSegment, - "Message"]) -> "Message": + def __add__(self: T_Message, other: Union[str, T_MessageSegment, + T_Message]) -> T_Message: result = self.__class__(self) if isinstance(other, str): result.extend(self._construct(other)) @@ -361,11 +381,12 @@ class Message(list, abc.ABC): result.extend(other) return result - def __radd__(self, other: Union[str, MessageSegment, "Message"]): + def __radd__(self: T_Message, other: Union[str, T_MessageSegment, + T_Message]): result = self.__class__(other) return result.__add__(self) - def append(self, obj: Union[str, MessageSegment]) -> "Message": + def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message: """ :说明: @@ -383,8 +404,8 @@ class Message(list, abc.ABC): raise ValueError(f"Unexpected type: {type(obj)} {obj}") return self - def extend(self, obj: Union["Message", - Iterable[MessageSegment]]) -> "Message": + def extend(self: T_Message, + obj: Union[T_Message, Iterable[T_MessageSegment]]) -> T_Message: """ :说明: @@ -398,7 +419,7 @@ class Message(list, abc.ABC): self.append(segment) return self - def reduce(self) -> None: + def reduce(self: T_Message) -> None: """ :说明: @@ -413,14 +434,14 @@ class Message(list, abc.ABC): else: index += 1 - def extract_plain_text(self) -> str: + def extract_plain_text(self: T_Message) -> str: """ :说明: 提取消息内纯文本消息 """ - def _concat(x: str, y: MessageSegment) -> str: + def _concat(x: str, y: T_MessageSegment) -> str: return f"{x} {y}" if y.is_text() else x plain_text = reduce(_concat, self, "") diff --git a/nonebot/adapters/cqhttp/message.py b/nonebot/adapters/cqhttp/message.py index fae867ac..466da970 100644 --- a/nonebot/adapters/cqhttp/message.py +++ b/nonebot/adapters/cqhttp/message.py @@ -35,6 +35,10 @@ class MessageSegment(BaseMessageSegment): def __add__(self, other) -> "Message": return Message(self) + other + @overrides(BaseMessageSegment) + def __radd__(self, other) -> "Message": + return Message(other) + self + @overrides(BaseMessageSegment) def is_text(self) -> bool: return self.type == "text"