2021-08-29 00:24:28 +08:00
import abc
from copy import deepcopy
2023-04-04 21:42:01 +08:00
from typing_extensions import Self
2024-04-16 00:33:48 +08:00
from collections . abc import Iterable
2021-09-27 00:19:30 +08:00
from dataclasses import field , asdict , dataclass
2024-04-16 00:33:48 +08:00
from typing import ( # noqa: UP035
2021-11-22 23:21:26 +08:00
Any ,
Type ,
Union ,
Generic ,
TypeVar ,
2022-01-17 00:28:36 +08:00
Optional ,
2023-04-04 21:42:01 +08:00
SupportsIndex ,
2022-01-16 13:17:05 +08:00
overload ,
2021-11-22 23:21:26 +08:00
)
2021-08-29 00:24:28 +08:00
2024-01-26 11:12:57 +08:00
from nonebot . compat import custom_validation , type_validate_python
2022-01-29 18:20:30 +08:00
2022-02-06 14:52:50 +08:00
from . template import MessageTemplate
2021-08-29 00:24:28 +08:00
2022-01-29 13:56:54 +08:00
TMS = TypeVar ( " TMS " , bound = " MessageSegment " )
2021-08-29 00:24:28 +08:00
TM = TypeVar ( " TM " , bound = " Message " )
2024-01-26 11:12:57 +08:00
@custom_validation
2021-08-29 00:24:28 +08:00
@dataclass
2022-01-29 13:56:54 +08:00
class MessageSegment ( abc . ABC , Generic [ TM ] ) :
2021-08-29 00:24:28 +08:00
""" 消息段基类 """
2021-11-22 23:21:26 +08:00
2021-08-29 00:24:28 +08:00
type : str
2022-01-20 14:49:46 +08:00
""" 消息段类型 """
2024-04-16 00:33:48 +08:00
data : dict [ str , Any ] = field ( default_factory = dict )
2022-01-20 14:49:46 +08:00
""" 消息段数据 """
2021-08-29 00:24:28 +08:00
@classmethod
@abc.abstractmethod
2024-04-16 00:33:48 +08:00
def get_message_class ( cls ) - > Type [ TM ] : # noqa: UP006
2022-01-20 14:49:46 +08:00
""" 获取消息数组类型 """
2021-08-29 00:24:28 +08:00
raise NotImplementedError
@abc.abstractmethod
def __str__ ( self ) - > str :
""" 该消息段所代表的 str, 在命令匹配部分使用 """
raise NotImplementedError
def __len__ ( self ) - > int :
return len ( str ( self ) )
2024-04-16 00:33:48 +08:00
def __ne__ ( # pyright: ignore[reportIncompatibleMethodOverride]
self , other : Self
) - > bool :
2021-08-29 00:24:28 +08:00
return not self == other
2022-01-29 13:56:54 +08:00
def __add__ ( self : TMS , other : Union [ str , TMS , Iterable [ TMS ] ] ) - > TM :
return self . get_message_class ( ) ( self ) + other
2021-08-29 00:24:28 +08:00
2022-01-29 13:56:54 +08:00
def __radd__ ( self : TMS , other : Union [ str , TMS , Iterable [ TMS ] ] ) - > TM :
return self . get_message_class ( ) ( other ) + self
2021-08-29 00:24:28 +08:00
2022-01-29 18:20:30 +08:00
@classmethod
def __get_validators__ ( cls ) :
yield cls . _validate
@classmethod
2023-04-04 21:42:01 +08:00
def _validate ( cls , value ) - > Self :
2022-01-29 18:20:30 +08:00
if isinstance ( value , cls ) :
return value
2024-01-26 11:12:57 +08:00
if isinstance ( value , MessageSegment ) :
raise ValueError ( f " Type { type ( value ) } can not be converted to { cls } " )
2022-01-29 18:20:30 +08:00
if not isinstance ( value , dict ) :
raise ValueError ( f " Expected dict for MessageSegment, got { type ( value ) } " )
2022-06-20 15:52:12 +08:00
if " type " not in value :
raise ValueError (
f " Expected dict with ' type ' for MessageSegment, got { value } "
)
return cls ( type = value [ " type " ] , data = value . get ( " data " , { } ) )
2022-01-29 18:20:30 +08:00
2021-08-29 00:24:28 +08:00
def get ( self , key : str , default : Any = None ) :
2022-01-29 13:56:54 +08:00
return asdict ( self ) . get ( key , default )
2021-08-29 00:24:28 +08:00
def keys ( self ) :
return asdict ( self ) . keys ( )
def values ( self ) :
return asdict ( self ) . values ( )
def items ( self ) :
return asdict ( self ) . items ( )
2023-04-04 21:42:01 +08:00
def join ( self : TMS , iterable : Iterable [ Union [ TMS , TM ] ] ) - > TM :
return self . get_message_class ( ) ( self ) . join ( iterable )
def copy ( self ) - > Self :
2021-08-29 00:24:28 +08:00
return deepcopy ( self )
@abc.abstractmethod
def is_text ( self ) - > bool :
2022-01-20 14:49:46 +08:00
""" 当前消息段是否为纯文本 """
2021-08-29 00:24:28 +08:00
raise NotImplementedError
2024-01-26 11:12:57 +08:00
@custom_validation
2024-04-16 00:33:48 +08:00
class Message ( list [ TMS ] , abc . ABC ) :
2023-06-24 14:47:35 +08:00
""" 消息序列
2022-01-20 14:49:46 +08:00
参数 :
message : 消息内容
"""
2021-08-29 00:24:28 +08:00
2021-11-22 23:21:26 +08:00
def __init__ (
2022-01-29 13:56:54 +08:00
self ,
2022-01-29 18:20:30 +08:00
message : Union [ str , None , Iterable [ TMS ] , TMS ] = None ,
2021-11-22 23:21:26 +08:00
) :
2022-01-29 13:56:54 +08:00
super ( ) . __init__ ( )
2021-08-29 00:24:28 +08:00
if message is None :
return
2022-01-29 13:56:54 +08:00
elif isinstance ( message , str ) :
self . extend ( self . _construct ( message ) )
2021-08-29 00:24:28 +08:00
elif isinstance ( message , MessageSegment ) :
self . append ( message )
2022-01-29 13:56:54 +08:00
elif isinstance ( message , Iterable ) :
self . extend ( message )
2021-08-29 00:24:28 +08:00
else :
2022-01-30 11:04:02 +08:00
self . extend ( self . _construct ( message ) ) # pragma: no cover
2021-08-29 00:24:28 +08:00
@classmethod
2023-04-04 21:42:01 +08:00
def template ( cls , format_string : Union [ str , TM ] ) - > MessageTemplate [ Self ] :
2022-01-20 14:49:46 +08:00
""" 创建消息模板。
2023-06-24 14:47:35 +08:00
用法和 ` str . format ` 大致相同 , 支持以 ` Message ` 对象作为消息模板并输出消息对象 。
并且提供了拓展的格式化控制符 ,
可以通过该消息类型的 ` MessageSegment ` 工厂方法创建消息 。
2021-08-29 00:24:28 +08:00
2022-01-12 18:31:12 +08:00
参数 :
2022-02-12 16:15:06 +08:00
format_string : 格式化模板
2021-08-29 00:24:28 +08:00
2022-01-12 18:43:07 +08:00
返回 :
2022-01-20 14:49:46 +08:00
消息格式化器
2021-08-29 00:24:28 +08:00
"""
2021-10-04 22:00:32 +08:00
return MessageTemplate ( format_string , cls )
2021-08-29 00:24:28 +08:00
@classmethod
@abc.abstractmethod
2024-04-16 00:33:48 +08:00
def get_segment_class ( cls ) - > type [ TMS ] :
2022-01-20 14:49:46 +08:00
""" 获取消息段类型 """
2021-08-29 00:24:28 +08:00
raise NotImplementedError
2022-01-29 13:56:54 +08:00
def __str__ ( self ) - > str :
2021-08-29 00:24:28 +08:00
return " " . join ( str ( seg ) for seg in self )
@classmethod
def __get_validators__ ( cls ) :
yield cls . _validate
@classmethod
2023-04-04 21:42:01 +08:00
def _validate ( cls , value ) - > Self :
2022-01-29 23:55:14 +08:00
if isinstance ( value , cls ) :
return value
2022-01-30 11:04:02 +08:00
elif isinstance ( value , Message ) :
raise ValueError ( f " Type { type ( value ) } can not be converted to { cls } " )
2022-01-29 23:55:14 +08:00
elif isinstance ( value , str ) :
2022-01-29 18:20:30 +08:00
pass
elif isinstance ( value , dict ) :
2024-01-26 11:12:57 +08:00
value = type_validate_python ( cls . get_segment_class ( ) , value )
2022-01-29 18:20:30 +08:00
elif isinstance ( value , Iterable ) :
2024-01-26 11:12:57 +08:00
value = [ type_validate_python ( cls . get_segment_class ( ) , v ) for v in value ]
2022-01-29 18:20:30 +08:00
else :
raise ValueError (
f " Expected str, dict or iterable for Message, got { type ( value ) } "
)
2021-08-29 00:24:28 +08:00
return cls ( value )
@staticmethod
@abc.abstractmethod
2022-01-29 18:20:30 +08:00
def _construct ( msg : str ) - > Iterable [ TMS ] :
2022-01-20 14:49:46 +08:00
""" 构造消息数组 """
2021-08-29 00:24:28 +08:00
raise NotImplementedError
2024-04-16 00:33:48 +08:00
def __add__ ( # pyright: ignore[reportIncompatibleMethodOverride]
self , other : Union [ str , TMS , Iterable [ TMS ] ]
) - > Self :
2021-08-29 00:24:28 +08:00
result = self . copy ( )
result + = other
return result
2023-04-04 21:42:01 +08:00
def __radd__ ( self , other : Union [ str , TMS , Iterable [ TMS ] ] ) - > Self :
2022-01-29 13:56:54 +08:00
result = self . __class__ ( other )
2021-08-29 00:24:28 +08:00
return result + self
2023-04-04 21:42:01 +08:00
def __iadd__ ( self , other : Union [ str , TMS , Iterable [ TMS ] ] ) - > Self :
2022-01-29 13:56:54 +08:00
if isinstance ( other , str ) :
self . extend ( self . _construct ( other ) )
elif isinstance ( other , MessageSegment ) :
2021-08-29 00:24:28 +08:00
self . append ( other )
2022-01-29 13:56:54 +08:00
elif isinstance ( other , Iterable ) :
2021-08-29 00:24:28 +08:00
self . extend ( other )
else :
2022-08-22 14:39:00 +08:00
raise TypeError ( f " Unsupported type { type ( other ) !r} " )
2021-08-29 00:24:28 +08:00
return self
2022-01-16 13:17:05 +08:00
@overload
2023-04-04 21:42:01 +08:00
def __getitem__ ( self , args : str ) - > Self :
""" 获取仅包含指定消息段类型的消息
2022-01-29 13:56:54 +08:00
参数 :
2023-04-04 21:42:01 +08:00
args : 消息段类型
2022-01-29 13:56:54 +08:00
返回 :
2023-04-04 21:42:01 +08:00
所有类型为 ` args ` 的消息段
2022-01-29 13:56:54 +08:00
"""
2022-01-16 13:17:05 +08:00
@overload
2024-04-16 00:33:48 +08:00
def __getitem__ ( self , args : tuple [ str , int ] ) - > TMS :
2023-04-04 21:42:01 +08:00
""" 索引指定类型的消息段
2022-01-29 13:56:54 +08:00
参数 :
2023-04-04 21:42:01 +08:00
args : 消息段类型和索引
2022-01-29 13:56:54 +08:00
返回 :
2023-04-04 21:42:01 +08:00
类型为 ` args [ 0 ] ` 的消息段第 ` args [ 1 ] ` 个
2022-01-29 13:56:54 +08:00
"""
2022-01-16 13:17:05 +08:00
@overload
2024-04-16 00:33:48 +08:00
def __getitem__ ( self , args : tuple [ str , slice ] ) - > Self :
2023-04-04 21:42:01 +08:00
""" 切片指定类型的消息段
2022-01-29 13:56:54 +08:00
参数 :
2023-04-04 21:42:01 +08:00
args : 消息段类型和切片
2022-01-29 13:56:54 +08:00
返回 :
2023-04-04 21:42:01 +08:00
类型为 ` args [ 0 ] ` 的消息段切片 ` args [ 1 ] `
2022-01-29 13:56:54 +08:00
"""
2022-01-16 13:17:05 +08:00
@overload
2023-04-04 21:42:01 +08:00
def __getitem__ ( self , args : int ) - > TMS :
""" 索引消息段
2022-01-29 13:56:54 +08:00
参数 :
2023-04-04 21:42:01 +08:00
args : 索引
2022-01-29 13:56:54 +08:00
返回 :
2023-04-04 21:42:01 +08:00
第 ` args ` 个消息段
2022-01-29 13:56:54 +08:00
"""
2022-01-16 13:17:05 +08:00
@overload
2023-04-04 21:42:01 +08:00
def __getitem__ ( self , args : slice ) - > Self :
""" 切片消息段
2022-01-29 13:56:54 +08:00
参数 :
2023-04-04 21:42:01 +08:00
args : 切片
2022-01-29 13:56:54 +08:00
返回 :
2023-04-04 21:42:01 +08:00
消息切片 ` args `
2022-01-29 13:56:54 +08:00
"""
2022-01-16 13:17:05 +08:00
2024-04-16 00:33:48 +08:00
def __getitem__ ( # pyright: ignore[reportIncompatibleMethodOverride]
2023-04-04 21:42:01 +08:00
self ,
2022-01-16 17:13:26 +08:00
args : Union [
2022-01-16 13:17:05 +08:00
str ,
2024-04-16 00:33:48 +08:00
tuple [ str , int ] ,
tuple [ str , slice ] ,
2022-01-16 13:17:05 +08:00
int ,
slice ,
] ,
2023-04-04 21:42:01 +08:00
) - > Union [ TMS , Self ] :
2022-01-16 17:13:26 +08:00
arg1 , arg2 = args if isinstance ( args , tuple ) else ( args , None )
2022-01-16 13:17:05 +08:00
if isinstance ( arg1 , int ) and arg2 is None :
return super ( ) . __getitem__ ( arg1 )
elif isinstance ( arg1 , slice ) and arg2 is None :
return self . __class__ ( super ( ) . __getitem__ ( arg1 ) )
elif isinstance ( arg1 , str ) and arg2 is None :
return self . __class__ ( seg for seg in self if seg . type == arg1 )
elif isinstance ( arg1 , str ) and isinstance ( arg2 , int ) :
return [ seg for seg in self if seg . type == arg1 ] [ arg2 ]
elif isinstance ( arg1 , str ) and isinstance ( arg2 , slice ) :
return self . __class__ ( [ seg for seg in self if seg . type == arg1 ] [ arg2 ] )
else :
2022-01-30 11:04:02 +08:00
raise ValueError ( " Incorrect arguments to slice " ) # pragma: no cover
2022-01-16 17:13:26 +08:00
2024-04-16 00:33:48 +08:00
def __contains__ ( # pyright: ignore[reportIncompatibleMethodOverride]
self , value : Union [ TMS , str ]
) - > bool :
2023-04-04 21:42:01 +08:00
""" 检查消息段是否存在
参数 :
value : 消息段或消息段类型
返回 :
消息内是否存在给定消息段或给定类型的消息段
"""
if isinstance ( value , str ) :
2024-02-12 17:53:50 +08:00
return next ( ( seg for seg in self if seg . type == value ) , None ) is not None
2023-04-04 21:42:01 +08:00
return super ( ) . __contains__ ( value )
def has ( self , value : Union [ TMS , str ] ) - > bool :
""" 与 {ref} ``__contains__` <nonebot.adapters.Message.__contains__>` 相同 """
return value in self
def index ( self , value : Union [ TMS , str ] , * args : SupportsIndex ) - > int :
""" 索引消息段
参数 :
value : 消息段或者消息段类型
arg : start 与 end
返回 :
索引 index
异常 :
ValueError : 消息段不存在
"""
2022-01-16 17:13:26 +08:00
if isinstance ( value , str ) :
2022-01-29 13:56:54 +08:00
first_segment = next ( ( seg for seg in self if seg . type == value ) , None )
if first_segment is None :
2023-04-04 21:42:01 +08:00
raise ValueError ( f " Segment with type { value !r} is not in message " )
2022-01-29 13:56:54 +08:00
return super ( ) . index ( first_segment , * args )
2022-01-16 17:13:26 +08:00
return super ( ) . index ( value , * args )
2023-04-04 21:42:01 +08:00
def get ( self , type_ : str , count : Optional [ int ] = None ) - > Self :
""" 获取指定类型的消息段
参数 :
type_ : 消息段类型
count : 获取个数
返回 :
构建的新消息
"""
2022-01-17 00:28:36 +08:00
if count is None :
return self [ type_ ]
2022-01-29 13:56:54 +08:00
iterator , filtered = (
seg for seg in self if seg . type == type_
) , self . __class__ ( )
2022-01-17 00:28:36 +08:00
for _ in range ( count ) :
seg = next ( iterator , None )
if seg is None :
break
filtered . append ( seg )
2022-01-29 13:56:54 +08:00
return filtered
2022-01-17 00:28:36 +08:00
def count ( self , value : Union [ TMS , str ] ) - > int :
2023-04-04 21:42:01 +08:00
""" 计算指定消息段的个数
参数 :
value : 消息段或消息段类型
返回 :
个数
"""
2022-01-17 00:28:36 +08:00
return len ( self [ value ] ) if isinstance ( value , str ) else super ( ) . count ( value )
2022-01-16 13:17:05 +08:00
2023-04-04 21:42:01 +08:00
def only ( self , value : Union [ TMS , str ] ) - > bool :
""" 检查消息中是否仅包含指定消息段
参数 :
value : 指定消息段或消息段类型
返回 :
是否仅包含指定消息段
"""
if isinstance ( value , str ) :
return all ( seg . type == value for seg in self )
return all ( seg == value for seg in self )
2024-04-16 00:33:48 +08:00
def append ( # pyright: ignore[reportIncompatibleMethodOverride]
self , obj : Union [ str , TMS ]
) - > Self :
2022-01-20 14:49:46 +08:00
""" 添加一个消息段到消息数组末尾。
2021-08-29 00:24:28 +08:00
2022-01-12 18:31:12 +08:00
参数 :
2022-01-12 19:10:29 +08:00
obj : 要添加的消息段
2021-08-29 00:24:28 +08:00
"""
if isinstance ( obj , MessageSegment ) :
2022-01-29 13:56:54 +08:00
super ( ) . append ( obj )
2021-08-29 00:24:28 +08:00
elif isinstance ( obj , str ) :
self . extend ( self . _construct ( obj ) )
else :
2022-01-30 11:04:02 +08:00
raise ValueError ( f " Unexpected type: { type ( obj ) } { obj } " ) # pragma: no cover
2021-08-29 00:24:28 +08:00
return self
2024-04-16 00:33:48 +08:00
def extend ( # pyright: ignore[reportIncompatibleMethodOverride]
self , obj : Union [ Self , Iterable [ TMS ] ]
) - > Self :
2022-01-20 14:49:46 +08:00
""" 拼接一个消息数组或多个消息段到消息数组末尾。
2021-08-29 00:24:28 +08:00
2022-01-12 18:31:12 +08:00
参数 :
2022-01-12 19:10:29 +08:00
obj : 要添加的消息数组
2021-08-29 00:24:28 +08:00
"""
for segment in obj :
self . append ( segment )
return self
2023-04-04 21:42:01 +08:00
def join ( self , iterable : Iterable [ Union [ TMS , Self ] ] ) - > Self :
""" 将多个消息连接并将自身作为分割
参数 :
iterable : 要连接的消息
返回 :
连接后的消息
"""
ret = self . __class__ ( )
for index , msg in enumerate ( iterable ) :
if index != 0 :
ret . extend ( self )
if isinstance ( msg , MessageSegment ) :
ret . append ( msg . copy ( ) )
else :
ret . extend ( msg . copy ( ) )
return ret
def copy ( self ) - > Self :
""" 深拷贝消息 """
2021-08-29 00:24:28 +08:00
return deepcopy ( self )
2023-04-04 21:42:01 +08:00
def include ( self , * types : str ) - > Self :
""" 过滤消息
参数 :
types : 包含的消息段类型
返回 :
新构造的消息
"""
return self . __class__ ( seg for seg in self if seg . type in types )
def exclude ( self , * types : str ) - > Self :
""" 过滤消息
参数 :
types : 不包含的消息段类型
返回 :
新构造的消息
"""
return self . __class__ ( seg for seg in self if seg . type not in types )
2022-01-29 13:56:54 +08:00
def extract_plain_text ( self ) - > str :
2022-01-20 14:49:46 +08:00
""" 提取消息内纯文本消息 """
2021-08-29 00:24:28 +08:00
return " " . join ( str ( seg ) for seg in self if seg . is_text ( ) )