Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode 2024-10-26 15:36:01 +08:00 committed by GitHub
parent bd9befbb55
commit ff21ceb946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 5422 additions and 4080 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

1396
envs/test/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,11 +8,11 @@ packages = [{ include = "nonebot-test.py" }]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
nonebug = "^0.3.7" trio = "^0.27.0"
nonebug = "^0.4.1"
wsproto = "^1.2.0" wsproto = "^1.2.0"
pytest-cov = "^5.0.0" pytest-cov = "^5.0.0"
pytest-xdist = "^3.0.2" pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.23.2"
werkzeug = ">=2.3.6,<4.0.0" werkzeug = ">=2.3.6,<4.0.0"
coverage-conditional-plugin = "^0.9.0" coverage-conditional-plugin = "^0.9.0"

View File

@ -8,17 +8,20 @@ FrontMatter:
""" """
import abc import abc
import asyncio
import inspect import inspect
from functools import partial
from dataclasses import field, dataclass from dataclasses import field, dataclass
from collections.abc import Iterable, Awaitable from collections.abc import Iterable, Awaitable
from typing import Any, Generic, TypeVar, Callable, Optional, cast from typing import Any, Generic, TypeVar, Callable, Optional, cast
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import _DependentCallable from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from nonebot.utils import run_sync, is_coroutine_callable
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
from .utils import check_field_type, get_typed_signature from .utils import check_field_type, get_typed_signature
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
) )
async def __call__(self, **kwargs: Any) -> R: async def __call__(self, **kwargs: Any) -> R:
try: exception: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")
with catch({SkippedException: _handle_skipped}):
# do pre-check # do pre-check
await self.check(**kwargs) await self.check(**kwargs)
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
return await cast(Callable[..., Awaitable[R]], self.call)(**values) return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else: else:
return await run_sync(cast(Callable[..., R], self.call))(**values) return await run_sync(cast(Callable[..., R], self.call))(**values)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}") raise exception
raise
@staticmethod @staticmethod
def parse_params( def parse_params(
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
return cls(call, params, parameterless_params) return cls(call, params, parameterless_params)
async def check(self, **params: Any) -> None: async def check(self, **params: Any) -> None:
await asyncio.gather(*(param._check(**params) for param in self.parameterless)) async with anyio.create_task_group() as tg:
await asyncio.gather( for param in self.parameterless:
*(cast(Param, param.field_info)._check(**params) for param in self.params) tg.start_soon(partial(param._check, **params))
)
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any: async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info) param = cast(Param, field.field_info)
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
await param._solve(**params) await param._solve(**params)
# solve param values # solve param values
values = await asyncio.gather( result: dict[str, Any] = {}
*(self._solve_field(field, params) for field in self.params)
) async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
return {field.name: value for field, value in zip(self.params, values)} value = await self._solve_field(field, params)
result[field.name] = value
async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)
return result
__autodoc__ = {"CustomConfig": False} __autodoc__ = {"CustomConfig": False}

View File

@ -12,14 +12,18 @@ FrontMatter:
""" """
import signal import signal
import asyncio from typing import Optional
import threading
from typing_extensions import override from typing_extensions import override
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.consts import WINDOWS from nonebot.consts import WINDOWS
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import flatten_exception_group
HANDLED_SIGNALS = ( HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
@ -35,8 +39,8 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(env, config) super().__init__(env, config)
self.should_exit: asyncio.Event = asyncio.Event() self.should_exit: anyio.Event = anyio.Event()
self.force_exit: bool = False self.force_exit: anyio.Event = anyio.Event()
@property @property
@override @override
@ -54,85 +58,98 @@ class Driver(BaseDriver):
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
"""启动 none driver""" """启动 none driver"""
super().run(*args, **kwargs) super().run(*args, **kwargs)
loop = asyncio.get_event_loop() anyio.run(self._serve)
loop.run_until_complete(self._serve())
async def _serve(self): async def _serve(self):
self._install_signal_handlers() async with anyio.create_task_group() as driver_tg:
await self._startup() driver_tg.start_soon(self._handle_signals)
if self.should_exit.is_set(): driver_tg.start_soon(self._listen_force_exit, driver_tg)
return driver_tg.start_soon(self._handle_lifespan, driver_tg)
await self._main_loop()
await self._shutdown()
async def _startup(self): async def _handle_signals(self):
try: try:
await self._lifespan.startup() with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signal_receiver:
except Exception as e: async for sig in signal_receiver:
logger.opt(colors=True, exception=e).error( self.exit(force=self.should_exit.is_set())
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
self.should_exit.set()
return
logger.info("Application startup completed.")
async def _main_loop(self):
await self.should_exit.wait()
async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown.")
try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("Application shutdown complete.")
loop = asyncio.get_event_loop()
loop.stop()
def _install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
loop = asyncio.get_event_loop()
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self._handle_exit, sig, None)
except NotImplementedError: except NotImplementedError:
# Windows # Windows
for sig in HANDLED_SIGNALS: for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_exit) signal.signal(sig, self._handle_legacy_signal)
def _handle_exit(self, sig, frame): # backport for Windows signal handling
def _handle_legacy_signal(self, sig, frame):
self.exit(force=self.should_exit.is_set()) self.exit(force=self.should_exit.is_set())
async def _handle_lifespan(self, tg: TaskGroup):
try:
await self._startup()
if self.should_exit.is_set():
return
await self._listen_exit()
await self._shutdown()
finally:
tg.cancel_scope.cancel()
async def _startup(self):
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
self.should_exit.set()
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running startup hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
with catch({Exception: handle_exception}):
await self._lifespan.startup()
if not self.should_exit.is_set():
logger.info("Application startup completed.")
async def _listen_exit(self, tg: Optional[TaskGroup] = None):
await self.should_exit.wait()
if tg is not None:
tg.cancel_scope.cancel()
async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown. (CTRL+C to force quit)")
error_occurred: bool = False
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
nonlocal error_occurred
error_occurred = True
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running shutdown hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application shutdown failed. "
"Exiting.</bg #f8bbd0></r>"
)
with catch({Exception: handle_exception}):
await self._lifespan.shutdown()
if not error_occurred:
logger.info("Application shutdown complete.")
async def _listen_force_exit(self, tg: TaskGroup):
await self.force_exit.wait()
tg.cancel_scope.cancel()
def exit(self, force: bool = False): def exit(self, force: bool = False):
"""退出 none driver """退出 none driver
@ -142,4 +159,4 @@ class Driver(BaseDriver):
if not self.should_exit.is_set(): if not self.should_exit.is_set():
self.should_exit.set() self.should_exit.set()
if force: if force:
self.force_exit = True self.force_exit.set()

View File

@ -1,11 +1,14 @@
import abc import abc
import asyncio
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.exception import MockApiException from nonebot.exception import MockApiException
from nonebot.utils import flatten_exception_group
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
if TYPE_CHECKING: if TYPE_CHECKING:
@ -76,47 +79,98 @@ class Bot(abc.ABC):
skip_calling_api: bool = False skip_calling_api: bool = False
exception: Optional[Exception] = None exception: Optional[Exception] = None
if coros := [hook(self, api, data) for hook in self._calling_api_hook]: if self._calling_api_hook:
try: logger.debug("Running CallingAPI hooks...")
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros) def _handle_mock_api_exception(
except MockApiException as e: exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal skip_calling_api, result
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
skip_calling_api = True skip_calling_api = True
result = e.result result = excs[0].result
logger.debug( logger.debug(
f"Calling API {api} is cancelled. Return {result} instead." f"Calling API {api} is cancelled. Return {result!r} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
) )
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._calling_api_hook:
tg.start_soon(hook, self, api, data)
if not skip_calling_api: if not skip_calling_api:
try: try:
result = await self.adapter._call_api(self, api, **data) result = await self.adapter._call_api(self, api, **data)
except Exception as e: except Exception as e:
exception = e exception = e
if coros := [ if self._called_api_hook:
hook(self, exception, api, data, result) for hook in self._called_api_hook logger.debug("Running CalledAPI hooks...")
]:
try: def _handle_mock_api_exception(
logger.debug("Running CalledAPI hooks...") exc_group: BaseExceptionGroup[MockApiException],
await asyncio.gather(*coros) ) -> None:
except MockApiException as e: nonlocal result, exception
# mock api result
result = e.result excs = [
# ignore exception exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
result = excs[0].result
exception = None exception = None
logger.debug( logger.debug(
f"Calling API {api} result is mocked. Return {result} instead." f"Calling API {api} result is mocked. Return {result} instead."
) )
except Exception as e:
logger.opt(colors=True, exception=e).error( def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
"<r><bg #f8bbd0>Error when running CalledAPI hook. " for exc in flatten_exception_group(exc_group):
"Running cancelled!</bg #f8bbd0></r>" logger.opt(colors=True, exception=exc).error(
) "<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._called_api_hook:
tg.start_soon(hook, self, exception, api, data, result)
if exception: if exception:
raise exception raise exception

View File

@ -1,6 +1,11 @@
from collections.abc import Awaitable from types import TracebackType
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast from collections.abc import Iterable, Awaitable
from typing import Any, Union, Callable, Optional, cast
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import suppress
from nonebot.utils import run_sync, is_coroutine_callable from nonebot.utils import run_sync, is_coroutine_callable
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan: class Lifespan:
def __init__(self) -> None: def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._startup_funcs: list[LIFESPAN_FUNC] = [] self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = [] self._ready_funcs: list[LIFESPAN_FUNC] = []
self._shutdown_funcs: list[LIFESPAN_FUNC] = [] self._shutdown_funcs: list[LIFESPAN_FUNC] = []
@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Lifespan not started")
return self._task_group
@task_group.setter
def task_group(self, task_group: TaskGroup) -> None:
if self._task_group is not None:
raise RuntimeError("Lifespan already started")
self._task_group = task_group
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func) self._startup_funcs.append(func)
return func return func
@ -29,7 +48,7 @@ class Lifespan:
@staticmethod @staticmethod
async def _run_lifespan_func( async def _run_lifespan_func(
funcs: list[LIFESPAN_FUNC], funcs: Iterable[LIFESPAN_FUNC],
) -> None: ) -> None:
for func in funcs: for func in funcs:
if is_coroutine_callable(func): if is_coroutine_callable(func):
@ -38,18 +57,44 @@ class Lifespan:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))() await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None: async def startup(self) -> None:
# create background task group
self.task_group = anyio.create_task_group()
await self.task_group.__aenter__()
# run startup funcs
if self._startup_funcs: if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs) await self._run_lifespan_func(self._startup_funcs)
# run ready funcs
if self._ready_funcs: if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs) await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None: async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self._shutdown_funcs: if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs) # reverse shutdown funcs to ensure stack order
await self._run_lifespan_func(reversed(self._shutdown_funcs))
# shutdown background task group
self.task_group.cancel_scope.cancel()
with suppress(Exception):
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
self._task_group = None
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
await self.startup() await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: async def __aexit__(
await self.shutdown() self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

View File

@ -1,17 +1,20 @@
import abc import abc
import asyncio
from types import TracebackType from types import TracebackType
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.internal.params import BotParam, DependParam, DefaultParam from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import ( from nonebot.typing import (
T_DependencyCache, T_DependencyCache,
T_BotConnectionHook, T_BotConnectionHook,
@ -61,7 +64,6 @@ class Driver(abc.ABC):
self.config: Config = config self.config: Config = config
"""全局配置对象""" """全局配置对象"""
self._bots: dict[str, "Bot"] = {} self._bots: dict[str, "Bot"] = {}
self._bot_tasks: set[asyncio.Task] = set()
self._lifespan = Lifespan() self._lifespan = Lifespan()
def __repr__(self) -> str: def __repr__(self) -> str:
@ -75,6 +77,10 @@ class Driver(abc.ABC):
"""获取当前所有已连接的 Bot""" """获取当前所有已连接的 Bot"""
return self._bots return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None: def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器 """注册一个协议适配器
@ -112,8 +118,6 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>" f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
) )
self.on_shutdown(self._cleanup)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC: def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数""" """注册一个启动时执行的函数"""
return self._lifespan.on_startup(func) return self._lifespan.on_startup(func)
@ -154,66 +158,57 @@ class Driver(abc.ABC):
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}") raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot self._bots[bot.self_id] = bot
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None: async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {} dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack: with CancelScope(shield=True), catch({Exception: handle_exception}):
if coros := [ async with AsyncExitStack() as stack, create_task_group() as tg:
run_coro_with_catch( for hook in self._bot_connection_hook:
hook(bot=bot, stack=stack, dependency_cache=dependency_cache), tg.start_soon(
(SkippedException,), run_coro_with_catch,
) hook(
for hook in self._bot_connection_hook bot=bot, stack=stack, dependency_cache=dependency_cache
]: ),
try: (SkippedException,),
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
) )
task = asyncio.create_task(_run_hook(bot)) self.task_group.start_soon(_run_hook, bot)
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
def _bot_disconnect(self, bot: "Bot") -> None: def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象""" """在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots: if bot.self_id in self._bots:
del self._bots[bot.self_id] del self._bots[bot.self_id]
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None: async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {} dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack: # shield cancellation to ensure bot disconnect hooks are always run
if coros := [ with CancelScope(shield=True), catch({Exception: handle_exception}):
run_coro_with_catch( async with create_task_group() as tg, AsyncExitStack() as stack:
hook(bot=bot, stack=stack, dependency_cache=dependency_cache), for hook in self._bot_disconnection_hook:
(SkippedException,), tg.start_soon(
) run_coro_with_catch,
for hook in self._bot_disconnection_hook hook(
]: bot=bot, stack=stack, dependency_cache=dependency_cache
try: ),
await asyncio.gather(*coros) (SkippedException,),
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
) )
task = asyncio.create_task(_run_hook(bot)) self.task_group.start_soon(_run_hook, bot)
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
class Mixin(abc.ABC): class Mixin(abc.ABC):

