mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
⚗️ add call_api hook
This commit is contained in:
parent
4e7592de98
commit
8f99b01fb5
@ -27,6 +27,21 @@ Driver 对象
|
||||
Config 配置对象
|
||||
|
||||
|
||||
### `_call_api_hook`
|
||||
|
||||
|
||||
* **类型**
|
||||
|
||||
`Set[T_CallingAPIHook]`
|
||||
|
||||
|
||||
|
||||
* **说明**
|
||||
|
||||
call_api 时执行的函数
|
||||
|
||||
|
||||
|
||||
### _abstract_ `__init__(connection_type, self_id, *, websocket=None)`
|
||||
|
||||
|
||||
@ -127,7 +142,26 @@ Adapter 类型
|
||||
|
||||
|
||||
|
||||
### _abstract async_ `call_api(api, **data)`
|
||||
### _abstract async_ `_call_api(api, **data)`
|
||||
|
||||
|
||||
* **说明**
|
||||
|
||||
`adapter` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||
|
||||
|
||||
|
||||
* **参数**
|
||||
|
||||
|
||||
* `api: str`: API 名称
|
||||
|
||||
|
||||
* `**data`: API 数据
|
||||
|
||||
|
||||
|
||||
### _async_ `call_api(api, **data)`
|
||||
|
||||
|
||||
* **说明**
|
||||
@ -142,6 +176,9 @@ Adapter 类型
|
||||
* `api: str`: API 名称
|
||||
|
||||
|
||||
* `self_id: Optional[str]`: 指定调用 API 的机器人
|
||||
|
||||
|
||||
* `**data`: API 数据
|
||||
|
||||
|
||||
|
@ -129,6 +129,9 @@ sidebarDepth: 0
|
||||
* `api: str`: API 名称
|
||||
|
||||
|
||||
* `event: Optional[MessageEvent]`: Event 对象
|
||||
|
||||
|
||||
* `**data: Any`: API 参数
|
||||
|
||||
|
||||
|
@ -6,20 +6,30 @@
|
||||
"""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from copy import copy
|
||||
from functools import reduce, partial
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Union, TypeVar, Mapping, Optional, Callable, Iterable, Iterator, Awaitable, TYPE_CHECKING
|
||||
from typing import (Any, Set, Dict, Union, TypeVar, Mapping, Optional, Iterable,
|
||||
Protocol, Awaitable, TYPE_CHECKING)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.typing import T_CallingAPIHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
|
||||
|
||||
class _ApiCall(Protocol):
|
||||
|
||||
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
|
||||
...
|
||||
|
||||
|
||||
class Bot(abc.ABC):
|
||||
"""
|
||||
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
||||
@ -29,6 +39,11 @@ class Bot(abc.ABC):
|
||||
"""Driver 对象"""
|
||||
config: "Config"
|
||||
"""Config 配置对象"""
|
||||
_call_api_hook: Set[T_CallingAPIHook] = set()
|
||||
"""
|
||||
:类型: ``Set[T_CallingAPIHook]``
|
||||
:说明: call_api 时执行的函数
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self,
|
||||
@ -50,7 +65,7 @@ class Bot(abc.ABC):
|
||||
self.websocket = websocket
|
||||
"""Websocket 连接对象"""
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
|
||||
def __getattr__(self, name: str) -> _ApiCall:
|
||||
return partial(self.call_api, name)
|
||||
|
||||
@property
|
||||
@ -109,7 +124,20 @@ class Bot(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call_api(self, api: str, **data) -> Any:
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data``: API 数据
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def call_api(self, api: str, **data: Any) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -118,6 +146,7 @@ class Bot(abc.ABC):
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``self_id: Optional[str]``: 指定调用 API 的机器人
|
||||
* ``**data``: API 数据
|
||||
|
||||
:示例:
|
||||
@ -127,7 +156,23 @@ class Bot(abc.ABC):
|
||||
await bot.call_api("send_msg", message="hello world")
|
||||
await bot.send_msg(message="hello world")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
coros = list(map(lambda x: x(api, data), self._call_api_hook))
|
||||
if coros:
|
||||
try:
|
||||
logger.debug("Running CallingAPI hooks...")
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
||||
"Running cancelled!</bg #f8bbd0></r>")
|
||||
|
||||
if "self_id" in data:
|
||||
self_id = data.pop("self_id")
|
||||
if self_id:
|
||||
bot = self.driver.bots[str(self_id)]
|
||||
return await bot._call_api(api, **data)
|
||||
|
||||
return await self._call_api(api, **data)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, event: "Event", message: Union[str, "Message",
|
||||
@ -146,6 +191,11 @@ class Bot(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
|
||||
cls._call_api_hook.add(func)
|
||||
return func
|
||||
|
||||
|
||||
T_Message = TypeVar("T_Message", bound="Message")
|
||||
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
|
||||
|
@ -71,6 +71,7 @@ T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
|
||||
|
||||
WebSocket 连接断开时执行的函数
|
||||
"""
|
||||
T_CallingAPIHook = Callable[[str, Dict[str, Any]], Awaitable[None]]
|
||||
|
||||
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
|
||||
"""
|
||||
|
@ -328,32 +328,7 @@ class Bot(BaseBot):
|
||||
)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 CQHTTP 协议 API
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data: Any``: API 参数
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
|
||||
:异常:
|
||||
|
||||
- ``NetworkError``: 网络错误
|
||||
- ``ActionFailed``: API 调用失败
|
||||
"""
|
||||
if "self_id" in data:
|
||||
self_id = data.pop("self_id")
|
||||
if self_id:
|
||||
bot = self.driver.bots[str(self_id)]
|
||||
return await bot.call_api(api, **data)
|
||||
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
log("DEBUG", f"Calling API <y>{api}</y>")
|
||||
if self.connection_type == "websocket":
|
||||
seq = ResultStore.get_seq()
|
||||
@ -396,6 +371,29 @@ class Bot(BaseBot):
|
||||
except httpx.HTTPError:
|
||||
raise NetworkError("HTTP request failed")
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 CQHTTP 协议 API
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data: Any``: API 参数
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
|
||||
:异常:
|
||||
|
||||
- ``NetworkError``: 网络错误
|
||||
- ``ActionFailed``: API 调用失败
|
||||
"""
|
||||
return super().call_api(api, **data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self,
|
||||
event: Event,
|
||||
|
@ -68,7 +68,8 @@ class Bot(BaseBot):
|
||||
async def handle_message(self, message: dict):
|
||||
...
|
||||
|
||||
async def call_api(self, api: str, **data) -> Any:
|
||||
async def call_api(self, api: str, *, self_id: Optional[str],
|
||||
**data) -> Any:
|
||||
...
|
||||
|
||||
async def send(self, event: Event, message: Union[str, Message,
|
||||
|
@ -109,37 +109,13 @@ class Bot(BaseBot):
|
||||
return
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self,
|
||||
api: str,
|
||||
event: Optional[MessageEvent] = None,
|
||||
**data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 钉钉 协议 API
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data: Any``: API 参数
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
|
||||
:异常:
|
||||
|
||||
- ``NetworkError``: 网络错误
|
||||
- ``ActionFailed``: API 调用失败
|
||||
"""
|
||||
async def _call_api(self,
|
||||
api: str,
|
||||
event: Optional[MessageEvent] = None,
|
||||
**data) -> Any:
|
||||
if self.connection_type != "http":
|
||||
log("ERROR", "Only support http connection.")
|
||||
return
|
||||
if "self_id" in data:
|
||||
self_id = data.pop("self_id")
|
||||
if self_id:
|
||||
bot = self.driver.bots[str(self_id)]
|
||||
return await bot.call_api(api, **data)
|
||||
|
||||
log("DEBUG", f"Calling API <y>{api}</y>")
|
||||
params = {}
|
||||
@ -192,6 +168,33 @@ class Bot(BaseBot):
|
||||
except httpx.HTTPError:
|
||||
raise NetworkError("HTTP request failed")
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self,
|
||||
api: str,
|
||||
event: Optional[MessageEvent] = None,
|
||||
**data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 钉钉 协议 API
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``event: Optional[MessageEvent]``: Event 对象
|
||||
* ``**data: Any``: API 参数
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
|
||||
:异常:
|
||||
|
||||
- ``NetworkError``: 网络错误
|
||||
- ``ActionFailed``: API 调用失败
|
||||
"""
|
||||
return super().call_api(api, event=event, **data)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self,
|
||||
event: MessageEvent,
|
||||
|
@ -218,6 +218,10 @@ class Bot(BaseBot):
|
||||
except Exception as e:
|
||||
Log.error(f'Failed to handle message: {message}', e)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def _call_api(self, api: str, **data) -> NoReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data) -> NoReturn:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user