mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-19 01:45:45 +08:00
299 lines
10 KiB
Python
299 lines
10 KiB
Python
import re
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from base64 import b64encode
|
|
from typing import Type, Union, Tuple, Mapping, Iterable, Optional
|
|
|
|
from nonebot.typing import overrides
|
|
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
|
|
|
from .utils import log, escape, unescape, _b2s
|
|
|
|
|
|
class MessageSegment(BaseMessageSegment["Message"]):
|
|
"""
|
|
CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
|
"""
|
|
|
|
@classmethod
|
|
@overrides(BaseMessageSegment)
|
|
def get_message_class(cls) -> Type["Message"]:
|
|
return Message
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def __str__(self) -> str:
|
|
type_ = self.type
|
|
data = self.data.copy()
|
|
|
|
# process special types
|
|
if type_ == "text":
|
|
return escape(
|
|
data.get("text", ""), # type: ignore
|
|
escape_comma=False)
|
|
|
|
params = ",".join(
|
|
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None])
|
|
return f"[CQ:{type_}{',' if params else ''}{params}]"
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def __add__(self, other) -> "Message":
|
|
return Message(self) + (MessageSegment.text(other) if isinstance(
|
|
other, str) else other)
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def __radd__(self, other) -> "Message":
|
|
return (MessageSegment.text(other)
|
|
if isinstance(other, str) else Message(other)) + self
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def is_text(self) -> bool:
|
|
return self.type == "text"
|
|
|
|
@staticmethod
|
|
def anonymous(ignore_failure: Optional[bool] = None) -> "MessageSegment":
|
|
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
|
|
|
|
@staticmethod
|
|
def at(user_id: Union[int, str]) -> "MessageSegment":
|
|
return MessageSegment("at", {"qq": str(user_id)})
|
|
|
|
@staticmethod
|
|
def contact(type_: str, id: int) -> "MessageSegment":
|
|
return MessageSegment("contact", {"type": type_, "id": str(id)})
|
|
|
|
@staticmethod
|
|
def contact_group(group_id: int) -> "MessageSegment":
|
|
return MessageSegment("contact", {"type": "group", "id": str(group_id)})
|
|
|
|
@staticmethod
|
|
def contact_user(user_id: int) -> "MessageSegment":
|
|
return MessageSegment("contact", {"type": "qq", "id": str(user_id)})
|
|
|
|
@staticmethod
|
|
def dice() -> "MessageSegment":
|
|
return MessageSegment("dice", {})
|
|
|
|
@staticmethod
|
|
def face(id_: int) -> "MessageSegment":
|
|
return MessageSegment("face", {"id": str(id_)})
|
|
|
|
@staticmethod
|
|
def forward(id_: str) -> "MessageSegment":
|
|
log("WARNING", "Forward Message only can be received!")
|
|
return MessageSegment("forward", {"id": id_})
|
|
|
|
@staticmethod
|
|
def image(file: Union[str, bytes, BytesIO, Path],
|
|
type_: Optional[str] = None,
|
|
cache: bool = True,
|
|
proxy: bool = True,
|
|
timeout: Optional[int] = None) -> "MessageSegment":
|
|
if isinstance(file, BytesIO):
|
|
file = file.getvalue()
|
|
if isinstance(file, bytes):
|
|
file = f"base64://{b64encode(file).decode()}"
|
|
elif isinstance(file, Path):
|
|
file = f"file:///{file.resolve()}"
|
|
return MessageSegment(
|
|
"image", {
|
|
"file": file,
|
|
"type": type_,
|
|
"cache": _b2s(cache),
|
|
"proxy": _b2s(proxy),
|
|
"timeout": timeout
|
|
})
|
|
|
|
@staticmethod
|
|
def json(data: str) -> "MessageSegment":
|
|
return MessageSegment("json", {"data": data})
|
|
|
|
@staticmethod
|
|
def location(latitude: float,
|
|
longitude: float,
|
|
title: Optional[str] = None,
|
|
content: Optional[str] = None) -> "MessageSegment":
|
|
return MessageSegment(
|
|
"location", {
|
|
"lat": str(latitude),
|
|
"lon": str(longitude),
|
|
"title": title,
|
|
"content": content
|
|
})
|
|
|
|
@staticmethod
|
|
def music(type_: str, id_: int) -> "MessageSegment":
|
|
return MessageSegment("music", {"type": type_, "id": id_})
|
|
|
|
@staticmethod
|
|
def music_custom(url: str,
|
|
audio: str,
|
|
title: str,
|
|
content: Optional[str] = None,
|
|
img_url: Optional[str] = None) -> "MessageSegment":
|
|
return MessageSegment(
|
|
"music", {
|
|
"type": "custom",
|
|
"url": url,
|
|
"audio": audio,
|
|
"title": title,
|
|
"content": content,
|
|
"image": img_url
|
|
})
|
|
|
|
@staticmethod
|
|
def node(id_: int) -> "MessageSegment":
|
|
return MessageSegment("node", {"id": str(id_)})
|
|
|
|
@staticmethod
|
|
def node_custom(user_id: int, nickname: str,
|
|
content: Union[str, "Message"]) -> "MessageSegment":
|
|
return MessageSegment("node", {
|
|
"user_id": str(user_id),
|
|
"nickname": nickname,
|
|
"content": content
|
|
})
|
|
|
|
@staticmethod
|
|
def poke(type_: str, id_: str) -> "MessageSegment":
|
|
return MessageSegment("poke", {"type": type_, "id": id_})
|
|
|
|
@staticmethod
|
|
def record(file: Union[str, bytes, BytesIO, Path],
|
|
magic: Optional[bool] = None,
|
|
cache: Optional[bool] = None,
|
|
proxy: Optional[bool] = None,
|
|
timeout: Optional[int] = None) -> "MessageSegment":
|
|
if isinstance(file, BytesIO):
|
|
file = file.getvalue()
|
|
if isinstance(file, bytes):
|
|
file = f"base64://{b64encode(file).decode()}"
|
|
elif isinstance(file, Path):
|
|
file = f"file:///{file.resolve()}"
|
|
return MessageSegment(
|
|
"record", {
|
|
"file": file,
|
|
"magic": _b2s(magic),
|
|
"cache": _b2s(cache),
|
|
"proxy": _b2s(proxy),
|
|
"timeout": timeout
|
|
})
|
|
|
|
@staticmethod
|
|
def reply(id_: int) -> "MessageSegment":
|
|
return MessageSegment("reply", {"id": str(id_)})
|
|
|
|
@staticmethod
|
|
def rps() -> "MessageSegment":
|
|
return MessageSegment("rps", {})
|
|
|
|
@staticmethod
|
|
def shake() -> "MessageSegment":
|
|
return MessageSegment("shake", {})
|
|
|
|
@staticmethod
|
|
def share(url: str = "",
|
|
title: str = "",
|
|
content: Optional[str] = None,
|
|
image: Optional[str] = None) -> "MessageSegment":
|
|
return MessageSegment("share", {
|
|
"url": url,
|
|
"title": title,
|
|
"content": content,
|
|
"image": image
|
|
})
|
|
|
|
@staticmethod
|
|
def text(text: str) -> "MessageSegment":
|
|
return MessageSegment("text", {"text": text})
|
|
|
|
@staticmethod
|
|
def video(file: Union[str, bytes, BytesIO, Path],
|
|
cache: Optional[bool] = None,
|
|
proxy: Optional[bool] = None,
|
|
timeout: Optional[int] = None) -> "MessageSegment":
|
|
if isinstance(file, BytesIO):
|
|
file = file.getvalue()
|
|
if isinstance(file, bytes):
|
|
file = f"base64://{b64encode(file).decode()}"
|
|
elif isinstance(file, Path):
|
|
file = f"file:///{file.resolve()}"
|
|
return MessageSegment(
|
|
"video", {
|
|
"file": file,
|
|
"cache": _b2s(cache),
|
|
"proxy": _b2s(proxy),
|
|
"timeout": timeout
|
|
})
|
|
|
|
@staticmethod
|
|
def xml(data: str) -> "MessageSegment":
|
|
return MessageSegment("xml", {"data": data})
|
|
|
|
|
|
class Message(BaseMessage[MessageSegment]):
|
|
"""
|
|
CQHTTP 协议 Message 适配。
|
|
"""
|
|
|
|
@classmethod
|
|
@overrides(BaseMessage)
|
|
def get_segment_class(cls) -> Type[MessageSegment]:
|
|
return MessageSegment
|
|
|
|
@overrides(BaseMessage)
|
|
def __add__(self, other: Union[str, Mapping,
|
|
Iterable[Mapping]]) -> "Message":
|
|
return super(Message, self).__add__(
|
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
|
|
|
@overrides(BaseMessage)
|
|
def __radd__(self, other: Union[str, Mapping,
|
|
Iterable[Mapping]]) -> "Message":
|
|
return super(Message, self).__radd__(
|
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
|
|
|
@staticmethod
|
|
@overrides(BaseMessage)
|
|
def _construct(
|
|
msg: Union[str, Mapping,
|
|
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
|
if isinstance(msg, Mapping):
|
|
yield MessageSegment(msg["type"], msg.get("data") or {})
|
|
return
|
|
elif isinstance(msg, Iterable) and not isinstance(msg, str):
|
|
for seg in msg:
|
|
yield MessageSegment(seg["type"], seg.get("data") or {})
|
|
return
|
|
elif isinstance(msg, str):
|
|
|
|
def _iter_message(msg: str) -> Iterable[Tuple[str, str]]:
|
|
text_begin = 0
|
|
for cqcode in re.finditer(
|
|
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
|
|
r"(?P<params>"
|
|
r"(?:,[a-zA-Z0-9-_.]+=[^,\]]+)*"
|
|
r"),?\]", msg):
|
|
yield "text", msg[text_begin:cqcode.pos + cqcode.start()]
|
|
text_begin = cqcode.pos + cqcode.end()
|
|
yield cqcode.group("type"), cqcode.group("params").lstrip(
|
|
",")
|
|
yield "text", msg[text_begin:]
|
|
|
|
for type_, data in _iter_message(msg):
|
|
if type_ == "text":
|
|
if data:
|
|
# only yield non-empty text segment
|
|
yield MessageSegment(type_, {"text": unescape(data)})
|
|
else:
|
|
data = {
|
|
k: unescape(v) for k, v in map(
|
|
lambda x: x.split("=", maxsplit=1),
|
|
filter(lambda x: x, (
|
|
x.lstrip() for x in data.split(","))))
|
|
}
|
|
yield MessageSegment(type_, data)
|
|
|
|
@overrides(BaseMessage)
|
|
def extract_plain_text(self) -> str:
|
|
return "".join(seg.data["text"] for seg in self if seg.is_text())
|