⚗️ add called api hook

This commit is contained in:
yanyongyu 2021-04-01 20:23:55 +08:00
parent 6b763e20d0
commit f0a6ff4627
3 changed files with 57 additions and 10 deletions

View File

@ -17,7 +17,7 @@ from pydantic import BaseModel
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.typing import T_CallingAPIHook from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.config import Config from nonebot.config import Config
@ -39,11 +39,16 @@ class Bot(abc.ABC):
"""Driver 对象""" """Driver 对象"""
config: "Config" config: "Config"
"""Config 配置对象""" """Config 配置对象"""
_call_api_hook: Set[T_CallingAPIHook] = set() _calling_api_hook: Set[T_CallingAPIHook] = set()
""" """
:类型: ``Set[T_CallingAPIHook]`` :类型: ``Set[T_CallingAPIHook]``
:说明: call_api 时执行的函数 :说明: call_api 时执行的函数
""" """
_called_api_hook: Set[T_CalledAPIHook] = set()
"""
:类型: ``Set[T_CalledAPIHook]``
:说明: call_api 后执行的函数
"""
@abc.abstractmethod @abc.abstractmethod
def __init__(self, def __init__(self,
@ -156,7 +161,7 @@ class Bot(abc.ABC):
await bot.call_api("send_msg", message="hello world") await bot.call_api("send_msg", message="hello world")
await bot.send_msg(message="hello world") await bot.send_msg(message="hello world")
""" """
coros = list(map(lambda x: x(self, api, data), self._call_api_hook)) coros = list(map(lambda x: x(self, api, data), self._calling_api_hook))
if coros: if coros:
try: try:
logger.debug("Running CallingAPI hooks...") logger.debug("Running CallingAPI hooks...")
@ -166,13 +171,33 @@ class Bot(abc.ABC):
"<r><bg #f8bbd0>Error when running CallingAPI hook. " "<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>") "Running cancelled!</bg #f8bbd0></r>")
if "self_id" in data: exception = None
self_id = data.pop("self_id") result = None
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot._call_api(api, **data)
return await self._call_api(api, **data) try:
if "self_id" in data and data["self_id"]:
bot = self.driver.bots[str(data["self_id"])]
result = await bot._call_api(api, **data)
else:
result = await self._call_api(api, **data)
except Exception as e:
exception = e
coros = list(
map(lambda x: x(self, exception, api, data, result),
self._called_api_hook))
if coros:
try:
logger.debug("Running CalledAPI hooks...")
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>")
if exception:
raise exception
return result
@abc.abstractmethod @abc.abstractmethod
async def send(self, event: "Event", message: Union[str, "Message", async def send(self, event: "Event", message: Union[str, "Message",
@ -193,7 +218,12 @@ class Bot(abc.ABC):
@classmethod @classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook: def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
cls._call_api_hook.add(func) cls._calling_api_hook.add(func)
return func
@classmethod
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
cls._called_api_hook.add(func)
return func return func

View File

@ -72,6 +72,22 @@ T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
WebSocket 连接断开时执行的函数 WebSocket 连接断开时执行的函数
""" """
T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]] T_CallingAPIHook = Callable[["Bot", str, Dict[str, Any]], Awaitable[None]]
"""
:类型: ``Callable[[Bot, str, Dict[str, Any]], Awaitable[None]]``
:说明:
``bot.call_api`` 时执行的函数
"""
T_CalledAPIHook = Callable[
["Bot", Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]]
"""
:类型: ``Callable[[Bot, Optional[Exception], str, Dict[str, Any], Any], Awaitable[None]]``
:说明:
``bot.call_api`` 后执行的函数参数分别为 bot, exception, api, data, result
"""
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
""" """

View File

@ -58,6 +58,7 @@ async def _check_reply(bot: "Bot", event: "Event"):
)) ))
except Exception as e: except Exception as e:
log("WARNING", f"Error when getting message reply info: {repr(e)}", e) log("WARNING", f"Error when getting message reply info: {repr(e)}", e)
return
# ensure string comparation # ensure string comparation
if str(event.reply.sender.user_id) == str(event.self_id): if str(event.reply.sender.user_id) == str(event.self_id):
event.to_me = True event.to_me = True