diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 6e08553f..efc34acc 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -179,7 +179,7 @@ class BaseEvent(abc.ABC): @dataclass class BaseMessageSegment(abc.ABC): type: str - data: Dict[str, str] = field(default_factory=lambda: {}) + data: Dict[str, Union[str, list]] = field(default_factory=lambda: {}) @abc.abstractmethod def __str__(self): diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 9e5ba061..bb9ec12d 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -1,5 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +""" +CQHTTP (OneBot) v11 协议适配 +============================ + +协议详情请看: `CQHTTP`_ | `OneBot`_ + +.. _CQHTTP: + http://cqhttp.cc/ +.. _OneBot: + https://github.com/howmanybots/onebot +""" import re import sys @@ -38,8 +49,8 @@ def unescape(s: str) -> str: .replace("&", "&") -def _b2s(b: bool) -> str: - return str(b).lower() +def _b2s(b: Optional[bool]) -> Optional[str]: + return b if b is None else str(b).lower() def _check_at_me(bot: "Bot", event: "Event"): @@ -389,14 +400,8 @@ class Event(BaseEvent): class MessageSegment(BaseMessageSegment): @overrides(BaseMessageSegment) - def __init__(self, type: str, data: Dict[str, str]) -> None: - if type == "at" and data.get("qq") == "all": - type = "at_all" - data.clear() - elif type == "shake": - type = "poke" - data = {"type": "Poke"} - elif type == "text": + def __init__(self, type: str, data: Dict[str, Union[str, list]]) -> None: + if type == "text": data["text"] = unescape(data["text"]) super().__init__(type=type, data=data) @@ -406,16 +411,11 @@ class MessageSegment(BaseMessageSegment): data = self.data.copy() # process special types - if type_ == "at_all": - type_ = "at" - data = {"qq": "all"} - elif type_ == "poke": - type_ = "shake" - data.clear() - elif type_ == "text": + if type_ == "text": return escape(data.get("text", ""), escape_comma=False) - params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) + params = ",".join( + [f"{k}={escape(str(v))}" for k, v in data.items() if v is not None]) return f"[CQ:{type_}{',' if params else ''}{params}]" @overrides(BaseMessageSegment) @@ -423,17 +423,13 @@ class MessageSegment(BaseMessageSegment): return Message(self) + other @staticmethod - def anonymous(ignore_failure: bool = False) -> "MessageSegment": + def anonymous(ignore_failure: Optional[bool] = None) -> "MessageSegment": return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) @staticmethod def at(user_id: Union[int, str]) -> "MessageSegment": return MessageSegment("at", {"qq": str(user_id)}) - @staticmethod - def at_all() -> "MessageSegment": - return MessageSegment("at_all") - @staticmethod def contact_group(group_id: int) -> "MessageSegment": return MessageSegment("contact", {"type": "group", "id": str(group_id)}) @@ -442,23 +438,43 @@ class MessageSegment(BaseMessageSegment): def contact_user(user_id: int) -> "MessageSegment": return MessageSegment("contact", {"type": "qq", "id": str(user_id)}) + @staticmethod + def dice() -> "MessageSegment": + return MessageSegment("dice", {}) + @staticmethod def face(id_: int) -> "MessageSegment": return MessageSegment("face", {"id": str(id_)}) @staticmethod def forward(id_: str) -> "MessageSegment": + logger.warning("Forward Message only can be received!") return MessageSegment("forward", {"id": id_}) @staticmethod - def image(file: str) -> "MessageSegment": - return MessageSegment("image", {"file": file}) + def image(file: str, + type_: Optional[str] = None, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None) -> "MessageSegment": + return MessageSegment( + "image", { + "file": file, + "type": type_, + "cache": cache, + "proxy": proxy, + "timeout": timeout + }) + + @staticmethod + def json(data: str) -> "MessageSegment": + return MessageSegment("json", {"data": data}) @staticmethod def location(latitude: float, longitude: float, - title: str = "", - content: str = "") -> "MessageSegment": + title: Optional[str] = None, + content: Optional[str] = None) -> "MessageSegment": return MessageSegment( "location", { "lat": str(latitude), @@ -468,36 +484,18 @@ class MessageSegment(BaseMessageSegment): }) @staticmethod - def magic_face(type_: str) -> "MessageSegment": - if type_ not in ["dice", "rpc"]: - raise ValueError( - f"Coolq doesn't support magic face type {type_}. Supported types: dice, rpc." - ) - return MessageSegment("magic_face", {"type": type_}) + def music(type_: str, id_: int) -> "MessageSegment": + return MessageSegment("music", {"type": type_, "id": id_}) @staticmethod - def music(type_: str, - id_: int, - style: Optional[int] = None) -> "MessageSegment": - if style is None: - return MessageSegment("music", {"type": type_, "id": id_}) - else: - return MessageSegment("music", { - "type": type_, - "id": id_, - "style": style - }) - - @staticmethod - def music_custom(type_: str, - url: str, + def music_custom(url: str, audio: str, title: str, - content: str = "", - img_url: str = "") -> "MessageSegment": + content: Optional[str] = None, + img_url: Optional[str] = None) -> "MessageSegment": return MessageSegment( "music", { - "type": type_, + "type": "custom", "url": url, "audio": audio, "title": title, @@ -510,35 +508,43 @@ class MessageSegment(BaseMessageSegment): return MessageSegment("node", {"id": str(id_)}) @staticmethod - def node_custom(name: str, uin: int, - content: "Message") -> "MessageSegment": + def node_custom(user_id: int, nickname: str, + content: Union[str, "Message"]) -> "MessageSegment": return MessageSegment("node", { - "name": name, - "uin": str(uin), - "content": str(content) + "user_id": str(user_id), + "nickname": nickname, + "content": content }) @staticmethod - def poke(type_: str = "Poke") -> "MessageSegment": - if type_ not in ["Poke"]: - raise ValueError( - f"Coolq doesn't support poke type {type_}. Supported types: Poke." - ) - return MessageSegment("poke", {"type": type_}) + def poke(type_: str, id_: str) -> "MessageSegment": + return MessageSegment("poke", {"type": type_, "id": id_}) @staticmethod - def record(file: str, magic: bool = False) -> "MessageSegment": + def record(file: str, + magic: Optional[bool] = None, + cache: Optional[bool] = None, + proxy: Optional[bool] = None, + timeout: Optional[int] = None) -> "MessageSegment": return MessageSegment("record", {"file": file, "magic": _b2s(magic)}) @staticmethod - def replay(id_: int) -> "MessageSegment": - return MessageSegment("replay", {"id": str(id_)}) + def reply(id_: int) -> "MessageSegment": + return MessageSegment("reply", {"id": str(id_)}) + + @staticmethod + def rps() -> "MessageSegment": + return MessageSegment("rps", {}) + + @staticmethod + def shake() -> "MessageSegment": + return MessageSegment("shake", {}) @staticmethod def share(url: str = "", title: str = "", - content: str = "", - img_url: str = "") -> "MessageSegment": + content: Optional[str] = None, + img_url: Optional[str] = None) -> "MessageSegment": return MessageSegment("share", { "url": url, "title": title, @@ -550,6 +556,22 @@ class MessageSegment(BaseMessageSegment): def text(text: str) -> "MessageSegment": return MessageSegment("text", {"text": text}) + @staticmethod + def video(file: str, + cache: Optional[bool] = None, + proxy: Optional[bool] = None, + timeout: Optional[int] = None) -> "MessageSegment": + return MessageSegment("video", { + "file": file, + "cache": cache, + "proxy": proxy, + "timeout": timeout + }) + + @staticmethod + def xml(data: str) -> "MessageSegment": + return MessageSegment("xml", {"data": data}) + class Message(BaseMessage): @@ -564,7 +586,7 @@ class Message(BaseMessage): yield MessageSegment(seg["type"], seg.get("data") or {}) return - def _iter_message() -> Iterable[Tuple[str, str]]: + def _iter_message(msg: str) -> Iterable[Tuple[str, str]]: text_begin = 0 for cqcode in re.finditer( r"\[CQ:(?P[a-zA-Z0-9-_.]+)" @@ -577,7 +599,7 @@ class Message(BaseMessage): yield cqcode.group("type"), cqcode.group("params").lstrip(",") yield "text", unescape(msg[text_begin:]) - for type_, data in _iter_message(): + for type_, data in _iter_message(msg): if type_ == "text": if data: # only yield non-empty text segment @@ -589,13 +611,4 @@ class Message(BaseMessage): filter(lambda x: x, ( x.lstrip() for x in data.split(",")))) } - if type_ == "at" and data["qq"] == "all": - type_ = "at_all" - data.clear() - elif type_ in ["dice", "rpc"]: - type_ = "magic_face" - data["type"] = type_ - elif type_ == "shake": - type_ = "poke" - data["type"] = "Poke" yield MessageSegment(type_, data) diff --git a/package.json b/package.json index ecc13c8c..0baf5f4d 100644 --- a/package.json +++ b/package.json @@ -4,10 +4,12 @@ "description": "An asynchronous python bot framework.", "homepage": "https://docs.nonebot.dev/", "main": "index.js", - "contributors": [{ - "name": "yanyongyu", - "email": "yanyongyu_1@126.com" - }], + "contributors": [ + { + "name": "yanyongyu", + "email": "yanyongyu_1@126.com" + } + ], "repository": "https://github.com/nonebot/nonebot/", "bugs": { "url": "https://github.com/nonebot/nonebot/issues"