Feature: 迁移至结构化并发框架 AnyIO (#3053)

This commit is contained in:
Ju4tCode 2024-10-26 15:36:01 +08:00 committed by GitHub
parent bd9befbb55
commit ff21ceb946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 5422 additions and 4080 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

1396
envs/test/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,11 +8,11 @@ packages = [{ include = "nonebot-test.py" }]
[tool.poetry.dependencies]
python = "^3.9"
nonebug = "^0.3.7"
trio = "^0.27.0"
nonebug = "^0.4.1"
wsproto = "^1.2.0"
pytest-cov = "^5.0.0"
pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.23.2"
werkzeug = ">=2.3.6,<4.0.0"
coverage-conditional-plugin = "^0.9.0"

View File

@ -8,17 +8,20 @@ FrontMatter:
"""
import abc
import asyncio
import inspect
from functools import partial
from dataclasses import field, dataclass
from collections.abc import Iterable, Awaitable
from typing import Any, Generic, TypeVar, Callable, Optional, cast
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.utils import run_sync, is_coroutine_callable
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
from .utils import check_field_type, get_typed_signature
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
)
async def __call__(self, **kwargs: Any) -> R:
try:
exception: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")
with catch({SkippedException: _handle_skipped}):
# do pre-check
await self.check(**kwargs)
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}")
raise
raise exception
@staticmethod
def parse_params(
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
return cls(call, params, parameterless_params)
async def check(self, **params: Any) -> None:
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
await asyncio.gather(
*(cast(Param, param.field_info)._check(**params) for param in self.params)
)
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
await param._solve(**params)
# solve param values
values = await asyncio.gather(
*(self._solve_field(field, params) for field in self.params)
)
return {field.name: value for field, value in zip(self.params, values)}
result: dict[str, Any] = {}
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
result[field.name] = value
async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)
return result
__autodoc__ = {"CustomConfig": False}

View File

@ -12,14 +12,18 @@ FrontMatter:
"""
import signal
import asyncio
import threading
from typing import Optional
from typing_extensions import override
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.consts import WINDOWS
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import flatten_exception_group
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
@ -35,8 +39,8 @@ class Driver(BaseDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
self.should_exit: anyio.Event = anyio.Event()
self.force_exit: anyio.Event = anyio.Event()
@property
@override
@ -54,85 +58,98 @@ class Driver(BaseDriver):
def run(self, *args, **kwargs):
"""启动 none driver"""
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(self._serve())
anyio.run(self._serve)
async def _serve(self):
self._install_signal_handlers()
await self._startup()
if self.should_exit.is_set():
return
await self._main_loop()
await self._shutdown()
async with anyio.create_task_group() as driver_tg:
driver_tg.start_soon(self._handle_signals)
driver_tg.start_soon(self._listen_force_exit, driver_tg)
driver_tg.start_soon(self._handle_lifespan, driver_tg)
async def _startup(self):
async def _handle_signals(self):
try:
await self._lifespan.startup()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
self.should_exit.set()
return
logger.info("Application startup completed.")
async def _main_loop(self):
await self.should_exit.wait()
async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown.")
try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("Application shutdown complete.")
loop = asyncio.get_event_loop()
loop.stop()
def _install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
loop = asyncio.get_event_loop()
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self._handle_exit, sig, None)
with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signal_receiver:
async for sig in signal_receiver:
self.exit(force=self.should_exit.is_set())
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_exit)
signal.signal(sig, self._handle_legacy_signal)
def _handle_exit(self, sig, frame):
# backport for Windows signal handling
def _handle_legacy_signal(self, sig, frame):
self.exit(force=self.should_exit.is_set())
async def _handle_lifespan(self, tg: TaskGroup):
try:
await self._startup()
if self.should_exit.is_set():
return
await self._listen_exit()
await self._shutdown()
finally:
tg.cancel_scope.cancel()
async def _startup(self):
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
self.should_exit.set()
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running startup hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application startup failed. "
"Exiting.</bg #f8bbd0></r>"
)
with catch({Exception: handle_exception}):
await self._lifespan.startup()
if not self.should_exit.is_set():
logger.info("Application startup completed.")
async def _listen_exit(self, tg: Optional[TaskGroup] = None):
await self.should_exit.wait()
if tg is not None:
tg.cancel_scope.cancel()
async def _shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown. (CTRL+C to force quit)")
error_occurred: bool = False
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
nonlocal error_occurred
error_occurred = True
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error occurred while running shutdown hook."
"</bg #f8bbd0></r>"
)
logger.error(
"<r><bg #f8bbd0>Application shutdown failed. "
"Exiting.</bg #f8bbd0></r>"
)
with catch({Exception: handle_exception}):
await self._lifespan.shutdown()
if not error_occurred:
logger.info("Application shutdown complete.")
async def _listen_force_exit(self, tg: TaskGroup):
await self.force_exit.wait()
tg.cancel_scope.cancel()
def exit(self, force: bool = False):
"""退出 none driver
@ -142,4 +159,4 @@ class Driver(BaseDriver):
if not self.should_exit.is_set():
self.should_exit.set()
if force:
self.force_exit = True
self.force_exit.set()

View File

@ -1,11 +1,14 @@
import abc
import asyncio
from functools import partial
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Config
from nonebot.exception import MockApiException
from nonebot.utils import flatten_exception_group
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
if TYPE_CHECKING:
@ -76,47 +79,98 @@ class Bot(abc.ABC):
skip_calling_api: bool = False
exception: Optional[Exception] = None
if coros := [hook(self, api, data) for hook in self._calling_api_hook]:
try:
logger.debug("Running CallingAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
if self._calling_api_hook:
logger.debug("Running CallingAPI hooks...")
def _handle_mock_api_exception(
exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal skip_calling_api, result
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
skip_calling_api = True
result = e.result
result = excs[0].result
logger.debug(
f"Calling API {api} is cancelled. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
f"Calling API {api} is cancelled. Return {result!r} instead."
)
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._calling_api_hook:
tg.start_soon(hook, self, api, data)
if not skip_calling_api:
try:
result = await self.adapter._call_api(self, api, **data)
except Exception as e:
exception = e
if coros := [
hook(self, exception, api, data, result) for hook in self._called_api_hook
]:
try:
logger.debug("Running CalledAPI hooks...")
await asyncio.gather(*coros)
except MockApiException as e:
# mock api result
result = e.result
# ignore exception
if self._called_api_hook:
logger.debug("Running CalledAPI hooks...")
def _handle_mock_api_exception(
exc_group: BaseExceptionGroup[MockApiException],
) -> None:
nonlocal result, exception
excs = [
exc
for exc in flatten_exception_group(exc_group)
if isinstance(exc, MockApiException)
]
if not excs:
return
elif len(excs) > 1:
logger.warning(
"Multiple hooks want to mock API result. Use the first one."
)
result = excs[0].result
exception = None
logger.debug(
f"Calling API {api} result is mocked. Return {result} instead."
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
"Running cancelled!</bg #f8bbd0></r>"
)
with catch(
{
MockApiException: _handle_mock_api_exception,
Exception: _handle_exception,
}
):
async with anyio.create_task_group() as tg:
for hook in self._called_api_hook:
tg.start_soon(hook, self, exception, api, data, result)
if exception:
raise exception

View File

@ -1,6 +1,11 @@
from collections.abc import Awaitable
from types import TracebackType
from typing_extensions import TypeAlias
from typing import Any, Union, Callable, cast
from collections.abc import Iterable, Awaitable
from typing import Any, Union, Callable, Optional, cast
import anyio
from anyio.abc import TaskGroup
from exceptiongroup import suppress
from nonebot.utils import run_sync, is_coroutine_callable
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
class Lifespan:
def __init__(self) -> None:
self._task_group: Optional[TaskGroup] = None
self._startup_funcs: list[LIFESPAN_FUNC] = []
self._ready_funcs: list[LIFESPAN_FUNC] = []
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
@property
def task_group(self) -> TaskGroup:
if self._task_group is None:
raise RuntimeError("Lifespan not started")
return self._task_group
@task_group.setter
def task_group(self, task_group: TaskGroup) -> None:
if self._task_group is not None:
raise RuntimeError("Lifespan already started")
self._task_group = task_group
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func
@ -29,7 +48,7 @@ class Lifespan:
@staticmethod
async def _run_lifespan_func(
funcs: list[LIFESPAN_FUNC],
funcs: Iterable[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
@ -38,18 +57,44 @@ class Lifespan:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
async def startup(self) -> None:
# create background task group
self.task_group = anyio.create_task_group()
await self.task_group.__aenter__()
# run startup funcs
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)
# run ready funcs
if self._ready_funcs:
await self._run_lifespan_func(self._ready_funcs)
async def shutdown(self) -> None:
async def shutdown(
self,
*,
exc_type: Optional[type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)
# reverse shutdown funcs to ensure stack order
await self._run_lifespan_func(reversed(self._shutdown_funcs))
# shutdown background task group
self.task_group.cancel_scope.cancel()
with suppress(Exception):
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
self._task_group = None
async def __aenter__(self) -> None:
await self.startup()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

View File

@ -1,17 +1,20 @@
import abc
import asyncio
from types import TracebackType
from collections.abc import AsyncGenerator
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
from anyio.abc import TaskGroup
from anyio import CancelScope, create_task_group
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.internal.params import BotParam, DependParam, DefaultParam
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
from nonebot.typing import (
T_DependencyCache,
T_BotConnectionHook,
@ -61,7 +64,6 @@ class Driver(abc.ABC):
self.config: Config = config
"""全局配置对象"""
self._bots: dict[str, "Bot"] = {}
self._bot_tasks: set[asyncio.Task] = set()
self._lifespan = Lifespan()
def __repr__(self) -> str:
@ -75,6 +77,10 @@ class Driver(abc.ABC):
"""获取当前所有已连接的 Bot"""
return self._bots
@property
def task_group(self) -> TaskGroup:
return self._lifespan.task_group
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
"""注册一个协议适配器
@ -112,8 +118,6 @@ class Driver(abc.ABC):
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
)
self.on_shutdown(self._cleanup)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@ -154,66 +158,57 @@ class Driver(abc.ABC):
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._bots[bot.self_id] = bot
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_connection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with AsyncExitStack() as stack, create_task_group() as tg:
for hook in self._bot_connection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
self.task_group.start_soon(_run_hook, bot)
def _bot_disconnect(self, bot: "Bot") -> None:
"""在连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._bots:
del self._bots[bot.self_id]
def handle_exception(exc_group: BaseExceptionGroup) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook:"
"</bg #f8bbd0></r>"
)
async def _run_hook(bot: "Bot") -> None:
dependency_cache: T_DependencyCache = {}
async with AsyncExitStack() as stack:
if coros := [
run_coro_with_catch(
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
(SkippedException,),
)
for hook in self._bot_disconnection_hook
]:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>"
"Error when running WebSocketDisConnection hook. "
"Running cancelled!"
"</bg #f8bbd0></r>"
# shield cancellation to ensure bot disconnect hooks are always run
with CancelScope(shield=True), catch({Exception: handle_exception}):
async with create_task_group() as tg, AsyncExitStack() as stack:
for hook in self._bot_disconnection_hook:
tg.start_soon(
run_coro_with_catch,
hook(
bot=bot, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
)
task = asyncio.create_task(_run_hook(bot))
task.add_done_callback(self._bot_tasks.discard)
self._bot_tasks.add(task)
async def _cleanup(self) -> None:
"""清理驱动器资源"""
if self._bot_tasks:
logger.opt(colors=True).debug(
"<y>Waiting for running bot connection hooks...</y>"
)
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
self.task_group.start_soon(_run_hook, bot)
class Mixin(abc.ABC):

View File

@ -22,11 +22,13 @@ from typing import ( # noqa: UP035
overload,
)
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.internal.rule import Rule
from nonebot.utils import classproperty
from nonebot.dependencies import Param, Dependent
from nonebot.internal.permission import User, Permission
from nonebot.utils import classproperty, flatten_exception_group
from nonebot.internal.adapter import (
Bot,
Event,
@ -812,28 +814,34 @@ class Matcher(metaclass=MatcherMeta):
f"bot={bot}, event={event!r}, state={state!r}"
)
def _handle_stop_propagation(exc_group: BaseExceptionGroup[StopPropagation]):
self.block = True
with self.ensure_context(bot, event):
try:
# Refresh preprocess state
self.state.update(state)
with catch({StopPropagation: _handle_stop_propagation}):
# Refresh preprocess state
self.state.update(state)
while self.remain_handlers:
handler = self.remain_handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
try:
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException:
logger.debug(f"Handler {handler} skipped")
except StopPropagation:
self.block = True
while self.remain_handlers:
handler = self.remain_handlers.pop(0)
current_handler.set(handler)
logger.debug(f"Running handler {handler}")
def _handle_skipped(
exc_group: BaseExceptionGroup[SkippedException],
):
logger.debug(f"Handler {handler} skipped")
with catch({SkippedException: _handle_skipped}):
await handler(
matcher=self,
bot=bot,
event=event,
state=self.state,
stack=stack,
dependency_cache=dependency_cache,
)
finally:
logger.info(f"{self} running complete")
@ -846,10 +854,54 @@ class Matcher(metaclass=MatcherMeta):
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
):
try:
exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
None
)
def _handle_special_exception(
exc_group: BaseExceptionGroup[
Union[FinishedException, RejectedException, PausedException]
]
):
nonlocal exc
excs = list(flatten_exception_group(exc_group))
if len(excs) > 1:
logger.warning(
"Multiple session control exceptions occurred. "
"NoneBot will choose the proper one."
)
finished_exc = next(
(e for e in excs if isinstance(e, FinishedException)),
None,
)
rejected_exc = next(
(e for e in excs if isinstance(e, RejectedException)),
None,
)
paused_exc = next(
(e for e in excs if isinstance(e, PausedException)),
None,
)
exc = finished_exc or rejected_exc or paused_exc
elif isinstance(
excs[0], (FinishedException, RejectedException, PausedException)
):
exc = excs[0]
with catch(
{
(
FinishedException,
RejectedException,
PausedException,
): _handle_special_exception
}
):
await self.simple_run(bot, event, state, stack, dependency_cache)
except RejectedException:
if isinstance(exc, FinishedException):
pass
elif isinstance(exc, RejectedException):
await self.resolve_reject()
type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission(
@ -870,7 +922,7 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except PausedException:
elif isinstance(exc, PausedException):
type_ = await self.update_type(bot, event, stack, dependency_cache)
permission = await self.update_permission(
bot, event, stack, dependency_cache
@ -890,5 +942,3 @@ class Matcher(metaclass=MatcherMeta):
default_type_updater=self.__class__._default_type_updater,
default_permission_updater=self.__class__._default_permission_updater,
)
except FinishedException:
pass

View File

@ -1,5 +1,5 @@
import asyncio
import inspect
from enum import Enum
from typing_extensions import Self, get_args, override, get_origin
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import (
@ -13,8 +13,11 @@ from typing import (
cast,
)
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from pydantic.fields import FieldInfo as PydanticFieldInfo
from nonebot.exception import SkippedException
from nonebot.dependencies import Param, Dependent
from nonebot.dependencies.utils import check_field_type
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
@ -93,6 +96,75 @@ def Depends(
return DependsInner(dependency, use_cache=use_cache, validate=validate)
class CacheState(str, Enum):
"""子依赖缓存状态"""
PENDING = "PENDING"
FINISHED = "FINISHED"
class DependencyCache:
"""子依赖结果缓存。
用于缓存子依赖的结果以避免重复计算
"""
def __init__(self):
self._state = CacheState.PENDING
self._result: Any = None
self._exception: Optional[BaseException] = None
self._waiter = anyio.Event()
def result(self) -> Any:
"""获取子依赖结果"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
if self._exception is not None:
raise self._exception
return self._result
def exception(self) -> Optional[BaseException]:
"""获取子依赖异常"""
if self._state != CacheState.FINISHED:
raise RuntimeError("Result is not ready")
return self._exception
def set_result(self, result: Any) -> None:
"""设置子依赖结果"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._result = result
self._state = CacheState.FINISHED
self._waiter.set()
def set_exception(self, exception: BaseException) -> None:
"""设置子依赖异常"""
if self._state != CacheState.PENDING:
raise RuntimeError(f"Cache state invalid: {self._state}")
self._exception = exception
self._state = CacheState.FINISHED
self._waiter.set()
async def wait(self):
"""等待子依赖结果"""
await self._waiter.wait()
if self._state != CacheState.FINISHED:
raise RuntimeError("Invalid cache state")
if self._exception is not None:
raise self._exception
return self._result
class DependParam(Param):
"""子依赖注入参数。
@ -194,17 +266,27 @@ class DependParam(Param):
call = cast(Callable[..., Any], sub_dependent.call)
# solve sub dependency with current cache
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
exc: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exc
exc = exc_group
with catch({SkippedException: _handle_skipped}):
sub_values = await sub_dependent.solve(
stack=stack,
dependency_cache=dependency_cache,
**kwargs,
)
if exc is not None:
raise exc
# run dependency function
task: asyncio.Task[Any]
if use_cache and call in dependency_cache:
return await dependency_cache[call]
elif is_gen_callable(call) or is_async_gen_callable(call):
return await dependency_cache[call].wait()
if is_gen_callable(call) or is_async_gen_callable(call):
assert isinstance(
stack, AsyncExitStack
), "Generator dependency should be called in context"
@ -212,17 +294,21 @@ class DependParam(Param):
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
else:
cm = asynccontextmanager(call)(**sub_values)
task = asyncio.create_task(stack.enter_async_context(cm))
dependency_cache[call] = task
return await task
target = stack.enter_async_context(cm)
elif is_coroutine_callable(call):
task = asyncio.create_task(call(**sub_values))
dependency_cache[call] = task
return await task
target = call(**sub_values)
else:
task = asyncio.create_task(run_sync(call)(**sub_values))
dependency_cache[call] = task
return await task
target = run_sync(call)(**sub_values)
dependency_cache[call] = cache = DependencyCache()
try:
result = await target
cache.set_result(result)
return result
except BaseException as e:
cache.set_exception(e)
raise
@override
async def _check(self, **kwargs: Any) -> None:

