From 8f99b01fb5374a1a015a76cd0cd980a31abf78a3 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Wed, 31 Mar 2021 16:51:09 +0800 Subject: [PATCH] :alembic: add call_api hook --- docs/api/adapters/README.md | 39 +++++++++++- docs/api/adapters/ding.md | 3 + nonebot/adapters/_base.py | 58 ++++++++++++++++-- nonebot/typing.py | 1 + .../nonebot/adapters/cqhttp/bot.py | 50 ++++++++-------- .../nonebot/adapters/cqhttp/bot.pyi | 3 +- .../nonebot/adapters/ding/bot.py | 59 ++++++++++--------- .../nonebot/adapters/mirai/bot.py | 4 ++ 8 files changed, 157 insertions(+), 60 deletions(-) diff --git a/docs/api/adapters/README.md b/docs/api/adapters/README.md index 05fe3b49..bd4aab2a 100644 --- a/docs/api/adapters/README.md +++ b/docs/api/adapters/README.md @@ -27,6 +27,21 @@ Driver 对象 Config 配置对象 +### `_call_api_hook` + + +* **类型** + + `Set[T_CallingAPIHook]` + + + +* **说明** + + call_api 时执行的函数 + + + ### _abstract_ `__init__(connection_type, self_id, *, websocket=None)` @@ -127,7 +142,26 @@ Adapter 类型 -### _abstract async_ `call_api(api, **data)` +### _abstract async_ `_call_api(api, **data)` + + +* **说明** + + `adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。 + + + +* **参数** + + + * `api: str`: API 名称 + + + * `**data`: API 数据 + + + +### _async_ `call_api(api, **data)` * **说明** @@ -142,6 +176,9 @@ Adapter 类型 * `api: str`: API 名称 + * `self_id: Optional[str]`: 指定调用 API 的机器人 + + * `**data`: API 数据 diff --git a/docs/api/adapters/ding.md b/docs/api/adapters/ding.md index 7cfe5932..2c531a7b 100644 --- a/docs/api/adapters/ding.md +++ b/docs/api/adapters/ding.md @@ -129,6 +129,9 @@ sidebarDepth: 0 * `api: str`: API 名称 + * `event: Optional[MessageEvent]`: Event 对象 + + * `**data: Any`: API 参数 diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 8f4c2898..d36407a9 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -6,20 +6,30 @@ """ import abc +import asyncio from copy import copy from functools import reduce, partial from dataclasses import dataclass, field -from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING +from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable, + Protocol, Awaitable, TYPE_CHECKING) from pydantic import BaseModel +from nonebot.log import logger from nonebot.utils import DataclassEncoder +from nonebot.typing import T_CallingAPIHook if TYPE_CHECKING: from nonebot.config import Config from nonebot.drivers import Driver, WebSocket +class _ApiCall(Protocol): + + def __call__(self, **kwargs: Any) -> Awaitable[Any]: + ... + + class Bot(abc.ABC): """ Bot 基类。用于处理上报消息,并提供 API 调用接口。 @@ -29,6 +39,11 @@ class Bot(abc.ABC): """Driver 对象""" config: "Config" """Config 配置对象""" + _call_api_hook: Set[T_CallingAPIHook] = set() + """ + :类型: ``Set[T_CallingAPIHook]`` + :说明: call_api 时执行的函数 + """ @abc.abstractmethod def __init__(self, @@ -50,7 +65,7 @@ class Bot(abc.ABC): self.websocket = websocket """Websocket 连接对象""" - def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]: + def __getattr__(self, name: str) -> _ApiCall: return partial(self.call_api, name) @property @@ -109,7 +124,20 @@ class Bot(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def call_api(self, api: str, **data) -> Any: + async def _call_api(self, api: str, **data) -> Any: + """ + :说明: + + ``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。 + + :参数: + + * ``api: str``: API 名称 + * ``**data``: API 数据 + """ + raise NotImplementedError + + async def call_api(self, api: str, **data: Any) -> Any: """ :说明: @@ -118,6 +146,7 @@ class Bot(abc.ABC): :参数: * ``api: str``: API 名称 + * ``self_id: Optional[str]``: 指定调用 API 的机器人 * ``**data``: API 数据 :示例: @@ -127,7 +156,23 @@ class Bot(abc.ABC): await bot.call_api("send_msg", message="hello world") await bot.send_msg(message="hello world") """ - raise NotImplementedError + coros = list(map(lambda x: x(api, data), self._call_api_hook)) + if coros: + try: + logger.debug("Running CallingAPI hooks...") + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running CallingAPI hook. " + "Running cancelled!") + + if "self_id" in data: + self_id = data.pop("self_id") + if self_id: + bot = self.driver.bots[str(self_id)] + return await bot._call_api(api, **data) + + return await self._call_api(api, **data) @abc.abstractmethod async def send(self, event: "Event", message: Union[str, "Message", @@ -146,6 +191,11 @@ class Bot(abc.ABC): """ raise NotImplementedError + @classmethod + def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook: + cls._call_api_hook.add(func) + return func + T_Message = TypeVar("T_Message", bound="Message") T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") diff --git a/nonebot/typing.py b/nonebot/typing.py index dd2f24c5..c1dc008a 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -71,6 +71,7 @@ T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]] WebSocket 连接断开时执行的函数 """ +T_CallingAPIHook = Callable[[str, Dict[str, Any]], Awaitable[None]] T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] """ diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py index ca477559..20e5015d 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py @@ -328,32 +328,7 @@ class Bot(BaseBot): ) @overrides(BaseBot) - async def call_api(self, api: str, **data) -> Any: - """ - :说明: - - 调用 CQHTTP 协议 API - - :参数: - - * ``api: str``: API 名称 - * ``**data: Any``: API 参数 - - :返回: - - - ``Any``: API 调用返回数据 - - :异常: - - - ``NetworkError``: 网络错误 - - ``ActionFailed``: API 调用失败 - """ - if "self_id" in data: - self_id = data.pop("self_id") - if self_id: - bot = self.driver.bots[str(self_id)] - return await bot.call_api(api, **data) - + async def _call_api(self, api: str, **data) -> Any: log("DEBUG", f"Calling API {api}") if self.connection_type == "websocket": seq = ResultStore.get_seq() @@ -396,6 +371,29 @@ class Bot(BaseBot): except httpx.HTTPError: raise NetworkError("HTTP request failed") + @overrides(BaseBot) + async def call_api(self, api: str, **data) -> Any: + """ + :说明: + + 调用 CQHTTP 协议 API + + :参数: + + * ``api: str``: API 名称 + * ``**data: Any``: API 参数 + + :返回: + + - ``Any``: API 调用返回数据 + + :异常: + + - ``NetworkError``: 网络错误 + - ``ActionFailed``: API 调用失败 + """ + return super().call_api(api, **data) + @overrides(BaseBot) async def send(self, event: Event, diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi index 7ba09f8a..ad8d459c 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.pyi @@ -68,7 +68,8 @@ class Bot(BaseBot): async def handle_message(self, message: dict): ... - async def call_api(self, api: str, **data) -> Any: + async def call_api(self, api: str, *, self_id: Optional[str], + **data) -> Any: ... async def send(self, event: Event, message: Union[str, Message, diff --git a/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py b/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py index 08175ce4..410515bb 100644 --- a/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py +++ b/packages/nonebot-adapter-ding/nonebot/adapters/ding/bot.py @@ -109,37 +109,13 @@ class Bot(BaseBot): return @overrides(BaseBot) - async def call_api(self, - api: str, - event: Optional[MessageEvent] = None, - **data) -> Any: - """ - :说明: - - 调用 钉钉 协议 API - - :参数: - - * ``api: str``: API 名称 - * ``**data: Any``: API 参数 - - :返回: - - - ``Any``: API 调用返回数据 - - :异常: - - - ``NetworkError``: 网络错误 - - ``ActionFailed``: API 调用失败 - """ + async def _call_api(self, + api: str, + event: Optional[MessageEvent] = None, + **data) -> Any: if self.connection_type != "http": log("ERROR", "Only support http connection.") return - if "self_id" in data: - self_id = data.pop("self_id") - if self_id: - bot = self.driver.bots[str(self_id)] - return await bot.call_api(api, **data) log("DEBUG", f"Calling API {api}") params = {} @@ -192,6 +168,33 @@ class Bot(BaseBot): except httpx.HTTPError: raise NetworkError("HTTP request failed") + @overrides(BaseBot) + async def call_api(self, + api: str, + event: Optional[MessageEvent] = None, + **data) -> Any: + """ + :说明: + + 调用 钉钉 协议 API + + :参数: + + * ``api: str``: API 名称 + * ``event: Optional[MessageEvent]``: Event 对象 + * ``**data: Any``: API 参数 + + :返回: + + - ``Any``: API 调用返回数据 + + :异常: + + - ``NetworkError``: 网络错误 + - ``ActionFailed``: API 调用失败 + """ + return super().call_api(api, event=event, **data) + @overrides(BaseBot) async def send(self, event: MessageEvent, diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py index 1b598ebf..ebce2d74 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot.py @@ -218,6 +218,10 @@ class Bot(BaseBot): except Exception as e: Log.error(f'Failed to handle message: {message}', e) + @overrides(BaseBot) + async def _call_api(self, api: str, **data) -> NoReturn: + raise NotImplementedError + @overrides(BaseBot) async def call_api(self, api: str, **data) -> NoReturn: """