add message segment for coolq

This commit is contained in:
yanyongyu 2020-07-18 18:18:43 +08:00
parent 3dbd927a2a
commit 9355ed4baf
5 changed files with 277 additions and 34 deletions

View File

@ -1,19 +1,24 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Any, Dict, Optional import abc
from functools import reduce
from typing import Dict, Union, Iterable, Optional
from nonebot.config import Config from nonebot.config import Config
class BaseBot(object): class BaseBot(abc.ABC):
@abc.abstractmethod
def __init__(self, type: str, config: Config, *, websocket=None): def __init__(self, type: str, config: Config, *, websocket=None):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, data: dict):
raise NotImplementedError raise NotImplementedError
@ -64,8 +69,73 @@ class BaseMessageSegment(dict):
class BaseMessage(list): class BaseMessage(list):
def __init__(self, message: str = None): def __init__(self,
raise NotImplementedError message: Union[str, BaseMessageSegment, "BaseMessage"] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
if isinstance(message, str):
self.extend(self._construct(message))
elif isinstance(message, BaseMessage):
self.extend(message)
elif isinstance(message, BaseMessageSegment):
self.append(message)
def __str__(self): def __str__(self):
return ''.join((str(seg) for seg in self)) return ''.join((str(seg) for seg in self))
@staticmethod
def _construct(msg: str) -> Iterable[BaseMessageSegment]:
raise NotImplementedError
def __add__(
self, other: Union[str, BaseMessageSegment,
"BaseMessage"]) -> "BaseMessage":
result = self.__class__(self)
if isinstance(other, str):
result.extend(self._construct(other))
elif isinstance(other, BaseMessageSegment):
result.append(other)
elif isinstance(other, BaseMessage):
result.extend(other)
return result
def __radd__(self, other: Union[str, BaseMessageSegment, "BaseMessage"]):
result = self.__class__(other)
return result.__add__(self)
def append(self, obj: Union[str, BaseMessageSegment]) -> "BaseMessage":
if isinstance(obj, BaseMessageSegment):
if obj.type == "text" and self and self[-1].type == "text":
self[-1].data["text"] += obj.data["text"]
else:
super().append(obj)
elif isinstance(obj, str):
self.extend(self._construct(obj))
else:
raise ValueError(f"Unexpected type: {type(obj)} {obj}")
return self
def extend(
self, obj: Union["BaseMessage",
Iterable[BaseMessageSegment]]) -> "BaseMessage":
for segment in obj:
self.append(segment)
return self
def reduce(self) -> None:
index = 0
while index < len(self):
if index > 0 and self[
index - 1].type == "text" and self[index].type == "text":
self[index - 1].data["text"] += self[index].data["text"]
del self[index]
else:
index += 1
def extract_plain_text(self) -> str:
def _concat(x: str, y: BaseMessageSegment) -> str:
return f"{x} {y.data['text']}" if y.type == "text" else x
return reduce(_concat, self, "")

View File

@ -1,6 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re
from typing import Tuple, Iterable, Optional
import httpx import httpx
from nonebot.event import Event from nonebot.event import Event
@ -32,16 +35,21 @@ def unescape(s: str) -> str:
.replace("&amp;", "&") .replace("&amp;", "&")
def _b2s(b: bool) -> str:
return str(b).lower()
class Bot(BaseBot): class Bot(BaseBot):
def __init__(self, def __init__(self,
type_: str, connection_type: str,
config: Config, config: Config,
*, *,
websocket: BaseWebSocket = None): websocket: BaseWebSocket = None):
if type_ not in ["http", "websocket"]: if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type") raise ValueError("Unsupported connection type")
self.type = type_ self.type = "coolq"
self.connection_type = connection_type
self.config = config self.config = config
self.websocket = websocket self.websocket = websocket
@ -49,10 +57,20 @@ class Bot(BaseBot):
# TODO: convert message into event # TODO: convert message into event
event = Event.from_payload(message) event = Event.from_payload(message)
if not event:
return
if "message" in event.keys():
event["message"] = Message(event["message"])
# TODO: Handle Meta Event # TODO: Handle Meta Event
await handle_event(self, event) if event.type == "meta_event":
pass
else:
await handle_event(self, event)
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, data: dict):
# TODO: Call API
if self.type == "websocket": if self.type == "websocket":
pass pass
elif self.type == "http": elif self.type == "http":
@ -66,15 +84,22 @@ class MessageSegment(BaseMessageSegment):
data = self.data.copy() data = self.data.copy()
# process special types # process special types
if type_ == "text": if type_ == "at_all":
return escape(data.get("text", ""), escape_comma=False)
elif type_ == "at_all":
type_ = "at" type_ = "at"
data = {"qq": "all"} data = {"qq": "all"}
elif type_ == "poke":
type_ = "shake"
data.clear()
elif type_ == "text":
return escape(data.get("text", ""), escape_comma=False)
params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()])
return f"[CQ:{type_}{',' if params else ''}{params}]" return f"[CQ:{type_}{',' if params else ''}{params}]"
@staticmethod
def anonymous(ignore_failure: bool = False) -> "MessageSegment":
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
@staticmethod @staticmethod
def at(user_id: int) -> "MessageSegment": def at(user_id: int) -> "MessageSegment":
return MessageSegment("at", {"qq": str(user_id)}) return MessageSegment("at", {"qq": str(user_id)})
@ -84,9 +109,138 @@ class MessageSegment(BaseMessageSegment):
return MessageSegment("at_all") return MessageSegment("at_all")
@staticmethod @staticmethod
def dice() -> "MessageSegment": def contact_group(group_id: int) -> "MessageSegment":
return MessageSegment(type_="dice") return MessageSegment("contact", {"type": "group", "id": str(group_id)})
@staticmethod
def contact_user(user_id: int) -> "MessageSegment":
return MessageSegment("contact", {"type": "qq", "id": str(user_id)})
@staticmethod
def face(id_: int) -> "MessageSegment":
return MessageSegment("face", {"id": str(id_)})
@staticmethod
def image(file: str) -> "MessageSegment":
return MessageSegment("image", {"file": "file"})
@staticmethod
def location(latitude: float,
longitude: float,
title: str = "",
content: str = "") -> "MessageSegment":
return MessageSegment(
"location", {
"lat": str(latitude),
"lon": str(longitude),
"title": title,
"content": content
})
@staticmethod
def magic_face(type_: str) -> "MessageSegment":
if type_ not in ["dice", "rpc"]:
raise ValueError(
f"Coolq doesn't support magic face type {type_}. Supported types: dice, rpc."
)
return MessageSegment("magic_face", {"type": type_})
@staticmethod
def music(type_: str,
id_: int,
style: Optional[int] = None) -> "MessageSegment":
if style is None:
return MessageSegment("music", {"type": type_, "id": id_})
else:
return MessageSegment("music", {
"type": type_,
"id": id_,
"style": style
})
@staticmethod
def music_custom(type_: str,
url: str,
audio: str,
title: str,
content: str = "",
img_url: str = "") -> "MessageSegment":
return MessageSegment(
"music", {
"type": type_,
"url": url,
"audio": audio,
"title": title,
"content": content,
"image": img_url
})
@staticmethod
def poke(type_: str = "Poke") -> "MessageSegment":
if type_ not in ["Poke"]:
raise ValueError(
f"Coolq doesn't support poke type {type_}. Supported types: Poke."
)
return MessageSegment("poke", {"type": type_})
@staticmethod
def record(file: str, magic: bool = False) -> "MessageSegment":
return MessageSegment("record", {"file": file, "magic": _b2s(magic)})
@staticmethod
def share(url: str = "",
title: str = "",
content: str = "",
img_url: str = "") -> "MessageSegment":
return MessageSegment("share", {
"url": url,
"title": title,
"content": content,
"img_url": img_url
})
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text})
class Message(BaseMessage): class Message(BaseMessage):
pass
@staticmethod
def _construct(msg: str) -> Iterable[MessageSegment]:
def _iter_message() -> Iterable[Tuple[str, str]]:
text_begin = 0
for cqcode in re.finditer(
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
r"(?P<params>"
r"(?:,[a-zA-Z0-9-_.]+=?[^,\]]*)*"
r"),?\]", msg):
yield "text", unescape(msg[text_begin:cqcode.pos +
cqcode.start()])
text_begin = cqcode.pos + cqcode.end()
yield cqcode.group("type"), cqcode.group("params").lstrip(",")
yield "text", unescape(msg[text_begin:])
for type_, data in _iter_message():
if type_ == "text":
if data:
# only yield non-empty text segment
yield MessageSegment(type_, {"text": data})
else:
data = {
k: v for k, v in map(
lambda x: x.split("=", maxsplit=1),
filter(lambda x: x, (
x.lstrip() for x in data.split(","))))
}
if type_ == "at" and data["qq"] == "all":
type_ = "at_all"
data.clear()
elif type_ in ["dice", "rpc"]:
type_ = "magic_face"
data["type"] = type_
elif type_ == "shake":
type_ = "poke"
data["type"] = "Poke"
yield MessageSegment(type_, data)