View File

@ -1,8 +1,9 @@
import asyncio
from typing_extensions import Self
from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from nonebot.dependencies import Dependent
from nonebot.utils import run_coro_with_catch
from nonebot.exception import SkippedException
@ -70,22 +71,26 @@ class Permission:
"""
if not self.checkers:
return True
results = await asyncio.gather(
*(
run_coro_with_catch(
checker(
bot=bot,
event=event,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
False,
)
for checker in self.checkers
),
)
return any(results)
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await run_coro_with_catch(
checker(
bot=bot, event=event, stack=stack, dependency_cache=dependency_cache
),
(SkippedException,),
False,
)
result |= is_passed
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: object) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")

View File

@ -1,7 +1,9 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
@ -71,22 +73,33 @@ class Rule:
"""
if not self.checkers:
return True
try:
results = await asyncio.gather(
*(
checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
for checker in self.checkers
)
result = True
def _handle_skipped_exception(
exc_group: BaseExceptionGroup[SkippedException],
) -> None:
nonlocal result
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
except SkippedException:
return False
return all(results)
result &= is_passed
with catch({SkippedException: _handle_skipped_exception}):
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None:

View File

@ -9,23 +9,30 @@ FrontMatter:
description: nonebot.message 模块
"""
import asyncio
import contextlib
from datetime import datetime
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.rule import TrieRule
from nonebot.dependencies import Dependent
from nonebot.matcher import Matcher, matchers
from nonebot.utils import escape_tag, run_coro_with_catch
from nonebot.exception import (
NoLogException,
StopPropagation,
IgnoredException,
SkippedException,
)
from nonebot.utils import (
escape_tag,
run_coro_with_catch,
run_coro_with_shield,
flatten_exception_group,
)
from nonebot.typing import (
T_State,
T_DependencyCache,
@ -125,6 +132,21 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func
def _handle_ignored_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
def _handle(exc_group: BaseExceptionGroup[IgnoredException]) -> None:
logger.opt(colors=True).info(msg)
return _handle
def _handle_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
def _handle(exc_group: BaseExceptionGroup[Exception]) -> None:
for exc in flatten_exception_group(exc_group):
logger.opt(colors=True, exception=exc).error(msg)
return _handle
async def _apply_event_preprocessors(
bot: "Bot",
event: "Event",
@ -152,10 +174,21 @@ async def _apply_event_preprocessors(
if show_log:
logger.debug("Running PreProcessors...")
try:
await asyncio.gather(
*(
run_coro_with_catch(
with catch(
{
IgnoredException: _handle_ignored_exception(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
),
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>"
),
}
):
async with anyio.create_task_group() as tg:
for proc in _event_preprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
bot=bot,
event=event,
@ -165,22 +198,10 @@ async def _apply_event_preprocessors(
),
(SkippedException,),
)
for proc in _event_preprocessors
)
)
except IgnoredException:
logger.opt(colors=True).info(
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
)
return False
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
"Event ignored!</bg #f8bbd0></r>"
)
return False
return True
return True
return False
async def _apply_event_postprocessors(
@ -207,10 +228,17 @@ async def _apply_event_postprocessors(
if show_log:
logger.debug("Running PostProcessors...")
try:
await asyncio.gather(
*(
run_coro_with_catch(
with catch(
{
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
}
):
async with anyio.create_task_group() as tg:
for proc in _event_postprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
bot=bot,
event=event,
@ -220,13 +248,6 @@ async def _apply_event_postprocessors(
),
(SkippedException,),
)
for proc in _event_postprocessors
)
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
)
async def _apply_run_preprocessors(
@ -254,35 +275,38 @@ async def _apply_run_preprocessors(
return True
# ensure matcher function can be correctly called
with matcher.ensure_context(bot, event):
try:
await asyncio.gather(
*(
run_coro_with_catch(
proc(
matcher=matcher,
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
for proc in _run_preprocessors
with (
matcher.ensure_context(bot, event),
catch(
{
IgnoredException: _handle_ignored_exception(
f"{matcher} running is <b>cancelled</b>"
),
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>"
),
}
),
):
async with anyio.create_task_group() as tg:
for proc in _run_preprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
matcher=matcher,
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
)
except IgnoredException:
logger.opt(colors=True).info(f"{matcher} running is <b>cancelled</b>")
return False
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
"Running cancelled!</bg #f8bbd0></r>"
)
return False
return True
return True
return False
async def _apply_run_postprocessors(
@ -306,29 +330,32 @@ async def _apply_run_postprocessors(
if not _run_postprocessors:
return
with matcher.ensure_context(bot, event):
try:
await asyncio.gather(
*(
run_coro_with_catch(
proc(
matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=matcher.state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
for proc in _run_postprocessors
with (
matcher.ensure_context(bot, event),
catch(
{
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when running RunPostProcessors"
"</bg #f8bbd0></r>"
)
}
),
):
async with anyio.create_task_group() as tg:
for proc in _run_postprocessors:
tg.start_soon(
run_coro_with_catch,
proc(
matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=matcher.state,
stack=stack,
dependency_cache=dependency_cache,
),
(SkippedException,),
)
)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running RunPostProcessors</bg #f8bbd0></r>"
)
async def _check_matcher(
@ -425,8 +452,9 @@ async def _run_matcher(
exception = None
logger.debug(f"Running {matcher}")
try:
logger.debug(f"Running {matcher}")
await matcher.run(bot, event, state, stack, dependency_cache)
except Exception as e:
logger.opt(colors=True, exception=e).error(
@ -494,8 +522,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
用法:
```python
import asyncio
asyncio.create_task(handle_event(bot, event))
driver.task_group.start_soon(handle_event, bot, event)
```
"""
show_log = True
@ -530,6 +557,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
)
break_flag = False
def _handle_stop_propagation(exc_group: BaseExceptionGroup) -> None:
nonlocal break_flag
break_flag = True
logger.debug("Stop event propagation")
# iterate through all priority until stop propagation
for priority in sorted(matchers.keys()):
if break_flag:
@ -538,23 +572,30 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
if show_log:
logger.debug(f"Checking for matchers in priority {priority}...")
pending_tasks = [
check_and_run_matcher(
matcher, bot, event, state.copy(), stack, dependency_cache
)
for matcher in matchers[priority]
]
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
for result in results:
if not isinstance(result, Exception):
continue
if isinstance(result, StopPropagation):
break_flag = True
logger.debug("Stop event propagation")
else:
logger.opt(colors=True, exception=result).error(
if not (priority_matchers := matchers[priority]):
continue
with catch(
{
StopPropagation: _handle_stop_propagation,
Exception: _handle_exception(
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
)
),
}
):
async with anyio.create_task_group() as tg:
for matcher in priority_matchers:
tg.start_soon(
run_coro_with_shield,
check_and_run_matcher(
matcher,
bot,
event,
state.copy(),
stack,
dependency_cache,
),
)
if show_log:
logger.debug("Checking for matchers completed")

