mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 05:06:56 +08:00
🐛 Fix: 修复 bot hook 缺少依赖缓存和上下文管理 (#1826)
This commit is contained in:
parent
05a6af46b9
commit
78bbf9e623
@ -1,6 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
|
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
@ -8,8 +8,12 @@ from nonebot.config import Env, Config
|
|||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
from nonebot.utils import escape_tag, run_coro_with_catch
|
||||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
|
||||||
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
||||||
|
from nonebot.typing import (
|
||||||
|
T_DependencyCache,
|
||||||
|
T_BotConnectionHook,
|
||||||
|
T_BotDisconnectionHook,
|
||||||
|
)
|
||||||
|
|
||||||
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
||||||
|
|
||||||
@ -135,13 +139,15 @@ class Driver(abc.ABC):
|
|||||||
self._bots[bot.self_id] = bot
|
self._bots[bot.self_id] = bot
|
||||||
|
|
||||||
async def _run_hook(bot: "Bot") -> None:
|
async def _run_hook(bot: "Bot") -> None:
|
||||||
coros = list(
|
dependency_cache: T_DependencyCache = {}
|
||||||
map(
|
async with AsyncExitStack() as stack:
|
||||||
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
|
if coros := [
|
||||||
self._bot_connection_hook,
|
run_coro_with_catch(
|
||||||
|
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||||
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
)
|
for hook in self._bot_connection_hook
|
||||||
if coros:
|
]:
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -158,13 +164,15 @@ class Driver(abc.ABC):
|
|||||||
del self._bots[bot.self_id]
|
del self._bots[bot.self_id]
|
||||||
|
|
||||||
async def _run_hook(bot: "Bot") -> None:
|
async def _run_hook(bot: "Bot") -> None:
|
||||||
coros = list(
|
dependency_cache: T_DependencyCache = {}
|
||||||
map(
|
async with AsyncExitStack() as stack:
|
||||||
lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
|
if coros := [
|
||||||
self._bot_disconnection_hook,
|
run_coro_with_catch(
|
||||||
|
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
||||||
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
)
|
for hook in self._bot_disconnection_hook
|
||||||
if coros:
|
]:
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import cast
|
from typing import Any, Set, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from nonebot.config import Env
|
from nonebot.config import Env
|
||||||
|
from nonebot.adapters import Bot
|
||||||
|
from nonebot.params import Depends
|
||||||
from nonebot import _resolve_combine_expr
|
from nonebot import _resolve_combine_expr
|
||||||
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
URL,
|
URL,
|
||||||
@ -169,3 +172,69 @@ async def test_http_driver(driver: Driver):
|
|||||||
)
|
)
|
||||||
async def test_combine_driver(driver: Driver, driver_type: str):
|
async def test_combine_driver(driver: Driver, driver_type: str):
|
||||||
assert driver.type == driver_type
|
assert driver.type == driver_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bot_connect_hook(app: App, driver: Driver):
|
||||||
|
with pytest.MonkeyPatch.context() as m:
|
||||||
|
conn_hooks: Set[Dependent[Any]] = set()
|
||||||
|
disconn_hooks: Set[Dependent[Any]] = set()
|
||||||
|
m.setattr(Driver, "_bot_connection_hook", conn_hooks)
|
||||||
|
m.setattr(Driver, "_bot_disconnection_hook", disconn_hooks)
|
||||||
|
|
||||||
|
conn_should_be_called = False
|
||||||
|
disconn_should_be_called = False
|
||||||
|
dependency_should_be_run = False
|
||||||
|
dependency_should_be_cleaned = False
|
||||||
|
|
||||||
|
async def dependency():
|
||||||
|
nonlocal dependency_should_be_run, dependency_should_be_cleaned
|
||||||
|
dependency_should_be_run = True
|
||||||
|
try:
|
||||||
|
yield 1
|
||||||
|
finally:
|
||||||
|
dependency_should_be_cleaned = True
|
||||||
|
|
||||||
|
@driver.on_bot_connect
|
||||||
|
async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
|
||||||
|
nonlocal conn_should_be_called
|
||||||
|
conn_should_be_called = True
|
||||||
|
|
||||||
|
if foo is not bot:
|
||||||
|
pytest.fail("on_bot_connect hook called with wrong bot")
|
||||||
|
if dep != 1:
|
||||||
|
pytest.fail("on_bot_connect hook called with wrong dependency")
|
||||||
|
if default != 1:
|
||||||
|
pytest.fail("on_bot_connect hook called with wrong default value")
|
||||||
|
|
||||||
|
@driver.on_bot_disconnect
|
||||||
|
async def disconn_hook(
|
||||||
|
foo: Bot, dep: int = Depends(dependency), default: int = 1
|
||||||
|
):
|
||||||
|
nonlocal disconn_should_be_called
|
||||||
|
disconn_should_be_called = True
|
||||||
|
|
||||||
|
if foo is not bot:
|
||||||
|
pytest.fail("on_bot_disconnect hook called with wrong bot")
|
||||||
|
if dep != 1:
|
||||||
|
pytest.fail("on_bot_connect hook called with wrong dependency")
|
||||||
|
if default != 1:
|
||||||
|
pytest.fail("on_bot_connect hook called with wrong default value")
|
||||||
|
|
||||||
|
if conn_hook not in {hook.call for hook in conn_hooks}:
|
||||||
|
pytest.fail("on_bot_connect hook not registered")
|
||||||
|
if disconn_hook not in {hook.call for hook in disconn_hooks}:
|
||||||
|
pytest.fail("on_bot_disconnect hook not registered")
|
||||||
|
|
||||||
|
async with app.test_api() as ctx:
|
||||||
|
bot = ctx.create_bot()
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if not conn_should_be_called:
|
||||||
|
pytest.fail("on_bot_connect hook not called")
|
||||||
|
if not disconn_should_be_called:
|
||||||
|
pytest.fail("on_bot_disconnect hook not called")
|
||||||
|
if not dependency_should_be_run:
|
||||||
|
pytest.fail("dependency not run")
|
||||||
|
if not dependency_should_be_cleaned:
|
||||||
|
pytest.fail("dependency not cleaned")
|
||||||
|
Loading…
Reference in New Issue
Block a user