View File

@ -22,11 +22,13 @@ from typing import ( # noqa: UP035
overload, overload,
) )
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.internal.rule import Rule from nonebot.internal.rule import Rule
from nonebot.utils import classproperty
from nonebot.dependencies import Param, Dependent from nonebot.dependencies import Param, Dependent
from nonebot.internal.permission import User, Permission from nonebot.internal.permission import User, Permission
from nonebot.utils import classproperty, flatten_exception_group
from nonebot.internal.adapter import ( from nonebot.internal.adapter import (
Bot, Bot,
Event, Event,
@ -812,28 +814,34 @@ class Matcher(metaclass=MatcherMeta):
f"bot={bot}, event={event!r}, state={state!r}" f"bot={bot}, event={event!r}, state={state!r}"
) )
def _handle_stop_propagation(exc_group: BaseExceptionGroup[StopPropagation]):
self.block = True
with self.ensure_context(bot, event): with self.ensure_context(bot, event):
try: try:
# Refresh preprocess state with catch({StopPropagation: _handle_stop_propagation}):
self.state.update(state) # Refresh preprocess state
self.state.update(state)
while self.remain_handlers: while self.remain_handlers:
handler = self.remain_handlers.pop(0) handler = self.remain_handlers.pop(0)
current_handler.set(handler) current_handler.set(handler)
logger.debug(f"Running handler {handler}") logger.debug(f"Running handler {handler}")
try:
await handler( def _handle_skipped(
matcher=self, exc_group: BaseExceptionGroup[SkippedException],
bot=bot, ):
event=event, logger.debug(f"Handler {handler} skipped")
state=self.state,
stack=stack, with catch({SkippedException: _handle_skipped}):
dependency_cache=dependency_cache, await handler(
) matcher=self,
except SkippedException: bot=bot,
logger.debug(f"Handler {handler} skipped") event=event,
except StopPropagation: state=self.state,
self.block = True stack=stack,
dependency_cache=dependency_cache,
)
finally: finally:
logger.info(f"{self} running complete") logger.info(f"{self} running complete")
@ -846,10 +854,54 @@ class Matcher(metaclass=MatcherMeta):
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None, dependency_cache: Optional[T_DependencyCache] = None,
): ):
try: exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
None
)
def _handle_special_exception(
exc_group: BaseExceptionGroup[
Union[FinishedException, RejectedException, PausedException]
]
):
nonlocal exc
excs = list(flatten_exception_group(exc_group))
if len(excs) > 1:
logger.warning(
"Multiple session control exceptions occurred. "
"NoneBot will choose the proper one."
)
finished_exc = next(
(e for e in excs if isinstance(e, FinishedException)),
None,
)
rejected_exc = next(
(e for e in excs if isinstance(e, RejectedException)),
None,
)
paused_exc = next(
(e for e in excs if isinstance(e, PausedException)),
None,
)
exc = finished_exc or rejected_exc or paused_exc
elif isinstance(
excs[0], (FinishedException, RejectedException, PausedException)
):
exc = excs[0]
with catch(
{
(
FinishedException,
RejectedException,
PausedException,
): _handle_special_exception
}
):
await self.simple_run(bot, event, state, stack, dependency_cache) await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException: if isinstance(exc, FinishedException):
pass
elif isinstance(exc, RejectedException):
await self.resolve_reject() await self.resolve_reject()
type_ = await self.update_type(bot, event, stack, dependency_cache) type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission( permission = await self.update_permission(
@ -870,7 +922,7 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater, default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater, default_permission_updater=self.__class__._default_permission_updater,
) )
except PausedException: elif isinstance(exc, PausedException):
type_ = await self.update_type(bot, event, stack, dependency_cache) type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission( permission = await self.update_permission(
bot, event, stack, dependency_cache bot, event, stack, dependency_cache
@ -890,5 +942,3 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater, default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater, default_permission_updater=self.__class__._default_permission_updater,
) )
except FinishedException:
pass

View File

@ -1,5 +1,5 @@
import asyncio
import inspect import inspect
from enum import Enum
from typing_extensions import Self, get_args, override, get_origin from typing_extensions import Self, get_args, override, get_origin
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import ( from typing import (
@ -13,8 +13,11 @@ from typing import (
cast, cast,
) )
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.exception import SkippedException
from nonebot.dependencies import Param, Dependent from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type from nonebot.dependencies.utils import check_field_type
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
@ -93,6 +96,75 @@ def Depends(
return DependsInner(dependency, use_cache=use_cache, validate=validate) return DependsInner(dependency, use_cache=use_cache, validate=validate)
class CacheState(str, Enum):
"""子依赖缓存状态"""
PENDING = "PENDING"
FINISHED = "FINISHED"
class DependencyCache:
"""子依赖结果缓存。
用于缓存子依赖的结果以避免重复计算
"""
def __init__(self):
self._state = CacheState.PENDING
self._result: Any = None
self._exception: Optional[BaseException] = None
self._waiter = anyio.Event()
def result(self) -> Any:
"""获取子依赖结果"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
if self._exception is not None:
raise self._exception
return self._result
def exception(self) -> Optional[BaseException]:
"""获取子依赖异常"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
return self._exception
def set_result(self, result: Any) -> None:
"""设置子依赖结果"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._result = result
self._state = CacheState.FINISHED
self._waiter.set()
def set_exception(self, exception: BaseException) -> None:
"""设置子依赖异常"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._exception = exception
self._state = CacheState.FINISHED
self._waiter.set()
async def wait(self):
"""等待子依赖结果"""
await self._waiter.wait()
if self._state != CacheState.FINISHED:
raise RuntimeError("Invalid cache state")
if self._exception is not None:
raise self._exception
return self._result
class DependParam(Param): class DependParam(Param):
"""子依赖注入参数。 """子依赖注入参数。
@ -194,17 +266,27 @@ class DependParam(Param):
call = cast(Callable[..., Any], sub_dependent.call) call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache # solve sub dependency with current cache
sub_values = await sub_dependent.solve( exc: Optional[BaseExceptionGroup[SkippedException]] = None
stack=stack,
dependency_cache=dependency_cache, def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
**kwargs, nonlocal exc
) exc = exc_group
with catch({SkippedException: _handle_skipped}):
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
if exc is not None:
raise exc
# run dependency function # run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache: if use_cache and call in dependency_cache:
return await dependency_cache[call] return await dependency_cache[call].wait()
elif is_gen_callable(call) or is_async_gen_callable(call):
if is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance( assert isinstance(
stack, AsyncExitStack stack, AsyncExitStack
), "Generator dependency should be called in context" ), "Generator dependency should be called in context"
@ -212,17 +294,21 @@ class DependParam(Param):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values)) cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else: else:
cm = asynccontextmanager(call)(**sub_values) cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task target = stack.enter_async_context(cm)
return await task
elif is_coroutine_callable(call): elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values)) target = call(**sub_values)
dependency_cache[call] = task
return await task
else: else:
task = asyncio.create_task(run_sync(call)(**sub_values)) target = run_sync(call)(**sub_values)
dependency_cache[call] = task
return await task dependency_cache[call] = cache = DependencyCache()
try:
result = await target
cache.set_result(result)
return result
except BaseException as e:
cache.set_exception(e)
raise
@override @override
async def _check(self, **kwargs: Any) -> None: async def _check(self, **kwargs: Any) -> None:

View File