View File

@ -22,7 +22,7 @@ from . import _managers, get_plugin, _module_name_to_plugin_id
try: # pragma: py-gte-311
import tomllib # pyright: ignore[reportMissingImports]
except ModuleNotFoundError: # pragma: py-lt-311
import tomli as tomllib
import tomli as tomllib # pyright: ignore[reportMissingImports]
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:

View File

@ -21,10 +21,9 @@ from typing import TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin
if TYPE_CHECKING:
from asyncio import Task
from nonebot.adapters import Bot
from nonebot.permission import Permission
from nonebot.internal.params import DependencyCache
T = TypeVar("T")
P = ParamSpec("P")
@ -258,5 +257,5 @@ T_PermissionUpdater: TypeAlias = _DependentCallable["Permission"]
- MatcherParam: Matcher 对象
- DefaultParam: 带有默认值的参数
"""
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "Task[t.Any]"]
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "DependencyCache"]
"""依赖缓存, 用于存储依赖函数的返回值"""

View File

@ -9,21 +9,22 @@ FrontMatter:
import re
import json
import asyncio
import inspect
import importlib
import contextlib
import dataclasses
from pathlib import Path
from collections import deque
from contextvars import copy_context
from functools import wraps, partial
from contextlib import AbstractContextManager, asynccontextmanager
from typing_extensions import ParamSpec, get_args, override, get_origin
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
from collections.abc import Mapping, Sequence, Coroutine, Generator, AsyncGenerator
import anyio
import anyio.to_thread
from pydantic import BaseModel
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.log import logger
from nonebot.typing import (
@ -39,6 +40,7 @@ R = TypeVar("R")
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
E = TypeVar("E", bound=BaseException)
def escape_tag(s: str) -> str:
@ -178,11 +180,9 @@ def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
@wraps(call)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
loop = asyncio.get_running_loop()
pfunc = partial(call, *args, **kwargs)
context = copy_context()
result = await loop.run_in_executor(None, partial(context.run, pfunc))
return result
return await anyio.to_thread.run_sync(
partial(call, *args, **kwargs), abandon_on_cancel=True
)
return _wrapper
@ -234,10 +234,34 @@ async def run_coro_with_catch(
协程的返回值或发生异常时的指定值
"""
try:
with catch({exc: lambda exc_group: None}):
return await coro
except exc:
return return_on_err
return return_on_err
async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T:
"""运行协程并在取消时屏蔽取消异常。
参数:
coro: 要运行的协程
返回:
协程的返回值
"""
with anyio.CancelScope(shield=True):
return await coro
def flatten_exception_group(
exc_group: BaseExceptionGroup[E],
) -> Generator[E, None, None]:
for exc in exc_group.exceptions:
if isinstance(exc, BaseExceptionGroup):
yield from flatten_exception_group(exc)
else:
yield exc
def get_name(obj: Any) -> str:

