nonebot2/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py

472 lines
16 KiB
Python
Raw Normal View History

2020-12-02 19:52:45 +08:00
import re
import sys
import hmac
import json
import asyncio
2021-06-10 21:52:20 +08:00
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
2020-12-02 19:52:45 +08:00
import httpx
from nonebot.log import logger
2020-12-06 02:30:19 +08:00
from nonebot.typing import overrides
2020-12-02 19:52:45 +08:00
from nonebot.message import handle_event
2020-12-07 00:06:09 +08:00
from nonebot.adapters import Bot as BaseBot
2021-07-07 14:08:44 +08:00
from nonebot.utils import escape_tag, DataclassEncoder
2021-07-31 12:24:11 +08:00
from nonebot.drivers import Driver, ForwardDriver, WebSocketSetup
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
2020-12-02 19:52:45 +08:00
from .utils import log, escape
from .config import Config as CQHTTPConfig
from .message import Message, MessageSegment
from .event import Reply, Event, MessageEvent, get_event_model
2020-12-03 16:04:14 +08:00
from .exception import NetworkError, ApiNotAvailable, ActionFailed
2020-12-06 02:30:19 +08:00
if TYPE_CHECKING:
2021-01-17 13:46:29 +08:00
from nonebot.config import Config
2020-12-06 02:30:19 +08:00
2020-12-03 16:04:14 +08:00
2020-12-05 20:32:38 +08:00
def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]:
2020-12-03 16:04:14 +08:00
if not access_token:
return None
scheme, _, param = access_token.partition(" ")
if scheme.lower() not in ["bearer", "token"]:
2021-06-10 21:52:20 +08:00
return None
2020-12-03 16:04:14 +08:00
return param
2020-12-02 19:52:45 +08:00
async def _check_reply(bot: "Bot", event: "Event"):
2020-12-02 19:52:45 +08:00
"""
:说明:
检查消息中存在的回复去除并赋值 ``event.reply``, ``event.to_me``
:参数:
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
2020-12-02 19:52:45 +08:00
"""
2020-12-09 17:51:24 +08:00
if not isinstance(event, MessageEvent):
2020-12-02 19:52:45 +08:00
return
try:
index = list(map(lambda x: x.type == "reply",
event.message)).index(True)
except ValueError:
return
msg_seg = event.message[index]
2021-03-31 20:09:00 +08:00
try:
event.reply = Reply.parse_obj(await
bot.get_msg(message_id=msg_seg.data["id"]
))
except Exception as e:
log("WARNING", f"Error when getting message reply info: {repr(e)}", e)
2021-04-01 20:23:55 +08:00
return
2020-12-02 19:52:45 +08:00
# ensure string comparation
2020-12-09 17:51:24 +08:00
if str(event.reply.sender.user_id) == str(event.self_id):
2020-12-02 19:52:45 +08:00
event.to_me = True
del event.message[index]
if len(event.message) > index and event.message[index].type == "at":
del event.message[index]
if len(event.message) > index and event.message[index].type == "text":
event.message[index].data["text"] = event.message[index].data[
"text"].lstrip()
if not event.message[index].data["text"]:
del event.message[index]
if not event.message:
event.message.append(MessageSegment.text(""))
def _check_at_me(bot: "Bot", event: "Event"):
2020-12-02 19:52:45 +08:00
"""
:说明:
检查消息开头或结尾是否存在 @机器人去除并赋值 ``event.to_me``
:参数:
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
2020-12-02 19:52:45 +08:00
"""
2020-12-09 17:51:24 +08:00
if not isinstance(event, MessageEvent):
2020-12-02 19:52:45 +08:00
return
2021-01-04 13:27:49 +08:00
# ensure message not empty
if not event.message:
event.message.append(MessageSegment.text(""))
2020-12-09 17:51:24 +08:00
if event.message_type == "private":
2020-12-02 19:52:45 +08:00
event.to_me = True
else:
def _is_at_me_seg(segment: MessageSegment):
return segment.type == "at" and str(segment.data.get(
"qq", "")) == str(event.self_id)
2020-12-02 19:52:45 +08:00
# check the first segment
if _is_at_me_seg(event.message[0]):
2020-12-02 19:52:45 +08:00
event.to_me = True
event.message.pop(0)
2020-12-02 19:52:45 +08:00
if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[
"text"].lstrip()
if not event.message[0].data["text"]:
del event.message[0]
if event.message and _is_at_me_seg(event.message[0]):
event.message.pop(0)
2020-12-02 19:52:45 +08:00
if event.message and event.message[0].type == "text":
event.message[0].data["text"] = event.message[0].data[
"text"].lstrip()
if not event.message[0].data["text"]:
del event.message[0]
if not event.to_me:
# check the last segment
i = -1
last_msg_seg = event.message[i]
if last_msg_seg.type == "text" and \
not last_msg_seg.data["text"].strip() and \
len(event.message) >= 2:
i -= 1
last_msg_seg = event.message[i]
if _is_at_me_seg(last_msg_seg):
2020-12-02 19:52:45 +08:00
event.to_me = True
del event.message[i:]
if not event.message:
event.message.append(MessageSegment.text(""))
def _check_nickname(bot: "Bot", event: "Event"):
2020-12-02 19:52:45 +08:00
"""
:说明:
2021-07-11 11:37:14 +08:00
检查消息开头是否存在昵称去除并赋值 ``event.to_me``
2020-12-02 19:52:45 +08:00
:参数:
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
2020-12-02 19:52:45 +08:00
"""
2020-12-09 17:51:24 +08:00
if not isinstance(event, MessageEvent):
2020-12-02 19:52:45 +08:00
return
first_msg_seg = event.message[0]
if first_msg_seg.type != "text":
return
first_text = first_msg_seg.data["text"]
nicknames = set(filter(lambda n: n, bot.config.nickname))
if nicknames:
# check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
if m:
nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
2020-12-05 20:32:38 +08:00
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
2020-12-02 19:52:45 +08:00
"""
:说明:
处理 API 请求返回值
:参数:
* ``result: Optional[Dict[str, Any]]``: API 返回数据
:返回:
- ``Any``: API 调用返回数据
:异常:
2020-12-03 16:04:14 +08:00
- ``ActionFailed``: API 调用失败
2020-12-02 19:52:45 +08:00
"""
if isinstance(result, dict):
if result.get("status") == "failed":
2020-12-19 00:50:17 +08:00
raise ActionFailed(**result)
2020-12-02 19:52:45 +08:00
return result.get("data")
class ResultStore:
_seq = 1
_futures: Dict[int, asyncio.Future] = {}
@classmethod
def get_seq(cls) -> int:
s = cls._seq
cls._seq = (cls._seq + 1) % sys.maxsize
return s
@classmethod
def add_result(cls, result: Dict[str, Any]):
if isinstance(result.get("echo"), dict) and \
isinstance(result["echo"].get("seq"), int):
future = cls._futures.get(result["echo"]["seq"])
if future:
future.set_result(result)
@classmethod
async def fetch(cls, seq: int, timeout: Optional[float]) -> Dict[str, Any]:
future = asyncio.get_event_loop().create_future()
cls._futures[seq] = future
try:
return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
raise NetworkError("WebSocket API call timeout") from None
finally:
del cls._futures[seq]
class Bot(BaseBot):
"""
CQHTTP 协议 Bot 适配继承属性参考 `BaseBot <./#class-basebot>`_ 。
"""
2021-01-17 13:46:29 +08:00
cqhttp_config: CQHTTPConfig
2020-12-02 19:52:45 +08:00
@property
@overrides(BaseBot)
def type(self) -> str:
"""
- 返回: ``"cqhttp"``
"""
return "cqhttp"
2021-01-17 13:46:29 +08:00
@classmethod
2021-06-10 21:52:20 +08:00
def register(cls, driver: Driver, config: "Config"):
2021-01-17 13:46:29 +08:00
super().register(driver, config)
cls.cqhttp_config = CQHTTPConfig(**config.dict())
if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
logger.warning(
f"Current driver {cls.config.driver} don't support forward connections"
)
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items():
try:
headers = {
"authorization":
2021-07-20 15:35:56 +08:00
f"Bearer {cls.cqhttp_config.access_token}"
2021-07-20 15:47:52 +08:00
} if cls.cqhttp_config.access_token else {}
2021-07-31 12:24:11 +08:00
driver.setup_websocket(
WebSocketSetup("cqhttp", self_id, url, headers=headers))
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {escape_tag(url)} for bot {escape_tag(self_id)} "
2021-07-20 15:35:56 +08:00
"in cqhttp forward websocket</bg #f8bbd0></r>")
2021-01-17 13:46:29 +08:00
2020-12-02 19:52:45 +08:00
@classmethod
@overrides(BaseBot)
2021-06-10 21:52:20 +08:00
async def check_permission(
cls, driver: Driver,
request: HTTPConnection) -> Tuple[Optional[str], HTTPResponse]:
2020-12-02 19:52:45 +08:00
"""
:说明:
CQHTTP (OneBot) 协议鉴权参考 `鉴权 <https://github.com/howmanybots/onebot/blob/master/v11/specs/communication/authorization.md>`_
"""
2021-06-10 21:52:20 +08:00
x_self_id = request.headers.get("x-self-id")
x_signature = request.headers.get("x-signature")
token = get_auth_bearer(request.headers.get("authorization"))
cqhttp_config = CQHTTPConfig(**driver.config.dict())
2020-12-02 19:52:45 +08:00
# 检查self_id
if not x_self_id:
log("WARNING", "Missing X-Self-ID Header")
2021-06-10 21:52:20 +08:00
return None, HTTPResponse(400, b"Missing X-Self-ID Header")
2020-12-02 19:52:45 +08:00
# 检查签名
2021-01-17 13:46:29 +08:00
secret = cqhttp_config.secret
2021-06-10 21:52:20 +08:00
if secret and isinstance(request, HTTPRequest):
2020-12-02 19:52:45 +08:00
if not x_signature:
log("WARNING", "Missing Signature Header")
2021-06-10 21:52:20 +08:00
return None, HTTPResponse(401, b"Missing Signature")
sig = hmac.new(secret.encode("utf-8"), request.body,
"sha1").hexdigest()
2020-12-02 19:52:45 +08:00
if x_signature != "sha1=" + sig:
log("WARNING", "Signature Header is invalid")
2021-06-10 21:52:20 +08:00
return None, HTTPResponse(403, b"Signature is invalid")
2020-12-02 19:52:45 +08:00
2021-01-17 13:46:29 +08:00
access_token = cqhttp_config.access_token
2021-06-10 21:52:20 +08:00
if access_token and access_token != token and isinstance(
request, WebSocket):
2020-12-02 19:52:45 +08:00
log(
"WARNING", "Authorization Header is invalid"
2020-12-03 17:08:16 +08:00
if token else "Missing Authorization Header")
2021-06-10 21:52:20 +08:00
return None, HTTPResponse(
403, b"Authorization Header is invalid"
if token else b"Missing Authorization Header")
return str(x_self_id), HTTPResponse(204, b'')
2020-12-02 19:52:45 +08:00
@overrides(BaseBot)
2021-06-10 21:52:20 +08:00
async def handle_message(self, message: bytes):
2020-12-02 19:52:45 +08:00
"""
:说明:
调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_
2020-12-02 19:52:45 +08:00
"""
data: dict = json.loads(message)
2021-06-10 21:52:20 +08:00
if not data:
2020-12-02 19:52:45 +08:00
return
2021-06-10 21:52:20 +08:00
if "post_type" not in data:
ResultStore.add_result(data)
2020-12-02 19:52:45 +08:00
return
try:
2021-06-10 21:52:20 +08:00
post_type = data['post_type']
detail_type = data.get(f"{post_type}_type")
2020-12-09 19:57:49 +08:00
detail_type = f".{detail_type}" if detail_type else ""
2021-06-10 21:52:20 +08:00
sub_type = data.get("sub_type")
2020-12-09 19:57:49 +08:00
sub_type = f".{sub_type}" if sub_type else ""
models = get_event_model(post_type + detail_type + sub_type)
2020-12-09 19:57:49 +08:00
for model in models:
try:
2021-06-10 21:52:20 +08:00
event = model.parse_obj(data)
2020-12-09 19:57:49 +08:00
break
except Exception as e:
log("DEBUG", "Event Parser Error", e)
else:
2021-06-10 21:52:20 +08:00
event = Event.parse_obj(data)
2020-12-02 19:52:45 +08:00
# Check whether user is calling me
await _check_reply(self, event)
_check_at_me(self, event)
_check_nickname(self, event)
await handle_event(self, event)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {escape_tag(str(data))}</bg #f8bbd0></r>"
2020-12-02 19:52:45 +08:00
)
@overrides(BaseBot)
2021-03-31 16:51:09 +08:00
async def _call_api(self, api: str, **data) -> Any:
2020-12-02 19:52:45 +08:00
log("DEBUG", f"Calling API <y>{api}</y>")
2021-06-10 21:52:20 +08:00
if isinstance(self.request, WebSocket):
2020-12-02 19:52:45 +08:00
seq = ResultStore.get_seq()
2021-06-10 21:52:20 +08:00
json_data = json.dumps(
{
"action": api,
"params": data,
"echo": {
"seq": seq
}
},
cls=DataclassEncoder)
await self.request.send(json_data)
2020-12-02 19:52:45 +08:00
return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout))
2021-06-10 21:52:20 +08:00
elif isinstance(self.request, HTTPRequest):
2020-12-02 19:52:45 +08:00
api_root = self.config.api_root.get(self.self_id)
if not api_root:
raise ApiNotAvailable
elif not api_root.endswith("/"):
api_root += "/"
headers = {"Content-Type": "application/json"}
2021-01-17 13:46:29 +08:00
if self.cqhttp_config.access_token is not None:
headers[
2021-01-17 13:46:29 +08:00
"Authorization"] = "Bearer " + self.cqhttp_config.access_token
2020-12-02 19:52:45 +08:00
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
api_root + api,
content=json.dumps(data, cls=DataclassEncoder),
2020-12-02 19:52:45 +08:00
timeout=self.config.api_timeout)
if 200 <= response.status_code < 300:
result = response.json()
return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
2021-03-31 16:51:09 +08:00
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> Any:
"""
:说明:
调用 CQHTTP 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
2021-03-31 21:20:07 +08:00
return await super().call_api(api, **data)
2021-03-31 16:51:09 +08:00
2020-12-02 19:52:45 +08:00
@overrides(BaseBot)
async def send(self,
event: Event,
message: Union[str, Message, MessageSegment],
2020-12-02 19:52:45 +08:00
at_sender: bool = False,
2020-12-05 20:32:38 +08:00
**kwargs) -> Any:
2020-12-02 19:52:45 +08:00
"""
:说明:
根据 ``event`` 向触发事件的主体发送消息
:参数:
* ``event: Event``: Event 对象
2020-12-02 19:52:45 +08:00
* ``message: Union[str, Message, MessageSegment]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主体
* ``**kwargs``: 覆盖默认参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``ValueError``: 缺少 ``user_id``, ``group_id``
- ``NetworkError``: 网络错误
2020-12-03 16:04:14 +08:00
- ``ActionFailed``: API 调用失败
2020-12-02 19:52:45 +08:00
"""
2021-02-05 14:26:03 +08:00
message = escape(message, escape_comma=False) if isinstance(
message, str) else message
2020-12-02 19:52:45 +08:00
msg = message if isinstance(message, Message) else Message(message)
2021-06-10 21:52:20 +08:00
at_sender = at_sender and bool(getattr(event, "user_id", None))
2020-12-02 19:52:45 +08:00
params = {}
if getattr(event, "user_id", None):
2020-12-11 16:29:12 +08:00
params["user_id"] = getattr(event, "user_id")
if getattr(event, "group_id", None):
2020-12-11 16:29:12 +08:00
params["group_id"] = getattr(event, "group_id")
2020-12-02 19:52:45 +08:00
params.update(kwargs)
if "message_type" not in params:
if params.get("group_id", None):
2020-12-02 19:52:45 +08:00
params["message_type"] = "group"
elif params.get("user_id", None):
2020-12-02 19:52:45 +08:00
params["message_type"] = "private"
else:
raise ValueError("Cannot guess message type to reply!")
if at_sender and params["message_type"] != "private":
2021-06-10 21:52:20 +08:00
params["message"] = MessageSegment.at(params["user_id"]) + " " + msg
2020-12-02 19:52:45 +08:00
else:
params["message"] = msg
return await self.send_msg(**params)