View File

@ -1,29 +1,35 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import abc
from typing import Optional from typing import Optional
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.config import Config from nonebot.config import Config
class BaseDriver(object): class BaseDriver(abc.ABC):
@abc.abstractmethod
def __init__(self, config: Config): def __init__(self, config: Config):
raise NotImplementedError raise NotImplementedError
@property @property
@abc.abstractmethod
def server_app(self): def server_app(self):
raise NotImplementedError raise NotImplementedError
@property @property
@abc.abstractmethod
def asgi(self): def asgi(self):
raise NotImplementedError raise NotImplementedError
@property @property
@abc.abstractmethod
def logger(self): def logger(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def run(self, def run(self,
host: Optional[IPv4Address] = None, host: Optional[IPv4Address] = None,
port: Optional[int] = None, port: Optional[int] = None,
@ -31,37 +37,47 @@ class BaseDriver(object):
**kwargs): **kwargs):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def _handle_http(self): async def _handle_http(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def _handle_ws_reverse(self): async def _handle_ws_reverse(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def _handle_http_api(self): async def _handle_http_api(self):
raise NotImplementedError raise NotImplementedError
class BaseWebSocket(object): class BaseWebSocket(object):
@abc.abstractmethod
def __init__(self, websocket): def __init__(self, websocket):
self._websocket = websocket self._websocket = websocket
@property @property
@abc.abstractmethod
def websocket(self): def websocket(self):
return self._websocket return self._websocket
@property @property
@abc.abstractmethod
def closed(self): def closed(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def accept(self): async def accept(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def close(self): async def close(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> dict: async def receive(self) -> dict:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def send(self, data: dict): async def send(self, data: dict):
raise NotImplementedError raise NotImplementedError

View File

@ -29,7 +29,7 @@ class Event(dict):
""" """
事件类型 ``message````notice````request````meta_event`` 事件类型 ``message````notice````request````meta_event``
""" """
return self['post_type'] return self["post_type"]
@property @property
def detail_type(self) -> str: def detail_type(self) -> str:
@ -37,7 +37,7 @@ class Event(dict):
事件具体类型 `type` 的不同而不同 ``message`` 类型为例 事件具体类型 `type` 的不同而不同 ``message`` 类型为例
``private````group````discuss`` ``private````group````discuss``
""" """
return self[f'{self.type}_type'] return self[f"{self.type}_type"]
@property @property
def sub_type(self) -> Optional[str]: def sub_type(self) -> Optional[str]:
@ -45,7 +45,7 @@ class Event(dict):
事件子类型 `detail_type` 不同而不同 ``message.private`` 为例 事件子类型 `detail_type` 不同而不同 ``message.private`` 为例
``friend````group````discuss````other`` ``friend````group````discuss````other``
""" """
return self.get('sub_type') return self.get("sub_type")
@property @property
def name(self): def name(self):
@ -53,75 +53,75 @@ class Event(dict):
事件名对于有 `sub_type` 的事件 ``{type}.{detail_type}.{sub_type}``否则为 事件名对于有 `sub_type` 的事件 ``{type}.{detail_type}.{sub_type}``否则为
``{type}.{detail_type}`` ``{type}.{detail_type}``
""" """
n = self.type + '.' + self.detail_type n = self.type + "." + self.detail_type
if self.sub_type: if self.sub_type:
n += '.' + self.sub_type n += "." + self.sub_type
return n return n
@property @property
def self_id(self) -> int: def self_id(self) -> int:
"""机器人自身 ID。""" """机器人自身 ID。"""
return self['self_id'] return self["self_id"]
@property @property
def user_id(self) -> Optional[int]: def user_id(self) -> Optional[int]:
"""用户 ID。""" """用户 ID。"""
return self.get('user_id') return self.get("user_id")
@property @property
def operator_id(self) -> Optional[int]: def operator_id(self) -> Optional[int]:
"""操作者 ID。""" """操作者 ID。"""
return self.get('operator_id') return self.get("operator_id")
@property @property
def group_id(self) -> Optional[int]: def group_id(self) -> Optional[int]:
"""群 ID。""" """群 ID。"""
return self.get('group_id') return self.get("group_id")
@property @property
def discuss_id(self) -> Optional[int]: def discuss_id(self) -> Optional[int]:
"""讨论组 ID。""" """讨论组 ID。"""
return self.get('discuss_id') return self.get("discuss_id")
@property @property
def message_id(self) -> Optional[int]: def message_id(self) -> Optional[int]:
"""消息 ID。""" """消息 ID。"""
return self.get('message_id') return self.get("message_id")
@property @property
def message(self) -> Optional[Any]: def message(self) -> Optional[Any]:
"""消息。""" """消息。"""
return self.get('message') return self.get("message")
@property @property
def raw_message(self) -> Optional[str]: def raw_message(self) -> Optional[str]:
"""未经 CQHTTP 处理的原始消息。""" """未经 CQHTTP 处理的原始消息。"""
return self.get('raw_message') return self.get("raw_message")
@property @property
def sender(self) -> Optional[Dict[str, Any]]: def sender(self) -> Optional[Dict[str, Any]]:
"""消息发送者信息。""" """消息发送者信息。"""
return self.get('sender') return self.get("sender")
@property @property
def anonymous(self) -> Optional[Dict[str, Any]]: def anonymous(self) -> Optional[Dict[str, Any]]:
"""匿名信息。""" """匿名信息。"""
return self.get('anonymous') return self.get("anonymous")
@property @property
def file(self) -> Optional[Dict[str, Any]]: def file(self) -> Optional[Dict[str, Any]]:
"""文件信息。""" """文件信息。"""
return self.get('file') return self.get("file")
@property @property
def comment(self) -> Optional[str]: def comment(self) -> Optional[str]:
"""请求验证消息。""" """请求验证消息。"""
return self.get('comment') return self.get("comment")
@property @property
def flag(self) -> Optional[str]: def flag(self) -> Optional[str]:
"""请求标识。""" """请求标识。"""
return self.get('flag') return self.get("flag")
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Event, {super().__repr__()}>' return f"<Event, {super().__repr__()}>"

View File

@ -4,6 +4,9 @@
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.event import Event from nonebot.event import Event
from nonebot.plugin import on_message from nonebot.plugin import on_message
from nonebot.adapters.coolq import Message
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
test_matcher = on_message(Rule(), state={"default": 1}) test_matcher = on_message(Rule(), state={"default": 1})