⚗️ support segment typing for message

This commit is contained in:
yanyongyu 2021-05-10 00:54:15 +08:00
parent 4b38afdcd7
commit f8ad9ef278
5 changed files with 9 additions and 9 deletions

View File

@ -11,8 +11,8 @@ from copy import copy
from functools import reduce, partial
from typing_extensions import Protocol
from dataclasses import dataclass, field
from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
Awaitable, TYPE_CHECKING)
from typing import (Any, Set, List, Dict, Union, TypeVar, Mapping, Optional,
Iterable, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel
@ -316,7 +316,7 @@ class MessageSegment(abc.ABC, Mapping):
raise NotImplementedError
class Message(list, abc.ABC):
class Message(List[T_MessageSegment], abc.ABC):
"""消息数组"""
def __init__(self,

View File

@ -100,14 +100,14 @@ def _check_at_me(bot: "Bot", event: "Event"):
# check the first segment
if event.message[0] == at_me_seg:
event.to_me = True
del event.message[0]
event.message.pop(0)
if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[
"text"].lstrip()
if not event.message[0].data["text"]:
del event.message[0]
if event.message and event.message[0] == at_me_seg:
del event.message[0]
event.message.pop(0)
if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[
"text"].lstrip()

View File

@ -3,7 +3,7 @@ from io import BytesIO
from pathlib import Path
from base64 import b64encode
from functools import reduce
from typing import Any, Dict, Union, Tuple, Mapping, Iterable, Optional
from typing import Any, List, Dict, Union, Tuple, Mapping, Iterable, Optional
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
@ -229,7 +229,7 @@ class MessageSegment(BaseMessageSegment):
return MessageSegment("xml", {"data": data})
class Message(BaseMessage):
class Message(BaseMessage[MessageSegment]):
"""
CQHTTP 协议 Message 适配
"""

View File

@ -155,7 +155,7 @@ class MessageSegment(BaseMessageSegment):
return {"msgtype": self.type, self.type: copy(self.data)}
class Message(BaseMessage):
class Message(BaseMessage[MessageSegment]):
"""
钉钉 协议 Message 适配
"""

View File

@ -266,7 +266,7 @@ class MessageSegment(BaseMessageSegment):
return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage):
class MessageChain(BaseMessage[MessageSegment]):
"""
Mirai 协议 Message 适配