@ -1,8 +1,9 @@
import asyncio
from typing_extensions import Self from typing_extensions import Self
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional from typing import Union, ClassVar, NoReturn, Optional
import anyio
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.utils import run_coro_with_catch from nonebot.utils import run_coro_with_catch
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
@ -70,22 +71,26 @@ class Permission:
""" """
if not self.checkers: if not self.checkers:
return True return True
results = await asyncio.gather(
*( result = False
run_coro_with_catch(
checker( async def _run_checker(checker: Dependent[bool]) -> None:
bot=bot, nonlocal result
event=event, # calculate the result first to avoid data racing
stack=stack, is_passed = await run_coro_with_catch(
dependency_cache=dependency_cache, checker(
), bot=bot, event=event, stack=stack, dependency_cache=dependency_cache
(SkippedException,), ),
False, (SkippedException,),
) False,
for checker in self.checkers )
), result |= is_passed
)
return any(results) async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: object) -> NoReturn: def __and__(self, other: object) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.") raise RuntimeError("And operation between Permissions is not allowed.")

View File

@ -1,7 +1,9 @@
import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional from typing import Union, ClassVar, NoReturn, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
@ -71,22 +73,33 @@ class Rule:
""" """
if not self.checkers: if not self.checkers:
return True return True
try:
results = await asyncio.gather( result = True
*(
checker( def _handle_skipped_exception(
bot=bot, exc_group: BaseExceptionGroup[SkippedException],
event=event, ) -> None:
state=state, nonlocal result
stack=stack, result = False
dependency_cache=dependency_cache,
) async def _run_checker(checker: Dependent[bool]) -> None:
for checker in self.checkers nonlocal result
) # calculate the result first to avoid data racing
is_passed = await checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
) )
except SkippedException: result &= is_passed
return False
return all(results) with catch({SkippedException: _handle_skipped_exception}):
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule": def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None: if other is None:

View File