2274
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,9 @@ include = ["nonebot/py.typed"]
[tool.poetry.dependencies]
python = "^3.9"
yarl = "^1.7.2"
anyio = "^4.4.0"
pygtrie = "^2.4.1"
exceptiongroup = "^1.2.2"
loguru = ">=0.6.0,<1.0.0"
python-dotenv = ">=0.21.0,<2.0.0"
typing-extensions = ">=4.4.0,<5.0.0"
@ -65,7 +67,6 @@ fastapi = ["fastapi", "uvicorn"]
all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"]
[tool.pytest.ini_options]
asyncio_mode = "strict"
addopts = "--cov=nonebot --cov-report=term-missing"
filterwarnings = ["error", "ignore::DeprecationWarning"]

View File

@ -1,8 +1,10 @@
import os
import threading
from pathlib import Path
from typing import TYPE_CHECKING
from functools import wraps
from collections.abc import Generator
from typing_extensions import ParamSpec
from typing import TYPE_CHECKING, TypeVar, Callable
import pytest
from nonebug import NONEBOT_INIT_KWARGS
@ -20,6 +22,9 @@ os.environ["CONFIG_OVERRIDE"] = "new"
if TYPE_CHECKING:
from nonebot.plugin import Plugin
P = ParamSpec("P")
R = TypeVar("R")
collect_ignore = ["plugins/", "dynamic/", "bad_plugins/"]
@ -38,14 +43,36 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
return DriverClass(Env(environment=global_driver.env), global_driver.config)
@pytest.fixture(scope="session", params=[pytest.param("asyncio"), pytest.param("trio")])
def anyio_backend(request: pytest.FixtureRequest):
return request.param
def run_once(func: Callable[P, R]) -> Callable[P, R]:
result = ...
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
nonlocal result
if result is not Ellipsis:
return result
result = func(*args, **kwargs)
return result
return _wrapper
@pytest.fixture(scope="session", autouse=True)
def load_plugin(nonebug_init: None) -> set["Plugin"]:
@run_once
def load_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
# preload global plugins
return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))
@pytest.fixture(scope="session", autouse=True)
def load_builtin_plugin(nonebug_init: None) -> set["Plugin"]:
@run_once
def load_builtin_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
# preload builtin plugins
return nonebot.load_builtin_plugins("echo", "single_session")

View File

