🐛 Fix: 修复 bot hook 缺少依赖缓存和上下文管理 (#1826)

This commit is contained in:
Ju4tCode 2023-03-20 22:37:57 +08:00 committed by GitHub
parent 05a6af46b9
commit 78bbf9e623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 106 additions and 29 deletions

View File

@ -1,6 +1,6 @@
import abc import abc
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from nonebot.log import logger from nonebot.log import logger
@ -8,8 +8,12 @@ from nonebot.config import Env, Config
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
from nonebot.internal.params import BotParam, DependParam, DefaultParam from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
T_BotDisconnectionHook,
)
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
@ -135,13 +139,15 @@ class Driver(abc.ABC):
self._bots[bot.self_id] = bot self._bots[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None: async def _run_hook(bot: "Bot") -> None:
coros = list( dependency_cache: T_DependencyCache = {}
map( async with AsyncExitStack() as stack:
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)), if coros := [
self._bot_connection_hook, run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
) )
) for hook in self._bot_connection_hook
if coros: ]:
try: try:
await asyncio.gather(*coros) await asyncio.gather(*coros)
except Exception as e: except Exception as e:
@ -158,13 +164,15 @@ class Driver(abc.ABC):
del self._bots[bot.self_id] del self._bots[bot.self_id]
async def _run_hook(bot: "Bot") -> None: async def _run_hook(bot: "Bot") -> None:
coros = list( dependency_cache: T_DependencyCache = {}
map( async with AsyncExitStack() as stack:
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)), if coros := [
self._bot_disconnection_hook, run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
) )
) for hook in self._bot_disconnection_hook
if coros: ]:
try: try:
await asyncio.gather(*coros) await asyncio.gather(*coros)
except Exception as e: except Exception as e:

View File

@ -1,13 +1,16 @@
import json import json
import asyncio import asyncio
from typing import cast from typing import Any, Set, cast
import pytest import pytest
from nonebug import App from nonebug import App
import nonebot import nonebot
from nonebot.config import Env from nonebot.config import Env
from nonebot.adapters import Bot
from nonebot.params import Depends
from nonebot import _resolve_combine_expr from nonebot import _resolve_combine_expr
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.drivers import ( from nonebot.drivers import (
URL, URL,
@ -169,3 +172,69 @@ async def test_http_driver(driver: Driver):
) )
async def test_combine_driver(driver: Driver, driver_type: str): async def test_combine_driver(driver: Driver, driver_type: str):
assert driver.type == driver_type assert driver.type == driver_type
@pytest.mark.asyncio
async def test_bot_connect_hook(app: App, driver: Driver):
with pytest.MonkeyPatch.context() as m:
conn_hooks: Set[Dependent[Any]] = set()
disconn_hooks: Set[Dependent[Any]] = set()
m.setattr(Driver, "_bot_connection_hook", conn_hooks)
m.setattr(Driver, "_bot_disconnection_hook", disconn_hooks)
conn_should_be_called = False
disconn_should_be_called = False
dependency_should_be_run = False
dependency_should_be_cleaned = False
async def dependency():
nonlocal dependency_should_be_run, dependency_should_be_cleaned
dependency_should_be_run = True
try:
yield 1
finally:
dependency_should_be_cleaned = True
@driver.on_bot_connect
async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
nonlocal conn_should_be_called
conn_should_be_called = True
if foo is not bot:
pytest.fail("on_bot_connect hook called with wrong bot")
if dep != 1:
pytest.fail("on_bot_connect hook called with wrong dependency")
if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value")
@driver.on_bot_disconnect
async def disconn_hook(
foo: Bot, dep: int = Depends(dependency), default: int = 1
):
nonlocal disconn_should_be_called
disconn_should_be_called = True
if foo is not bot:
pytest.fail("on_bot_disconnect hook called with wrong bot")
if dep != 1:
pytest.fail("on_bot_connect hook called with wrong dependency")
if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value")
if conn_hook not in {hook.call for hook in conn_hooks}:
pytest.fail("on_bot_connect hook not registered")
if disconn_hook not in {hook.call for hook in disconn_hooks}:
pytest.fail("on_bot_disconnect hook not registered")
async with app.test_api() as ctx:
bot = ctx.create_bot()
await asyncio.sleep(1)
if not conn_should_be_called:
pytest.fail("on_bot_connect hook not called")
if not disconn_should_be_called:
pytest.fail("on_bot_disconnect hook not called")
if not dependency_should_be_run:
pytest.fail("dependency not run")
if not dependency_should_be_cleaned:
pytest.fail("dependency not cleaned")