@ -9,23 +9,30 @@ FrontMatter:
description: nonebot.message 模块 description: nonebot.message 模块
""" """
import asyncio
import contextlib import contextlib
from datetime import datetime from datetime import datetime
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import TrieRule from nonebot.rule import TrieRule
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.matcher import Matcher, matchers from nonebot.matcher import Matcher, matchers
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.exception import ( from nonebot.exception import (
NoLogException, NoLogException,
StopPropagation, StopPropagation,
IgnoredException, IgnoredException,
SkippedException, SkippedException,
) )
from nonebot.utils import (
escape_tag,
run_coro_with_catch,
run_coro_with_shield,
flatten_exception_group,
)
from nonebot.typing import ( from nonebot.typing import (
T_State, T_State,
T_DependencyCache, T_DependencyCache,
@ -125,6 +132,21 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func return func
def _handle_ignored_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
def _handle(exc_group: BaseExceptionGroup[IgnoredException]) -> None:
logger.opt(colors=True).info(msg)
return _handle
def _handle_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
def _handle(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(msg)
return _handle
async def _apply_event_preprocessors( async def _apply_event_preprocessors(
bot: "Bot", bot: "Bot",
event: "Event", event: "Event",
@ -152,10 +174,21 @@ async def _apply_event_preprocessors(
if show_log: if show_log:
logger.debug("Running PreProcessors...") logger.debug("Running PreProcessors...")
try: with catch(
await asyncio.gather( {
*( IgnoredException: _handle_ignored_exception(
run_coro_with_catch( f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
),
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>"
),
}
):
async with anyio.create_task_group() as tg:
for proc in _event_preprocessors:
tg.start_soon(
run_coro_with_catch,
proc( proc(
bot=bot, bot=bot,
event=event, event=event,
@ -165,22 +198,10 @@ async def _apply_event_preprocessors(
), ),
(SkippedException,), (SkippedException,),
) )
for proc in _event_preprocessors
)
)
except IgnoredException:
logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
)
return False
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>"
)
return False
return True return True
return False
async def _apply_event_postprocessors( async def _apply_event_postprocessors(
@ -207,10 +228,17 @@ async def _apply_event_postprocessors(
if show_log: if show_log:
logger.debug("Running PostProcessors...") logger.debug("Running PostProcessors...")
try: with catch(
await asyncio.gather( {
*( Exception: _handle_exception(
run_coro_with_catch( "<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
}
):
async with anyio.create_task_group() as tg:
for proc in _event_postprocessors:
tg.start_soon(
run_coro_with_catch,
proc( proc(
bot=bot, bot=bot,
event=event, event=event,
@ -220,13 +248,6 @@ async def _apply_event_postprocessors(
), ),
(SkippedException,), (SkippedException,),
) )
for proc in _event_postprocessors
)
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
async def _apply_run_preprocessors( async def _apply_run_preprocessors(
@ -254,35 +275,38 @@ async def _apply_run_preprocessors(
return True return True
# ensure matcher function can be correctly called # ensure matcher function can be correctly called
with matcher.ensure_context(bot, event): with (
try: matcher.ensure_context(bot, event),
await asyncio.gather( catch(
*( {
run_coro_with_catch( IgnoredException: _handle_ignored_exception(
proc( f"{matcher} running is <b>cancelled</b>"
matcher=matcher, ),
bot=bot, Exception: _handle_exception(
event=event, "<r><bg #f8bbd0>Error when running RunPreProcessors. "
state=state, "Running cancelled!</bg #f8bbd0></r>"
stack=stack, ),
dependency_cache=dependency_cache, }
), ),
(SkippedException,), ):
) async with anyio.create_task_group() as tg:
for proc in _run_preprocessors for proc in _run_preprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
matcher=matcher,
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
) )
)
except IgnoredException:
logger.opt(colors=True).info(f"{matcher} running is <b>cancelled</b>")
return False
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>"
)
return False
return True return True
return False
async def _apply_run_postprocessors( async def _apply_run_postprocessors(
@ -306,29 +330,32 @@ async def _apply_run_postprocessors(
if not _run_postprocessors: if not _run_postprocessors:
return return
with matcher.ensure_context(bot, event): with (
try: matcher.ensure_context(bot, event),
await asyncio.gather( catch(
*( {
run_coro_with_catch( Exception: _handle_exception(
proc( "<r><bg #f8bbd0>Error when running RunPostProcessors"
matcher=matcher, "</bg #f8bbd0></r>"
exception=exception, )
bot=bot, }
event=event, ),
state=matcher.state, ):
stack=stack, async with anyio.create_task_group() as tg:
dependency_cache=dependency_cache, for proc in _run_postprocessors:
), tg.start_soon(
(SkippedException,), run_coro_with_catch,
) proc(
for proc in _run_postprocessors matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=matcher.state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
) )
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPostProcessors</bg #f8bbd0></r>"
)
async def _check_matcher( async def _check_matcher(
@ -425,8 +452,9 @@ async def _run_matcher(
exception = None exception = None
logger.debug(f"Running {matcher}")
try: try:
logger.debug(f"Running {matcher}")
await matcher.run(bot, event, state, stack, dependency_cache) await matcher.run(bot, event, state, stack, dependency_cache)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
@ -494,8 +522,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
用法: 用法:
```python ```python
import asyncio driver.task_group.start_soon(handle_event, bot, event)
asyncio.create_task(handle_event(bot, event))
``` ```
""" """
show_log = True show_log = True
@ -530,6 +557,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
) )
break_flag = False break_flag = False
def _handle_stop_propagation(exc_group: BaseExceptionGroup) -> None:
nonlocal break_flag
break_flag = True
logger.debug("Stop event propagation")
# iterate through all priority until stop propagation # iterate through all priority until stop propagation
for priority in sorted(matchers.keys()): for priority in sorted(matchers.keys()):
if break_flag: if break_flag:
@ -538,23 +572,30 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
if show_log: if show_log:
logger.debug(f"Checking for matchers in priority {priority}...") logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [ if not (priority_matchers := matchers[priority]):
check_and_run_matcher( continue
matcher, bot, event, state.copy(), stack, dependency_cache
) with catch(
for matcher in matchers[priority] {
] StopPropagation: _handle_stop_propagation,
results = await asyncio.gather(*pending_tasks, return_exceptions=True) Exception: _handle_exception(
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>" "<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
) ),
}
):
async with anyio.create_task_group() as tg:
for matcher in priority_matchers:
tg.start_soon(
run_coro_with_shield,
check_and_run_matcher(
matcher,
bot,
event,
state.copy(),
stack,
dependency_cache,
),
)
if show_log: if show_log:
logger.debug("Checking for matchers completed") logger.debug("Checking for matchers completed")

View File

@ -22,7 +22,7 @@ from . import _managers, get_plugin, _module_name_to_plugin_id
try: # pragma: py-gte-311 try: # pragma: py-gte-311
import tomllib # pyright: ignore[reportMissingImports] import tomllib # pyright: ignore[reportMissingImports]
except ModuleNotFoundError: # pragma: py-lt-311 except ModuleNotFoundError: # pragma: py-lt-311
import tomli as tomllib import tomli as tomllib # pyright: ignore[reportMissingImports]
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]: def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:

View File

@ -21,10 +21,9 @@ from typing import TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin
if TYPE_CHECKING: if TYPE_CHECKING:
from asyncio import Task
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.internal.params import DependencyCache
T = TypeVar("T") T = TypeVar("T")
P = ParamSpec("P") P = ParamSpec("P")
@ -258,5 +257,5 @@ T_PermissionUpdater: TypeAlias = _DependentCallable["Permission"]
- MatcherParam: Matcher 对象 - MatcherParam: Matcher 对象
- DefaultParam: 带有默认值的参数 - DefaultParam: 带有默认值的参数
""" """
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "Task[t.Any]"] T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "DependencyCache"]
"""依赖缓存, 用于存储依赖函数的返回值""" """依赖缓存, 用于存储依赖函数的返回值"""

View File

@ -9,21 +9,22 @@ FrontMatter:
import re import re
import json import json
import asyncio
import inspect import inspect
import importlib import importlib
import contextlib import contextlib
import dataclasses import dataclasses
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
from contextvars import copy_context
from functools import wraps, partial from functools import wraps, partial
from contextlib import AbstractContextManager, asynccontextmanager from contextlib import AbstractContextManager, asynccontextmanager
from typing_extensions import ParamSpec, get_args, override, get_origin from typing_extensions import ParamSpec, get_args, override, get_origin
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
from collections.abc import Mapping, Sequence, Coroutine, Generator, AsyncGenerator
import anyio
import anyio.to_thread
from pydantic import BaseModel from pydantic import BaseModel
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import ( from nonebot.typing import (
@ -39,6 +40,7 @@ R = TypeVar("R")
T = TypeVar("T") T = TypeVar("T")
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
E = TypeVar("E", bound=BaseException)
def escape_tag(s: str) -> str: def escape_tag(s: str) -> str:
@ -178,11 +180,9 @@ def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
@wraps(call) @wraps(call)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_running_loop() return await anyio.to_thread.run_sync(
pfunc = partial(call, *args, **kwargs) partial(call, *args, **kwargs), abandon_on_cancel=True
context = copy_context() )
result = await loop.run_in_executor(None, partial(context.run, pfunc))
return result
return _wrapper return _wrapper
@ -234,10 +234,34 @@ async def run_coro_with_catch(
协程的返回值或发生异常时的指定值 协程的返回值或发生异常时的指定值
""" """
try: with catch({exc: lambda exc_group: None}):
return await coro return await coro
except exc:
return return_on_err return return_on_err
async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T:
"""运行协程并在取消时屏蔽取消异常。
参数:
coro: 要运行的协程
返回:
协程的返回值
"""
with anyio.CancelScope(shield=True):
return await coro
def flatten_exception_group(
exc_group: BaseExceptionGroup[E],
) -> Generator[E, None, None]:
for exc in exc_group.exceptions:
if isinstance(exc, BaseExceptionGroup):
yield from flatten_exception_group(exc)
else:
yield exc
def get_name(obj: Any) -> str: def get_name(obj: Any) -> str:

2274
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,9 @@ include = ["nonebot/py.typed"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
yarl = "^1.7.2" yarl = "^1.7.2"
anyio = "^4.4.0"
pygtrie = "^2.4.1" pygtrie = "^2.4.1"
exceptiongroup = "^1.2.2"
loguru = ">=0.6.0,<1.0.0" loguru = ">=0.6.0,<1.0.0"
python-dotenv = ">=0.21.0,<2.0.0" python-dotenv = ">=0.21.0,<2.0.0"
typing-extensions = ">=4.4.0,<5.0.0" typing-extensions = ">=4.4.0,<5.0.0"
@ -65,7 +67,6 @@ fastapi = ["fastapi", "uvicorn"]
all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"] all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "strict"
addopts = "--cov=nonebot --cov-report=term-missing" addopts = "--cov=nonebot --cov-report=term-missing"
filterwarnings = ["error", "ignore::DeprecationWarning"] filterwarnings = ["error", "ignore::DeprecationWarning"]

View File

@ -1,8 +1,10 @@
import os import os
import threading import threading
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from functools import wraps
from collections.abc import Generator from collections.abc import Generator
from typing_extensions import ParamSpec
from typing import TYPE_CHECKING, TypeVar, Callable
import pytest import pytest
from nonebug import NONEBOT_INIT_KWARGS from nonebug import NONEBOT_INIT_KWARGS
@ -20,6 +22,9 @@ os.environ["CONFIG_OVERRIDE"] = "new"
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.plugin import Plugin from nonebot.plugin import Plugin
P = ParamSpec("P")
R = TypeVar("R")
collect_ignore = ["plugins/", "dynamic/", "bad_plugins/"] collect_ignore = ["plugins/", "dynamic/", "bad_plugins/"]
@ -38,14 +43,36 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
return DriverClass(Env(environment=global_driver.env), global_driver.config) return DriverClass(Env(environment=global_driver.env), global_driver.config)
@pytest.fixture(scope="session", params=[pytest.param("asyncio"), pytest.param("trio")])
def anyio_backend(request: pytest.FixtureRequest):
return request.param
def run_once(func: Callable[P, R]) -> Callable[P, R]:
result = ...
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
nonlocal result
if result is not Ellipsis:
return result
result = func(*args, **kwargs)
return result
return _wrapper
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def load_plugin(nonebug_init: None) -> set["Plugin"]: @run_once
def load_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
# preload global plugins # preload global plugins
return nonebot.load_plugins(str(Path(__file__).parent / "plugins")) return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def load_builtin_plugin(nonebug_init: None) -> set["Plugin"]: @run_once
def load_builtin_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
# preload builtin plugins # preload builtin plugins
return nonebot.load_builtin_plugins("echo", "single_session") return nonebot.load_builtin_plugins("echo", "single_session")

View File

@ -17,7 +17,7 @@ from nonebot.drivers import (
) )
@pytest.mark.asyncio @pytest.mark.anyio
async def test_adapter_connect(app: App, driver: Driver): async def test_adapter_connect(app: App, driver: Driver):
last_connect_bot: Optional[Bot] = None last_connect_bot: Optional[Bot] = None
last_disconnect_bot: Optional[Bot] = None last_disconnect_bot: Optional[Bot] = None
@ -45,7 +45,6 @@ async def test_adapter_connect(app: App, driver: Driver):
assert bot.self_id not in adapter.bots assert bot.self_id not in adapter.bots
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -75,7 +74,7 @@ async def test_adapter_connect(app: App, driver: Driver):
], ],
indirect=True, indirect=True,
) )
async def test_adapter_server(driver: Driver): def test_adapter_server(driver: Driver):
last_http_setup: Optional[HTTPServerSetup] = None last_http_setup: Optional[HTTPServerSetup] = None
last_ws_setup: Optional[WebSocketServerSetup] = None last_ws_setup: Optional[WebSocketServerSetup] = None
@ -112,7 +111,7 @@ async def test_adapter_server(driver: Driver):
assert last_ws_setup is setup assert last_ws_setup is setup
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -159,7 +158,7 @@ async def test_adapter_http_client(driver: Driver):
assert last_request is request assert last_request is request
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [

View File

@ -1,5 +1,6 @@
from typing import Any, Optional from typing import Any, Optional
import anyio
import pytest import pytest
from nonebug import App from nonebug import App
@ -7,7 +8,7 @@ from nonebot.adapters import Bot
from nonebot.exception import MockApiException from nonebot.exception import MockApiException
@pytest.mark.asyncio @pytest.mark.anyio
async def test_bot_call_api(app: App): async def test_bot_call_api(app: App):
async with app.test_api() as ctx: async with app.test_api() as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
@ -23,7 +24,7 @@ async def test_bot_call_api(app: App):
await bot.call_api("test") await bot.call_api("test")
@pytest.mark.asyncio @pytest.mark.anyio
async def test_bot_calling_api_hook_simple(app: App): async def test_bot_calling_api_hook_simple(app: App):
runned: bool = False runned: bool = False
@ -49,7 +50,7 @@ async def test_bot_calling_api_hook_simple(app: App):
assert result is True assert result is True
@pytest.mark.asyncio @pytest.mark.anyio
async def test_bot_calling_api_hook_mock(app: App): async def test_bot_calling_api_hook_mock(app: App):
runned: bool = False runned: bool = False
@ -76,7 +77,47 @@ async def test_bot_calling_api_hook_mock(app: App):
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.anyio
async def test_bot_calling_api_hook_multi_mock(app: App):
runned1: bool = False
runned2: bool = False
event = anyio.Event()
async def calling_api_hook1(bot: Bot, api: str, data: dict[str, Any]):
nonlocal runned1
runned1 = True
event.set()
raise MockApiException(1)
async def calling_api_hook2(bot: Bot, api: str, data: dict[str, Any]):
nonlocal runned2
runned2 = True
with anyio.fail_after(1):
await event.wait()
raise MockApiException(2)
hooks = set()
with pytest.MonkeyPatch.context() as m:
m.setattr(Bot, "_calling_api_hook", hooks)
Bot.on_calling_api(calling_api_hook1)
Bot.on_calling_api(calling_api_hook2)
assert hooks == {calling_api_hook1, calling_api_hook2}
async with app.test_api() as ctx:
bot = ctx.create_bot()
result = await bot.call_api("test")
assert runned1 is True
assert runned2 is True
assert result == 1
@pytest.mark.anyio
async def test_bot_called_api_hook_simple(app: App): async def test_bot_called_api_hook_simple(app: App):
runned: bool = False runned: bool = False
@ -108,7 +149,7 @@ async def test_bot_called_api_hook_simple(app: App):
assert result is True assert result is True
@pytest.mark.asyncio @pytest.mark.anyio
async def test_bot_called_api_hook_mock(app: App): async def test_bot_called_api_hook_mock(app: App):
runned: bool = False runned: bool = False
@ -150,3 +191,56 @@ async def test_bot_called_api_hook_mock(app: App):
assert runned is True assert runned is True
assert result is False assert result is False
@pytest.mark.anyio
async def test_bot_called_api_hook_multi_mock(app: App):
runned1: bool = False
runned2: bool = False
event = anyio.Event()
async def called_api_hook1(
bot: Bot,
exception: Optional[Exception],
api: str,
data: dict[str, Any],
result: Any,
):
nonlocal runned1
runned1 = True
event.set()
raise MockApiException(1)
async def called_api_hook2(
bot: Bot,
exception: Optional[Exception],
api: str,
data: dict[str, Any],
result: Any,
):
nonlocal runned2
runned2 = True
with anyio.fail_after(1):
await event.wait()
raise MockApiException(2)
hooks = set()
with pytest.MonkeyPatch.context() as m:
m.setattr(Bot, "_called_api_hook", hooks)
Bot.on_called_api(called_api_hook1)
Bot.on_called_api(called_api_hook2)
assert hooks == {called_api_hook1, called_api_hook2}
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_api("test", {}, True)
result = await bot.call_api("test")
assert runned1 is True
assert runned2 is True
assert result == 1

View File

@ -25,7 +25,7 @@ async def _dependency() -> int:
return 1 return 1
@pytest.mark.asyncio @pytest.mark.anyio
async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch): async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(message, "_event_preprocessors", set()) m.setattr(message, "_event_preprocessors", set())
@ -58,7 +58,7 @@ async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "event_preprocessor should runned" assert runned, "event_preprocessor should runned"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch): async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(message, "_event_preprocessors", set()) m.setattr(message, "_event_preprocessors", set())
@ -88,7 +88,7 @@ async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPat
assert not runned, "matcher should not runned" assert not runned, "matcher should not runned"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_event_preprocessor_exception( async def test_event_preprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
): ):
@ -132,7 +132,7 @@ async def test_event_preprocessor_exception(
assert "RuntimeError: test" in capsys.readouterr().out assert "RuntimeError: test" in capsys.readouterr().out
@pytest.mark.asyncio @pytest.mark.anyio
async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch): async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(message, "_event_postprocessors", set()) m.setattr(message, "_event_postprocessors", set())
@ -165,7 +165,7 @@ async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "event_postprocessor should runned" assert runned, "event_postprocessor should runned"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_event_postprocessor_exception( async def test_event_postprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
): ):
@ -202,7 +202,7 @@ async def test_event_postprocessor_exception(
assert "RuntimeError: test" in capsys.readouterr().out assert "RuntimeError: test" in capsys.readouterr().out
@pytest.mark.asyncio @pytest.mark.anyio
async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch): async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(message, "_run_preprocessors", set()) m.setattr(message, "_run_preprocessors", set())
@ -239,7 +239,7 @@ async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "run_preprocessor should runned" assert runned, "run_preprocessor should runned"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch): async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(message, "_run_preprocessors", set()) m.setattr(message, "_run_preprocessors", set())
@ -269,7 +269,7 @@ async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch
assert not runned, "matcher should not runned" assert not runned, "matcher should not runned"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_run_preprocessor_exception( async def test_run_preprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
): ):
@ -313,7 +313,7 @@ async def test_run_preprocessor_exception(
assert "RuntimeError: test" in capsys.readouterr().out assert "RuntimeError: test" in capsys.readouterr().out
@pytest.mark.asyncio @pytest.mark.anyio
async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch): async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(message, "_run_postprocessors", set()) m.setattr(message, "_run_postprocessors", set())
@ -351,7 +351,7 @@ async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "run_postprocessor should runned" assert runned, "run_postprocessor should runned"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_run_postprocessor_exception( async def test_run_postprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
): ):

View File

@ -17,14 +17,12 @@ from nonebot.compat import (
) )
@pytest.mark.asyncio def test_default_config():
async def test_default_config():
assert DEFAULT_CONFIG.get("extra") == "allow" assert DEFAULT_CONFIG.get("extra") == "allow"
assert DEFAULT_CONFIG.get("arbitrary_types_allowed") is True assert DEFAULT_CONFIG.get("arbitrary_types_allowed") is True
@pytest.mark.asyncio def test_field_info():
async def test_field_info():
# required should be convert to PydanticUndefined # required should be convert to PydanticUndefined
assert FieldInfo(Required).default is PydanticUndefined assert FieldInfo(Required).default is PydanticUndefined
@ -32,8 +30,7 @@ async def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test" assert FieldInfo(test="test").extra["test"] == "test"
@pytest.mark.asyncio def test_type_adapter():
async def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)]) t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
assert t.validate_python(2) == 2 assert t.validate_python(2) == 2
@ -47,8 +44,7 @@ async def test_type_adapter():
t.validate_json("0") t.validate_json("0")
@pytest.mark.asyncio def test_model_dump():
async def test_model_dump():
class TestModel(BaseModel): class TestModel(BaseModel):
test1: int test1: int
test2: int test2: int
@ -57,8 +53,7 @@ async def test_model_dump():
assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2} assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2}
@pytest.mark.asyncio def test_custom_validation():
async def test_custom_validation():
called = [] called = []
@custom_validation @custom_validation
@ -85,8 +80,7 @@ async def test_custom_validation():
assert called == [1, 2] assert called == [1, 2]
@pytest.mark.asyncio def test_validate_json():
async def test_validate_json():
class TestModel(BaseModel): class TestModel(BaseModel):
test1: int test1: int
test2: str test2: str