@ -17,7 +17,7 @@ from nonebot.drivers import (
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_adapter_connect(app: App, driver: Driver):
last_connect_bot: Optional[Bot] = None
last_disconnect_bot: Optional[Bot] = None
@ -45,7 +45,6 @@ async def test_adapter_connect(app: App, driver: Driver):
assert bot.self_id not in adapter.bots
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
@ -75,7 +74,7 @@ async def test_adapter_connect(app: App, driver: Driver):
],
indirect=True,
)
async def test_adapter_server(driver: Driver):
def test_adapter_server(driver: Driver):
last_http_setup: Optional[HTTPServerSetup] = None
last_ws_setup: Optional[WebSocketServerSetup] = None
@ -112,7 +111,7 @@ async def test_adapter_server(driver: Driver):
assert last_ws_setup is setup
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -159,7 +158,7 @@ async def test_adapter_http_client(driver: Driver):
assert last_request is request
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[

View File

@ -1,5 +1,6 @@
from typing import Any, Optional
import anyio
import pytest
from nonebug import App
@ -7,7 +8,7 @@ from nonebot.adapters import Bot
from nonebot.exception import MockApiException
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot_call_api(app: App):
async with app.test_api() as ctx:
bot = ctx.create_bot()
@ -23,7 +24,7 @@ async def test_bot_call_api(app: App):
await bot.call_api("test")
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot_calling_api_hook_simple(app: App):
runned: bool = False
@ -49,7 +50,7 @@ async def test_bot_calling_api_hook_simple(app: App):
assert result is True
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot_calling_api_hook_mock(app: App):
runned: bool = False
@ -76,7 +77,47 @@ async def test_bot_calling_api_hook_mock(app: App):
assert result is False
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot_calling_api_hook_multi_mock(app: App):
runned1: bool = False
runned2: bool = False
event = anyio.Event()
async def calling_api_hook1(bot: Bot, api: str, data: dict[str, Any]):
nonlocal runned1
runned1 = True
event.set()
raise MockApiException(1)
async def calling_api_hook2(bot: Bot, api: str, data: dict[str, Any]):
nonlocal runned2
runned2 = True
with anyio.fail_after(1):
await event.wait()
raise MockApiException(2)
hooks = set()
with pytest.MonkeyPatch.context() as m:
m.setattr(Bot, "_calling_api_hook", hooks)
Bot.on_calling_api(calling_api_hook1)
Bot.on_calling_api(calling_api_hook2)
assert hooks == {calling_api_hook1, calling_api_hook2}
async with app.test_api() as ctx:
bot = ctx.create_bot()
result = await bot.call_api("test")
assert runned1 is True
assert runned2 is True
assert result == 1
@pytest.mark.anyio
async def test_bot_called_api_hook_simple(app: App):
runned: bool = False
@ -108,7 +149,7 @@ async def test_bot_called_api_hook_simple(app: App):
assert result is True
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot_called_api_hook_mock(app: App):
runned: bool = False
@ -150,3 +191,56 @@ async def test_bot_called_api_hook_mock(app: App):
assert runned is True
assert result is False
@pytest.mark.anyio
async def test_bot_called_api_hook_multi_mock(app: App):
runned1: bool = False
runned2: bool = False
event = anyio.Event()
async def called_api_hook1(
bot: Bot,
exception: Optional[Exception],
api: str,
data: dict[str, Any],
result: Any,
):
nonlocal runned1
runned1 = True
event.set()
raise MockApiException(1)
async def called_api_hook2(
bot: Bot,
exception: Optional[Exception],
api: str,
data: dict[str, Any],
result: Any,
):
nonlocal runned2
runned2 = True
with anyio.fail_after(1):
await event.wait()
raise MockApiException(2)
hooks = set()
with pytest.MonkeyPatch.context() as m:
m.setattr(Bot, "_called_api_hook", hooks)
Bot.on_called_api(called_api_hook1)
Bot.on_called_api(called_api_hook2)
assert hooks == {called_api_hook1, called_api_hook2}
async with app.test_api() as ctx:
bot = ctx.create_bot()
ctx.should_call_api("test", {}, True)
result = await bot.call_api("test")
assert runned1 is True
assert runned2 is True
assert result == 1

View File

@ -25,7 +25,7 @@ async def _dependency() -> int:
return 1
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(message, "_event_preprocessors", set())
@ -58,7 +58,7 @@ async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "event_preprocessor should runned"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(message, "_event_preprocessors", set())
@ -88,7 +88,7 @@ async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPat
assert not runned, "matcher should not runned"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_event_preprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
):
@ -132,7 +132,7 @@ async def test_event_preprocessor_exception(
assert "RuntimeError: test" in capsys.readouterr().out
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(message, "_event_postprocessors", set())
@ -165,7 +165,7 @@ async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "event_postprocessor should runned"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_event_postprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
):
@ -202,7 +202,7 @@ async def test_event_postprocessor_exception(
assert "RuntimeError: test" in capsys.readouterr().out
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(message, "_run_preprocessors", set())
@ -239,7 +239,7 @@ async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "run_preprocessor should runned"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(message, "_run_preprocessors", set())
@ -269,7 +269,7 @@ async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch
assert not runned, "matcher should not runned"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_run_preprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
):
@ -313,7 +313,7 @@ async def test_run_preprocessor_exception(
assert "RuntimeError: test" in capsys.readouterr().out
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(message, "_run_postprocessors", set())
@ -351,7 +351,7 @@ async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned, "run_postprocessor should runned"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_run_postprocessor_exception(
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
):

View File

@ -17,14 +17,12 @@ from nonebot.compat import (
)
@pytest.mark.asyncio
async def test_default_config():
def test_default_config():
assert DEFAULT_CONFIG.get("extra") == "allow"
assert DEFAULT_CONFIG.get("arbitrary_types_allowed") is True
@pytest.mark.asyncio
async def test_field_info():
def test_field_info():
# required should be convert to PydanticUndefined
assert FieldInfo(Required).default is PydanticUndefined
@ -32,8 +30,7 @@ async def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test"
@pytest.mark.asyncio
async def test_type_adapter():
def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
assert t.validate_python(2) == 2
@ -47,8 +44,7 @@ async def test_type_adapter():
t.validate_json("0")
@pytest.mark.asyncio
async def test_model_dump():
def test_model_dump():
class TestModel(BaseModel):
test1: int
test2: int
@ -57,8 +53,7 @@ async def test_model_dump():
assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2}
@pytest.mark.asyncio
async def test_custom_validation():
def test_custom_validation():
called = []
@custom_validation
@ -85,8 +80,7 @@ async def test_custom_validation():
assert called == [1, 2]
@pytest.mark.asyncio
async def test_validate_json():
def test_validate_json():
class TestModel(BaseModel):
test1: int
test2: str

View File

@ -50,16 +50,14 @@ class ExampleWithoutDelimiter(Example):
env_nested_delimiter = None
@pytest.mark.asyncio
async def test_config_no_env():
def test_config_no_env():
config = Example(_env_file=None)
assert config.simple == ""
with pytest.raises(AttributeError):
config.common_config
@pytest.mark.asyncio
async def test_config_with_env():
def test_config_with_env():
config = Example(_env_file=(".env", ".env.example"))
assert config.simple == "simple"
@ -102,8 +100,7 @@ async def test_config_with_env():
config.other_nested_inner__b
@pytest.mark.asyncio
async def test_config_error_env():
def test_config_error_env():
with pytest.MonkeyPatch().context() as m:
m.setenv("COMPLEX", "not json")
@ -111,8 +108,7 @@ async def test_config_error_env():
Example(_env_file=(".env", ".env.example"))
@pytest.mark.asyncio
async def test_config_without_delimiter():
def test_config_without_delimiter():
config = ExampleWithoutDelimiter()
assert config.nested.a == 1
assert config.nested.b == 0

View File

@ -1,8 +1,8 @@
import json
import asyncio
from typing import Any, Optional
from http.cookies import SimpleCookie
import anyio
import pytest
from nonebug import App
@ -25,7 +25,7 @@ from nonebot.drivers import (
)
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
)
@ -59,22 +59,22 @@ async def test_lifespan(driver: Driver):
@driver.on_shutdown
async def _shutdown1():
assert shutdown_log == []
assert shutdown_log == [2]
shutdown_log.append(1)
@driver.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
assert shutdown_log == []
shutdown_log.append(2)
async with driver._lifespan:
assert start_log == [1, 2]
assert ready_log == [1, 2]
assert shutdown_log == [1, 2]
assert shutdown_log == [2, 1]
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -99,10 +99,10 @@ async def test_http_server(app: App, driver: Driver):
assert response.status_code == 200
assert response.text == "test"
await asyncio.sleep(1)
await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -155,10 +155,10 @@ async def test_websocket_server(app: App, driver: Driver):
await ws.close(code=1000)
await asyncio.sleep(1)
await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -171,9 +171,10 @@ async def test_cross_context(app: App, driver: Driver):
assert isinstance(driver, ASGIMixin)
ws: Optional[WebSocket] = None
ws_ready = asyncio.Event()
ws_should_close = asyncio.Event()
ws_ready = anyio.Event()
ws_should_close = anyio.Event()
# create a background task before the ws connection established
async def background_task():
try:
await ws_ready.wait()
@ -185,8 +186,6 @@ async def test_cross_context(app: App, driver: Driver):
finally:
ws_should_close.set()
task = asyncio.create_task(background_task())
async def _handle_ws(websocket: WebSocket) -> None:
nonlocal ws
await websocket.accept()
@ -199,7 +198,9 @@ async def test_cross_context(app: App, driver: Driver):
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx:
async with anyio.create_task_group() as tg, app.test_server(driver.asgi) as ctx:
tg.start_soon(background_task)
client = ctx.get_client()
async with client.websocket_connect("/ws_test") as websocket:
@ -211,11 +212,10 @@ async def test_cross_context(app: App, driver: Driver):
if not e.args or "websocket.close" not in str(e.args[0]):
raise
await task
await asyncio.sleep(1)
await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -304,10 +304,10 @@ async def test_http_client(driver: Driver, server_url: URL):
"test3": "test",
}, "file parsing error"
await asyncio.sleep(1)
await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -419,10 +419,10 @@ async def test_http_client_session(driver: Driver, server_url: URL):
"test3": "test",
}, "file parsing error"
await asyncio.sleep(1)
await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
"driver",
[
@ -452,10 +452,9 @@ async def test_websocket_client(driver: Driver, server_url: URL):
with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive()
await asyncio.sleep(1)
await anyio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("driver", "driver_type"),
[
@ -472,11 +471,11 @@ async def test_websocket_client(driver: Driver, server_url: URL):
],
indirect=["driver"],
)
async def test_combine_driver(driver: Driver, driver_type: str):
def test_combine_driver(driver: Driver, driver_type: str):
assert driver.type == driver_type
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot_connect_hook(app: App, driver: Driver):
with pytest.MonkeyPatch.context() as m:
conn_hooks: set[Dependent[Any]] = set()
@ -533,7 +532,7 @@ async def test_bot_connect_hook(app: App, driver: Driver):
async with app.test_api() as ctx:
bot = ctx.create_bot()
await asyncio.sleep(1)
await anyio.sleep(1)
if not conn_should_be_called:
pytest.fail("on_bot_connect hook not called")

