#!/usr/bin/env python3 # -*- coding: utf-8 -*- import re import httpx # from nonebot.event import Event from nonebot.config import Config from nonebot.message import handle_event from nonebot.exception import ApiNotAvailable from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment def escape(s: str, *, escape_comma: bool = True) -> str: """ 对字符串进行 CQ 码转义。 ``escape_comma`` 参数控制是否转义逗号(``,``)。 """ s = s.replace("&", "&") \ .replace("[", "[") \ .replace("]", "]") if escape_comma: s = s.replace(",", ",") return s def unescape(s: str) -> str: """对字符串进行 CQ 码去转义。""" return s.replace(",", ",") \ .replace("[", "[") \ .replace("]", "]") \ .replace("&", "&") def _b2s(b: bool) -> str: return str(b).lower() class Bot(BaseBot): def __init__(self, connection_type: str, config: Config, self_id: int, *, websocket: WebSocket = None): if connection_type not in ["http", "websocket"]: raise ValueError("Unsupported connection type") super().__init__(connection_type, config, self_id, websocket=websocket) @property @overrides(BaseBot) def type(self) -> str: return "cqhttp" @overrides(BaseBot) async def handle_message(self, message: dict): # TODO: convert message into event event = Event(message) if not event: return # if "message" in event.keys(): # event["message"] = Message(event["message"]) await handle_event(self, event) @overrides(BaseBot) async def call_api(self, api: str, data: dict): # TODO: Call API if self.type == "websocket": pass elif self.type == "http": api_root = self.config.api_root.get(self.self_id) if not api_root: raise ApiNotAvailable elif not api_root.endswith("/"): api_root += "/" headers = {} if self.config.access_token: headers["Authorization"] = "Bearer " + self.config.access_token async with httpx.AsyncClient() as client: response = await client.post(api_root + api) if 200 <= response.status_code < 300: # TODO: handle http api response return ... raise httpx.HTTPError( "", response) class Event(BaseEvent): @property @overrides(BaseEvent) def type(self): return self._raw_event["post_type"] @type.setter @overrides(BaseEvent) def type(self, value): self._raw_event["post_type"] = value @property @overrides(BaseEvent) def detail_type(self): return self._raw_event[f"{self.type}_type"] @detail_type.setter @overrides(BaseEvent) def detail_type(self, value): self._raw_event[f"{self.type}_type"] = value @property @overrides(BaseEvent) def sub_type(self): return self._raw_event["sub_type"] @type.setter @overrides(BaseEvent) def sub_type(self, value): self._raw_event["sub_type"] = value class MessageSegment(BaseMessageSegment): @overrides(BaseMessageSegment) def __str__(self): type_ = self.type data = self.data.copy() # process special types if type_ == "at_all": type_ = "at" data = {"qq": "all"} elif type_ == "poke": type_ = "shake" data.clear() elif type_ == "text": return escape(data.get("text", ""), escape_comma=False) params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) return f"[CQ:{type_}{',' if params else ''}{params}]" @overrides(BaseMessageSegment) def __add__(self, other) -> "Message": return Message(self) + other @staticmethod def anonymous(ignore_failure: bool = False) -> "MessageSegment": return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) @staticmethod def at(user_id: int) -> "MessageSegment": return MessageSegment("at", {"qq": str(user_id)}) @staticmethod def at_all() -> "MessageSegment": return MessageSegment("at_all") @staticmethod def contact_group(group_id: int) -> "MessageSegment": return MessageSegment("contact", {"type": "group", "id": str(group_id)}) @staticmethod def contact_user(user_id: int) -> "MessageSegment": return MessageSegment("contact", {"type": "qq", "id": str(user_id)}) @staticmethod def face(id_: int) -> "MessageSegment": return MessageSegment("face", {"id": str(id_)}) @staticmethod def forward(id_: str) -> "MessageSegment": return MessageSegment("forward", {"id": id_}) @staticmethod def image(file: str) -> "MessageSegment": return MessageSegment("image", {"file": file}) @staticmethod def location(latitude: float, longitude: float, title: str = "", content: str = "") -> "MessageSegment": return MessageSegment( "location", { "lat": str(latitude), "lon": str(longitude), "title": title, "content": content }) @staticmethod def magic_face(type_: str) -> "MessageSegment": if type_ not in ["dice", "rpc"]: raise ValueError( f"Coolq doesn't support magic face type {type_}. Supported types: dice, rpc." ) return MessageSegment("magic_face", {"type": type_}) @staticmethod def music(type_: str, id_: int, style: Optional[int] = None) -> "MessageSegment": if style is None: return MessageSegment("music", {"type": type_, "id": id_}) else: return MessageSegment("music", { "type": type_, "id": id_, "style": style }) @staticmethod def music_custom(type_: str, url: str, audio: str, title: str, content: str = "", img_url: str = "") -> "MessageSegment": return MessageSegment( "music", { "type": type_, "url": url, "audio": audio, "title": title, "content": content, "image": img_url }) @staticmethod def node(id_: int) -> "MessageSegment": return MessageSegment("node", {"id": str(id_)}) @staticmethod def node_custom(name: str, uin: int, content: "Message") -> "MessageSegment": return MessageSegment("node", { "name": name, "uin": str(uin), "content": str(content) }) @staticmethod def poke(type_: str = "Poke") -> "MessageSegment": if type_ not in ["Poke"]: raise ValueError( f"Coolq doesn't support poke type {type_}. Supported types: Poke." ) return MessageSegment("poke", {"type": type_}) @staticmethod def record(file: str, magic: bool = False) -> "MessageSegment": return MessageSegment("record", {"file": file, "magic": _b2s(magic)}) @staticmethod def replay(id_: int) -> "MessageSegment": return MessageSegment("replay", {"id": str(id_)}) @staticmethod def share(url: str = "", title: str = "", content: str = "", img_url: str = "") -> "MessageSegment": return MessageSegment("share", { "url": url, "title": title, "content": content, "img_url": img_url }) @staticmethod def text(text: str) -> "MessageSegment": return MessageSegment("text", {"text": text}) class Message(BaseMessage): @staticmethod @overrides(BaseMessage) def _construct(msg: str) -> Iterable[MessageSegment]: def _iter_message() -> Iterable[Tuple[str, str]]: text_begin = 0 for cqcode in re.finditer( r"\[CQ:(?P[a-zA-Z0-9-_.]+)" r"(?P" r"(?:,[a-zA-Z0-9-_.]+=?[^,\]]*)*" r"),?\]", msg): yield "text", unescape(msg[text_begin:cqcode.pos + cqcode.start()]) text_begin = cqcode.pos + cqcode.end() yield cqcode.group("type"), cqcode.group("params").lstrip(",") yield "text", unescape(msg[text_begin:]) for type_, data in _iter_message(): if type_ == "text": if data: # only yield non-empty text segment yield MessageSegment(type_, {"text": data}) else: data = { k: v for k, v in map( lambda x: x.split("=", maxsplit=1), filter(lambda x: x, ( x.lstrip() for x in data.split(",")))) } if type_ == "at" and data["qq"] == "all": type_ = "at_all" data.clear() elif type_ in ["dice", "rpc"]: type_ = "magic_face" data["type"] = type_ elif type_ == "shake": type_ = "poke" data["type"] = "Poke" yield MessageSegment(type_, data)