View File

@ -50,16 +50,14 @@ class ExampleWithoutDelimiter(Example):
env_nested_delimiter = None env_nested_delimiter = None
@pytest.mark.asyncio def test_config_no_env():
async def test_config_no_env():
config = Example(_env_file=None) config = Example(_env_file=None)
assert config.simple == "" assert config.simple == ""
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
config.common_config config.common_config
@pytest.mark.asyncio def test_config_with_env():
async def test_config_with_env():
config = Example(_env_file=(".env", ".env.example")) config = Example(_env_file=(".env", ".env.example"))
assert config.simple == "simple" assert config.simple == "simple"
@ -102,8 +100,7 @@ async def test_config_with_env():
config.other_nested_inner__b config.other_nested_inner__b
@pytest.mark.asyncio def test_config_error_env():
async def test_config_error_env():
with pytest.MonkeyPatch().context() as m: with pytest.MonkeyPatch().context() as m:
m.setenv("COMPLEX", "not json") m.setenv("COMPLEX", "not json")
@ -111,8 +108,7 @@ async def test_config_error_env():
Example(_env_file=(".env", ".env.example")) Example(_env_file=(".env", ".env.example"))
@pytest.mark.asyncio def test_config_without_delimiter():
async def test_config_without_delimiter():
config = ExampleWithoutDelimiter() config = ExampleWithoutDelimiter()
assert config.nested.a == 1 assert config.nested.a == 1
assert config.nested.b == 0 assert config.nested.b == 0

View File

@ -1,8 +1,8 @@
import json import json
import asyncio
from typing import Any, Optional from typing import Any, Optional
from http.cookies import SimpleCookie from http.cookies import SimpleCookie
import anyio
import pytest import pytest
from nonebug import App from nonebug import App
@ -25,7 +25,7 @@ from nonebot.drivers import (
) )
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True "driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
) )
@ -59,22 +59,22 @@ async def test_lifespan(driver: Driver):
@driver.on_shutdown @driver.on_shutdown
async def _shutdown1(): async def _shutdown1():
assert shutdown_log == [] assert shutdown_log == [2]
shutdown_log.append(1) shutdown_log.append(1)
@driver.on_shutdown @driver.on_shutdown
async def _shutdown2(): async def _shutdown2():
assert shutdown_log == [1] assert shutdown_log == []
shutdown_log.append(2) shutdown_log.append(2)
async with driver._lifespan: async with driver._lifespan:
assert start_log == [1, 2] assert start_log == [1, 2]
assert ready_log == [1, 2] assert ready_log == [1, 2]
assert shutdown_log == [1, 2] assert shutdown_log == [2, 1]
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -99,10 +99,10 @@ async def test_http_server(app: App, driver: Driver):
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "test" assert response.text == "test"
await asyncio.sleep(1) await anyio.sleep(1)
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -155,10 +155,10 @@ async def test_websocket_server(app: App, driver: Driver):
await ws.close(code=1000) await ws.close(code=1000)
await asyncio.sleep(1) await anyio.sleep(1)
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -171,9 +171,10 @@ async def test_cross_context(app: App, driver: Driver):
assert isinstance(driver, ASGIMixin) assert isinstance(driver, ASGIMixin)
ws: Optional[WebSocket] = None ws: Optional[WebSocket] = None
ws_ready = asyncio.Event() ws_ready = anyio.Event()
ws_should_close = asyncio.Event() ws_should_close = anyio.Event()
# create a background task before the ws connection established
async def background_task(): async def background_task():
try: try:
await ws_ready.wait() await ws_ready.wait()
@ -185,8 +186,6 @@ async def test_cross_context(app: App, driver: Driver):
finally: finally:
ws_should_close.set() ws_should_close.set()
task = asyncio.create_task(background_task())
async def _handle_ws(websocket: WebSocket) -> None: async def _handle_ws(websocket: WebSocket) -> None:
nonlocal ws nonlocal ws
await websocket.accept() await websocket.accept()
@ -199,7 +198,9 @@ async def test_cross_context(app: App, driver: Driver):
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup) driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx: async with anyio.create_task_group() as tg, app.test_server(driver.asgi) as ctx:
tg.start_soon(background_task)
client = ctx.get_client() client = ctx.get_client()
async with client.websocket_connect("/ws_test") as websocket: async with client.websocket_connect("/ws_test") as websocket:
@ -211,11 +212,10 @@ async def test_cross_context(app: App, driver: Driver):
if not e.args or "websocket.close" not in str(e.args[0]): if not e.args or "websocket.close" not in str(e.args[0]):
raise raise
await task await anyio.sleep(1)
await asyncio.sleep(1)
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -304,10 +304,10 @@ async def test_http_client(driver: Driver, server_url: URL):
"test3": "test", "test3": "test",
}, "file parsing error" }, "file parsing error"
await asyncio.sleep(1) await anyio.sleep(1)
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -419,10 +419,10 @@ async def test_http_client_session(driver: Driver, server_url: URL):
"test3": "test", "test3": "test",
}, "file parsing error" }, "file parsing error"
await asyncio.sleep(1) await anyio.sleep(1)
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
[ [
@ -452,10 +452,9 @@ async def test_websocket_client(driver: Driver, server_url: URL):
with pytest.raises(WebSocketClosed, match=r"code=1000"): with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive() await ws.receive()
await asyncio.sleep(1) await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("driver", "driver_type"), ("driver", "driver_type"),
[ [
@ -472,11 +471,11 @@ async def test_websocket_client(driver: Driver, server_url: URL):
], ],
indirect=["driver"], indirect=["driver"],
) )
async def test_combine_driver(driver: Driver, driver_type: str): def test_combine_driver(driver: Driver, driver_type: str):
assert driver.type == driver_type assert driver.type == driver_type
@pytest.mark.asyncio @pytest.mark.anyio
async def test_bot_connect_hook(app: App, driver: Driver): async def test_bot_connect_hook(app: App, driver: Driver):
with pytest.MonkeyPatch.context() as m: with pytest.MonkeyPatch.context() as m:
conn_hooks: set[Dependent[Any]] = set() conn_hooks: set[Dependent[Any]] = set()
@ -533,7 +532,7 @@ async def test_bot_connect_hook(app: App, driver: Driver):
async with app.test_api() as ctx: async with app.test_api() as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
await asyncio.sleep(1) await anyio.sleep(1)
if not conn_should_be_called: if not conn_should_be_called:
pytest.fail("on_bot_connect hook not called") pytest.fail("on_bot_connect hook not called")

View File

@ -4,7 +4,7 @@ from nonebug import App
from utils import FakeMessage, FakeMessageSegment, make_fake_event from utils import FakeMessage, FakeMessageSegment, make_fake_event
@pytest.mark.asyncio @pytest.mark.anyio
async def test_echo(app: App): async def test_echo(app: App):
from nonebot.plugins.echo import echo from nonebot.plugins.echo import echo

View File

