diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index c12d96bb..9f7c69b7 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -1,19 +1,24 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Any, Dict, Optional +import abc +from functools import reduce +from typing import Dict, Union, Iterable, Optional from nonebot.config import Config -class BaseBot(object): +class BaseBot(abc.ABC): + @abc.abstractmethod def __init__(self, type: str, config: Config, *, websocket=None): raise NotImplementedError + @abc.abstractmethod async def handle_message(self, message: dict): raise NotImplementedError + @abc.abstractmethod async def call_api(self, api: str, data: dict): raise NotImplementedError @@ -64,8 +69,73 @@ class BaseMessageSegment(dict): class BaseMessage(list): - def __init__(self, message: str = None): - raise NotImplementedError + def __init__(self, + message: Union[str, BaseMessageSegment, "BaseMessage"] = None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + if isinstance(message, str): + self.extend(self._construct(message)) + elif isinstance(message, BaseMessage): + self.extend(message) + elif isinstance(message, BaseMessageSegment): + self.append(message) def __str__(self): return ''.join((str(seg) for seg in self)) + + @staticmethod + def _construct(msg: str) -> Iterable[BaseMessageSegment]: + raise NotImplementedError + + def __add__( + self, other: Union[str, BaseMessageSegment, + "BaseMessage"]) -> "BaseMessage": + result = self.__class__(self) + if isinstance(other, str): + result.extend(self._construct(other)) + elif isinstance(other, BaseMessageSegment): + result.append(other) + elif isinstance(other, BaseMessage): + result.extend(other) + return result + + def __radd__(self, other: Union[str, BaseMessageSegment, "BaseMessage"]): + result = self.__class__(other) + return result.__add__(self) + + def append(self, obj: Union[str, BaseMessageSegment]) -> "BaseMessage": + if isinstance(obj, BaseMessageSegment): + if obj.type == "text" and self and self[-1].type == "text": + self[-1].data["text"] += obj.data["text"] + else: + super().append(obj) + elif isinstance(obj, str): + self.extend(self._construct(obj)) + else: + raise ValueError(f"Unexpected type: {type(obj)} {obj}") + return self + + def extend( + self, obj: Union["BaseMessage", + Iterable[BaseMessageSegment]]) -> "BaseMessage": + for segment in obj: + self.append(segment) + return self + + def reduce(self) -> None: + index = 0 + while index < len(self): + if index > 0 and self[ + index - 1].type == "text" and self[index].type == "text": + self[index - 1].data["text"] += self[index].data["text"] + del self[index] + else: + index += 1 + + def extract_plain_text(self) -> str: + + def _concat(x: str, y: BaseMessageSegment) -> str: + return f"{x} {y.data['text']}" if y.type == "text" else x + + return reduce(_concat, self, "") diff --git a/nonebot/adapters/coolq.py b/nonebot/adapters/coolq.py index f31feaab..376e27f7 100644 --- a/nonebot/adapters/coolq.py +++ b/nonebot/adapters/coolq.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import re +from typing import Tuple, Iterable, Optional + import httpx from nonebot.event import Event @@ -32,16 +35,21 @@ def unescape(s: str) -> str: .replace("&", "&") +def _b2s(b: bool) -> str: + return str(b).lower() + + class Bot(BaseBot): def __init__(self, - type_: str, + connection_type: str, config: Config, *, websocket: BaseWebSocket = None): - if type_ not in ["http", "websocket"]: + if connection_type not in ["http", "websocket"]: raise ValueError("Unsupported connection type") - self.type = type_ + self.type = "coolq" + self.connection_type = connection_type self.config = config self.websocket = websocket @@ -49,10 +57,20 @@ class Bot(BaseBot): # TODO: convert message into event event = Event.from_payload(message) + if not event: + return + + if "message" in event.keys(): + event["message"] = Message(event["message"]) + # TODO: Handle Meta Event - await handle_event(self, event) + if event.type == "meta_event": + pass + else: + await handle_event(self, event) async def call_api(self, api: str, data: dict): + # TODO: Call API if self.type == "websocket": pass elif self.type == "http": @@ -66,15 +84,22 @@ class MessageSegment(BaseMessageSegment): data = self.data.copy() # process special types - if type_ == "text": - return escape(data.get("text", ""), escape_comma=False) - elif type_ == "at_all": + if type_ == "at_all": type_ = "at" data = {"qq": "all"} + elif type_ == "poke": + type_ = "shake" + data.clear() + elif type_ == "text": + return escape(data.get("text", ""), escape_comma=False) params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) return f"[CQ:{type_}{',' if params else ''}{params}]" + @staticmethod + def anonymous(ignore_failure: bool = False) -> "MessageSegment": + return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) + @staticmethod def at(user_id: int) -> "MessageSegment": return MessageSegment("at", {"qq": str(user_id)}) @@ -84,9 +109,138 @@ class MessageSegment(BaseMessageSegment): return MessageSegment("at_all") @staticmethod - def dice() -> "MessageSegment": - return MessageSegment(type_="dice") + def contact_group(group_id: int) -> "MessageSegment": + return MessageSegment("contact", {"type": "group", "id": str(group_id)}) + + @staticmethod + def contact_user(user_id: int) -> "MessageSegment": + return MessageSegment("contact", {"type": "qq", "id": str(user_id)}) + + @staticmethod + def face(id_: int) -> "MessageSegment": + return MessageSegment("face", {"id": str(id_)}) + + @staticmethod + def image(file: str) -> "MessageSegment": + return MessageSegment("image", {"file": "file"}) + + @staticmethod + def location(latitude: float, + longitude: float, + title: str = "", + content: str = "") -> "MessageSegment": + return MessageSegment( + "location", { + "lat": str(latitude), + "lon": str(longitude), + "title": title, + "content": content + }) + + @staticmethod + def magic_face(type_: str) -> "MessageSegment": + if type_ not in ["dice", "rpc"]: + raise ValueError( + f"Coolq doesn't support magic face type {type_}. Supported types: dice, rpc." + ) + return MessageSegment("magic_face", {"type": type_}) + + @staticmethod + def music(type_: str, + id_: int, + style: Optional[int] = None) -> "MessageSegment": + if style is None: + return MessageSegment("music", {"type": type_, "id": id_}) + else: + return MessageSegment("music", { + "type": type_, + "id": id_, + "style": style + }) + + @staticmethod + def music_custom(type_: str, + url: str, + audio: str, + title: str, + content: str = "", + img_url: str = "") -> "MessageSegment": + return MessageSegment( + "music", { + "type": type_, + "url": url, + "audio": audio, + "title": title, + "content": content, + "image": img_url + }) + + @staticmethod + def poke(type_: str = "Poke") -> "MessageSegment": + if type_ not in ["Poke"]: + raise ValueError( + f"Coolq doesn't support poke type {type_}. Supported types: Poke." + ) + return MessageSegment("poke", {"type": type_}) + + @staticmethod + def record(file: str, magic: bool = False) -> "MessageSegment": + return MessageSegment("record", {"file": file, "magic": _b2s(magic)}) + + @staticmethod + def share(url: str = "", + title: str = "", + content: str = "", + img_url: str = "") -> "MessageSegment": + return MessageSegment("share", { + "url": url, + "title": title, + "content": content, + "img_url": img_url + }) + + @staticmethod + def text(text: str) -> "MessageSegment": + return MessageSegment("text", {"text": text}) class Message(BaseMessage): - pass + + @staticmethod + def _construct(msg: str) -> Iterable[MessageSegment]: + + def _iter_message() -> Iterable[Tuple[str, str]]: + text_begin = 0 + for cqcode in re.finditer( + r"\[CQ:(?P[a-zA-Z0-9-_.]+)" + r"(?P" + r"(?:,[a-zA-Z0-9-_.]+=?[^,\]]*)*" + r"),?\]", msg): + yield "text", unescape(msg[text_begin:cqcode.pos + + cqcode.start()]) + text_begin = cqcode.pos + cqcode.end() + yield cqcode.group("type"), cqcode.group("params").lstrip(",") + yield "text", unescape(msg[text_begin:]) + + for type_, data in _iter_message(): + if type_ == "text": + if data: + # only yield non-empty text segment + yield MessageSegment(type_, {"text": data}) + else: + data = { + k: v for k, v in map( + lambda x: x.split("=", maxsplit=1), + filter(lambda x: x, ( + x.lstrip() for x in data.split(",")))) + } + if type_ == "at" and data["qq"] == "all": + type_ = "at_all" + data.clear() + elif type_ in ["dice", "rpc"]: + type_ = "magic_face" + data["type"] = type_ + elif type_ == "shake": + type_ = "poke" + data["type"] = "Poke" + yield MessageSegment(type_, data) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 49de914c..e4b88764 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -1,29 +1,35 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import abc from typing import Optional from ipaddress import IPv4Address from nonebot.config import Config -class BaseDriver(object): +class BaseDriver(abc.ABC): + @abc.abstractmethod def __init__(self, config: Config): raise NotImplementedError @property + @abc.abstractmethod def server_app(self): raise NotImplementedError @property + @abc.abstractmethod def asgi(self): raise NotImplementedError @property + @abc.abstractmethod def logger(self): raise NotImplementedError + @abc.abstractmethod def run(self, host: Optional[IPv4Address] = None, port: Optional[int] = None, @@ -31,37 +37,47 @@ class BaseDriver(object): **kwargs): raise NotImplementedError + @abc.abstractmethod async def _handle_http(self): raise NotImplementedError + @abc.abstractmethod async def _handle_ws_reverse(self): raise NotImplementedError + @abc.abstractmethod async def _handle_http_api(self): raise NotImplementedError class BaseWebSocket(object): + @abc.abstractmethod def __init__(self, websocket): self._websocket = websocket @property + @abc.abstractmethod def websocket(self): return self._websocket @property + @abc.abstractmethod def closed(self): raise NotImplementedError + @abc.abstractmethod async def accept(self): raise NotImplementedError + @abc.abstractmethod async def close(self): raise NotImplementedError + @abc.abstractmethod async def receive(self) -> dict: raise NotImplementedError + @abc.abstractmethod async def send(self, data: dict): raise NotImplementedError diff --git a/nonebot/event.py b/nonebot/event.py index e5e576ac..d88138d8 100644 --- a/nonebot/event.py +++ b/nonebot/event.py @@ -29,7 +29,7 @@ class Event(dict): """ 事件类型,有 ``message``、``notice``、``request``、``meta_event`` 等。 """ - return self['post_type'] + return self["post_type"] @property def detail_type(self) -> str: @@ -37,7 +37,7 @@ class Event(dict): 事件具体类型,依 `type` 的不同而不同,以 ``message`` 类型为例,有 ``private``、``group``、``discuss`` 等。 """ - return self[f'{self.type}_type'] + return self[f"{self.type}_type"] @property def sub_type(self) -> Optional[str]: @@ -45,7 +45,7 @@ class Event(dict): 事件子类型,依 `detail_type` 不同而不同,以 ``message.private`` 为例,有 ``friend``、``group``、``discuss``、``other`` 等。 """ - return self.get('sub_type') + return self.get("sub_type") @property def name(self): @@ -53,75 +53,75 @@ class Event(dict): 事件名,对于有 `sub_type` 的事件,为 ``{type}.{detail_type}.{sub_type}``,否则为 ``{type}.{detail_type}``。 """ - n = self.type + '.' + self.detail_type + n = self.type + "." + self.detail_type if self.sub_type: - n += '.' + self.sub_type + n += "." + self.sub_type return n @property def self_id(self) -> int: """机器人自身 ID。""" - return self['self_id'] + return self["self_id"] @property def user_id(self) -> Optional[int]: """用户 ID。""" - return self.get('user_id') + return self.get("user_id") @property def operator_id(self) -> Optional[int]: """操作者 ID。""" - return self.get('operator_id') + return self.get("operator_id") @property def group_id(self) -> Optional[int]: """群 ID。""" - return self.get('group_id') + return self.get("group_id") @property def discuss_id(self) -> Optional[int]: """讨论组 ID。""" - return self.get('discuss_id') + return self.get("discuss_id") @property def message_id(self) -> Optional[int]: """消息 ID。""" - return self.get('message_id') + return self.get("message_id") @property def message(self) -> Optional[Any]: """消息。""" - return self.get('message') + return self.get("message") @property def raw_message(self) -> Optional[str]: """未经 CQHTTP 处理的原始消息。""" - return self.get('raw_message') + return self.get("raw_message") @property def sender(self) -> Optional[Dict[str, Any]]: """消息发送者信息。""" - return self.get('sender') + return self.get("sender") @property def anonymous(self) -> Optional[Dict[str, Any]]: """匿名信息。""" - return self.get('anonymous') + return self.get("anonymous") @property def file(self) -> Optional[Dict[str, Any]]: """文件信息。""" - return self.get('file') + return self.get("file") @property def comment(self) -> Optional[str]: """请求验证消息。""" - return self.get('comment') + return self.get("comment") @property def flag(self) -> Optional[str]: """请求标识。""" - return self.get('flag') + return self.get("flag") def __repr__(self) -> str: - return f'' + return f"" diff --git a/tests/test_plugins/test_matcher.py b/tests/test_plugins/test_matcher.py index 08472417..f7bf957c 100644 --- a/tests/test_plugins/test_matcher.py +++ b/tests/test_plugins/test_matcher.py @@ -4,6 +4,9 @@ from nonebot.rule import Rule from nonebot.event import Event from nonebot.plugin import on_message +from nonebot.adapters.coolq import Message + +print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]"))) test_matcher = on_message(Rule(), state={"default": 1})