View File

@ -4,7 +4,7 @@ from nonebug import App
from utils import FakeMessage, FakeMessageSegment, make_fake_event
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_echo(app: App):
from nonebot.plugins.echo import echo

View File

@ -14,8 +14,7 @@ from nonebot import (
)
@pytest.mark.asyncio
async def test_init():
def test_init():
env = nonebot.get_driver().env
assert env == "test"
@ -35,31 +34,28 @@ async def test_init():
assert config.not_nested == "some string"
@pytest.mark.asyncio
async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
def test_get_driver(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(nonebot, "_driver", None)
with pytest.raises(ValueError, match="initialized"):
get_driver()
@pytest.mark.asyncio
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
def test_get_asgi():
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert isinstance(driver, ASGIMixin)
assert get_asgi() == driver.asgi
@pytest.mark.asyncio
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
def test_get_app():
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert isinstance(driver, ASGIMixin)
assert get_app() == driver.server_app
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
async with app.test_api() as ctx:
adapter = ctx.create_adapter()
@ -74,8 +70,7 @@ async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
get_adapter("not exist")
@pytest.mark.asyncio
async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
def test_run(monkeypatch: pytest.MonkeyPatch):
runned = False
def mock_run(*args, **kwargs):
@ -93,8 +88,7 @@ async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
assert runned
@pytest.mark.asyncio
async def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
with pytest.raises(ValueError, match="no bots"):

View File

@ -12,8 +12,7 @@ from nonebot.permission import User, Permission
from nonebot.message import _check_matcher, check_and_run_matcher
@pytest.mark.asyncio
async def test_matcher_info(app: App):
def test_matcher_info(app: App):
from plugins.matcher.matcher_info import matcher
assert issubclass(matcher, Matcher)
@ -43,7 +42,7 @@ async def test_matcher_info(app: App):
assert matcher._source.lineno == 3
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_check(app: App):
async def falsy():
return False
@ -87,7 +86,7 @@ async def test_matcher_check(app: App):
assert await _check_matcher(test_rule_error, bot, event, {}) is False
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_handle(app: App):
from plugins.matcher.matcher_process import test_handle
@ -102,7 +101,7 @@ async def test_matcher_handle(app: App):
ctx.should_finished()
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_got(app: App):
from plugins.matcher.matcher_process import test_got
@ -124,7 +123,7 @@ async def test_matcher_got(app: App):
ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_receive(app: App):
from plugins.matcher.matcher_process import test_receive
@ -141,7 +140,7 @@ async def test_matcher_receive(app: App):
ctx.should_paused()
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_combine(app: App):
from plugins.matcher.matcher_process import test_combine
@ -164,7 +163,7 @@ async def test_matcher_combine(app: App):
ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_preset(app: App):
from plugins.matcher.matcher_process import test_preset
@ -182,7 +181,7 @@ async def test_matcher_preset(app: App):
ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_overload(app: App):
from plugins.matcher.matcher_process import test_overload
@ -196,7 +195,7 @@ async def test_matcher_overload(app: App):
ctx.should_finished()
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_destroy(app: App):
from plugins.matcher.matcher_process import test_destroy
@ -210,7 +209,7 @@ async def test_matcher_destroy(app: App):
assert len(matchers[test_destroy.priority]) == 0
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_type_updater(app: App):
from plugins.matcher.matcher_type import test_type_updater, test_custom_updater
@ -231,7 +230,7 @@ async def test_type_updater(app: App):
assert new_type == "custom"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_default_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
default_permission,
@ -252,7 +251,7 @@ async def test_default_permission_updater(app: App):
assert checker.perm is default_permission
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_user_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
default_permission,
@ -274,7 +273,7 @@ async def test_user_permission_updater(app: App):
assert checker.perm is default_permission
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_custom_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
new_permission,
@ -291,7 +290,7 @@ async def test_custom_permission_updater(app: App):
assert new_perm is new_permission
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_run(app: App):
with app.provider.context({}):
assert not matchers
@ -322,37 +321,46 @@ async def test_run(app: App):
assert len(matchers[0][0].handlers) == 0
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_temp(app: App):
from plugins.matcher.matcher_expire import test_temp_matcher
event = make_fake_event(_type="test")()
async with app.test_api() as ctx:
bot = ctx.create_bot()
assert test_temp_matcher in matchers[test_temp_matcher.priority]
await check_and_run_matcher(test_temp_matcher, bot, event, {})
assert test_temp_matcher not in matchers[test_temp_matcher.priority]
with app.provider.context({test_temp_matcher.priority: [test_temp_matcher]}):
async with app.test_api() as ctx:
bot = ctx.create_bot()
assert test_temp_matcher in matchers[test_temp_matcher.priority]
await check_and_run_matcher(test_temp_matcher, bot, event, {})
assert test_temp_matcher not in matchers[test_temp_matcher.priority]
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_datetime_expire(app: App):
from plugins.matcher.matcher_expire import test_datetime_matcher
event = make_fake_event()()
async with app.test_api() as ctx:
bot = ctx.create_bot()
assert test_datetime_matcher in matchers[test_datetime_matcher.priority]
await check_and_run_matcher(test_datetime_matcher, bot, event, {})
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
with app.provider.context(
{test_datetime_matcher.priority: [test_datetime_matcher]}
):
async with app.test_matcher(test_datetime_matcher) as ctx:
bot = ctx.create_bot()
assert test_datetime_matcher in matchers[test_datetime_matcher.priority]
await check_and_run_matcher(test_datetime_matcher, bot, event, {})
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_timedelta_expire(app: App):
from plugins.matcher.matcher_expire import test_timedelta_matcher
event = make_fake_event()()
async with app.test_api() as ctx:
bot = ctx.create_bot()
assert test_timedelta_matcher in matchers[test_timedelta_matcher.priority]
await check_and_run_matcher(test_timedelta_matcher, bot, event, {})
assert test_timedelta_matcher not in matchers[test_timedelta_matcher.priority]
with app.provider.context(
{test_timedelta_matcher.priority: [test_timedelta_matcher]}
):
async with app.test_api() as ctx:
bot = ctx.create_bot()
assert test_timedelta_matcher in matchers[test_timedelta_matcher.priority]
await check_and_run_matcher(test_timedelta_matcher, bot, event, {})
assert (
test_timedelta_matcher not in matchers[test_timedelta_matcher.priority]
)

View File

@ -1,11 +1,9 @@
import pytest
from nonebug import App
from nonebot.matcher import DEFAULT_PROVIDER_CLASS, matchers
@pytest.mark.asyncio
async def test_manager(app: App):
def test_manager(app: App):
try:
default_provider = matchers.provider
matchers.set_provider(DEFAULT_PROVIDER_CLASS)

View File

@ -2,6 +2,7 @@ import re
import pytest
from nonebug import App
from exceptiongroup import BaseExceptionGroup
from nonebot.matcher import Matcher
from nonebot.dependencies import Dependent
@ -36,7 +37,7 @@ from nonebot.consts import (
UNKNOWN_PARAM = "Unknown parameter"
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_depend(app: App):
from plugins.param.param_depend import (
ClassDependency,
@ -90,36 +91,47 @@ async def test_depend(app: App):
assert runned == [1, 1, 1]
runned.clear()
async with app.test_dependent(
annotated_class_depend, allow_types=[DependParam]
) as ctx:
ctx.should_return(ClassDependency(x=1, y=2))
with pytest.raises(TypeMisMatch): # noqa: PT012
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: # noqa: PT012
async with app.test_dependent(
sub_type_mismatch, allow_types=[DependParam, BotParam]
) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
ctx.should_return(1)
with pytest.raises(TypeMisMatch):
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
...
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
ctx.should_return(1)
with pytest.raises(TypeMisMatch):
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(
validate_field_fail, allow_types=[DependParam]
) as ctx:
...
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_bot(app: App):
from plugins.param.param_bot import (
FooBot,
@ -157,11 +169,14 @@ async def test_bot(app: App):
ctx.pass_params(bot=bot)
ctx.should_return(bot)
with pytest.raises(TypeMisMatch): # noqa: PT012
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: # noqa: PT012
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot(base=FooBot)
ctx.pass_params(bot=bot)
@ -181,7 +196,7 @@ async def test_bot(app: App):
app.test_dependent(not_bot, allow_types=[BotParam])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_event(app: App):
from plugins.param.param_event import (
FooEvent,
@ -223,10 +238,13 @@ async def test_event(app: App):
ctx.pass_params(event=fake_fooevent)
ctx.should_return(fake_fooevent)
with pytest.raises(TypeMisMatch):
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_fooevent)
ctx.should_return(fake_fooevent)
@ -267,7 +285,7 @@ async def test_event(app: App):
ctx.should_return(fake_event.is_tome())
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_state(app: App):
from plugins.param.param_state import (
state,
@ -418,7 +436,7 @@ async def test_state(app: App):
ctx.should_return(fake_state[KEYWORD_KEY])
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher(app: App):
from plugins.param.param_matcher import (
FooMatcher,
@ -457,10 +475,13 @@ async def test_matcher(app: App):
ctx.pass_params(matcher=foo_matcher)
ctx.should_return(foo_matcher)
with pytest.raises(TypeMisMatch):
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
if isinstance(exc_info.value, BaseExceptionGroup):
assert exc_info.group_contains(TypeMisMatch)
async with app.test_dependent(union_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=foo_matcher)
ctx.should_return(foo_matcher)
@ -496,7 +517,7 @@ async def test_matcher(app: App):
ctx.should_return(event_next)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_arg(app: App):
from plugins.param.param_arg import (
arg,
@ -548,7 +569,7 @@ async def test_arg(app: App):
ctx.should_return(message.extract_plain_text())
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_exception(app: App):
from plugins.param.param_exception import exc, legacy_exc
@ -562,7 +583,7 @@ async def test_exception(app: App):
ctx.should_return(exception)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_default(app: App):
from plugins.param.param_default import default
@ -570,8 +591,7 @@ async def test_default(app: App):
ctx.should_return(1)
@pytest.mark.asyncio
async def test_priority():
def test_priority():
from plugins.param.priority import complex_priority
dependent = Dependent[None].parse(

View File

@ -22,7 +22,7 @@ from nonebot.permission import (
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_permission(app: App):
async def falsy():
return False
@ -54,7 +54,7 @@ async def test_permission(app: App):
assert await Permission(truthy, skipped)(bot, event) is True
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(("type", "expected"), [("message", True), ("notice", False)])
async def test_message(type: str, expected: bool):
dependent = next(iter(MESSAGE.checkers))
@ -66,7 +66,7 @@ async def test_message(type: str, expected: bool):
assert await dependent(event=event) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("notice", True)])
async def test_notice(type: str, expected: bool):
dependent = next(iter(NOTICE.checkers))
@ -78,7 +78,7 @@ async def test_notice(type: str, expected: bool):
assert await dependent(event=event) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("request", True)])
async def test_request(type: str, expected: bool):
dependent = next(iter(REQUEST.checkers))
@ -90,7 +90,7 @@ async def test_request(type: str, expected: bool):
assert await dependent(event=event) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("type", "expected"), [("message", False), ("meta_event", True)]
)
@ -104,7 +104,7 @@ async def test_metaevent(type: str, expected: bool):
assert await dependent(event=event) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("type", "user_id", "expected"),
[
@ -128,7 +128,7 @@ async def test_superuser(app: App, type: str, user_id: str, expected: bool):
assert await dependent(bot=bot, event=event) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("session_ids", "session_id", "expected"),
[

View File

@ -1,12 +1,10 @@
import pytest
from pydantic import BaseModel
import nonebot
from nonebot.plugin import PluginManager, _managers
@pytest.mark.asyncio
async def test_get_plugin():
def test_get_plugin():
# check simple plugin
plugin = nonebot.get_plugin("export")
assert plugin
@ -22,8 +20,7 @@ async def test_get_plugin():
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
@pytest.mark.asyncio
async def test_get_plugin_by_module_name():
def test_get_plugin_by_module_name():
# check get plugin by exact module name
plugin = nonebot.get_plugin_by_module_name("plugins.nested")
assert plugin
@ -48,8 +45,7 @@ async def test_get_plugin_by_module_name():
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
@pytest.mark.asyncio
async def test_get_available_plugin():
def test_get_available_plugin():
old_managers = _managers.copy()
_managers.clear()
try:
@ -63,8 +59,7 @@ async def test_get_available_plugin():
_managers.extend(old_managers)
@pytest.mark.asyncio
async def test_get_plugin_config():
def test_get_plugin_config():
class Config(BaseModel):
plugin_config: int

View File

@ -1,15 +1,44 @@
import sys
from pathlib import Path
from functools import wraps
from dataclasses import asdict
from typing import TypeVar, Callable
from typing_extensions import ParamSpec
import pytest
import nonebot
from nonebot.plugin import Plugin, PluginManager, _managers, inherit_supported_adapters
from nonebot.plugin import (
Plugin,
PluginManager,
_plugins,
_managers,
inherit_supported_adapters,
)
P = ParamSpec("P")
R = TypeVar("R")
@pytest.mark.asyncio
async def test_load_plugin():
def _recover(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
origin_managers = _managers.copy()
origin_plugins = _plugins.copy()
try:
return func(*args, **kwargs)
finally:
_managers.clear()
_managers.extend(origin_managers)
_plugins.clear()
_plugins.update(origin_plugins)
return _wrapper
@_recover
def test_load_plugin():
# check regular
assert nonebot.load_plugin("dynamic.simple")
@ -20,8 +49,7 @@ async def test_load_plugin():
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[Plugin]):
def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[Plugin]):
loaded_plugins = {
plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
}
@ -44,8 +72,7 @@ async def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[P
PluginManager(search_path=["plugins"]).load_all_plugins()
@pytest.mark.asyncio
async def test_load_nested_plugin():
def test_load_nested_plugin():
parent_plugin = nonebot.get_plugin("nested")
sub_plugin = nonebot.get_plugin("nested:nested_subplugin")
sub_plugin2 = nonebot.get_plugin("nested:nested_subplugin2")
@ -57,16 +84,16 @@ async def test_load_nested_plugin():
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
@pytest.mark.asyncio
async def test_load_json():
@_recover
def test_load_json():
nonebot.load_from_json("./plugins.json")
with pytest.raises(TypeError):
nonebot.load_from_json("./plugins.invalid.json")
@pytest.mark.asyncio
async def test_load_toml():
@_recover
def test_load_toml():
nonebot.load_from_toml("./plugins.toml")
with pytest.raises(ValueError, match="Cannot find"):
@ -76,52 +103,54 @@ async def test_load_toml():
nonebot.load_from_toml("./plugins.invalid.toml")
@pytest.mark.asyncio
async def test_bad_plugin():
@_recover
def test_bad_plugin():
nonebot.load_plugins("bad_plugins")
assert nonebot.get_plugin("bad_plugin") is None
@pytest.mark.asyncio
async def test_require_loaded(monkeypatch: pytest.MonkeyPatch):
@_recover
def test_require_loaded(monkeypatch: pytest.MonkeyPatch):
def _patched_find(name: str):
pytest.fail("require existing plugin should not call find_manager_by_name")
monkeypatch.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find)
with monkeypatch.context() as m:
m.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find)
# require use module name
nonebot.require("plugins.export")
# require use plugin id
nonebot.require("export")
nonebot.require("nested:nested_subplugin")
# require use module name
nonebot.require("plugins.export")
# require use plugin id
nonebot.require("export")
nonebot.require("nested:nested_subplugin")
@pytest.mark.asyncio
async def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch):
m = PluginManager(["dynamic.require_not_loaded"], ["dynamic/require_not_loaded/"])
_managers.append(m)
@_recover
def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch):
pm = PluginManager(["dynamic.require_not_loaded"], ["dynamic/require_not_loaded/"])
_managers.append(pm)
num_managers = len(_managers)
origin_load = PluginManager.load_plugin
def _patched_load(self: PluginManager, name: str):
assert self is m
assert self is pm
return origin_load(self, name)
monkeypatch.setattr(PluginManager, "load_plugin", _patched_load)
with monkeypatch.context() as m:
m.setattr(PluginManager, "load_plugin", _patched_load)
# require standalone plugin
nonebot.require("dynamic.require_not_loaded")
# require searched plugin
nonebot.require("dynamic.require_not_loaded.subplugin1")
nonebot.require("require_not_loaded:subplugin2")
# require standalone plugin
nonebot.require("dynamic.require_not_loaded")
# require searched plugin
nonebot.require("dynamic.require_not_loaded.subplugin1")
nonebot.require("require_not_loaded:subplugin2")
assert len(_managers) == num_managers
@pytest.mark.asyncio
async def test_require_not_declared():
@_recover
def test_require_not_declared():
num_managers = len(_managers)
nonebot.require("dynamic.require_not_declared")
@ -130,14 +159,13 @@ async def test_require_not_declared():
assert _managers[-1].plugins == {"dynamic.require_not_declared"}
@pytest.mark.asyncio
async def test_require_not_found():
@_recover
def test_require_not_found():
with pytest.raises(RuntimeError):
nonebot.require("some_plugin_not_exist")
@pytest.mark.asyncio
async def test_plugin_metadata():
def test_plugin_metadata():
from plugins.metadata import Config, FakeAdapter
plugin = nonebot.get_plugin("metadata")
@ -157,8 +185,7 @@ async def test_plugin_metadata():
assert plugin.metadata.get_supported_adapters() == {FakeAdapter}
@pytest.mark.asyncio
async def test_inherit_supported_adapters_not_found():
def test_inherit_supported_adapters_not_found():
with pytest.raises(RuntimeError):
inherit_supported_adapters("some_plugin_not_exist")
@ -166,7 +193,6 @@ async def test_inherit_supported_adapters_not_found():
inherit_supported_adapters("export")
@pytest.mark.asyncio
@pytest.mark.parametrize(
("inherit_plugins", "expected"),
[
@ -233,7 +259,7 @@ async def test_inherit_supported_adapters_not_found():
),
],
)
async def test_inherit_supported_adapters_combine(
def test_inherit_supported_adapters_combine(
inherit_plugins: tuple[str], expected: set[str]
):
assert inherit_supported_adapters(*inherit_plugins) == expected

View File

@ -1,17 +1,17 @@
import pytest
from nonebot.plugin import PluginManager, _managers
@pytest.mark.asyncio
async def test_load_plugin_name():
def test_load_plugin_name():
m = PluginManager(plugins=["dynamic.manager"])
_managers.append(m)
try:
_managers.append(m)
# load by plugin id
module1 = m.load_plugin("manager")
# load by module name
module2 = m.load_plugin("dynamic.manager")
assert module1
assert module2
assert module1 is module2
# load by plugin id
module1 = m.load_plugin("manager")
# load by module name
module2 = m.load_plugin("dynamic.manager")
assert module1
assert module2
assert module1 is module2
finally:
_managers.remove(m)

View File

@ -18,7 +18,6 @@ from nonebot.rule import (
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("matcher_name", "pre_rule_factory", "has_permission"),
[
@ -102,7 +101,7 @@ from nonebot.rule import (
pytest.param("matcher_group_on_type", lambda e: IsTypeRule(e), True),
],
)
async def test_on(
def test_on(
matcher_name: str,
pre_rule_factory: Optional[Callable[[type[Event]], T_RuleChecker]],
has_permission: bool,
@ -150,8 +149,7 @@ async def test_on(
assert matcher.module_name == "plugins.plugin.matchers"
@pytest.mark.asyncio
async def test_runtime_on():
def test_runtime_on():
import plugins.plugin.matchers as module
from plugins.plugin.matchers import matcher_on_factory

View File

@ -49,7 +49,7 @@ from nonebot.rule import (
)
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_rule(app: App):
async def falsy():
return False
@ -81,7 +81,7 @@ async def test_rule(app: App):
assert await Rule(truthy, skipped)(bot, event, {}) is False
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_trie(app: App):
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
@ -146,7 +146,7 @@ async def test_trie(app: App):
del TrieRule.prefix["/fake-prefix"]
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("msg", "ignorecase", "type", "text", "expected"),
[
@ -186,7 +186,7 @@ async def test_startswith(
assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("msg", "ignorecase", "type", "text", "expected"),
[
@ -226,7 +226,7 @@ async def test_endswith(
assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("msg", "ignorecase", "type", "text", "expected"),
[
@ -266,7 +266,7 @@ async def test_fullmatch(
assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("kws", "type", "text", "expected"),
[
@ -298,7 +298,7 @@ async def test_keyword(
assert await dependent(event=event, state=state) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("cmds", "force_whitespace", "cmd", "whitespace", "arg_text", "expected"),
[
@ -344,7 +344,7 @@ async def test_command(
assert await dependent(state=state) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_shell_command():
state: T_State
CMD = ("test",)
@ -451,7 +451,7 @@ async def test_shell_command():
assert state[SHELL_ARGS].status != 0
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize(
("pattern", "type", "text", "expected", "matched"),
[
@ -494,7 +494,7 @@ async def test_regex(
assert result.span() == matched.span()
@pytest.mark.asyncio
@pytest.mark.anyio
@pytest.mark.parametrize("expected", [True, False])
async def test_to_me(expected: bool):
test_to_me = to_me()
@ -507,7 +507,7 @@ async def test_to_me(expected: bool):
assert await dependent(event=event) == expected
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_is_type():
Event1 = make_fake_event()
Event2 = make_fake_event()

View File

@ -5,7 +5,7 @@ import pytest
from utils import make_fake_event
@pytest.mark.asyncio
@pytest.mark.anyio
async def test_matcher_mutex():
from nonebot.plugins.single_session import matcher_mutex, _running_matcher

View File

@ -14,7 +14,7 @@ NoneBot2 是一个现代、跨平台、可扩展的 Python 聊天机器人框架
### 异步优先
NoneBot 基于 Python [asyncio](https://docs.python.org/zh-cn/3/library/asyncio.html) 编写,并在异步机制的基础上进行了一定程度的同步函数兼容。
NoneBot 基于 Python [asyncio](https://docs.python.org/zh-cn/3/library/asyncio.html) / [trio](https://trio.readthedocs.io/en/stable/) 编写,并在异步机制的基础上进行了一定程度的同步函数兼容。
### 完整的类型注解