@ -14,8 +14,7 @@ from nonebot import (
) )
@pytest.mark.asyncio def test_init():
async def test_init():
env = nonebot.get_driver().env env = nonebot.get_driver().env
assert env == "test" assert env == "test"
@ -35,31 +34,28 @@ async def test_init():
assert config.not_nested == "some string" assert config.not_nested == "some string"
@pytest.mark.asyncio def test_get_driver(monkeypatch: pytest.MonkeyPatch):
async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(nonebot, "_driver", None) m.setattr(nonebot, "_driver", None)
with pytest.raises(ValueError, match="initialized"): with pytest.raises(ValueError, match="initialized"):
get_driver() get_driver()
@pytest.mark.asyncio def test_get_asgi():
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver() driver = get_driver()
assert isinstance(driver, ReverseDriver) assert isinstance(driver, ReverseDriver)
assert isinstance(driver, ASGIMixin) assert isinstance(driver, ASGIMixin)
assert get_asgi() == driver.asgi assert get_asgi() == driver.asgi
@pytest.mark.asyncio def test_get_app():
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver() driver = get_driver()
assert isinstance(driver, ReverseDriver) assert isinstance(driver, ReverseDriver)
assert isinstance(driver, ASGIMixin) assert isinstance(driver, ASGIMixin)
assert get_app() == driver.server_app assert get_app() == driver.server_app
@pytest.mark.asyncio @pytest.mark.anyio
async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch): async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
async with app.test_api() as ctx: async with app.test_api() as ctx:
adapter = ctx.create_adapter() adapter = ctx.create_adapter()
@ -74,8 +70,7 @@ async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
get_adapter("not exist") get_adapter("not exist")
@pytest.mark.asyncio def test_run(monkeypatch: pytest.MonkeyPatch):
async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
runned = False runned = False
def mock_run(*args, **kwargs): def mock_run(*args, **kwargs):
@ -93,8 +88,7 @@ async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned assert runned
@pytest.mark.asyncio def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
async def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver() driver = get_driver()
with pytest.raises(ValueError, match="no bots"): with pytest.raises(ValueError, match="no bots"):

View File

