⚗️ support segment typing for message

This commit is contained in:
yanyongyu 2021-05-10 00:54:15 +08:00
parent 4b38afdcd7
commit f8ad9ef278
5 changed files with 9 additions and 9 deletions

View File

@ -11,8 +11,8 @@ from copy import copy
from functools import reduce, partial from functools import reduce, partial
from typing_extensions import Protocol from typing_extensions import Protocol
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable, from typing import (Any, Set, List, Dict, Union, TypeVar, Mapping, Optional,
Awaitable, TYPE_CHECKING) Iterable, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel from pydantic import BaseModel
@ -316,7 +316,7 @@ class MessageSegment(abc.ABC, Mapping):
raise NotImplementedError raise NotImplementedError
class Message(list, abc.ABC): class Message(List[T_MessageSegment], abc.ABC):
"""消息数组""" """消息数组"""
def __init__(self, def __init__(self,

View File

@ -100,14 +100,14 @@ def _check_at_me(bot: "Bot", event: "Event"):
# check the first segment # check the first segment
if event.message[0] == at_me_seg: if event.message[0] == at_me_seg:
event.to_me = True event.to_me = True
del event.message[0] event.message.pop(0)
if event.message and event.message[0].type == "text": if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[ event.message[0].data["text"] = event.message[0].data[
"text"].lstrip() "text"].lstrip()
if not event.message[0].data["text"]: if not event.message[0].data["text"]:
del event.message[0] del event.message[0]
if event.message and event.message[0] == at_me_seg: if event.message and event.message[0] == at_me_seg:
del event.message[0] event.message.pop(0)
if event.message and event.message[0].type == "text": if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[ event.message[0].data["text"] = event.message[0].data[
"text"].lstrip() "text"].lstrip()

View File

@ -3,7 +3,7 @@ from io import BytesIO
from pathlib import Path from pathlib import Path
from base64 import b64encode from base64 import b64encode
from functools import reduce from functools import reduce
from typing import Any, Dict, Union, Tuple, Mapping, Iterable, Optional from typing import Any, List, Dict, Union, Tuple, Mapping, Iterable, Optional
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
@ -229,7 +229,7 @@ class MessageSegment(BaseMessageSegment):
return MessageSegment("xml", {"data": data}) return MessageSegment("xml", {"data": data})
class Message(BaseMessage): class Message(BaseMessage[MessageSegment]):
""" """
CQHTTP 协议 Message 适配 CQHTTP 协议 Message 适配
""" """

View File

@ -155,7 +155,7 @@ class MessageSegment(BaseMessageSegment):
return {"msgtype": self.type, self.type: copy(self.data)} return {"msgtype": self.type, self.type: copy(self.data)}
class Message(BaseMessage): class Message(BaseMessage[MessageSegment]):
""" """
钉钉 协议 Message 适配 钉钉 协议 Message 适配
""" """

View File

@ -266,7 +266,7 @@ class MessageSegment(BaseMessageSegment):
return cls(type=MessageType.POKE, name=name) return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage): class MessageChain(BaseMessage[MessageSegment]):
""" """
Mirai 协议 Message 适配 Mirai 协议 Message 适配