⚗️ add call_api hook

This commit is contained in:
yanyongyu 2021-03-31 16:51:09 +08:00
parent 4e7592de98
commit 8f99b01fb5
8 changed files with 157 additions and 60 deletions

View File

@ -27,6 +27,21 @@ Driver 对象
Config 配置对象 Config 配置对象
### `_call_api_hook`
* **类型**
`Set[T_CallingAPIHook]`
* **说明**
call_api 时执行的函数
### _abstract_ `__init__(connection_type, self_id, *, websocket=None)` ### _abstract_ `__init__(connection_type, self_id, *, websocket=None)`
@ -127,7 +142,26 @@ Adapter 类型
### _abstract async_ `call_api(api, **data)` ### _abstract async_ `_call_api(api, **data)`
* **说明**
`adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
* **参数**
* `api: str`: API 名称
* `**data`: API 数据
### _async_ `call_api(api, **data)`
* **说明** * **说明**
@ -142,6 +176,9 @@ Adapter 类型
* `api: str`: API 名称 * `api: str`: API 名称
* `self_id: Optional[str]`: 指定调用 API 的机器人
* `**data`: API 数据 * `**data`: API 数据

View File

@ -129,6 +129,9 @@ sidebarDepth: 0
* `api: str`: API 名称 * `api: str`: API 名称
* `event: Optional[MessageEvent]`: Event 对象
* `**data: Any`: API 参数 * `**data: Any`: API 参数

View File

@ -6,20 +6,30 @@
""" """
import abc import abc
import asyncio
from copy import copy from copy import copy
from functools import reduce, partial from functools import reduce, partial
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
Protocol, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel from pydantic import BaseModel
from nonebot.log import logger
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.typing import T_CallingAPIHook
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver, WebSocket
class _ApiCall(Protocol):
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
...
class Bot(abc.ABC): class Bot(abc.ABC):
""" """
Bot 基类用于处理上报消息并提供 API 调用接口 Bot 基类用于处理上报消息并提供 API 调用接口
@ -29,6 +39,11 @@ class Bot(abc.ABC):
"""Driver 对象""" """Driver 对象"""
config: "Config" config: "Config"
"""Config 配置对象""" """Config 配置对象"""
_call_api_hook: Set[T_CallingAPIHook] = set()
"""
:类型: ``Set[T_CallingAPIHook]``
:说明: call_api 时执行的函数
"""
@abc.abstractmethod @abc.abstractmethod
def __init__(self, def __init__(self,
@ -50,7 +65,7 @@ class Bot(abc.ABC):
self.websocket = websocket self.websocket = websocket
"""Websocket 连接对象""" """Websocket 连接对象"""
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]: def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name) return partial(self.call_api, name)
@property @property
@ -109,7 +124,20 @@ class Bot(abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def call_api(self, api: str, **data) -> Any: async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数实现该方法以调用 api
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError
async def call_api(self, api: str, **data: Any) -> Any:
""" """
:说明: :说明:
@ -118,6 +146,7 @@ class Bot(abc.ABC):
:参数: :参数:
* ``api: str``: API 名称 * ``api: str``: API 名称
* ``self_id: Optional[str]``: 指定调用 API 的机器人
* ``**data``: API 数据 * ``**data``: API 数据
:示例: :示例:
@ -127,7 +156,23 @@ class Bot(abc.ABC):
await bot.call_api("send_msg", message="hello world") await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world") await bot.send_msg(message="hello world")
""" """
raise NotImplementedError coros = list(map(lambda x: x(api, data), self._call_api_hook))
if coros:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>")
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot._call_api(api, **data)
return await self._call_api(api, **data)
@abc.abstractmethod @abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message", async def send(self, event: "Event", message: Union[str, "Message",
@ -146,6 +191,11 @@ class Bot(abc.ABC):
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
cls._call_api_hook.add(func)
return func
T_Message = TypeVar("T_Message", bound="Message") T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")

View File

@ -71,6 +71,7 @@ T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
WebSocket 连接断开时执行的函数 WebSocket 连接断开时执行的函数
""" """
T_CallingAPIHook = Callable[[str, Dict[str, Any]], Awaitable[None]]
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
""" """

View File

@ -328,32 +328,7 @@ class Bot(BaseBot):
) )
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> Any: async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
调用 CQHTTP 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot.call_api(api, **data)
log("DEBUG", f"Calling API <y>{api}</y>") log("DEBUG", f"Calling API <y>{api}</y>")
if self.connection_type == "websocket": if self.connection_type == "websocket":
seq = ResultStore.get_seq() seq = ResultStore.get_seq()
@ -396,6 +371,29 @@ class Bot(BaseBot):
except httpx.HTTPError: except httpx.HTTPError:
raise NetworkError("HTTP request failed") raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> Any:
"""
:说明:
调用 CQHTTP 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
return super().call_api(api, **data)
@overrides(BaseBot) @overrides(BaseBot)
async def send(self, async def send(self,
event: Event, event: Event,

View File

@ -68,7 +68,8 @@ class Bot(BaseBot):
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
... ...
async def call_api(self, api: str, **data) -> Any: async def call_api(self, api: str, *, self_id: Optional[str],
**data) -> Any:
... ...
async def send(self, event: Event, message: Union[str, Message, async def send(self, event: Event, message: Union[str, Message,

View File

@ -109,37 +109,13 @@ class Bot(BaseBot):
return return
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, async def _call_api(self,
api: str, api: str,
event: Optional[MessageEvent] = None, event: Optional[MessageEvent] = None,
**data) -> Any: **data) -> Any:
"""
:说明:
调用 钉钉 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
if self.connection_type != "http": if self.connection_type != "http":
log("ERROR", "Only support http connection.") log("ERROR", "Only support http connection.")
return return
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot.call_api(api, **data)
log("DEBUG", f"Calling API <y>{api}</y>") log("DEBUG", f"Calling API <y>{api}</y>")
params = {} params = {}
@ -192,6 +168,33 @@ class Bot(BaseBot):
except httpx.HTTPError: except httpx.HTTPError:
raise NetworkError("HTTP request failed") raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
"""
:说明:
调用 钉钉 协议 API
:参数:
* ``api: str``: API 名称
* ``event: Optional[MessageEvent]``: Event 对象
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
return super().call_api(api, event=event, **data)
@overrides(BaseBot) @overrides(BaseBot)
async def send(self, async def send(self,
event: MessageEvent, event: MessageEvent,

View File

@ -218,6 +218,10 @@ class Bot(BaseBot):
except Exception as e: except Exception as e:
Log.error(f'Failed to handle message: {message}', e) Log.error(f'Failed to handle message: {message}', e)
@overrides(BaseBot)
async def _call_api(self, api: str, **data) -> NoReturn:
raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn: async def call_api(self, api: str, **data) -> NoReturn:
""" """