⚗️ add call_api hook

This commit is contained in:
yanyongyu 2021-03-31 16:51:09 +08:00
parent 4e7592de98
commit 8f99b01fb5
8 changed files with 157 additions and 60 deletions

View File

@ -27,6 +27,21 @@ Driver 对象
Config 配置对象
### `_call_api_hook`
* **类型**
`Set[T_CallingAPIHook]`
* **说明**
call_api 时执行的函数
### _abstract_ `__init__(connection_type, self_id, *, websocket=None)`
@ -127,7 +142,26 @@ Adapter 类型
### _abstract async_ `call_api(api, **data)`
### _abstract async_ `_call_api(api, **data)`
* **说明**
`adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
* **参数**
* `api: str`: API 名称
* `**data`: API 数据
### _async_ `call_api(api, **data)`
* **说明**
@ -142,6 +176,9 @@ Adapter 类型
* `api: str`: API 名称
* `self_id: Optional[str]`: 指定调用 API 的机器人
* `**data`: API 数据

View File

@ -129,6 +129,9 @@ sidebarDepth: 0
* `api: str`: API 名称
* `event: Optional[MessageEvent]`: Event 对象
* `**data: Any`: API 参数

View File

@ -6,20 +6,30 @@
"""
import abc
import asyncio
from copy import copy
from functools import reduce, partial
from dataclasses import dataclass, field
from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING
from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
Protocol, Awaitable, TYPE_CHECKING)
from pydantic import BaseModel
from nonebot.log import logger
from nonebot.utils import DataclassEncoder
from nonebot.typing import T_CallingAPIHook
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
class _ApiCall(Protocol):
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
...
class Bot(abc.ABC):
"""
Bot 基类用于处理上报消息并提供 API 调用接口
@ -29,6 +39,11 @@ class Bot(abc.ABC):
"""Driver 对象"""
config: "Config"
"""Config 配置对象"""
_call_api_hook: Set[T_CallingAPIHook] = set()
"""
:类型: ``Set[T_CallingAPIHook]``
:说明: call_api 时执行的函数
"""
@abc.abstractmethod
def __init__(self,
@ -50,7 +65,7 @@ class Bot(abc.ABC):
self.websocket = websocket
"""Websocket 连接对象"""
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
@ -109,7 +124,20 @@ class Bot(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def call_api(self, api: str, **data) -> Any:
async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数实现该方法以调用 api
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError
async def call_api(self, api: str, **data: Any) -> Any:
"""
:说明:
@ -118,6 +146,7 @@ class Bot(abc.ABC):
:参数:
* ``api: str``: API 名称
* ``self_id: Optional[str]``: 指定调用 API 的机器人
* ``**data``: API 数据
:示例:
@ -127,7 +156,23 @@ class Bot(abc.ABC):
await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world")
"""
raise NotImplementedError
coros = list(map(lambda x: x(api, data), self._call_api_hook))
if coros:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>")
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot._call_api(api, **data)
return await self._call_api(api, **data)
@abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message",
@ -146,6 +191,11 @@ class Bot(abc.ABC):
"""
raise NotImplementedError
@classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
cls._call_api_hook.add(func)
return func
T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")

View File

@ -71,6 +71,7 @@ T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
WebSocket 连接断开时执行的函数
"""
T_CallingAPIHook = Callable[[str, Dict[str, Any]], Awaitable[None]]
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
"""

View File

@ -328,32 +328,7 @@ class Bot(BaseBot):
)
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> Any:
"""
:说明:
调用 CQHTTP 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot.call_api(api, **data)
async def _call_api(self, api: str, **data) -> Any:
log("DEBUG", f"Calling API <y>{api}</y>")
if self.connection_type == "websocket":
seq = ResultStore.get_seq()
@ -396,6 +371,29 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> Any:
"""
:说明:
调用 CQHTTP 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
return super().call_api(api, **data)
@overrides(BaseBot)
async def send(self,
event: Event,

View File

@ -68,7 +68,8 @@ class Bot(BaseBot):
async def handle_message(self, message: dict):
...
async def call_api(self, api: str, **data) -> Any:
async def call_api(self, api: str, *, self_id: Optional[str],
**data) -> Any:
...
async def send(self, event: Event, message: Union[str, Message,

View File

@ -109,37 +109,13 @@ class Bot(BaseBot):
return
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
"""
:说明:
调用 钉钉 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
async def _call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
if self.connection_type != "http":
log("ERROR", "Only support http connection.")
return
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot.call_api(api, **data)
log("DEBUG", f"Calling API <y>{api}</y>")
params = {}
@ -192,6 +168,33 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
**data) -> Any:
"""
:说明:
调用 钉钉 协议 API
:参数:
* ``api: str``: API 名称
* ``event: Optional[MessageEvent]``: Event 对象
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
return super().call_api(api, event=event, **data)
@overrides(BaseBot)
async def send(self,
event: MessageEvent,

View File

@ -218,6 +218,10 @@ class Bot(BaseBot):
except Exception as e:
Log.error(f'Failed to handle message: {message}', e)
@overrides(BaseBot)
async def _call_api(self, api: str, **data) -> NoReturn:
raise NotImplementedError
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
"""