import re import sys import hmac import json import asyncio from typing import TYPE_CHECKING, Any, Dict, Tuple, Union, Optional import httpx from nonebot.log import logger from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot from nonebot.utils import DataclassEncoder, escape_tag from nonebot.drivers import ( Driver, WebSocket, HTTPRequest, HTTPResponse, ForwardDriver, HTTPConnection, WebSocketSetup, ) from .utils import log, escape from .config import Config as CQHTTPConfig from .message import Message, MessageSegment from .event import Event, Reply, MessageEvent, get_event_model from .exception import ActionFailed, NetworkError, ApiNotAvailable if TYPE_CHECKING: from nonebot.config import Config def get_auth_bearer(access_token: Optional[str] = None) -> Optional[str]: if not access_token: return None scheme, _, param = access_token.partition(" ") if scheme.lower() not in ["bearer", "token"]: return None return param async def _check_reply(bot: "Bot", event: "Event"): """ :说明: 检查消息中存在的回复,去除并赋值 ``event.reply``, ``event.to_me`` :参数: * ``bot: Bot``: Bot 对象 * ``event: Event``: Event 对象 """ if not isinstance(event, MessageEvent): return try: index = list(map(lambda x: x.type == "reply", event.message)).index(True) except ValueError: return msg_seg = event.message[index] try: event.reply = Reply.parse_obj(await bot.get_msg(message_id=msg_seg.data["id"])) except Exception as e: log("WARNING", f"Error when getting message reply info: {repr(e)}", e) return # ensure string comparation if str(event.reply.sender.user_id) == str(event.self_id): event.to_me = True del event.message[index] if len(event.message) > index and event.message[index].type == "at": del event.message[index] if len(event.message) > index and event.message[index].type == "text": event.message[index].data["text"] = event.message[index].data["text"].lstrip() if not event.message[index].data["text"]: del event.message[index] if not event.message: event.message.append(MessageSegment.text("")) def _check_at_me(bot: "Bot", event: "Event"): """ :说明: 检查消息开头或结尾是否存在 @机器人,去除并赋值 ``event.to_me`` :参数: * ``bot: Bot``: Bot 对象 * ``event: Event``: Event 对象 """ if not isinstance(event, MessageEvent): return # ensure message not empty if not event.message: event.message.append(MessageSegment.text("")) if event.message_type == "private": event.to_me = True else: def _is_at_me_seg(segment: MessageSegment): return segment.type == "at" and str(segment.data.get("qq", "")) == str( event.self_id ) # check the first segment if _is_at_me_seg(event.message[0]): event.to_me = True event.message.pop(0) if event.message and event.message[0].type == "text": event.message[0].data["text"] = event.message[0].data["text"].lstrip() if not event.message[0].data["text"]: del event.message[0] if event.message and _is_at_me_seg(event.message[0]): event.message.pop(0) if event.message and event.message[0].type == "text": event.message[0].data["text"] = ( event.message[0].data["text"].lstrip() ) if not event.message[0].data["text"]: del event.message[0] if not event.to_me: # check the last segment i = -1 last_msg_seg = event.message[i] if ( last_msg_seg.type == "text" and not last_msg_seg.data["text"].strip() and len(event.message) >= 2 ): i -= 1 last_msg_seg = event.message[i] if _is_at_me_seg(last_msg_seg): event.to_me = True del event.message[i:] if not event.message: event.message.append(MessageSegment.text("")) def _check_nickname(bot: "Bot", event: "Event"): """ :说明: 检查消息开头是否存在昵称,去除并赋值 ``event.to_me`` :参数: * ``bot: Bot``: Bot 对象 * ``event: Event``: Event 对象 """ if not isinstance(event, MessageEvent): return first_msg_seg = event.message[0] if first_msg_seg.type != "text": return first_text = first_msg_seg.data["text"] nicknames = set(filter(lambda n: n, bot.config.nickname)) if nicknames: # check if the user is calling me with my nickname nickname_regex = "|".join(nicknames) m = re.search(rf"^({nickname_regex})([\s,,]*|$)", first_text, re.IGNORECASE) if m: nickname = m.group(1) log("DEBUG", f"User is calling me {nickname}") event.to_me = True first_msg_seg.data["text"] = first_text[m.end() :] def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: """ :说明: 处理 API 请求返回值。 :参数: * ``result: Optional[Dict[str, Any]]``: API 返回数据 :返回: - ``Any``: API 调用返回数据 :异常: - ``ActionFailed``: API 调用失败 """ if isinstance(result, dict): if result.get("status") == "failed": raise ActionFailed(**result) return result.get("data") class ResultStore: _seq = 1 _futures: Dict[int, asyncio.Future] = {} @classmethod def get_seq(cls) -> int: s = cls._seq cls._seq = (cls._seq + 1) % sys.maxsize return s @classmethod def add_result(cls, result: Dict[str, Any]): if isinstance(result.get("echo"), dict) and isinstance( result["echo"].get("seq"), int ): future = cls._futures.get(result["echo"]["seq"]) if future: future.set_result(result) @classmethod async def fetch(cls, seq: int, timeout: Optional[float]) -> Dict[str, Any]: future = asyncio.get_event_loop().create_future() cls._futures[seq] = future try: return await asyncio.wait_for(future, timeout) except asyncio.TimeoutError: raise NetworkError("WebSocket API call timeout") from None finally: del cls._futures[seq] class Bot(BaseBot): """ CQHTTP 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。 """ cqhttp_config: CQHTTPConfig @property @overrides(BaseBot) def type(self) -> str: """ - 返回: ``"cqhttp"`` """ return "cqhttp" @classmethod def register(cls, driver: Driver, config: "Config"): super().register(driver, config) cls.cqhttp_config = CQHTTPConfig(**config.dict()) if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: logger.warning( f"Current driver {cls.config.driver} don't support forward connections" ) elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: for self_id, url in cls.cqhttp_config.ws_urls.items(): try: headers = ( {"authorization": f"Bearer {cls.cqhttp_config.access_token}"} if cls.cqhttp_config.access_token else {} ) driver.setup_websocket( WebSocketSetup("cqhttp", self_id, url, headers=headers) ) except Exception as e: logger.opt(colors=True, exception=e).error( f"Bad url {escape_tag(url)} for bot {escape_tag(self_id)} " "in cqhttp forward websocket" ) @classmethod @overrides(BaseBot) async def check_permission( cls, driver: Driver, request: HTTPConnection ) -> Tuple[Optional[str], HTTPResponse]: """ :说明: CQHTTP (OneBot) 协议鉴权。参考 `鉴权 `_ """ x_self_id = request.headers.get("x-self-id") x_signature = request.headers.get("x-signature") token = get_auth_bearer(request.headers.get("authorization")) cqhttp_config = CQHTTPConfig(**driver.config.dict()) # 检查self_id if not x_self_id: log("WARNING", "Missing X-Self-ID Header") return None, HTTPResponse(400, b"Missing X-Self-ID Header") # 检查签名 secret = cqhttp_config.secret if secret and isinstance(request, HTTPRequest): if not x_signature: log("WARNING", "Missing Signature Header") return None, HTTPResponse(401, b"Missing Signature") sig = hmac.new(secret.encode("utf-8"), request.body, "sha1").hexdigest() if x_signature != "sha1=" + sig: log("WARNING", "Signature Header is invalid") return None, HTTPResponse(403, b"Signature is invalid") access_token = cqhttp_config.access_token if access_token and access_token != token and isinstance(request, WebSocket): log( "WARNING", "Authorization Header is invalid" if token else "Missing Authorization Header", ) return None, HTTPResponse( 403, b"Authorization Header is invalid" if token else b"Missing Authorization Header", ) return str(x_self_id), HTTPResponse(204, b"") @overrides(BaseBot) async def handle_message(self, message: bytes): """ :说明: 调用 `_check_reply <#async-check-reply-bot-event>`_, `_check_at_me <#check-at-me-bot-event>`_, `_check_nickname <#check-nickname-bot-event>`_ 处理事件并转换为 `Event <#class-event>`_ """ data: dict = json.loads(message) if not data: return if "post_type" not in data: ResultStore.add_result(data) return try: post_type = data["post_type"] detail_type = data.get(f"{post_type}_type") detail_type = f".{detail_type}" if detail_type else "" sub_type = data.get("sub_type") sub_type = f".{sub_type}" if sub_type else "" models = get_event_model(post_type + detail_type + sub_type) for model in models: try: event = model.parse_obj(data) break except Exception as e: log("DEBUG", "Event Parser Error", e) else: event = Event.parse_obj(data) # Check whether user is calling me await _check_reply(self, event) _check_at_me(self, event) _check_nickname(self, event) await handle_event(self, event) except Exception as e: logger.opt(colors=True, exception=e).error( f"Failed to handle event. Raw: {escape_tag(str(data))}" ) @overrides(BaseBot) async def _call_api(self, api: str, **data) -> Any: log("DEBUG", f"Calling API {api}") if isinstance(self.request, WebSocket): seq = ResultStore.get_seq() json_data = json.dumps( {"action": api, "params": data, "echo": {"seq": seq}}, cls=DataclassEncoder, ) await self.request.send(json_data) return _handle_api_result( await ResultStore.fetch(seq, self.config.api_timeout) ) elif isinstance(self.request, HTTPRequest): api_root = self.config.api_root.get(self.self_id) if not api_root: raise ApiNotAvailable elif not api_root.endswith("/"): api_root += "/" headers = {"Content-Type": "application/json"} if self.cqhttp_config.access_token is not None: headers["Authorization"] = "Bearer " + self.cqhttp_config.access_token try: async with httpx.AsyncClient( headers=headers, follow_redirects=True ) as client: response = await client.post( api_root + api, content=json.dumps(data, cls=DataclassEncoder), timeout=self.config.api_timeout, ) if 200 <= response.status_code < 300: result = response.json() return _handle_api_result(result) raise NetworkError( f"HTTP request received unexpected " f"status code: {response.status_code}" ) except httpx.InvalidURL: raise NetworkError("API root url invalid") except httpx.HTTPError: raise NetworkError("HTTP request failed") @overrides(BaseBot) async def call_api(self, api: str, **data) -> Any: """ :说明: 调用 CQHTTP 协议 API :参数: * ``api: str``: API 名称 * ``**data: Any``: API 参数 :返回: - ``Any``: API 调用返回数据 :异常: - ``NetworkError``: 网络错误 - ``ActionFailed``: API 调用失败 """ return await super().call_api(api, **data) @overrides(BaseBot) async def send( self, event: Event, message: Union[str, Message, MessageSegment], at_sender: bool = False, **kwargs, ) -> Any: """ :说明: 根据 ``event`` 向触发事件的主体发送消息。 :参数: * ``event: Event``: Event 对象 * ``message: Union[str, Message, MessageSegment]``: 要发送的消息 * ``at_sender: bool``: 是否 @ 事件主体 * ``**kwargs``: 覆盖默认参数 :返回: - ``Any``: API 调用返回数据 :异常: - ``ValueError``: 缺少 ``user_id``, ``group_id`` - ``NetworkError``: 网络错误 - ``ActionFailed``: API 调用失败 """ message = ( escape(message, escape_comma=False) if isinstance(message, str) else message ) msg = message if isinstance(message, Message) else Message(message) at_sender = at_sender and bool(getattr(event, "user_id", None)) params = {} if getattr(event, "user_id", None): params["user_id"] = getattr(event, "user_id") if getattr(event, "group_id", None): params["group_id"] = getattr(event, "group_id") params.update(kwargs) if "message_type" not in params: if params.get("group_id", None): params["message_type"] = "group" elif params.get("user_id", None): params["message_type"] = "private" else: raise ValueError("Cannot guess message type to reply!") if at_sender and params["message_type"] != "private": params["message"] = MessageSegment.at(params["user_id"]) + " " + msg else: params["message"] = msg return await self.send_msg(**params)