@ -12,8 +12,7 @@ from nonebot.permission import User, Permission
from nonebot.message import _check_matcher, check_and_run_matcher from nonebot.message import _check_matcher, check_and_run_matcher
@pytest.mark.asyncio def test_matcher_info(app: App):
async def test_matcher_info(app: App):
from plugins.matcher.matcher_info import matcher from plugins.matcher.matcher_info import matcher
assert issubclass(matcher, Matcher) assert issubclass(matcher, Matcher)
@ -43,7 +42,7 @@ async def test_matcher_info(app: App):
assert matcher._source.lineno == 3 assert matcher._source.lineno == 3
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_check(app: App): async def test_matcher_check(app: App):
async def falsy(): async def falsy():
return False return False
@ -87,7 +86,7 @@ async def test_matcher_check(app: App):
assert await _check_matcher(test_rule_error, bot, event, {}) is False assert await _check_matcher(test_rule_error, bot, event, {}) is False
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_handle(app: App): async def test_matcher_handle(app: App):
from plugins.matcher.matcher_process import test_handle from plugins.matcher.matcher_process import test_handle
@ -102,7 +101,7 @@ async def test_matcher_handle(app: App):
ctx.should_finished() ctx.should_finished()
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_got(app: App): async def test_matcher_got(app: App):
from plugins.matcher.matcher_process import test_got from plugins.matcher.matcher_process import test_got
@ -124,7 +123,7 @@ async def test_matcher_got(app: App):
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_receive(app: App): async def test_matcher_receive(app: App):
from plugins.matcher.matcher_process import test_receive from plugins.matcher.matcher_process import test_receive
@ -141,7 +140,7 @@ async def test_matcher_receive(app: App):
ctx.should_paused() ctx.should_paused()
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_combine(app: App): async def test_matcher_combine(app: App):
from plugins.matcher.matcher_process import test_combine from plugins.matcher.matcher_process import test_combine
@ -164,7 +163,7 @@ async def test_matcher_combine(app: App):
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_preset(app: App): async def test_matcher_preset(app: App):
from plugins.matcher.matcher_process import test_preset from plugins.matcher.matcher_process import test_preset
@ -182,7 +181,7 @@ async def test_matcher_preset(app: App):
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_overload(app: App): async def test_matcher_overload(app: App):
from plugins.matcher.matcher_process import test_overload from plugins.matcher.matcher_process import test_overload
@ -196,7 +195,7 @@ async def test_matcher_overload(app: App):
ctx.should_finished() ctx.should_finished()
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_destroy(app: App): async def test_matcher_destroy(app: App):
from plugins.matcher.matcher_process import test_destroy from plugins.matcher.matcher_process import test_destroy
@ -210,7 +209,7 @@ async def test_matcher_destroy(app: App):
assert len(matchers[test_destroy.priority]) == 0 assert len(matchers[test_destroy.priority]) == 0
@pytest.mark.asyncio @pytest.mark.anyio
async def test_type_updater(app: App): async def test_type_updater(app: App):
from plugins.matcher.matcher_type import test_type_updater, test_custom_updater from plugins.matcher.matcher_type import test_type_updater, test_custom_updater
@ -231,7 +230,7 @@ async def test_type_updater(app: App):
assert new_type == "custom" assert new_type == "custom"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_default_permission_updater(app: App): async def test_default_permission_updater(app: App):
from plugins.matcher.matcher_permission import ( from plugins.matcher.matcher_permission import (
default_permission, default_permission,
@ -252,7 +251,7 @@ async def test_default_permission_updater(app: App):
assert checker.perm is default_permission assert checker.perm is default_permission
@pytest.mark.asyncio @pytest.mark.anyio
async def test_user_permission_updater(app: App): async def test_user_permission_updater(app: App):
from plugins.matcher.matcher_permission import ( from plugins.matcher.matcher_permission import (
default_permission, default_permission,
@ -274,7 +273,7 @@ async def test_user_permission_updater(app: App):
assert checker.perm is default_permission assert checker.perm is default_permission
@pytest.mark.asyncio @pytest.mark.anyio
async def test_custom_permission_updater(app: App): async def test_custom_permission_updater(app: App):
from plugins.matcher.matcher_permission import ( from plugins.matcher.matcher_permission import (
new_permission, new_permission,
@ -291,7 +290,7 @@ async def test_custom_permission_updater(app: App):
assert new_perm is new_permission assert new_perm is new_permission
@pytest.mark.asyncio @pytest.mark.anyio
async def test_run(app: App): async def test_run(app: App):
with app.provider.context({}): with app.provider.context({}):
assert not matchers assert not matchers
@ -322,37 +321,46 @@ async def test_run(app: App):
assert len(matchers[0][0].handlers) == 0 assert len(matchers[0][0].handlers) == 0
@pytest.mark.asyncio @pytest.mark.anyio
async def test_temp(app: App): async def test_temp(app: App):
from plugins.matcher.matcher_expire import test_temp_matcher from plugins.matcher.matcher_expire import test_temp_matcher
event = make_fake_event(_type="test")() event = make_fake_event(_type="test")()
async with app.test_api() as ctx: with app.provider.context({test_temp_matcher.priority: [test_temp_matcher]}):
bot = ctx.create_bot() async with app.test_api() as ctx:
assert test_temp_matcher in matchers[test_temp_matcher.priority] bot = ctx.create_bot()
await check_and_run_matcher(test_temp_matcher, bot, event, {}) assert test_temp_matcher in matchers[test_temp_matcher.priority]
assert test_temp_matcher not in matchers[test_temp_matcher.priority] await check_and_run_matcher(test_temp_matcher, bot, event, {})
assert test_temp_matcher not in matchers[test_temp_matcher.priority]
@pytest.mark.asyncio @pytest.mark.anyio
async def test_datetime_expire(app: App): async def test_datetime_expire(app: App):
from plugins.matcher.matcher_expire import test_datetime_matcher from plugins.matcher.matcher_expire import test_datetime_matcher
event = make_fake_event()() event = make_fake_event()()
async with app.test_api() as ctx: with app.provider.context(
bot = ctx.create_bot() {test_datetime_matcher.priority: [test_datetime_matcher]}
assert test_datetime_matcher in matchers[test_datetime_matcher.priority] ):
await check_and_run_matcher(test_datetime_matcher, bot, event, {}) async with app.test_matcher(test_datetime_matcher) as ctx:
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority] bot = ctx.create_bot()
assert test_datetime_matcher in matchers[test_datetime_matcher.priority]
await check_and_run_matcher(test_datetime_matcher, bot, event, {})
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
@pytest.mark.asyncio @pytest.mark.anyio
async def test_timedelta_expire(app: App): async def test_timedelta_expire(app: App):
from plugins.matcher.matcher_expire import test_timedelta_matcher from plugins.matcher.matcher_expire import test_timedelta_matcher
event = make_fake_event()() event = make_fake_event()()
async with app.test_api() as ctx: with app.provider.context(
bot = ctx.create_bot() {test_timedelta_matcher.priority: [test_timedelta_matcher]}
assert test_timedelta_matcher in matchers[test_timedelta_matcher.priority] ):
await check_and_run_matcher(test_timedelta_matcher, bot, event, {}) async with app.test_api() as ctx:
assert test_timedelta_matcher not in matchers[test_timedelta_matcher.priority] bot = ctx.create_bot()
assert test_timedelta_matcher in matchers[test_timedelta_matcher.priority]
await check_and_run_matcher(test_timedelta_matcher, bot, event, {})
assert (
test_timedelta_matcher not in matchers[test_timedelta_matcher.priority]
)

View File

@ -1,11 +1,9 @@
import pytest
from nonebug import App from nonebug import App
from nonebot.matcher import DEFAULT_PROVIDER_CLASS, matchers from nonebot.matcher import DEFAULT_PROVIDER_CLASS, matchers
@pytest.mark.asyncio def test_manager(app: App):
async def test_manager(app: App):
try: try:
default_provider = matchers.provider default_provider = matchers.provider
matchers.set_provider(DEFAULT_PROVIDER_CLASS) matchers.set_provider(DEFAULT_PROVIDER_CLASS)

View File

@ -2,6 +2,7 @@ import re
import pytest import pytest
from nonebug import App from nonebug import App
from exceptiongroup import BaseExceptionGroup
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
@ -36,7 +37,7 @@ from nonebot.consts import (
UNKNOWN_PARAM = "Unknown parameter" UNKNOWN_PARAM = "Unknown parameter"
@pytest.mark.asyncio @pytest.mark.anyio
async def test_depend(app: App): async def test_depend(app: App):
from plugins.param.param_depend import ( from plugins.param.param_depend import (
ClassDependency, ClassDependency,
@ -90,36 +91,47 @@ async def test_depend(app: App):
assert runned == [1, 1, 1] assert runned == [1, 1, 1]
runned.clear()
async with app.test_dependent( async with app.test_dependent(
annotated_class_depend, allow_types=[DependParam] annotated_class_depend, allow_types=[DependParam]
) as ctx: ) as ctx:
ctx.should_return(ClassDependency(x=1, y=2)) ctx.should_return(ClassDependency(x=1, y=2))
with pytest.raises(TypeMisMatch): # noqa: PT012 with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: # noqa: PT012
async with app.test_dependent( async with app.test_dependent(
sub_type_mismatch, allow_types=[DependParam, BotParam] sub_type_mismatch, allow_types=[DependParam, BotParam]
) as ctx: ) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(validate, allow_types=[DependParam]) as ctx: async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
ctx.should_return(1) ctx.should_return(1)
with pytest.raises(TypeMisMatch): with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx: async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
... ...
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx: async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
ctx.should_return(1) ctx.should_return(1)
with pytest.raises(TypeMisMatch): with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent( async with app.test_dependent(
validate_field_fail, allow_types=[DependParam] validate_field_fail, allow_types=[DependParam]
) as ctx: ) as ctx:
... ...
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot(app: App): async def test_bot(app: App):
from plugins.param.param_bot import ( from plugins.param.param_bot import (
FooBot, FooBot,
@ -157,11 +169,14 @@ async def test_bot(app: App):
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
ctx.should_return(bot) ctx.should_return(bot)
with pytest.raises(TypeMisMatch): # noqa: PT012 with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: # noqa: PT012
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx: async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx: async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot(base=FooBot) bot = ctx.create_bot(base=FooBot)
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
@ -181,7 +196,7 @@ async def test_bot(app: App):
app.test_dependent(not_bot, allow_types=[BotParam]) app.test_dependent(not_bot, allow_types=[BotParam])
@pytest.mark.asyncio @pytest.mark.anyio
async def test_event(app: App): async def test_event(app: App):
from plugins.param.param_event import ( from plugins.param.param_event import (
FooEvent, FooEvent,
@ -223,10 +238,13 @@ async def test_event(app: App):
ctx.pass_params(event=fake_fooevent) ctx.pass_params(event=fake_fooevent)
ctx.should_return(fake_fooevent) ctx.should_return(fake_fooevent)
with pytest.raises(TypeMisMatch): with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event) ctx.pass_params(event=fake_event)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx: async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_fooevent) ctx.pass_params(event=fake_fooevent)
ctx.should_return(fake_fooevent) ctx.should_return(fake_fooevent)
@ -267,7 +285,7 @@ async def test_event(app: App):
ctx.should_return(fake_event.is_tome()) ctx.should_return(fake_event.is_tome())
@pytest.mark.asyncio @pytest.mark.anyio
async def test_state(app: App): async def test_state(app: App):
from plugins.param.param_state import ( from plugins.param.param_state import (
state, state,
@ -418,7 +436,7 @@ async def test_state(app: App):
ctx.should_return(fake_state[KEYWORD_KEY]) ctx.should_return(fake_state[KEYWORD_KEY])
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher(app: App): async def test_matcher(app: App):
from plugins.param.param_matcher import ( from plugins.param.param_matcher import (
FooMatcher, FooMatcher,
@ -457,10 +475,13 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=foo_matcher) ctx.pass_params(matcher=foo_matcher)
ctx.should_return(foo_matcher) ctx.should_return(foo_matcher)
with pytest.raises(TypeMisMatch): with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx: async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher) ctx.pass_params(matcher=fake_matcher)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(union_matcher, allow_types=[MatcherParam]) as ctx: async with app.test_dependent(union_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=foo_matcher) ctx.pass_params(matcher=foo_matcher)
ctx.should_return(foo_matcher) ctx.should_return(foo_matcher)
@ -496,7 +517,7 @@ async def test_matcher(app: App):
ctx.should_return(event_next) ctx.should_return(event_next)
@pytest.mark.asyncio @pytest.mark.anyio
async def test_arg(app: App): async def test_arg(app: App):
from plugins.param.param_arg import ( from plugins.param.param_arg import (
arg, arg,
@ -548,7 +569,7 @@ async def test_arg(app: App):
ctx.should_return(message.extract_plain_text()) ctx.should_return(message.extract_plain_text())
@pytest.mark.asyncio @pytest.mark.anyio
async def test_exception(app: App): async def test_exception(app: App):
from plugins.param.param_exception import exc, legacy_exc from plugins.param.param_exception import exc, legacy_exc
@ -562,7 +583,7 @@ async def test_exception(app: App):
ctx.should_return(exception) ctx.should_return(exception)
@pytest.mark.asyncio @pytest.mark.anyio
async def test_default(app: App): async def test_default(app: App):
from plugins.param.param_default import default from plugins.param.param_default import default
@ -570,8 +591,7 @@ async def test_default(app: App):
ctx.should_return(1) ctx.should_return(1)
@pytest.mark.asyncio def test_priority():
async def test_priority():
from plugins.param.priority import complex_priority from plugins.param.priority import complex_priority
dependent = Dependent[None].parse( dependent = Dependent[None].parse(

View File

@ -22,7 +22,7 @@ from nonebot.permission import (
) )
@pytest.mark.asyncio @pytest.mark.anyio
async def test_permission(app: App): async def test_permission(app: App):
async def falsy(): async def falsy():
return False return False
@ -54,7 +54,7 @@ async def test_permission(app: App):
assert await Permission(truthy, skipped)(bot, event) is True assert await Permission(truthy, skipped)(bot, event) is True
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize(("type", "expected"), [("message", True), ("notice", False)]) @pytest.mark.parametrize(("type", "expected"), [("message", True), ("notice", False)])
async def test_message(type: str, expected: bool): async def test_message(type: str, expected: bool):
dependent = next(iter(MESSAGE.checkers)) dependent = next(iter(MESSAGE.checkers))
@ -66,7 +66,7 @@ async def test_message(type: str, expected: bool):
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("notice", True)]) @pytest.mark.parametrize(("type", "expected"), [("message", False), ("notice", True)])
async def test_notice(type: str, expected: bool): async def test_notice(type: str, expected: bool):
dependent = next(iter(NOTICE.checkers)) dependent = next(iter(NOTICE.checkers))
@ -78,7 +78,7 @@ async def test_notice(type: str, expected: bool):
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("request", True)]) @pytest.mark.parametrize(("type", "expected"), [("message", False), ("request", True)])
async def test_request(type: str, expected: bool): async def test_request(type: str, expected: bool):
dependent = next(iter(REQUEST.checkers)) dependent = next(iter(REQUEST.checkers))
@ -90,7 +90,7 @@ async def test_request(type: str, expected: bool):
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("type", "expected"), [("message", False), ("meta_event", True)] ("type", "expected"), [("message", False), ("meta_event", True)]
) )
@ -104,7 +104,7 @@ async def test_metaevent(type: str, expected: bool):
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("type", "user_id", "expected"), ("type", "user_id", "expected"),
[ [
@ -128,7 +128,7 @@ async def test_superuser(app: App, type: str, user_id: str, expected: bool):
assert await dependent(bot=bot, event=event) == expected assert await dependent(bot=bot, event=event) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("session_ids", "session_id", "expected"), ("session_ids", "session_id", "expected"),
[ [

View File

@ -1,12 +1,10 @@
import pytest
from pydantic import BaseModel from pydantic import BaseModel
import nonebot import nonebot
from nonebot.plugin import PluginManager, _managers from nonebot.plugin import PluginManager, _managers
@pytest.mark.asyncio def test_get_plugin():
async def test_get_plugin():
# check simple plugin # check simple plugin
plugin = nonebot.get_plugin("export") plugin = nonebot.get_plugin("export")
assert plugin assert plugin
@ -22,8 +20,7 @@ async def test_get_plugin():
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin" assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
@pytest.mark.asyncio def test_get_plugin_by_module_name():
async def test_get_plugin_by_module_name():
# check get plugin by exact module name # check get plugin by exact module name
plugin = nonebot.get_plugin_by_module_name("plugins.nested") plugin = nonebot.get_plugin_by_module_name("plugins.nested")
assert plugin assert plugin
@ -48,8 +45,7 @@ async def test_get_plugin_by_module_name():
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin" assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
@pytest.mark.asyncio def test_get_available_plugin():
async def test_get_available_plugin():
old_managers = _managers.copy() old_managers = _managers.copy()
_managers.clear() _managers.clear()
try: try:
@ -63,8 +59,7 @@ async def test_get_available_plugin():
_managers.extend(old_managers) _managers.extend(old_managers)
@pytest.mark.asyncio def test_get_plugin_config():
async def test_get_plugin_config():
class Config(BaseModel): class Config(BaseModel):
plugin_config: int plugin_config: int

View File

@ -1,15 +1,44 @@
import sys import sys
from pathlib import Path from pathlib import Path
from functools import wraps
from dataclasses import asdict from dataclasses import asdict
from typing import TypeVar, Callable
from typing_extensions import ParamSpec
import pytest import pytest
import nonebot import nonebot
from nonebot.plugin import Plugin, PluginManager, _managers, inherit_supported_adapters from nonebot.plugin import (
Plugin,
PluginManager,
_plugins,
_managers,
inherit_supported_adapters,
)
P = ParamSpec("P")
R = TypeVar("R")
@pytest.mark.asyncio def _recover(func: Callable[P, R]) -> Callable[P, R]:
async def test_load_plugin():
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
origin_managers = _managers.copy()
origin_plugins = _plugins.copy()
try:
return func(*args, **kwargs)
finally:
_managers.clear()
_managers.extend(origin_managers)
_plugins.clear()
_plugins.update(origin_plugins)
return _wrapper
@_recover
def test_load_plugin():
# check regular # check regular
assert nonebot.load_plugin("dynamic.simple") assert nonebot.load_plugin("dynamic.simple")
@ -20,8 +49,7 @@ async def test_load_plugin():
assert nonebot.load_plugin("some_plugin_not_exist") is None assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[Plugin]):
async def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[Plugin]):
loaded_plugins = { loaded_plugins = {
plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
} }
@ -44,8 +72,7 @@ async def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[P
PluginManager(search_path=["plugins"]).load_all_plugins() PluginManager(search_path=["plugins"]).load_all_plugins()
@pytest.mark.asyncio def test_load_nested_plugin():
async def test_load_nested_plugin():
parent_plugin = nonebot.get_plugin("nested") parent_plugin = nonebot.get_plugin("nested")
sub_plugin = nonebot.get_plugin("nested:nested_subplugin") sub_plugin = nonebot.get_plugin("nested:nested_subplugin")
sub_plugin2 = nonebot.get_plugin("nested:nested_subplugin2") sub_plugin2 = nonebot.get_plugin("nested:nested_subplugin2")
@ -57,16 +84,16 @@ async def test_load_nested_plugin():
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2} assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
@pytest.mark.asyncio @_recover
async def test_load_json(): def test_load_json():
nonebot.load_from_json("./plugins.json") nonebot.load_from_json("./plugins.json")
with pytest.raises(TypeError): with pytest.raises(TypeError):
nonebot.load_from_json("./plugins.invalid.json") nonebot.load_from_json("./plugins.invalid.json")
@pytest.mark.asyncio @_recover
async def test_load_toml(): def test_load_toml():
nonebot.load_from_toml("./plugins.toml") nonebot.load_from_toml("./plugins.toml")
with pytest.raises(ValueError, match="Cannot find"): with pytest.raises(ValueError, match="Cannot find"):
@ -76,52 +103,54 @@ async def test_load_toml():
nonebot.load_from_toml("./plugins.invalid.toml") nonebot.load_from_toml("./plugins.invalid.toml")
@pytest.mark.asyncio @_recover
async def test_bad_plugin(): def test_bad_plugin():
nonebot.load_plugins("bad_plugins") nonebot.load_plugins("bad_plugins")
assert nonebot.get_plugin("bad_plugin") is None assert nonebot.get_plugin("bad_plugin") is None
@pytest.mark.asyncio @_recover
async def test_require_loaded(monkeypatch: pytest.MonkeyPatch): def test_require_loaded(monkeypatch: pytest.MonkeyPatch):
def _patched_find(name: str): def _patched_find(name: str):
pytest.fail("require existing plugin should not call find_manager_by_name") pytest.fail("require existing plugin should not call find_manager_by_name")
monkeypatch.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find) with monkeypatch.context() as m:
m.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find)
# require use module name # require use module name
nonebot.require("plugins.export") nonebot.require("plugins.export")
# require use plugin id # require use plugin id
nonebot.require("export") nonebot.require("export")
nonebot.require("nested:nested_subplugin") nonebot.require("nested:nested_subplugin")
@pytest.mark.asyncio @_recover
async def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch): def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch):
m = PluginManager(["dynamic.require_not_loaded"], ["dynamic/require_not_loaded/"]) pm = PluginManager(["dynamic.require_not_loaded"], ["dynamic/require_not_loaded/"])
_managers.append(m) _managers.append(pm)
num_managers = len(_managers) num_managers = len(_managers)
origin_load = PluginManager.load_plugin origin_load = PluginManager.load_plugin
def _patched_load(self: PluginManager, name: str): def _patched_load(self: PluginManager, name: str):
assert self is m assert self is pm
return origin_load(self, name) return origin_load(self, name)
monkeypatch.setattr(PluginManager, "load_plugin", _patched_load) with monkeypatch.context() as m:
m.setattr(PluginManager, "load_plugin", _patched_load)
# require standalone plugin # require standalone plugin
nonebot.require("dynamic.require_not_loaded") nonebot.require("dynamic.require_not_loaded")
# require searched plugin # require searched plugin
nonebot.require("dynamic.require_not_loaded.subplugin1") nonebot.require("dynamic.require_not_loaded.subplugin1")
nonebot.require("require_not_loaded:subplugin2") nonebot.require("require_not_loaded:subplugin2")
assert len(_managers) == num_managers assert len(_managers) == num_managers
@pytest.mark.asyncio @_recover
async def test_require_not_declared(): def test_require_not_declared():
num_managers = len(_managers) num_managers = len(_managers)
nonebot.require("dynamic.require_not_declared") nonebot.require("dynamic.require_not_declared")
@ -130,14 +159,13 @@ async def test_require_not_declared():
assert _managers[-1].plugins == {"dynamic.require_not_declared"} assert _managers[-1].plugins == {"dynamic.require_not_declared"}
@pytest.mark.asyncio @_recover
async def test_require_not_found(): def test_require_not_found():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
nonebot.require("some_plugin_not_exist") nonebot.require("some_plugin_not_exist")
@pytest.mark.asyncio def test_plugin_metadata():
async def test_plugin_metadata():
from plugins.metadata import Config, FakeAdapter from plugins.metadata import Config, FakeAdapter
plugin = nonebot.get_plugin("metadata") plugin = nonebot.get_plugin("metadata")
@ -157,8 +185,7 @@ async def test_plugin_metadata():
assert plugin.metadata.get_supported_adapters() == {FakeAdapter} assert plugin.metadata.get_supported_adapters() == {FakeAdapter}
@pytest.mark.asyncio def test_inherit_supported_adapters_not_found():
async def test_inherit_supported_adapters_not_found():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
inherit_supported_adapters("some_plugin_not_exist") inherit_supported_adapters("some_plugin_not_exist")
@ -166,7 +193,6 @@ async def test_inherit_supported_adapters_not_found():
inherit_supported_adapters("export") inherit_supported_adapters("export")
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("inherit_plugins", "expected"), ("inherit_plugins", "expected"),
[ [
@ -233,7 +259,7 @@ async def test_inherit_supported_adapters_not_found():
), ),
], ],
) )
async def test_inherit_supported_adapters_combine( def test_inherit_supported_adapters_combine(
inherit_plugins: tuple[str], expected: set[str] inherit_plugins: tuple[str], expected: set[str]
): ):
assert inherit_supported_adapters(*inherit_plugins) == expected assert inherit_supported_adapters(*inherit_plugins) == expected

View File

@ -1,17 +1,17 @@
import pytest
from nonebot.plugin import PluginManager, _managers from nonebot.plugin import PluginManager, _managers
@pytest.mark.asyncio def test_load_plugin_name():
async def test_load_plugin_name():
m = PluginManager(plugins=["dynamic.manager"]) m = PluginManager(plugins=["dynamic.manager"])
_managers.append(m) try:
_managers.append(m)
# load by plugin id # load by plugin id
module1 = m.load_plugin("manager") module1 = m.load_plugin("manager")
# load by module name # load by module name
module2 = m.load_plugin("dynamic.manager") module2 = m.load_plugin("dynamic.manager")
assert module1 assert module1
assert module2 assert module2
assert module1 is module2 assert module1 is module2
finally:
_managers.remove(m)

View File

@ -18,7 +18,6 @@ from nonebot.rule import (
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("matcher_name", "pre_rule_factory", "has_permission"), ("matcher_name", "pre_rule_factory", "has_permission"),
[ [
@ -102,7 +101,7 @@ from nonebot.rule import (
pytest.param("matcher_group_on_type", lambda e: IsTypeRule(e), True), pytest.param("matcher_group_on_type", lambda e: IsTypeRule(e), True),
], ],
) )
async def test_on( def test_on(
matcher_name: str, matcher_name: str,
pre_rule_factory: Optional[Callable[[type[Event]], T_RuleChecker]], pre_rule_factory: Optional[Callable[[type[Event]], T_RuleChecker]],
has_permission: bool, has_permission: bool,
@ -150,8 +149,7 @@ async def test_on(
assert matcher.module_name == "plugins.plugin.matchers" assert matcher.module_name == "plugins.plugin.matchers"
@pytest.mark.asyncio def test_runtime_on():
async def test_runtime_on():
import plugins.plugin.matchers as module import plugins.plugin.matchers as module
from plugins.plugin.matchers import matcher_on_factory from plugins.plugin.matchers import matcher_on_factory

View File

@ -49,7 +49,7 @@ from nonebot.rule import (
) )
@pytest.mark.asyncio @pytest.mark.anyio
async def test_rule(app: App): async def test_rule(app: App):
async def falsy(): async def falsy():
return False return False
@ -81,7 +81,7 @@ async def test_rule(app: App):
assert await Rule(truthy, skipped)(bot, event, {}) is False assert await Rule(truthy, skipped)(bot, event, {}) is False
@pytest.mark.asyncio @pytest.mark.anyio
async def test_trie(app: App): async def test_trie(app: App):
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",))) TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
@ -146,7 +146,7 @@ async def test_trie(app: App):
del TrieRule.prefix["/fake-prefix"] del TrieRule.prefix["/fake-prefix"]
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("msg", "ignorecase", "type", "text", "expected"), ("msg", "ignorecase", "type", "text", "expected"),
[ [
@ -186,7 +186,7 @@ async def test_startswith(
assert await dependent(event=event, state=state) == expected assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("msg", "ignorecase", "type", "text", "expected"), ("msg", "ignorecase", "type", "text", "expected"),
[ [
@ -226,7 +226,7 @@ async def test_endswith(
assert await dependent(event=event, state=state) == expected assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("msg", "ignorecase", "type", "text", "expected"), ("msg", "ignorecase", "type", "text", "expected"),
[ [
@ -266,7 +266,7 @@ async def test_fullmatch(
assert await dependent(event=event, state=state) == expected assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kws", "type", "text", "expected"), ("kws", "type", "text", "expected"),
[ [
@ -298,7 +298,7 @@ async def test_keyword(
assert await dependent(event=event, state=state) == expected assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("cmds", "force_whitespace", "cmd", "whitespace", "arg_text", "expected"), ("cmds", "force_whitespace", "cmd", "whitespace", "arg_text", "expected"),
[ [
@ -344,7 +344,7 @@ async def test_command(
assert await dependent(state=state) == expected assert await dependent(state=state) == expected
@pytest.mark.asyncio @pytest.mark.anyio
async def test_shell_command(): async def test_shell_command():
state: T_State state: T_State
CMD = ("test",) CMD = ("test",)
@ -451,7 +451,7 @@ async def test_shell_command():
assert state[SHELL_ARGS].status != 0 assert state[SHELL_ARGS].status != 0
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
("pattern", "type", "text", "expected", "matched"), ("pattern", "type", "text", "expected", "matched"),
[ [
@ -494,7 +494,7 @@ async def test_regex(
assert result.span() == matched.span() assert result.span() == matched.span()
@pytest.mark.asyncio @pytest.mark.anyio
@pytest.mark.parametrize("expected", [True, False]) @pytest.mark.parametrize("expected", [True, False])
async def test_to_me(expected: bool): async def test_to_me(expected: bool):
test_to_me = to_me() test_to_me = to_me()
@ -507,7 +507,7 @@ async def test_to_me(expected: bool):
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@pytest.mark.asyncio @pytest.mark.anyio
async def test_is_type(): async def test_is_type():
Event1 = make_fake_event() Event1 = make_fake_event()
Event2 = make_fake_event() Event2 = make_fake_event()

View File

@ -5,7 +5,7 @@ import pytest
from utils import make_fake_event from utils import make_fake_event
@pytest.mark.asyncio @pytest.mark.anyio
async def test_matcher_mutex(): async def test_matcher_mutex():
from nonebot.plugins.single_session import matcher_mutex, _running_matcher from nonebot.plugins.single_session import matcher_mutex, _running_matcher

View File

@ -14,7 +14,7 @@ NoneBot2 是一个现代、跨平台、可扩展的 Python 聊天机器人框架
### 异步优先 ### 异步优先
NoneBot 基于 Python [asyncio](https://docs.python.org/zh-cn/3/library/asyncio.html) 编写,并在异步机制的基础上进行了一定程度的同步函数兼容。 NoneBot 基于 Python [asyncio](https://docs.python.org/zh-cn/3/library/asyncio.html) / [trio](https://trio.readthedocs.io/en/stable/) 编写,并在异步机制的基础上进行了一定程度的同步函数兼容。
### 完整的类型注解 ### 完整的类型注解