mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🐛 Fix: 修复 bot hook 缺少依赖缓存和上下文管理 (#1826)
This commit is contained in:
parent
05a6af46b9
commit
78bbf9e623
@ -1,6 +1,6 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
|
||||
|
||||
from nonebot.log import logger
|
||||
@ -8,8 +8,12 @@ from nonebot.config import Env, Config
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
||||
from nonebot.typing import (
|
||||
T_DependencyCache,
|
||||
T_BotConnectionHook,
|
||||
T_BotDisconnectionHook,
|
||||
)
|
||||
|
||||
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
@ -135,13 +139,15 @@ class Driver(abc.ABC):
|
||||
self._bots[bot.self_id] = bot
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
|
||||
self._bot_connection_hook,
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
async with AsyncExitStack() as stack:
|
||||
if coros := [
|
||||
run_coro_with_catch(
|
||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||
(SkippedException,),
|
||||
)
|
||||
)
|
||||
if coros:
|
||||
for hook in self._bot_connection_hook
|
||||
]:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
@ -158,13 +164,15 @@ class Driver(abc.ABC):
|
||||
del self._bots[bot.self_id]
|
||||
|
||||
async def _run_hook(bot: "Bot") -> None:
|
||||
coros = list(
|
||||
map(
|
||||
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
|
||||
self._bot_disconnection_hook,
|
||||
dependency_cache: T_DependencyCache = {}
|
||||
async with AsyncExitStack() as stack:
|
||||
if coros := [
|
||||
run_coro_with_catch(
|
||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||
(SkippedException,),
|
||||
)
|
||||
)
|
||||
if coros:
|
||||
for hook in self._bot_disconnection_hook
|
||||
]:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
except Exception as e:
|
||||
|
@ -1,13 +1,16 @@
|
||||
import json
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from typing import Any, Set, cast
|
||||
|
||||
import pytest
|
||||
from nonebug import App
|
||||
|
||||
import nonebot
|
||||
from nonebot.config import Env
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.params import Depends
|
||||
from nonebot import _resolve_combine_expr
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers import (
|
||||
URL,
|
||||
@ -169,3 +172,69 @@ async def test_http_driver(driver: Driver):
|
||||
)
|
||||
async def test_combine_driver(driver: Driver, driver_type: str):
|
||||
assert driver.type == driver_type
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_connect_hook(app: App, driver: Driver):
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
conn_hooks: Set[Dependent[Any]] = set()
|
||||
disconn_hooks: Set[Dependent[Any]] = set()
|
||||
m.setattr(Driver, "_bot_connection_hook", conn_hooks)
|
||||
m.setattr(Driver, "_bot_disconnection_hook", disconn_hooks)
|
||||
|
||||
conn_should_be_called = False
|
||||
disconn_should_be_called = False
|
||||
dependency_should_be_run = False
|
||||
dependency_should_be_cleaned = False
|
||||
|
||||
async def dependency():
|
||||
nonlocal dependency_should_be_run, dependency_should_be_cleaned
|
||||
dependency_should_be_run = True
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
dependency_should_be_cleaned = True
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
|
||||
nonlocal conn_should_be_called
|
||||
conn_should_be_called = True
|
||||
|
||||
if foo is not bot:
|
||||
pytest.fail("on_bot_connect hook called with wrong bot")
|
||||
if dep != 1:
|
||||
pytest.fail("on_bot_connect hook called with wrong dependency")
|
||||
if default != 1:
|
||||
pytest.fail("on_bot_connect hook called with wrong default value")
|
||||
|
||||
@driver.on_bot_disconnect
|
||||
async def disconn_hook(
|
||||
foo: Bot, dep: int = Depends(dependency), default: int = 1
|
||||
):
|
||||
nonlocal disconn_should_be_called
|
||||
disconn_should_be_called = True
|
||||
|
||||
if foo is not bot:
|
||||
pytest.fail("on_bot_disconnect hook called with wrong bot")
|
||||
if dep != 1:
|
||||
pytest.fail("on_bot_connect hook called with wrong dependency")
|
||||
if default != 1:
|
||||
pytest.fail("on_bot_connect hook called with wrong default value")
|
||||
|
||||
if conn_hook not in {hook.call for hook in conn_hooks}:
|
||||
pytest.fail("on_bot_connect hook not registered")
|
||||
if disconn_hook not in {hook.call for hook in disconn_hooks}:
|
||||
pytest.fail("on_bot_disconnect hook not registered")
|
||||
|
||||
async with app.test_api() as ctx:
|
||||
bot = ctx.create_bot()
|
||||
|
||||
await asyncio.sleep(1)
|
||||
if not conn_should_be_called:
|
||||
pytest.fail("on_bot_connect hook not called")
|
||||
if not disconn_should_be_called:
|
||||
pytest.fail("on_bot_disconnect hook not called")
|
||||
if not dependency_should_be_run:
|
||||
pytest.fail("dependency not run")
|
||||
if not dependency_should_be_cleaned:
|
||||
pytest.fail("dependency not cleaned")
|
||||
|
Loading…
Reference in New Issue
Block a user