mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ Feature: 迁移至结构化并发框架 AnyIO (#3053)
This commit is contained in:
parent
bd9befbb55
commit
ff21ceb946
2023
envs/pydantic-v1/poetry.lock
generated
2023
envs/pydantic-v1/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
2126
envs/pydantic-v2/poetry.lock
generated
2126
envs/pydantic-v2/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
1396
envs/test/poetry.lock
generated
1396
envs/test/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,11 +8,11 @@ packages = [{ include = "nonebot-test.py" }]
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
nonebug = "^0.3.7"
|
trio = "^0.27.0"
|
||||||
|
nonebug = "^0.4.1"
|
||||||
wsproto = "^1.2.0"
|
wsproto = "^1.2.0"
|
||||||
pytest-cov = "^5.0.0"
|
pytest-cov = "^5.0.0"
|
||||||
pytest-xdist = "^3.0.2"
|
pytest-xdist = "^3.0.2"
|
||||||
pytest-asyncio = "^0.23.2"
|
|
||||||
werkzeug = ">=2.3.6,<4.0.0"
|
werkzeug = ">=2.3.6,<4.0.0"
|
||||||
coverage-conditional-plugin = "^0.9.0"
|
coverage-conditional-plugin = "^0.9.0"
|
||||||
|
|
||||||
|
@ -8,17 +8,20 @@ FrontMatter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from functools import partial
|
||||||
from dataclasses import field, dataclass
|
from dataclasses import field, dataclass
|
||||||
from collections.abc import Iterable, Awaitable
|
from collections.abc import Iterable, Awaitable
|
||||||
from typing import Any, Generic, TypeVar, Callable, Optional, cast
|
from typing import Any, Generic, TypeVar, Callable, Optional, cast
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import _DependentCallable
|
from nonebot.typing import _DependentCallable
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable
|
|
||||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
||||||
|
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
|
||||||
|
|
||||||
from .utils import check_field_type, get_typed_signature
|
from .utils import check_field_type, get_typed_signature
|
||||||
|
|
||||||
@ -84,7 +87,16 @@ class Dependent(Generic[R]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(self, **kwargs: Any) -> R:
|
async def __call__(self, **kwargs: Any) -> R:
|
||||||
try:
|
exception: Optional[BaseExceptionGroup[SkippedException]] = None
|
||||||
|
|
||||||
|
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
||||||
|
nonlocal exception
|
||||||
|
exception = exc_group
|
||||||
|
# raise one of the exceptions instead
|
||||||
|
excs = list(flatten_exception_group(exc_group))
|
||||||
|
logger.trace(f"{self} skipped due to {excs}")
|
||||||
|
|
||||||
|
with catch({SkippedException: _handle_skipped}):
|
||||||
# do pre-check
|
# do pre-check
|
||||||
await self.check(**kwargs)
|
await self.check(**kwargs)
|
||||||
|
|
||||||
@ -96,9 +108,8 @@ class Dependent(Generic[R]):
|
|||||||
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
||||||
else:
|
else:
|
||||||
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
||||||
except SkippedException as e:
|
|
||||||
logger.trace(f"{self} skipped due to {e}")
|
raise exception
|
||||||
raise
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_params(
|
def parse_params(
|
||||||
@ -166,10 +177,13 @@ class Dependent(Generic[R]):
|
|||||||
return cls(call, params, parameterless_params)
|
return cls(call, params, parameterless_params)
|
||||||
|
|
||||||
async def check(self, **params: Any) -> None:
|
async def check(self, **params: Any) -> None:
|
||||||
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
|
async with anyio.create_task_group() as tg:
|
||||||
await asyncio.gather(
|
for param in self.parameterless:
|
||||||
*(cast(Param, param.field_info)._check(**params) for param in self.params)
|
tg.start_soon(partial(param._check, **params))
|
||||||
)
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for param in self.params:
|
||||||
|
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
|
||||||
|
|
||||||
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
|
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
|
||||||
param = cast(Param, field.field_info)
|
param = cast(Param, field.field_info)
|
||||||
@ -185,10 +199,17 @@ class Dependent(Generic[R]):
|
|||||||
await param._solve(**params)
|
await param._solve(**params)
|
||||||
|
|
||||||
# solve param values
|
# solve param values
|
||||||
values = await asyncio.gather(
|
result: dict[str, Any] = {}
|
||||||
*(self._solve_field(field, params) for field in self.params)
|
|
||||||
)
|
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
|
||||||
return {field.name: value for field, value in zip(self.params, values)}
|
value = await self._solve_field(field, params)
|
||||||
|
result[field.name] = value
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for field in self.params:
|
||||||
|
tg.start_soon(_solve_field, field, params)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
__autodoc__ = {"CustomConfig": False}
|
__autodoc__ = {"CustomConfig": False}
|
||||||
|
@ -12,14 +12,18 @@ FrontMatter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
import asyncio
|
from typing import Optional
|
||||||
import threading
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from anyio.abc import TaskGroup
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.consts import WINDOWS
|
from nonebot.consts import WINDOWS
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.drivers import Driver as BaseDriver
|
from nonebot.drivers import Driver as BaseDriver
|
||||||
|
from nonebot.utils import flatten_exception_group
|
||||||
|
|
||||||
HANDLED_SIGNALS = (
|
HANDLED_SIGNALS = (
|
||||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||||
@ -35,8 +39,8 @@ class Driver(BaseDriver):
|
|||||||
def __init__(self, env: Env, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
super().__init__(env, config)
|
super().__init__(env, config)
|
||||||
|
|
||||||
self.should_exit: asyncio.Event = asyncio.Event()
|
self.should_exit: anyio.Event = anyio.Event()
|
||||||
self.force_exit: bool = False
|
self.force_exit: anyio.Event = anyio.Event()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@override
|
@override
|
||||||
@ -54,85 +58,98 @@ class Driver(BaseDriver):
|
|||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
"""启动 none driver"""
|
"""启动 none driver"""
|
||||||
super().run(*args, **kwargs)
|
super().run(*args, **kwargs)
|
||||||
loop = asyncio.get_event_loop()
|
anyio.run(self._serve)
|
||||||
loop.run_until_complete(self._serve())
|
|
||||||
|
|
||||||
async def _serve(self):
|
async def _serve(self):
|
||||||
self._install_signal_handlers()
|
async with anyio.create_task_group() as driver_tg:
|
||||||
await self._startup()
|
driver_tg.start_soon(self._handle_signals)
|
||||||
if self.should_exit.is_set():
|
driver_tg.start_soon(self._listen_force_exit, driver_tg)
|
||||||
return
|
driver_tg.start_soon(self._handle_lifespan, driver_tg)
|
||||||
await self._main_loop()
|
|
||||||
await self._shutdown()
|
|
||||||
|
|
||||||
async def _startup(self):
|
async def _handle_signals(self):
|
||||||
try:
|
try:
|
||||||
await self._lifespan.startup()
|
with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signal_receiver:
|
||||||
except Exception as e:
|
async for sig in signal_receiver:
|
||||||
logger.opt(colors=True, exception=e).error(
|
self.exit(force=self.should_exit.is_set())
|
||||||
"<r><bg #f8bbd0>Application startup failed. "
|
|
||||||
"Exiting.</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
self.should_exit.set()
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Application startup completed.")
|
|
||||||
|
|
||||||
async def _main_loop(self):
|
|
||||||
await self.should_exit.wait()
|
|
||||||
|
|
||||||
async def _shutdown(self):
|
|
||||||
logger.info("Shutting down")
|
|
||||||
|
|
||||||
logger.info("Waiting for application shutdown.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._lifespan.shutdown()
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
|
||||||
"Ignored!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
for task in asyncio.all_tasks():
|
|
||||||
if task is not asyncio.current_task() and not task.done():
|
|
||||||
task.cancel()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
||||||
if tasks and not self.force_exit:
|
|
||||||
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
|
|
||||||
while tasks and not self.force_exit:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
logger.info("Application shutdown complete.")
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.stop()
|
|
||||||
|
|
||||||
def _install_signal_handlers(self) -> None:
|
|
||||||
if threading.current_thread() is not threading.main_thread():
|
|
||||||
# Signals can only be listened to from the main thread.
|
|
||||||
return
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
try:
|
|
||||||
for sig in HANDLED_SIGNALS:
|
|
||||||
loop.add_signal_handler(sig, self._handle_exit, sig, None)
|
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
# Windows
|
# Windows
|
||||||
for sig in HANDLED_SIGNALS:
|
for sig in HANDLED_SIGNALS:
|
||||||
signal.signal(sig, self._handle_exit)
|
signal.signal(sig, self._handle_legacy_signal)
|
||||||
|
|
||||||
def _handle_exit(self, sig, frame):
|
# backport for Windows signal handling
|
||||||
|
def _handle_legacy_signal(self, sig, frame):
|
||||||
self.exit(force=self.should_exit.is_set())
|
self.exit(force=self.should_exit.is_set())
|
||||||
|
|
||||||
|
async def _handle_lifespan(self, tg: TaskGroup):
|
||||||
|
try:
|
||||||
|
await self._startup()
|
||||||
|
|
||||||
|
if self.should_exit.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._listen_exit()
|
||||||
|
|
||||||
|
await self._shutdown()
|
||||||
|
finally:
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
|
|
||||||
|
async def _startup(self):
|
||||||
|
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||||
|
self.should_exit.set()
|
||||||
|
|
||||||
|
for exc in flatten_exception_group(exc_group):
|
||||||
|
logger.opt(colors=True, exception=exc).error(
|
||||||
|
"<r><bg #f8bbd0>Error occurred while running startup hook."
|
||||||
|
"</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"<r><bg #f8bbd0>Application startup failed. "
|
||||||
|
"Exiting.</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
|
||||||
|
with catch({Exception: handle_exception}):
|
||||||
|
await self._lifespan.startup()
|
||||||
|
|
||||||
|
if not self.should_exit.is_set():
|
||||||
|
logger.info("Application startup completed.")
|
||||||
|
|
||||||
|
async def _listen_exit(self, tg: Optional[TaskGroup] = None):
|
||||||
|
await self.should_exit.wait()
|
||||||
|
|
||||||
|
if tg is not None:
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
|
|
||||||
|
async def _shutdown(self):
|
||||||
|
logger.info("Shutting down")
|
||||||
|
logger.info("Waiting for application shutdown. (CTRL+C to force quit)")
|
||||||
|
|
||||||
|
error_occurred: bool = False
|
||||||
|
|
||||||
|
def handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||||
|
nonlocal error_occurred
|
||||||
|
|
||||||
|
error_occurred = True
|
||||||
|
|
||||||
|
for exc in flatten_exception_group(exc_group):
|
||||||
|
logger.opt(colors=True, exception=exc).error(
|
||||||
|
"<r><bg #f8bbd0>Error occurred while running shutdown hook."
|
||||||
|
"</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
"<r><bg #f8bbd0>Application shutdown failed. "
|
||||||
|
"Exiting.</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
|
||||||
|
with catch({Exception: handle_exception}):
|
||||||
|
await self._lifespan.shutdown()
|
||||||
|
|
||||||
|
if not error_occurred:
|
||||||
|
logger.info("Application shutdown complete.")
|
||||||
|
|
||||||
|
async def _listen_force_exit(self, tg: TaskGroup):
|
||||||
|
await self.force_exit.wait()
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
|
|
||||||
def exit(self, force: bool = False):
|
def exit(self, force: bool = False):
|
||||||
"""退出 none driver
|
"""退出 none driver
|
||||||
|
|
||||||
@ -142,4 +159,4 @@ class Driver(BaseDriver):
|
|||||||
if not self.should_exit.is_set():
|
if not self.should_exit.is_set():
|
||||||
self.should_exit.set()
|
self.should_exit.set()
|
||||||
if force:
|
if force:
|
||||||
self.force_exit = True
|
self.force_exit.set()
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
|
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional, Protocol
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.exception import MockApiException
|
from nonebot.exception import MockApiException
|
||||||
|
from nonebot.utils import flatten_exception_group
|
||||||
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -76,48 +79,99 @@ class Bot(abc.ABC):
|
|||||||
skip_calling_api: bool = False
|
skip_calling_api: bool = False
|
||||||
exception: Optional[Exception] = None
|
exception: Optional[Exception] = None
|
||||||
|
|
||||||
if coros := [hook(self, api, data) for hook in self._calling_api_hook]:
|
if self._calling_api_hook:
|
||||||
try:
|
|
||||||
logger.debug("Running CallingAPI hooks...")
|
logger.debug("Running CallingAPI hooks...")
|
||||||
await asyncio.gather(*coros)
|
|
||||||
except MockApiException as e:
|
def _handle_mock_api_exception(
|
||||||
skip_calling_api = True
|
exc_group: BaseExceptionGroup[MockApiException],
|
||||||
result = e.result
|
) -> None:
|
||||||
logger.debug(
|
nonlocal skip_calling_api, result
|
||||||
f"Calling API {api} is cancelled. Return {result} instead."
|
|
||||||
|
excs = [
|
||||||
|
exc
|
||||||
|
for exc in flatten_exception_group(exc_group)
|
||||||
|
if isinstance(exc, MockApiException)
|
||||||
|
]
|
||||||
|
if not excs:
|
||||||
|
return
|
||||||
|
elif len(excs) > 1:
|
||||||
|
logger.warning(
|
||||||
|
"Multiple hooks want to mock API result. Use the first one."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
skip_calling_api = True
|
||||||
|
result = excs[0].result
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Calling API {api} is cancelled. Return {result!r} instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||||
|
for exc in flatten_exception_group(exc_group):
|
||||||
|
logger.opt(colors=True, exception=exc).error(
|
||||||
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
"<r><bg #f8bbd0>Error when running CallingAPI hook. "
|
||||||
"Running cancelled!</bg #f8bbd0></r>"
|
"Running cancelled!</bg #f8bbd0></r>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with catch(
|
||||||
|
{
|
||||||
|
MockApiException: _handle_mock_api_exception,
|
||||||
|
Exception: _handle_exception,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for hook in self._calling_api_hook:
|
||||||
|
tg.start_soon(hook, self, api, data)
|
||||||
|
|
||||||
if not skip_calling_api:
|
if not skip_calling_api:
|
||||||
try:
|
try:
|
||||||
result = await self.adapter._call_api(self, api, **data)
|
result = await self.adapter._call_api(self, api, **data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception = e
|
exception = e
|
||||||
|
|
||||||
if coros := [
|
if self._called_api_hook:
|
||||||
hook(self, exception, api, data, result) for hook in self._called_api_hook
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
logger.debug("Running CalledAPI hooks...")
|
logger.debug("Running CalledAPI hooks...")
|
||||||
await asyncio.gather(*coros)
|
|
||||||
except MockApiException as e:
|
def _handle_mock_api_exception(
|
||||||
# mock api result
|
exc_group: BaseExceptionGroup[MockApiException],
|
||||||
result = e.result
|
) -> None:
|
||||||
# ignore exception
|
nonlocal result, exception
|
||||||
|
|
||||||
|
excs = [
|
||||||
|
exc
|
||||||
|
for exc in flatten_exception_group(exc_group)
|
||||||
|
if isinstance(exc, MockApiException)
|
||||||
|
]
|
||||||
|
if not excs:
|
||||||
|
return
|
||||||
|
elif len(excs) > 1:
|
||||||
|
logger.warning(
|
||||||
|
"Multiple hooks want to mock API result. Use the first one."
|
||||||
|
)
|
||||||
|
|
||||||
|
result = excs[0].result
|
||||||
exception = None
|
exception = None
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Calling API {api} result is mocked. Return {result} instead."
|
f"Calling API {api} result is mocked. Return {result} instead."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
def _handle_exception(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||||
|
for exc in flatten_exception_group(exc_group):
|
||||||
|
logger.opt(colors=True, exception=exc).error(
|
||||||
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
|
"<r><bg #f8bbd0>Error when running CalledAPI hook. "
|
||||||
"Running cancelled!</bg #f8bbd0></r>"
|
"Running cancelled!</bg #f8bbd0></r>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with catch(
|
||||||
|
{
|
||||||
|
MockApiException: _handle_mock_api_exception,
|
||||||
|
Exception: _handle_exception,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for hook in self._called_api_hook:
|
||||||
|
tg.start_soon(hook, self, exception, api, data, result)
|
||||||
|
|
||||||
if exception:
|
if exception:
|
||||||
raise exception
|
raise exception
|
||||||
return result
|
return result
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
from collections.abc import Awaitable
|
from types import TracebackType
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
from typing import Any, Union, Callable, cast
|
from collections.abc import Iterable, Awaitable
|
||||||
|
from typing import Any, Union, Callable, Optional, cast
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from anyio.abc import TaskGroup
|
||||||
|
from exceptiongroup import suppress
|
||||||
|
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable
|
from nonebot.utils import run_sync, is_coroutine_callable
|
||||||
|
|
||||||
@ -11,10 +16,24 @@ LIFESPAN_FUNC: TypeAlias = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]
|
|||||||
|
|
||||||
class Lifespan:
|
class Lifespan:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self._task_group: Optional[TaskGroup] = None
|
||||||
|
|
||||||
self._startup_funcs: list[LIFESPAN_FUNC] = []
|
self._startup_funcs: list[LIFESPAN_FUNC] = []
|
||||||
self._ready_funcs: list[LIFESPAN_FUNC] = []
|
self._ready_funcs: list[LIFESPAN_FUNC] = []
|
||||||
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
|
self._shutdown_funcs: list[LIFESPAN_FUNC] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_group(self) -> TaskGroup:
|
||||||
|
if self._task_group is None:
|
||||||
|
raise RuntimeError("Lifespan not started")
|
||||||
|
return self._task_group
|
||||||
|
|
||||||
|
@task_group.setter
|
||||||
|
def task_group(self, task_group: TaskGroup) -> None:
|
||||||
|
if self._task_group is not None:
|
||||||
|
raise RuntimeError("Lifespan already started")
|
||||||
|
self._task_group = task_group
|
||||||
|
|
||||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
self._startup_funcs.append(func)
|
self._startup_funcs.append(func)
|
||||||
return func
|
return func
|
||||||
@ -29,7 +48,7 @@ class Lifespan:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _run_lifespan_func(
|
async def _run_lifespan_func(
|
||||||
funcs: list[LIFESPAN_FUNC],
|
funcs: Iterable[LIFESPAN_FUNC],
|
||||||
) -> None:
|
) -> None:
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
if is_coroutine_callable(func):
|
if is_coroutine_callable(func):
|
||||||
@ -38,18 +57,44 @@ class Lifespan:
|
|||||||
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()
|
||||||
|
|
||||||
async def startup(self) -> None:
|
async def startup(self) -> None:
|
||||||
|
# create background task group
|
||||||
|
self.task_group = anyio.create_task_group()
|
||||||
|
await self.task_group.__aenter__()
|
||||||
|
|
||||||
|
# run startup funcs
|
||||||
if self._startup_funcs:
|
if self._startup_funcs:
|
||||||
await self._run_lifespan_func(self._startup_funcs)
|
await self._run_lifespan_func(self._startup_funcs)
|
||||||
|
|
||||||
|
# run ready funcs
|
||||||
if self._ready_funcs:
|
if self._ready_funcs:
|
||||||
await self._run_lifespan_func(self._ready_funcs)
|
await self._run_lifespan_func(self._ready_funcs)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
exc_type: Optional[type[BaseException]] = None,
|
||||||
|
exc_val: Optional[BaseException] = None,
|
||||||
|
exc_tb: Optional[TracebackType] = None,
|
||||||
|
) -> None:
|
||||||
if self._shutdown_funcs:
|
if self._shutdown_funcs:
|
||||||
await self._run_lifespan_func(self._shutdown_funcs)
|
# reverse shutdown funcs to ensure stack order
|
||||||
|
await self._run_lifespan_func(reversed(self._shutdown_funcs))
|
||||||
|
|
||||||
|
# shutdown background task group
|
||||||
|
self.task_group.cancel_scope.cancel()
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
await self.task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
self._task_group = None
|
||||||
|
|
||||||
async def __aenter__(self) -> None:
|
async def __aenter__(self) -> None:
|
||||||
await self.startup()
|
await self.startup()
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
async def __aexit__(
|
||||||
await self.shutdown()
|
self,
|
||||||
|
exc_type: Optional[type[BaseException]],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
await self.shutdown(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing_extensions import Self, TypeAlias
|
from typing_extensions import Self, TypeAlias
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
|
from typing import TYPE_CHECKING, Any, Union, ClassVar, Optional
|
||||||
|
|
||||||
|
from anyio.abc import TaskGroup
|
||||||
|
from anyio import CancelScope, create_task_group
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
|
||||||
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
from nonebot.internal.params import BotParam, DependParam, DefaultParam
|
||||||
|
from nonebot.utils import escape_tag, run_coro_with_catch, flatten_exception_group
|
||||||
from nonebot.typing import (
|
from nonebot.typing import (
|
||||||
T_DependencyCache,
|
T_DependencyCache,
|
||||||
T_BotConnectionHook,
|
T_BotConnectionHook,
|
||||||
@ -61,7 +64,6 @@ class Driver(abc.ABC):
|
|||||||
self.config: Config = config
|
self.config: Config = config
|
||||||
"""全局配置对象"""
|
"""全局配置对象"""
|
||||||
self._bots: dict[str, "Bot"] = {}
|
self._bots: dict[str, "Bot"] = {}
|
||||||
self._bot_tasks: set[asyncio.Task] = set()
|
|
||||||
self._lifespan = Lifespan()
|
self._lifespan = Lifespan()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -75,6 +77,10 @@ class Driver(abc.ABC):
|
|||||||
"""获取当前所有已连接的 Bot"""
|
"""获取当前所有已连接的 Bot"""
|
||||||
return self._bots
|
return self._bots
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_group(self) -> TaskGroup:
|
||||||
|
return self._lifespan.task_group
|
||||||
|
|
||||||
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
|
def register_adapter(self, adapter: type["Adapter"], **kwargs) -> None:
|
||||||
"""注册一个协议适配器
|
"""注册一个协议适配器
|
||||||
|
|
||||||
@ -112,8 +118,6 @@ class Driver(abc.ABC):
|
|||||||
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
|
f"<g>Loaded adapters: {escape_tag(', '.join(self._adapters))}</g>"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.on_shutdown(self._cleanup)
|
|
||||||
|
|
||||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||||
"""注册一个启动时执行的函数"""
|
"""注册一个启动时执行的函数"""
|
||||||
return self._lifespan.on_startup(func)
|
return self._lifespan.on_startup(func)
|
||||||
@ -154,66 +158,57 @@ class Driver(abc.ABC):
|
|||||||
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
|
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
|
||||||
self._bots[bot.self_id] = bot
|
self._bots[bot.self_id] = bot
|
||||||
|
|
||||||
async def _run_hook(bot: "Bot") -> None:
|
def handle_exception(exc_group: BaseExceptionGroup) -> None:
|
||||||
dependency_cache: T_DependencyCache = {}
|
for exc in flatten_exception_group(exc_group):
|
||||||
async with AsyncExitStack() as stack:
|
logger.opt(colors=True, exception=exc).error(
|
||||||
if coros := [
|
|
||||||
run_coro_with_catch(
|
|
||||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
|
||||||
(SkippedException,),
|
|
||||||
)
|
|
||||||
for hook in self._bot_connection_hook
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*coros)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>"
|
"<r><bg #f8bbd0>"
|
||||||
"Error when running WebSocketConnection hook. "
|
"Error when running WebSocketConnection hook:"
|
||||||
"Running cancelled!"
|
|
||||||
"</bg #f8bbd0></r>"
|
"</bg #f8bbd0></r>"
|
||||||
)
|
)
|
||||||
|
|
||||||
task = asyncio.create_task(_run_hook(bot))
|
async def _run_hook(bot: "Bot") -> None:
|
||||||
task.add_done_callback(self._bot_tasks.discard)
|
dependency_cache: T_DependencyCache = {}
|
||||||
self._bot_tasks.add(task)
|
with CancelScope(shield=True), catch({Exception: handle_exception}):
|
||||||
|
async with AsyncExitStack() as stack, create_task_group() as tg:
|
||||||
|
for hook in self._bot_connection_hook:
|
||||||
|
tg.start_soon(
|
||||||
|
run_coro_with_catch,
|
||||||
|
hook(
|
||||||
|
bot=bot, stack=stack, dependency_cache=dependency_cache
|
||||||
|
),
|
||||||
|
(SkippedException,),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.task_group.start_soon(_run_hook, bot)
|
||||||
|
|
||||||
def _bot_disconnect(self, bot: "Bot") -> None:
|
def _bot_disconnect(self, bot: "Bot") -> None:
|
||||||
"""在连接断开后,调用该函数来注销 bot 对象"""
|
"""在连接断开后,调用该函数来注销 bot 对象"""
|
||||||
if bot.self_id in self._bots:
|
if bot.self_id in self._bots:
|
||||||
del self._bots[bot.self_id]
|
del self._bots[bot.self_id]
|
||||||
|
|
||||||
async def _run_hook(bot: "Bot") -> None:
|
def handle_exception(exc_group: BaseExceptionGroup) -> None:
|
||||||
dependency_cache: T_DependencyCache = {}
|
for exc in flatten_exception_group(exc_group):
|
||||||
async with AsyncExitStack() as stack:
|
logger.opt(colors=True, exception=exc).error(
|
||||||
if coros := [
|
|
||||||
run_coro_with_catch(
|
|
||||||
hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
|
|
||||||
(SkippedException,),
|
|
||||||
)
|
|
||||||
for hook in self._bot_disconnection_hook
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*coros)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>"
|
"<r><bg #f8bbd0>"
|
||||||
"Error when running WebSocketDisConnection hook. "
|
"Error when running WebSocketDisConnection hook:"
|
||||||
"Running cancelled!"
|
|
||||||
"</bg #f8bbd0></r>"
|
"</bg #f8bbd0></r>"
|
||||||
)
|
)
|
||||||
|
|
||||||
task = asyncio.create_task(_run_hook(bot))
|
async def _run_hook(bot: "Bot") -> None:
|
||||||
task.add_done_callback(self._bot_tasks.discard)
|
dependency_cache: T_DependencyCache = {}
|
||||||
self._bot_tasks.add(task)
|
# shield cancellation to ensure bot disconnect hooks are always run
|
||||||
|
with CancelScope(shield=True), catch({Exception: handle_exception}):
|
||||||
async def _cleanup(self) -> None:
|
async with create_task_group() as tg, AsyncExitStack() as stack:
|
||||||
"""清理驱动器资源"""
|
for hook in self._bot_disconnection_hook:
|
||||||
if self._bot_tasks:
|
tg.start_soon(
|
||||||
logger.opt(colors=True).debug(
|
run_coro_with_catch,
|
||||||
"<y>Waiting for running bot connection hooks...</y>"
|
hook(
|
||||||
|
bot=bot, stack=stack, dependency_cache=dependency_cache
|
||||||
|
),
|
||||||
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
await asyncio.gather(*self._bot_tasks, return_exceptions=True)
|
|
||||||
|
self.task_group.start_soon(_run_hook, bot)
|
||||||
|
|
||||||
|
|
||||||
class Mixin(abc.ABC):
|
class Mixin(abc.ABC):
|
||||||
|
@ -22,11 +22,13 @@ from typing import ( # noqa: UP035
|
|||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.internal.rule import Rule
|
from nonebot.internal.rule import Rule
|
||||||
from nonebot.utils import classproperty
|
|
||||||
from nonebot.dependencies import Param, Dependent
|
from nonebot.dependencies import Param, Dependent
|
||||||
from nonebot.internal.permission import User, Permission
|
from nonebot.internal.permission import User, Permission
|
||||||
|
from nonebot.utils import classproperty, flatten_exception_group
|
||||||
from nonebot.internal.adapter import (
|
from nonebot.internal.adapter import (
|
||||||
Bot,
|
Bot,
|
||||||
Event,
|
Event,
|
||||||
@ -812,8 +814,12 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
f"bot={bot}, event={event!r}, state={state!r}"
|
f"bot={bot}, event={event!r}, state={state!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _handle_stop_propagation(exc_group: BaseExceptionGroup[StopPropagation]):
|
||||||
|
self.block = True
|
||||||
|
|
||||||
with self.ensure_context(bot, event):
|
with self.ensure_context(bot, event):
|
||||||
try:
|
try:
|
||||||
|
with catch({StopPropagation: _handle_stop_propagation}):
|
||||||
# Refresh preprocess state
|
# Refresh preprocess state
|
||||||
self.state.update(state)
|
self.state.update(state)
|
||||||
|
|
||||||
@ -821,7 +827,13 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
handler = self.remain_handlers.pop(0)
|
handler = self.remain_handlers.pop(0)
|
||||||
current_handler.set(handler)
|
current_handler.set(handler)
|
||||||
logger.debug(f"Running handler {handler}")
|
logger.debug(f"Running handler {handler}")
|
||||||
try:
|
|
||||||
|
def _handle_skipped(
|
||||||
|
exc_group: BaseExceptionGroup[SkippedException],
|
||||||
|
):
|
||||||
|
logger.debug(f"Handler {handler} skipped")
|
||||||
|
|
||||||
|
with catch({SkippedException: _handle_skipped}):
|
||||||
await handler(
|
await handler(
|
||||||
matcher=self,
|
matcher=self,
|
||||||
bot=bot,
|
bot=bot,
|
||||||
@ -830,10 +842,6 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
stack=stack,
|
stack=stack,
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
)
|
)
|
||||||
except SkippedException:
|
|
||||||
logger.debug(f"Handler {handler} skipped")
|
|
||||||
except StopPropagation:
|
|
||||||
self.block = True
|
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"{self} running complete")
|
logger.info(f"{self} running complete")
|
||||||
|
|
||||||
@ -846,10 +854,54 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
stack: Optional[AsyncExitStack] = None,
|
stack: Optional[AsyncExitStack] = None,
|
||||||
dependency_cache: Optional[T_DependencyCache] = None,
|
dependency_cache: Optional[T_DependencyCache] = None,
|
||||||
):
|
):
|
||||||
try:
|
exc: Optional[Union[FinishedException, RejectedException, PausedException]] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_special_exception(
|
||||||
|
exc_group: BaseExceptionGroup[
|
||||||
|
Union[FinishedException, RejectedException, PausedException]
|
||||||
|
]
|
||||||
|
):
|
||||||
|
nonlocal exc
|
||||||
|
excs = list(flatten_exception_group(exc_group))
|
||||||
|
if len(excs) > 1:
|
||||||
|
logger.warning(
|
||||||
|
"Multiple session control exceptions occurred. "
|
||||||
|
"NoneBot will choose the proper one."
|
||||||
|
)
|
||||||
|
finished_exc = next(
|
||||||
|
(e for e in excs if isinstance(e, FinishedException)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
rejected_exc = next(
|
||||||
|
(e for e in excs if isinstance(e, RejectedException)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
paused_exc = next(
|
||||||
|
(e for e in excs if isinstance(e, PausedException)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
exc = finished_exc or rejected_exc or paused_exc
|
||||||
|
elif isinstance(
|
||||||
|
excs[0], (FinishedException, RejectedException, PausedException)
|
||||||
|
):
|
||||||
|
exc = excs[0]
|
||||||
|
|
||||||
|
with catch(
|
||||||
|
{
|
||||||
|
(
|
||||||
|
FinishedException,
|
||||||
|
RejectedException,
|
||||||
|
PausedException,
|
||||||
|
): _handle_special_exception
|
||||||
|
}
|
||||||
|
):
|
||||||
await self.simple_run(bot, event, state, stack, dependency_cache)
|
await self.simple_run(bot, event, state, stack, dependency_cache)
|
||||||
|
|
||||||
except RejectedException:
|
if isinstance(exc, FinishedException):
|
||||||
|
pass
|
||||||
|
elif isinstance(exc, RejectedException):
|
||||||
await self.resolve_reject()
|
await self.resolve_reject()
|
||||||
type_ = await self.update_type(bot, event, stack, dependency_cache)
|
type_ = await self.update_type(bot, event, stack, dependency_cache)
|
||||||
permission = await self.update_permission(
|
permission = await self.update_permission(
|
||||||
@ -870,7 +922,7 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
default_type_updater=self.__class__._default_type_updater,
|
default_type_updater=self.__class__._default_type_updater,
|
||||||
default_permission_updater=self.__class__._default_permission_updater,
|
default_permission_updater=self.__class__._default_permission_updater,
|
||||||
)
|
)
|
||||||
except PausedException:
|
elif isinstance(exc, PausedException):
|
||||||
type_ = await self.update_type(bot, event, stack, dependency_cache)
|
type_ = await self.update_type(bot, event, stack, dependency_cache)
|
||||||
permission = await self.update_permission(
|
permission = await self.update_permission(
|
||||||
bot, event, stack, dependency_cache
|
bot, event, stack, dependency_cache
|
||||||
@ -890,5 +942,3 @@ class Matcher(metaclass=MatcherMeta):
|
|||||||
default_type_updater=self.__class__._default_type_updater,
|
default_type_updater=self.__class__._default_type_updater,
|
||||||
default_permission_updater=self.__class__._default_permission_updater,
|
default_permission_updater=self.__class__._default_permission_updater,
|
||||||
)
|
)
|
||||||
except FinishedException:
|
|
||||||
pass
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from enum import Enum
|
||||||
from typing_extensions import Self, get_args, override, get_origin
|
from typing_extensions import Self, get_args, override, get_origin
|
||||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -13,8 +13,11 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||||
|
|
||||||
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.dependencies import Param, Dependent
|
from nonebot.dependencies import Param, Dependent
|
||||||
from nonebot.dependencies.utils import check_field_type
|
from nonebot.dependencies.utils import check_field_type
|
||||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
||||||
@ -93,6 +96,75 @@ def Depends(
|
|||||||
return DependsInner(dependency, use_cache=use_cache, validate=validate)
|
return DependsInner(dependency, use_cache=use_cache, validate=validate)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheState(str, Enum):
|
||||||
|
"""子依赖缓存状态"""
|
||||||
|
|
||||||
|
PENDING = "PENDING"
|
||||||
|
FINISHED = "FINISHED"
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyCache:
|
||||||
|
"""子依赖结果缓存。
|
||||||
|
|
||||||
|
用于缓存子依赖的结果,以避免重复计算。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._state = CacheState.PENDING
|
||||||
|
self._result: Any = None
|
||||||
|
self._exception: Optional[BaseException] = None
|
||||||
|
self._waiter = anyio.Event()
|
||||||
|
|
||||||
|
def result(self) -> Any:
|
||||||
|
"""获取子依赖结果"""
|
||||||
|
|
||||||
|
if self._state != CacheState.FINISHED:
|
||||||
|
raise RuntimeError("Result is not ready")
|
||||||
|
|
||||||
|
if self._exception is not None:
|
||||||
|
raise self._exception
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
def exception(self) -> Optional[BaseException]:
|
||||||
|
"""获取子依赖异常"""
|
||||||
|
|
||||||
|
if self._state != CacheState.FINISHED:
|
||||||
|
raise RuntimeError("Result is not ready")
|
||||||
|
|
||||||
|
return self._exception
|
||||||
|
|
||||||
|
def set_result(self, result: Any) -> None:
|
||||||
|
"""设置子依赖结果"""
|
||||||
|
|
||||||
|
if self._state != CacheState.PENDING:
|
||||||
|
raise RuntimeError(f"Cache state invalid: {self._state}")
|
||||||
|
|
||||||
|
self._result = result
|
||||||
|
self._state = CacheState.FINISHED
|
||||||
|
self._waiter.set()
|
||||||
|
|
||||||
|
def set_exception(self, exception: BaseException) -> None:
|
||||||
|
"""设置子依赖异常"""
|
||||||
|
|
||||||
|
if self._state != CacheState.PENDING:
|
||||||
|
raise RuntimeError(f"Cache state invalid: {self._state}")
|
||||||
|
|
||||||
|
self._exception = exception
|
||||||
|
self._state = CacheState.FINISHED
|
||||||
|
self._waiter.set()
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""等待子依赖结果"""
|
||||||
|
await self._waiter.wait()
|
||||||
|
if self._state != CacheState.FINISHED:
|
||||||
|
raise RuntimeError("Invalid cache state")
|
||||||
|
|
||||||
|
if self._exception is not None:
|
||||||
|
raise self._exception
|
||||||
|
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
class DependParam(Param):
|
class DependParam(Param):
|
||||||
"""子依赖注入参数。
|
"""子依赖注入参数。
|
||||||
|
|
||||||
@ -194,17 +266,27 @@ class DependParam(Param):
|
|||||||
call = cast(Callable[..., Any], sub_dependent.call)
|
call = cast(Callable[..., Any], sub_dependent.call)
|
||||||
|
|
||||||
# solve sub dependency with current cache
|
# solve sub dependency with current cache
|
||||||
|
exc: Optional[BaseExceptionGroup[SkippedException]] = None
|
||||||
|
|
||||||
|
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
||||||
|
nonlocal exc
|
||||||
|
exc = exc_group
|
||||||
|
|
||||||
|
with catch({SkippedException: _handle_skipped}):
|
||||||
sub_values = await sub_dependent.solve(
|
sub_values = await sub_dependent.solve(
|
||||||
stack=stack,
|
stack=stack,
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if exc is not None:
|
||||||
|
raise exc
|
||||||
|
|
||||||
# run dependency function
|
# run dependency function
|
||||||
task: asyncio.Task[Any]
|
|
||||||
if use_cache and call in dependency_cache:
|
if use_cache and call in dependency_cache:
|
||||||
return await dependency_cache[call]
|
return await dependency_cache[call].wait()
|
||||||
elif is_gen_callable(call) or is_async_gen_callable(call):
|
|
||||||
|
if is_gen_callable(call) or is_async_gen_callable(call):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
stack, AsyncExitStack
|
stack, AsyncExitStack
|
||||||
), "Generator dependency should be called in context"
|
), "Generator dependency should be called in context"
|
||||||
@ -212,17 +294,21 @@ class DependParam(Param):
|
|||||||
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
cm = run_sync_ctx_manager(contextmanager(call)(**sub_values))
|
||||||
else:
|
else:
|
||||||
cm = asynccontextmanager(call)(**sub_values)
|
cm = asynccontextmanager(call)(**sub_values)
|
||||||
task = asyncio.create_task(stack.enter_async_context(cm))
|
|
||||||
dependency_cache[call] = task
|
target = stack.enter_async_context(cm)
|
||||||
return await task
|
|
||||||
elif is_coroutine_callable(call):
|
elif is_coroutine_callable(call):
|
||||||
task = asyncio.create_task(call(**sub_values))
|
target = call(**sub_values)
|
||||||
dependency_cache[call] = task
|
|
||||||
return await task
|
|
||||||
else:
|
else:
|
||||||
task = asyncio.create_task(run_sync(call)(**sub_values))
|
target = run_sync(call)(**sub_values)
|
||||||
dependency_cache[call] = task
|
|
||||||
return await task
|
dependency_cache[call] = cache = DependencyCache()
|
||||||
|
try:
|
||||||
|
result = await target
|
||||||
|
cache.set_result(result)
|
||||||
|
return result
|
||||||
|
except BaseException as e:
|
||||||
|
cache.set_exception(e)
|
||||||
|
raise
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def _check(self, **kwargs: Any) -> None:
|
async def _check(self, **kwargs: Any) -> None:
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import Union, ClassVar, NoReturn, Optional
|
from typing import Union, ClassVar, NoReturn, Optional
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.utils import run_coro_with_catch
|
from nonebot.utils import run_coro_with_catch
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
@ -70,22 +71,26 @@ class Permission:
|
|||||||
"""
|
"""
|
||||||
if not self.checkers:
|
if not self.checkers:
|
||||||
return True
|
return True
|
||||||
results = await asyncio.gather(
|
|
||||||
*(
|
result = False
|
||||||
run_coro_with_catch(
|
|
||||||
|
async def _run_checker(checker: Dependent[bool]) -> None:
|
||||||
|
nonlocal result
|
||||||
|
# calculate the result first to avoid data racing
|
||||||
|
is_passed = await run_coro_with_catch(
|
||||||
checker(
|
checker(
|
||||||
bot=bot,
|
bot=bot, event=event, stack=stack, dependency_cache=dependency_cache
|
||||||
event=event,
|
|
||||||
stack=stack,
|
|
||||||
dependency_cache=dependency_cache,
|
|
||||||
),
|
),
|
||||||
(SkippedException,),
|
(SkippedException,),
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
for checker in self.checkers
|
result |= is_passed
|
||||||
),
|
|
||||||
)
|
async with anyio.create_task_group() as tg:
|
||||||
return any(results)
|
for checker in self.checkers:
|
||||||
|
tg.start_soon(_run_checker, checker)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def __and__(self, other: object) -> NoReturn:
|
def __and__(self, other: object) -> NoReturn:
|
||||||
raise RuntimeError("And operation between Permissions is not allowed.")
|
raise RuntimeError("And operation between Permissions is not allowed.")
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import Union, ClassVar, NoReturn, Optional
|
from typing import Union, ClassVar, NoReturn, Optional
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
|
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
|
||||||
@ -71,22 +73,33 @@ class Rule:
|
|||||||
"""
|
"""
|
||||||
if not self.checkers:
|
if not self.checkers:
|
||||||
return True
|
return True
|
||||||
try:
|
|
||||||
results = await asyncio.gather(
|
result = True
|
||||||
*(
|
|
||||||
checker(
|
def _handle_skipped_exception(
|
||||||
|
exc_group: BaseExceptionGroup[SkippedException],
|
||||||
|
) -> None:
|
||||||
|
nonlocal result
|
||||||
|
result = False
|
||||||
|
|
||||||
|
async def _run_checker(checker: Dependent[bool]) -> None:
|
||||||
|
nonlocal result
|
||||||
|
# calculate the result first to avoid data racing
|
||||||
|
is_passed = await checker(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
state=state,
|
state=state,
|
||||||
stack=stack,
|
stack=stack,
|
||||||
dependency_cache=dependency_cache,
|
dependency_cache=dependency_cache,
|
||||||
)
|
)
|
||||||
for checker in self.checkers
|
result &= is_passed
|
||||||
)
|
|
||||||
)
|
with catch({SkippedException: _handle_skipped_exception}):
|
||||||
except SkippedException:
|
async with anyio.create_task_group() as tg:
|
||||||
return False
|
for checker in self.checkers:
|
||||||
return all(results)
|
tg.start_soon(_run_checker, checker)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
|
||||||
if other is None:
|
if other is None:
|
||||||
|
@ -9,23 +9,30 @@ FrontMatter:
|
|||||||
description: nonebot.message 模块
|
description: nonebot.message 模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import contextlib
|
import contextlib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.rule import TrieRule
|
from nonebot.rule import TrieRule
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.matcher import Matcher, matchers
|
from nonebot.matcher import Matcher, matchers
|
||||||
from nonebot.utils import escape_tag, run_coro_with_catch
|
|
||||||
from nonebot.exception import (
|
from nonebot.exception import (
|
||||||
NoLogException,
|
NoLogException,
|
||||||
StopPropagation,
|
StopPropagation,
|
||||||
IgnoredException,
|
IgnoredException,
|
||||||
SkippedException,
|
SkippedException,
|
||||||
)
|
)
|
||||||
|
from nonebot.utils import (
|
||||||
|
escape_tag,
|
||||||
|
run_coro_with_catch,
|
||||||
|
run_coro_with_shield,
|
||||||
|
flatten_exception_group,
|
||||||
|
)
|
||||||
from nonebot.typing import (
|
from nonebot.typing import (
|
||||||
T_State,
|
T_State,
|
||||||
T_DependencyCache,
|
T_DependencyCache,
|
||||||
@ -125,6 +132,21 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_ignored_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
|
||||||
|
def _handle(exc_group: BaseExceptionGroup[IgnoredException]) -> None:
|
||||||
|
logger.opt(colors=True).info(msg)
|
||||||
|
|
||||||
|
return _handle
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_exception(msg: str) -> Callable[[BaseExceptionGroup], None]:
|
||||||
|
def _handle(exc_group: BaseExceptionGroup[Exception]) -> None:
|
||||||
|
for exc in flatten_exception_group(exc_group):
|
||||||
|
logger.opt(colors=True, exception=exc).error(msg)
|
||||||
|
|
||||||
|
return _handle
|
||||||
|
|
||||||
|
|
||||||
async def _apply_event_preprocessors(
|
async def _apply_event_preprocessors(
|
||||||
bot: "Bot",
|
bot: "Bot",
|
||||||
event: "Event",
|
event: "Event",
|
||||||
@ -152,10 +174,21 @@ async def _apply_event_preprocessors(
|
|||||||
if show_log:
|
if show_log:
|
||||||
logger.debug("Running PreProcessors...")
|
logger.debug("Running PreProcessors...")
|
||||||
|
|
||||||
try:
|
with catch(
|
||||||
await asyncio.gather(
|
{
|
||||||
*(
|
IgnoredException: _handle_ignored_exception(
|
||||||
run_coro_with_catch(
|
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
|
||||||
|
),
|
||||||
|
Exception: _handle_exception(
|
||||||
|
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
|
||||||
|
"Event ignored!</bg #f8bbd0></r>"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for proc in _event_preprocessors:
|
||||||
|
tg.start_soon(
|
||||||
|
run_coro_with_catch,
|
||||||
proc(
|
proc(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
@ -165,23 +198,11 @@ async def _apply_event_preprocessors(
|
|||||||
),
|
),
|
||||||
(SkippedException,),
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
for proc in _event_preprocessors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except IgnoredException:
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running EventPreProcessors. "
|
|
||||||
"Event ignored!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _apply_event_postprocessors(
|
async def _apply_event_postprocessors(
|
||||||
bot: "Bot",
|
bot: "Bot",
|
||||||
@ -207,10 +228,17 @@ async def _apply_event_postprocessors(
|
|||||||
if show_log:
|
if show_log:
|
||||||
logger.debug("Running PostProcessors...")
|
logger.debug("Running PostProcessors...")
|
||||||
|
|
||||||
try:
|
with catch(
|
||||||
await asyncio.gather(
|
{
|
||||||
*(
|
Exception: _handle_exception(
|
||||||
run_coro_with_catch(
|
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for proc in _event_postprocessors:
|
||||||
|
tg.start_soon(
|
||||||
|
run_coro_with_catch,
|
||||||
proc(
|
proc(
|
||||||
bot=bot,
|
bot=bot,
|
||||||
event=event,
|
event=event,
|
||||||
@ -220,13 +248,6 @@ async def _apply_event_postprocessors(
|
|||||||
),
|
),
|
||||||
(SkippedException,),
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
for proc in _event_postprocessors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _apply_run_preprocessors(
|
async def _apply_run_preprocessors(
|
||||||
@ -254,11 +275,24 @@ async def _apply_run_preprocessors(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# ensure matcher function can be correctly called
|
# ensure matcher function can be correctly called
|
||||||
with matcher.ensure_context(bot, event):
|
with (
|
||||||
try:
|
matcher.ensure_context(bot, event),
|
||||||
await asyncio.gather(
|
catch(
|
||||||
*(
|
{
|
||||||
run_coro_with_catch(
|
IgnoredException: _handle_ignored_exception(
|
||||||
|
f"{matcher} running is <b>cancelled</b>"
|
||||||
|
),
|
||||||
|
Exception: _handle_exception(
|
||||||
|
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
|
||||||
|
"Running cancelled!</bg #f8bbd0></r>"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for proc in _run_preprocessors:
|
||||||
|
tg.start_soon(
|
||||||
|
run_coro_with_catch,
|
||||||
proc(
|
proc(
|
||||||
matcher=matcher,
|
matcher=matcher,
|
||||||
bot=bot,
|
bot=bot,
|
||||||
@ -269,21 +303,11 @@ async def _apply_run_preprocessors(
|
|||||||
),
|
),
|
||||||
(SkippedException,),
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
for proc in _run_preprocessors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except IgnoredException:
|
|
||||||
logger.opt(colors=True).info(f"{matcher} running is <b>cancelled</b>")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running RunPreProcessors. "
|
|
||||||
"Running cancelled!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _apply_run_postprocessors(
|
async def _apply_run_postprocessors(
|
||||||
bot: "Bot",
|
bot: "Bot",
|
||||||
@ -306,11 +330,21 @@ async def _apply_run_postprocessors(
|
|||||||
if not _run_postprocessors:
|
if not _run_postprocessors:
|
||||||
return
|
return
|
||||||
|
|
||||||
with matcher.ensure_context(bot, event):
|
with (
|
||||||
try:
|
matcher.ensure_context(bot, event),
|
||||||
await asyncio.gather(
|
catch(
|
||||||
*(
|
{
|
||||||
run_coro_with_catch(
|
Exception: _handle_exception(
|
||||||
|
"<r><bg #f8bbd0>Error when running RunPostProcessors"
|
||||||
|
"</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for proc in _run_postprocessors:
|
||||||
|
tg.start_soon(
|
||||||
|
run_coro_with_catch,
|
||||||
proc(
|
proc(
|
||||||
matcher=matcher,
|
matcher=matcher,
|
||||||
exception=exception,
|
exception=exception,
|
||||||
@ -322,13 +356,6 @@ async def _apply_run_postprocessors(
|
|||||||
),
|
),
|
||||||
(SkippedException,),
|
(SkippedException,),
|
||||||
)
|
)
|
||||||
for proc in _run_postprocessors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running RunPostProcessors</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_matcher(
|
async def _check_matcher(
|
||||||
@ -425,8 +452,9 @@ async def _run_matcher(
|
|||||||
|
|
||||||
exception = None
|
exception = None
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug(f"Running {matcher}")
|
logger.debug(f"Running {matcher}")
|
||||||
|
|
||||||
|
try:
|
||||||
await matcher.run(bot, event, state, stack, dependency_cache)
|
await matcher.run(bot, event, state, stack, dependency_cache)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.opt(colors=True, exception=e).error(
|
logger.opt(colors=True, exception=e).error(
|
||||||
@ -494,8 +522,7 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
|||||||
|
|
||||||
用法:
|
用法:
|
||||||
```python
|
```python
|
||||||
import asyncio
|
driver.task_group.start_soon(handle_event, bot, event)
|
||||||
asyncio.create_task(handle_event(bot, event))
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
show_log = True
|
show_log = True
|
||||||
@ -530,6 +557,13 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
break_flag = False
|
break_flag = False
|
||||||
|
|
||||||
|
def _handle_stop_propagation(exc_group: BaseExceptionGroup) -> None:
|
||||||
|
nonlocal break_flag
|
||||||
|
|
||||||
|
break_flag = True
|
||||||
|
logger.debug("Stop event propagation")
|
||||||
|
|
||||||
# iterate through all priority until stop propagation
|
# iterate through all priority until stop propagation
|
||||||
for priority in sorted(matchers.keys()):
|
for priority in sorted(matchers.keys()):
|
||||||
if break_flag:
|
if break_flag:
|
||||||
@ -538,22 +572,29 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
|
|||||||
if show_log:
|
if show_log:
|
||||||
logger.debug(f"Checking for matchers in priority {priority}...")
|
logger.debug(f"Checking for matchers in priority {priority}...")
|
||||||
|
|
||||||
pending_tasks = [
|
if not (priority_matchers := matchers[priority]):
|
||||||
check_and_run_matcher(
|
|
||||||
matcher, bot, event, state.copy(), stack, dependency_cache
|
|
||||||
)
|
|
||||||
for matcher in matchers[priority]
|
|
||||||
]
|
|
||||||
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
|
|
||||||
for result in results:
|
|
||||||
if not isinstance(result, Exception):
|
|
||||||
continue
|
continue
|
||||||
if isinstance(result, StopPropagation):
|
|
||||||
break_flag = True
|
with catch(
|
||||||
logger.debug("Stop event propagation")
|
{
|
||||||
else:
|
StopPropagation: _handle_stop_propagation,
|
||||||
logger.opt(colors=True, exception=result).error(
|
Exception: _handle_exception(
|
||||||
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
|
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
):
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
for matcher in priority_matchers:
|
||||||
|
tg.start_soon(
|
||||||
|
run_coro_with_shield,
|
||||||
|
check_and_run_matcher(
|
||||||
|
matcher,
|
||||||
|
bot,
|
||||||
|
event,
|
||||||
|
state.copy(),
|
||||||
|
stack,
|
||||||
|
dependency_cache,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if show_log:
|
if show_log:
|
||||||
|
@ -22,7 +22,7 @@ from . import _managers, get_plugin, _module_name_to_plugin_id
|
|||||||
try: # pragma: py-gte-311
|
try: # pragma: py-gte-311
|
||||||
import tomllib # pyright: ignore[reportMissingImports]
|
import tomllib # pyright: ignore[reportMissingImports]
|
||||||
except ModuleNotFoundError: # pragma: py-lt-311
|
except ModuleNotFoundError: # pragma: py-lt-311
|
||||||
import tomli as tomllib
|
import tomli as tomllib # pyright: ignore[reportMissingImports]
|
||||||
|
|
||||||
|
|
||||||
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:
|
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:
|
||||||
|
@ -21,10 +21,9 @@ from typing import TYPE_CHECKING, TypeVar
|
|||||||
from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin
|
from typing_extensions import ParamSpec, TypeAlias, get_args, override, get_origin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from asyncio import Task
|
|
||||||
|
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
from nonebot.permission import Permission
|
from nonebot.permission import Permission
|
||||||
|
from nonebot.internal.params import DependencyCache
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
@ -258,5 +257,5 @@ T_PermissionUpdater: TypeAlias = _DependentCallable["Permission"]
|
|||||||
- MatcherParam: Matcher 对象
|
- MatcherParam: Matcher 对象
|
||||||
- DefaultParam: 带有默认值的参数
|
- DefaultParam: 带有默认值的参数
|
||||||
"""
|
"""
|
||||||
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "Task[t.Any]"]
|
T_DependencyCache: TypeAlias = dict[_DependentCallable[t.Any], "DependencyCache"]
|
||||||
"""依赖缓存, 用于存储依赖函数的返回值"""
|
"""依赖缓存, 用于存储依赖函数的返回值"""
|
||||||
|
@ -9,21 +9,22 @@ FrontMatter:
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import importlib
|
import importlib
|
||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextvars import copy_context
|
|
||||||
from functools import wraps, partial
|
from functools import wraps, partial
|
||||||
from contextlib import AbstractContextManager, asynccontextmanager
|
from contextlib import AbstractContextManager, asynccontextmanager
|
||||||
from typing_extensions import ParamSpec, get_args, override, get_origin
|
from typing_extensions import ParamSpec, get_args, override, get_origin
|
||||||
from collections.abc import Mapping, Sequence, Coroutine, AsyncGenerator
|
|
||||||
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
|
from typing import Any, Union, Generic, TypeVar, Callable, Optional, overload
|
||||||
|
from collections.abc import Mapping, Sequence, Coroutine, Generator, AsyncGenerator
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
import anyio.to_thread
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from exceptiongroup import BaseExceptionGroup, catch
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import (
|
from nonebot.typing import (
|
||||||
@ -39,6 +40,7 @@ R = TypeVar("R")
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
K = TypeVar("K")
|
K = TypeVar("K")
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
|
E = TypeVar("E", bound=BaseException)
|
||||||
|
|
||||||
|
|
||||||
def escape_tag(s: str) -> str:
|
def escape_tag(s: str) -> str:
|
||||||
@ -178,11 +180,9 @@ def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]:
|
|||||||
|
|
||||||
@wraps(call)
|
@wraps(call)
|
||||||
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
loop = asyncio.get_running_loop()
|
return await anyio.to_thread.run_sync(
|
||||||
pfunc = partial(call, *args, **kwargs)
|
partial(call, *args, **kwargs), abandon_on_cancel=True
|
||||||
context = copy_context()
|
)
|
||||||
result = await loop.run_in_executor(None, partial(context.run, pfunc))
|
|
||||||
return result
|
|
||||||
|
|
||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
@ -234,12 +234,36 @@ async def run_coro_with_catch(
|
|||||||
协程的返回值或发生异常时的指定值
|
协程的返回值或发生异常时的指定值
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
with catch({exc: lambda exc_group: None}):
|
||||||
return await coro
|
return await coro
|
||||||
except exc:
|
|
||||||
return return_on_err
|
return return_on_err
|
||||||
|
|
||||||
|
|
||||||
|
async def run_coro_with_shield(coro: Coroutine[Any, Any, T]) -> T:
|
||||||
|
"""运行协程并在取消时屏蔽取消异常。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
coro: 要运行的协程
|
||||||
|
|
||||||
|
返回:
|
||||||
|
协程的返回值
|
||||||
|
"""
|
||||||
|
|
||||||
|
with anyio.CancelScope(shield=True):
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_exception_group(
|
||||||
|
exc_group: BaseExceptionGroup[E],
|
||||||
|
) -> Generator[E, None, None]:
|
||||||
|
for exc in exc_group.exceptions:
|
||||||
|
if isinstance(exc, BaseExceptionGroup):
|
||||||
|
yield from flatten_exception_group(exc)
|
||||||
|
else:
|
||||||
|
yield exc
|
||||||
|
|
||||||
|
|
||||||
def get_name(obj: Any) -> str:
|
def get_name(obj: Any) -> str:
|
||||||
"""获取对象的名称"""
|
"""获取对象的名称"""
|
||||||
if inspect.isfunction(obj) or inspect.isclass(obj):
|
if inspect.isfunction(obj) or inspect.isclass(obj):
|
||||||
|
2274
poetry.lock
generated
2274
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -27,7 +27,9 @@ include = ["nonebot/py.typed"]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
yarl = "^1.7.2"
|
yarl = "^1.7.2"
|
||||||
|
anyio = "^4.4.0"
|
||||||
pygtrie = "^2.4.1"
|
pygtrie = "^2.4.1"
|
||||||
|
exceptiongroup = "^1.2.2"
|
||||||
loguru = ">=0.6.0,<1.0.0"
|
loguru = ">=0.6.0,<1.0.0"
|
||||||
python-dotenv = ">=0.21.0,<2.0.0"
|
python-dotenv = ">=0.21.0,<2.0.0"
|
||||||
typing-extensions = ">=4.4.0,<5.0.0"
|
typing-extensions = ">=4.4.0,<5.0.0"
|
||||||
@ -65,7 +67,6 @@ fastapi = ["fastapi", "uvicorn"]
|
|||||||
all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"]
|
all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "strict"
|
|
||||||
addopts = "--cov=nonebot --cov-report=term-missing"
|
addopts = "--cov=nonebot --cov-report=term-missing"
|
||||||
filterwarnings = ["error", "ignore::DeprecationWarning"]
|
filterwarnings = ["error", "ignore::DeprecationWarning"]
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from functools import wraps
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
from typing import TYPE_CHECKING, TypeVar, Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from nonebug import NONEBOT_INIT_KWARGS
|
from nonebug import NONEBOT_INIT_KWARGS
|
||||||
@ -20,6 +22,9 @@ os.environ["CONFIG_OVERRIDE"] = "new"
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nonebot.plugin import Plugin
|
from nonebot.plugin import Plugin
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
collect_ignore = ["plugins/", "dynamic/", "bad_plugins/"]
|
collect_ignore = ["plugins/", "dynamic/", "bad_plugins/"]
|
||||||
|
|
||||||
|
|
||||||
@ -38,14 +43,36 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
|
|||||||
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
return DriverClass(Env(environment=global_driver.env), global_driver.config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", params=[pytest.param("asyncio"), pytest.param("trio")])
|
||||||
|
def anyio_backend(request: pytest.FixtureRequest):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
def run_once(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
result = ...
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
nonlocal result
|
||||||
|
if result is not Ellipsis:
|
||||||
|
return result
|
||||||
|
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def load_plugin(nonebug_init: None) -> set["Plugin"]:
|
@run_once
|
||||||
|
def load_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
|
||||||
# preload global plugins
|
# preload global plugins
|
||||||
return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))
|
return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def load_builtin_plugin(nonebug_init: None) -> set["Plugin"]:
|
@run_once
|
||||||
|
def load_builtin_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
|
||||||
# preload builtin plugins
|
# preload builtin plugins
|
||||||
return nonebot.load_builtin_plugins("echo", "single_session")
|
return nonebot.load_builtin_plugins("echo", "single_session")
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from nonebot.drivers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_adapter_connect(app: App, driver: Driver):
|
async def test_adapter_connect(app: App, driver: Driver):
|
||||||
last_connect_bot: Optional[Bot] = None
|
last_connect_bot: Optional[Bot] = None
|
||||||
last_disconnect_bot: Optional[Bot] = None
|
last_disconnect_bot: Optional[Bot] = None
|
||||||
@ -45,7 +45,6 @@ async def test_adapter_connect(app: App, driver: Driver):
|
|||||||
assert bot.self_id not in adapter.bots
|
assert bot.self_id not in adapter.bots
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -75,7 +74,7 @@ async def test_adapter_connect(app: App, driver: Driver):
|
|||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
async def test_adapter_server(driver: Driver):
|
def test_adapter_server(driver: Driver):
|
||||||
last_http_setup: Optional[HTTPServerSetup] = None
|
last_http_setup: Optional[HTTPServerSetup] = None
|
||||||
last_ws_setup: Optional[WebSocketServerSetup] = None
|
last_ws_setup: Optional[WebSocketServerSetup] = None
|
||||||
|
|
||||||
@ -112,7 +111,7 @@ async def test_adapter_server(driver: Driver):
|
|||||||
assert last_ws_setup is setup
|
assert last_ws_setup is setup
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -159,7 +158,7 @@ async def test_adapter_http_client(driver: Driver):
|
|||||||
assert last_request is request
|
assert last_request is request
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
@ -7,7 +8,7 @@ from nonebot.adapters import Bot
|
|||||||
from nonebot.exception import MockApiException
|
from nonebot.exception import MockApiException
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_bot_call_api(app: App):
|
async def test_bot_call_api(app: App):
|
||||||
async with app.test_api() as ctx:
|
async with app.test_api() as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
@ -23,7 +24,7 @@ async def test_bot_call_api(app: App):
|
|||||||
await bot.call_api("test")
|
await bot.call_api("test")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_bot_calling_api_hook_simple(app: App):
|
async def test_bot_calling_api_hook_simple(app: App):
|
||||||
runned: bool = False
|
runned: bool = False
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ async def test_bot_calling_api_hook_simple(app: App):
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_bot_calling_api_hook_mock(app: App):
|
async def test_bot_calling_api_hook_mock(app: App):
|
||||||
runned: bool = False
|
runned: bool = False
|
||||||
|
|
||||||
@ -76,7 +77,47 @@ async def test_bot_calling_api_hook_mock(app: App):
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
|
async def test_bot_calling_api_hook_multi_mock(app: App):
|
||||||
|
runned1: bool = False
|
||||||
|
runned2: bool = False
|
||||||
|
event = anyio.Event()
|
||||||
|
|
||||||
|
async def calling_api_hook1(bot: Bot, api: str, data: dict[str, Any]):
|
||||||
|
nonlocal runned1
|
||||||
|
runned1 = True
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
raise MockApiException(1)
|
||||||
|
|
||||||
|
async def calling_api_hook2(bot: Bot, api: str, data: dict[str, Any]):
|
||||||
|
nonlocal runned2
|
||||||
|
runned2 = True
|
||||||
|
with anyio.fail_after(1):
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
|
raise MockApiException(2)
|
||||||
|
|
||||||
|
hooks = set()
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch.context() as m:
|
||||||
|
m.setattr(Bot, "_calling_api_hook", hooks)
|
||||||
|
|
||||||
|
Bot.on_calling_api(calling_api_hook1)
|
||||||
|
Bot.on_calling_api(calling_api_hook2)
|
||||||
|
|
||||||
|
assert hooks == {calling_api_hook1, calling_api_hook2}
|
||||||
|
|
||||||
|
async with app.test_api() as ctx:
|
||||||
|
bot = ctx.create_bot()
|
||||||
|
result = await bot.call_api("test")
|
||||||
|
|
||||||
|
assert runned1 is True
|
||||||
|
assert runned2 is True
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
async def test_bot_called_api_hook_simple(app: App):
|
async def test_bot_called_api_hook_simple(app: App):
|
||||||
runned: bool = False
|
runned: bool = False
|
||||||
|
|
||||||
@ -108,7 +149,7 @@ async def test_bot_called_api_hook_simple(app: App):
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_bot_called_api_hook_mock(app: App):
|
async def test_bot_called_api_hook_mock(app: App):
|
||||||
runned: bool = False
|
runned: bool = False
|
||||||
|
|
||||||
@ -150,3 +191,56 @@ async def test_bot_called_api_hook_mock(app: App):
|
|||||||
|
|
||||||
assert runned is True
|
assert runned is True
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_bot_called_api_hook_multi_mock(app: App):
|
||||||
|
runned1: bool = False
|
||||||
|
runned2: bool = False
|
||||||
|
event = anyio.Event()
|
||||||
|
|
||||||
|
async def called_api_hook1(
|
||||||
|
bot: Bot,
|
||||||
|
exception: Optional[Exception],
|
||||||
|
api: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
result: Any,
|
||||||
|
):
|
||||||
|
nonlocal runned1
|
||||||
|
runned1 = True
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
raise MockApiException(1)
|
||||||
|
|
||||||
|
async def called_api_hook2(
|
||||||
|
bot: Bot,
|
||||||
|
exception: Optional[Exception],
|
||||||
|
api: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
result: Any,
|
||||||
|
):
|
||||||
|
nonlocal runned2
|
||||||
|
runned2 = True
|
||||||
|
with anyio.fail_after(1):
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
|
raise MockApiException(2)
|
||||||
|
|
||||||
|
hooks = set()
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch.context() as m:
|
||||||
|
m.setattr(Bot, "_called_api_hook", hooks)
|
||||||
|
|
||||||
|
Bot.on_called_api(called_api_hook1)
|
||||||
|
Bot.on_called_api(called_api_hook2)
|
||||||
|
|
||||||
|
assert hooks == {called_api_hook1, called_api_hook2}
|
||||||
|
|
||||||
|
async with app.test_api() as ctx:
|
||||||
|
bot = ctx.create_bot()
|
||||||
|
ctx.should_call_api("test", {}, True)
|
||||||
|
result = await bot.call_api("test")
|
||||||
|
|
||||||
|
assert runned1 is True
|
||||||
|
assert runned2 is True
|
||||||
|
assert result == 1
|
||||||
|
@ -25,7 +25,7 @@ async def _dependency() -> int:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(message, "_event_preprocessors", set())
|
m.setattr(message, "_event_preprocessors", set())
|
||||||
@ -58,7 +58,7 @@ async def test_event_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert runned, "event_preprocessor should runned"
|
assert runned, "event_preprocessor should runned"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(message, "_event_preprocessors", set())
|
m.setattr(message, "_event_preprocessors", set())
|
||||||
@ -88,7 +88,7 @@ async def test_event_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPat
|
|||||||
assert not runned, "matcher should not runned"
|
assert not runned, "matcher should not runned"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_event_preprocessor_exception(
|
async def test_event_preprocessor_exception(
|
||||||
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
||||||
):
|
):
|
||||||
@ -132,7 +132,7 @@ async def test_event_preprocessor_exception(
|
|||||||
assert "RuntimeError: test" in capsys.readouterr().out
|
assert "RuntimeError: test" in capsys.readouterr().out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(message, "_event_postprocessors", set())
|
m.setattr(message, "_event_postprocessors", set())
|
||||||
@ -165,7 +165,7 @@ async def test_event_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert runned, "event_postprocessor should runned"
|
assert runned, "event_postprocessor should runned"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_event_postprocessor_exception(
|
async def test_event_postprocessor_exception(
|
||||||
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
||||||
):
|
):
|
||||||
@ -202,7 +202,7 @@ async def test_event_postprocessor_exception(
|
|||||||
assert "RuntimeError: test" in capsys.readouterr().out
|
assert "RuntimeError: test" in capsys.readouterr().out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(message, "_run_preprocessors", set())
|
m.setattr(message, "_run_preprocessors", set())
|
||||||
@ -239,7 +239,7 @@ async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert runned, "run_preprocessor should runned"
|
assert runned, "run_preprocessor should runned"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(message, "_run_preprocessors", set())
|
m.setattr(message, "_run_preprocessors", set())
|
||||||
@ -269,7 +269,7 @@ async def test_run_preprocessor_ignore(app: App, monkeypatch: pytest.MonkeyPatch
|
|||||||
assert not runned, "matcher should not runned"
|
assert not runned, "matcher should not runned"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_run_preprocessor_exception(
|
async def test_run_preprocessor_exception(
|
||||||
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
||||||
):
|
):
|
||||||
@ -313,7 +313,7 @@ async def test_run_preprocessor_exception(
|
|||||||
assert "RuntimeError: test" in capsys.readouterr().out
|
assert "RuntimeError: test" in capsys.readouterr().out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(message, "_run_postprocessors", set())
|
m.setattr(message, "_run_postprocessors", set())
|
||||||
@ -351,7 +351,7 @@ async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert runned, "run_postprocessor should runned"
|
assert runned, "run_postprocessor should runned"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_run_postprocessor_exception(
|
async def test_run_postprocessor_exception(
|
||||||
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
app: App, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
||||||
):
|
):
|
||||||
|
@ -17,14 +17,12 @@ from nonebot.compat import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_default_config():
|
||||||
async def test_default_config():
|
|
||||||
assert DEFAULT_CONFIG.get("extra") == "allow"
|
assert DEFAULT_CONFIG.get("extra") == "allow"
|
||||||
assert DEFAULT_CONFIG.get("arbitrary_types_allowed") is True
|
assert DEFAULT_CONFIG.get("arbitrary_types_allowed") is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_field_info():
|
||||||
async def test_field_info():
|
|
||||||
# required should be convert to PydanticUndefined
|
# required should be convert to PydanticUndefined
|
||||||
assert FieldInfo(Required).default is PydanticUndefined
|
assert FieldInfo(Required).default is PydanticUndefined
|
||||||
|
|
||||||
@ -32,8 +30,7 @@ async def test_field_info():
|
|||||||
assert FieldInfo(test="test").extra["test"] == "test"
|
assert FieldInfo(test="test").extra["test"] == "test"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_type_adapter():
|
||||||
async def test_type_adapter():
|
|
||||||
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
|
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
|
||||||
|
|
||||||
assert t.validate_python(2) == 2
|
assert t.validate_python(2) == 2
|
||||||
@ -47,8 +44,7 @@ async def test_type_adapter():
|
|||||||
t.validate_json("0")
|
t.validate_json("0")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_model_dump():
|
||||||
async def test_model_dump():
|
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
test1: int
|
test1: int
|
||||||
test2: int
|
test2: int
|
||||||
@ -57,8 +53,7 @@ async def test_model_dump():
|
|||||||
assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2}
|
assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_custom_validation():
|
||||||
async def test_custom_validation():
|
|
||||||
called = []
|
called = []
|
||||||
|
|
||||||
@custom_validation
|
@custom_validation
|
||||||
@ -85,8 +80,7 @@ async def test_custom_validation():
|
|||||||
assert called == [1, 2]
|
assert called == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_validate_json():
|
||||||
async def test_validate_json():
|
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
test1: int
|
test1: int
|
||||||
test2: str
|
test2: str
|
||||||
|
@ -50,16 +50,14 @@ class ExampleWithoutDelimiter(Example):
|
|||||||
env_nested_delimiter = None
|
env_nested_delimiter = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_config_no_env():
|
||||||
async def test_config_no_env():
|
|
||||||
config = Example(_env_file=None)
|
config = Example(_env_file=None)
|
||||||
assert config.simple == ""
|
assert config.simple == ""
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
config.common_config
|
config.common_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_config_with_env():
|
||||||
async def test_config_with_env():
|
|
||||||
config = Example(_env_file=(".env", ".env.example"))
|
config = Example(_env_file=(".env", ".env.example"))
|
||||||
assert config.simple == "simple"
|
assert config.simple == "simple"
|
||||||
|
|
||||||
@ -102,8 +100,7 @@ async def test_config_with_env():
|
|||||||
config.other_nested_inner__b
|
config.other_nested_inner__b
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_config_error_env():
|
||||||
async def test_config_error_env():
|
|
||||||
with pytest.MonkeyPatch().context() as m:
|
with pytest.MonkeyPatch().context() as m:
|
||||||
m.setenv("COMPLEX", "not json")
|
m.setenv("COMPLEX", "not json")
|
||||||
|
|
||||||
@ -111,8 +108,7 @@ async def test_config_error_env():
|
|||||||
Example(_env_file=(".env", ".env.example"))
|
Example(_env_file=(".env", ".env.example"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_config_without_delimiter():
|
||||||
async def test_config_without_delimiter():
|
|
||||||
config = ExampleWithoutDelimiter()
|
config = ExampleWithoutDelimiter()
|
||||||
assert config.nested.a == 1
|
assert config.nested.a == 1
|
||||||
assert config.nested.b == 0
|
assert config.nested.b == 0
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from http.cookies import SimpleCookie
|
from http.cookies import SimpleCookie
|
||||||
|
|
||||||
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ from nonebot.drivers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
|
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
|
||||||
)
|
)
|
||||||
@ -59,22 +59,22 @@ async def test_lifespan(driver: Driver):
|
|||||||
|
|
||||||
@driver.on_shutdown
|
@driver.on_shutdown
|
||||||
async def _shutdown1():
|
async def _shutdown1():
|
||||||
assert shutdown_log == []
|
assert shutdown_log == [2]
|
||||||
shutdown_log.append(1)
|
shutdown_log.append(1)
|
||||||
|
|
||||||
@driver.on_shutdown
|
@driver.on_shutdown
|
||||||
async def _shutdown2():
|
async def _shutdown2():
|
||||||
assert shutdown_log == [1]
|
assert shutdown_log == []
|
||||||
shutdown_log.append(2)
|
shutdown_log.append(2)
|
||||||
|
|
||||||
async with driver._lifespan:
|
async with driver._lifespan:
|
||||||
assert start_log == [1, 2]
|
assert start_log == [1, 2]
|
||||||
assert ready_log == [1, 2]
|
assert ready_log == [1, 2]
|
||||||
|
|
||||||
assert shutdown_log == [1, 2]
|
assert shutdown_log == [2, 1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -99,10 +99,10 @@ async def test_http_server(app: App, driver: Driver):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "test"
|
assert response.text == "test"
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -155,10 +155,10 @@ async def test_websocket_server(app: App, driver: Driver):
|
|||||||
|
|
||||||
await ws.close(code=1000)
|
await ws.close(code=1000)
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -171,9 +171,10 @@ async def test_cross_context(app: App, driver: Driver):
|
|||||||
assert isinstance(driver, ASGIMixin)
|
assert isinstance(driver, ASGIMixin)
|
||||||
|
|
||||||
ws: Optional[WebSocket] = None
|
ws: Optional[WebSocket] = None
|
||||||
ws_ready = asyncio.Event()
|
ws_ready = anyio.Event()
|
||||||
ws_should_close = asyncio.Event()
|
ws_should_close = anyio.Event()
|
||||||
|
|
||||||
|
# create a background task before the ws connection established
|
||||||
async def background_task():
|
async def background_task():
|
||||||
try:
|
try:
|
||||||
await ws_ready.wait()
|
await ws_ready.wait()
|
||||||
@ -185,8 +186,6 @@ async def test_cross_context(app: App, driver: Driver):
|
|||||||
finally:
|
finally:
|
||||||
ws_should_close.set()
|
ws_should_close.set()
|
||||||
|
|
||||||
task = asyncio.create_task(background_task())
|
|
||||||
|
|
||||||
async def _handle_ws(websocket: WebSocket) -> None:
|
async def _handle_ws(websocket: WebSocket) -> None:
|
||||||
nonlocal ws
|
nonlocal ws
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
@ -199,7 +198,9 @@ async def test_cross_context(app: App, driver: Driver):
|
|||||||
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
|
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
|
||||||
driver.setup_websocket_server(ws_setup)
|
driver.setup_websocket_server(ws_setup)
|
||||||
|
|
||||||
async with app.test_server(driver.asgi) as ctx:
|
async with anyio.create_task_group() as tg, app.test_server(driver.asgi) as ctx:
|
||||||
|
tg.start_soon(background_task)
|
||||||
|
|
||||||
client = ctx.get_client()
|
client = ctx.get_client()
|
||||||
|
|
||||||
async with client.websocket_connect("/ws_test") as websocket:
|
async with client.websocket_connect("/ws_test") as websocket:
|
||||||
@ -211,11 +212,10 @@ async def test_cross_context(app: App, driver: Driver):
|
|||||||
if not e.args or "websocket.close" not in str(e.args[0]):
|
if not e.args or "websocket.close" not in str(e.args[0]):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
await task
|
await anyio.sleep(1)
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -304,10 +304,10 @@ async def test_http_client(driver: Driver, server_url: URL):
|
|||||||
"test3": "test",
|
"test3": "test",
|
||||||
}, "file parsing error"
|
}, "file parsing error"
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -419,10 +419,10 @@ async def test_http_client_session(driver: Driver, server_url: URL):
|
|||||||
"test3": "test",
|
"test3": "test",
|
||||||
}, "file parsing error"
|
}, "file parsing error"
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
[
|
[
|
||||||
@ -452,10 +452,9 @@ async def test_websocket_client(driver: Driver, server_url: URL):
|
|||||||
with pytest.raises(WebSocketClosed, match=r"code=1000"):
|
with pytest.raises(WebSocketClosed, match=r"code=1000"):
|
||||||
await ws.receive()
|
await ws.receive()
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await anyio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("driver", "driver_type"),
|
("driver", "driver_type"),
|
||||||
[
|
[
|
||||||
@ -472,11 +471,11 @@ async def test_websocket_client(driver: Driver, server_url: URL):
|
|||||||
],
|
],
|
||||||
indirect=["driver"],
|
indirect=["driver"],
|
||||||
)
|
)
|
||||||
async def test_combine_driver(driver: Driver, driver_type: str):
|
def test_combine_driver(driver: Driver, driver_type: str):
|
||||||
assert driver.type == driver_type
|
assert driver.type == driver_type
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_bot_connect_hook(app: App, driver: Driver):
|
async def test_bot_connect_hook(app: App, driver: Driver):
|
||||||
with pytest.MonkeyPatch.context() as m:
|
with pytest.MonkeyPatch.context() as m:
|
||||||
conn_hooks: set[Dependent[Any]] = set()
|
conn_hooks: set[Dependent[Any]] = set()
|
||||||
@ -533,7 +532,7 @@ async def test_bot_connect_hook(app: App, driver: Driver):
|
|||||||
async with app.test_api() as ctx:
|
async with app.test_api() as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await anyio.sleep(1)
|
||||||
|
|
||||||
if not conn_should_be_called:
|
if not conn_should_be_called:
|
||||||
pytest.fail("on_bot_connect hook not called")
|
pytest.fail("on_bot_connect hook not called")
|
||||||
|
@ -4,7 +4,7 @@ from nonebug import App
|
|||||||
from utils import FakeMessage, FakeMessageSegment, make_fake_event
|
from utils import FakeMessage, FakeMessageSegment, make_fake_event
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_echo(app: App):
|
async def test_echo(app: App):
|
||||||
from nonebot.plugins.echo import echo
|
from nonebot.plugins.echo import echo
|
||||||
|
|
||||||
|
@ -14,8 +14,7 @@ from nonebot import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_init():
|
||||||
async def test_init():
|
|
||||||
env = nonebot.get_driver().env
|
env = nonebot.get_driver().env
|
||||||
assert env == "test"
|
assert env == "test"
|
||||||
|
|
||||||
@ -35,31 +34,28 @@ async def test_init():
|
|||||||
assert config.not_nested == "some string"
|
assert config.not_nested == "some string"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_driver(monkeypatch: pytest.MonkeyPatch):
|
||||||
async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setattr(nonebot, "_driver", None)
|
m.setattr(nonebot, "_driver", None)
|
||||||
with pytest.raises(ValueError, match="initialized"):
|
with pytest.raises(ValueError, match="initialized"):
|
||||||
get_driver()
|
get_driver()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_asgi():
|
||||||
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
assert isinstance(driver, ReverseDriver)
|
assert isinstance(driver, ReverseDriver)
|
||||||
assert isinstance(driver, ASGIMixin)
|
assert isinstance(driver, ASGIMixin)
|
||||||
assert get_asgi() == driver.asgi
|
assert get_asgi() == driver.asgi
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_app():
|
||||||
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
assert isinstance(driver, ReverseDriver)
|
assert isinstance(driver, ReverseDriver)
|
||||||
assert isinstance(driver, ASGIMixin)
|
assert isinstance(driver, ASGIMixin)
|
||||||
assert get_app() == driver.server_app
|
assert get_app() == driver.server_app
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
|
async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
async with app.test_api() as ctx:
|
async with app.test_api() as ctx:
|
||||||
adapter = ctx.create_adapter()
|
adapter = ctx.create_adapter()
|
||||||
@ -74,8 +70,7 @@ async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
|
|||||||
get_adapter("not exist")
|
get_adapter("not exist")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_run(monkeypatch: pytest.MonkeyPatch):
|
||||||
async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
|
|
||||||
runned = False
|
runned = False
|
||||||
|
|
||||||
def mock_run(*args, **kwargs):
|
def mock_run(*args, **kwargs):
|
||||||
@ -93,8 +88,7 @@ async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert runned
|
assert runned
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
|
||||||
async def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="no bots"):
|
with pytest.raises(ValueError, match="no bots"):
|
||||||
|
@ -12,8 +12,7 @@ from nonebot.permission import User, Permission
|
|||||||
from nonebot.message import _check_matcher, check_and_run_matcher
|
from nonebot.message import _check_matcher, check_and_run_matcher
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_matcher_info(app: App):
|
||||||
async def test_matcher_info(app: App):
|
|
||||||
from plugins.matcher.matcher_info import matcher
|
from plugins.matcher.matcher_info import matcher
|
||||||
|
|
||||||
assert issubclass(matcher, Matcher)
|
assert issubclass(matcher, Matcher)
|
||||||
@ -43,7 +42,7 @@ async def test_matcher_info(app: App):
|
|||||||
assert matcher._source.lineno == 3
|
assert matcher._source.lineno == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_check(app: App):
|
async def test_matcher_check(app: App):
|
||||||
async def falsy():
|
async def falsy():
|
||||||
return False
|
return False
|
||||||
@ -87,7 +86,7 @@ async def test_matcher_check(app: App):
|
|||||||
assert await _check_matcher(test_rule_error, bot, event, {}) is False
|
assert await _check_matcher(test_rule_error, bot, event, {}) is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_handle(app: App):
|
async def test_matcher_handle(app: App):
|
||||||
from plugins.matcher.matcher_process import test_handle
|
from plugins.matcher.matcher_process import test_handle
|
||||||
|
|
||||||
@ -102,7 +101,7 @@ async def test_matcher_handle(app: App):
|
|||||||
ctx.should_finished()
|
ctx.should_finished()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_got(app: App):
|
async def test_matcher_got(app: App):
|
||||||
from plugins.matcher.matcher_process import test_got
|
from plugins.matcher.matcher_process import test_got
|
||||||
|
|
||||||
@ -124,7 +123,7 @@ async def test_matcher_got(app: App):
|
|||||||
ctx.receive_event(bot, event_next)
|
ctx.receive_event(bot, event_next)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_receive(app: App):
|
async def test_matcher_receive(app: App):
|
||||||
from plugins.matcher.matcher_process import test_receive
|
from plugins.matcher.matcher_process import test_receive
|
||||||
|
|
||||||
@ -141,7 +140,7 @@ async def test_matcher_receive(app: App):
|
|||||||
ctx.should_paused()
|
ctx.should_paused()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_combine(app: App):
|
async def test_matcher_combine(app: App):
|
||||||
from plugins.matcher.matcher_process import test_combine
|
from plugins.matcher.matcher_process import test_combine
|
||||||
|
|
||||||
@ -164,7 +163,7 @@ async def test_matcher_combine(app: App):
|
|||||||
ctx.receive_event(bot, event_next)
|
ctx.receive_event(bot, event_next)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_preset(app: App):
|
async def test_matcher_preset(app: App):
|
||||||
from plugins.matcher.matcher_process import test_preset
|
from plugins.matcher.matcher_process import test_preset
|
||||||
|
|
||||||
@ -182,7 +181,7 @@ async def test_matcher_preset(app: App):
|
|||||||
ctx.receive_event(bot, event_next)
|
ctx.receive_event(bot, event_next)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_overload(app: App):
|
async def test_matcher_overload(app: App):
|
||||||
from plugins.matcher.matcher_process import test_overload
|
from plugins.matcher.matcher_process import test_overload
|
||||||
|
|
||||||
@ -196,7 +195,7 @@ async def test_matcher_overload(app: App):
|
|||||||
ctx.should_finished()
|
ctx.should_finished()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_destroy(app: App):
|
async def test_matcher_destroy(app: App):
|
||||||
from plugins.matcher.matcher_process import test_destroy
|
from plugins.matcher.matcher_process import test_destroy
|
||||||
|
|
||||||
@ -210,7 +209,7 @@ async def test_matcher_destroy(app: App):
|
|||||||
assert len(matchers[test_destroy.priority]) == 0
|
assert len(matchers[test_destroy.priority]) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_type_updater(app: App):
|
async def test_type_updater(app: App):
|
||||||
from plugins.matcher.matcher_type import test_type_updater, test_custom_updater
|
from plugins.matcher.matcher_type import test_type_updater, test_custom_updater
|
||||||
|
|
||||||
@ -231,7 +230,7 @@ async def test_type_updater(app: App):
|
|||||||
assert new_type == "custom"
|
assert new_type == "custom"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_default_permission_updater(app: App):
|
async def test_default_permission_updater(app: App):
|
||||||
from plugins.matcher.matcher_permission import (
|
from plugins.matcher.matcher_permission import (
|
||||||
default_permission,
|
default_permission,
|
||||||
@ -252,7 +251,7 @@ async def test_default_permission_updater(app: App):
|
|||||||
assert checker.perm is default_permission
|
assert checker.perm is default_permission
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_user_permission_updater(app: App):
|
async def test_user_permission_updater(app: App):
|
||||||
from plugins.matcher.matcher_permission import (
|
from plugins.matcher.matcher_permission import (
|
||||||
default_permission,
|
default_permission,
|
||||||
@ -274,7 +273,7 @@ async def test_user_permission_updater(app: App):
|
|||||||
assert checker.perm is default_permission
|
assert checker.perm is default_permission
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_custom_permission_updater(app: App):
|
async def test_custom_permission_updater(app: App):
|
||||||
from plugins.matcher.matcher_permission import (
|
from plugins.matcher.matcher_permission import (
|
||||||
new_permission,
|
new_permission,
|
||||||
@ -291,7 +290,7 @@ async def test_custom_permission_updater(app: App):
|
|||||||
assert new_perm is new_permission
|
assert new_perm is new_permission
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_run(app: App):
|
async def test_run(app: App):
|
||||||
with app.provider.context({}):
|
with app.provider.context({}):
|
||||||
assert not matchers
|
assert not matchers
|
||||||
@ -322,11 +321,12 @@ async def test_run(app: App):
|
|||||||
assert len(matchers[0][0].handlers) == 0
|
assert len(matchers[0][0].handlers) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_temp(app: App):
|
async def test_temp(app: App):
|
||||||
from plugins.matcher.matcher_expire import test_temp_matcher
|
from plugins.matcher.matcher_expire import test_temp_matcher
|
||||||
|
|
||||||
event = make_fake_event(_type="test")()
|
event = make_fake_event(_type="test")()
|
||||||
|
with app.provider.context({test_temp_matcher.priority: [test_temp_matcher]}):
|
||||||
async with app.test_api() as ctx:
|
async with app.test_api() as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
assert test_temp_matcher in matchers[test_temp_matcher.priority]
|
assert test_temp_matcher in matchers[test_temp_matcher.priority]
|
||||||
@ -334,25 +334,33 @@ async def test_temp(app: App):
|
|||||||
assert test_temp_matcher not in matchers[test_temp_matcher.priority]
|
assert test_temp_matcher not in matchers[test_temp_matcher.priority]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_datetime_expire(app: App):
|
async def test_datetime_expire(app: App):
|
||||||
from plugins.matcher.matcher_expire import test_datetime_matcher
|
from plugins.matcher.matcher_expire import test_datetime_matcher
|
||||||
|
|
||||||
event = make_fake_event()()
|
event = make_fake_event()()
|
||||||
async with app.test_api() as ctx:
|
with app.provider.context(
|
||||||
|
{test_datetime_matcher.priority: [test_datetime_matcher]}
|
||||||
|
):
|
||||||
|
async with app.test_matcher(test_datetime_matcher) as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
assert test_datetime_matcher in matchers[test_datetime_matcher.priority]
|
assert test_datetime_matcher in matchers[test_datetime_matcher.priority]
|
||||||
await check_and_run_matcher(test_datetime_matcher, bot, event, {})
|
await check_and_run_matcher(test_datetime_matcher, bot, event, {})
|
||||||
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
|
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_timedelta_expire(app: App):
|
async def test_timedelta_expire(app: App):
|
||||||
from plugins.matcher.matcher_expire import test_timedelta_matcher
|
from plugins.matcher.matcher_expire import test_timedelta_matcher
|
||||||
|
|
||||||
event = make_fake_event()()
|
event = make_fake_event()()
|
||||||
|
with app.provider.context(
|
||||||
|
{test_timedelta_matcher.priority: [test_timedelta_matcher]}
|
||||||
|
):
|
||||||
async with app.test_api() as ctx:
|
async with app.test_api() as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
assert test_timedelta_matcher in matchers[test_timedelta_matcher.priority]
|
assert test_timedelta_matcher in matchers[test_timedelta_matcher.priority]
|
||||||
await check_and_run_matcher(test_timedelta_matcher, bot, event, {})
|
await check_and_run_matcher(test_timedelta_matcher, bot, event, {})
|
||||||
assert test_timedelta_matcher not in matchers[test_timedelta_matcher.priority]
|
assert (
|
||||||
|
test_timedelta_matcher not in matchers[test_timedelta_matcher.priority]
|
||||||
|
)
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import pytest
|
|
||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
from nonebot.matcher import DEFAULT_PROVIDER_CLASS, matchers
|
from nonebot.matcher import DEFAULT_PROVIDER_CLASS, matchers
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_manager(app: App):
|
||||||
async def test_manager(app: App):
|
|
||||||
try:
|
try:
|
||||||
default_provider = matchers.provider
|
default_provider = matchers.provider
|
||||||
matchers.set_provider(DEFAULT_PROVIDER_CLASS)
|
matchers.set_provider(DEFAULT_PROVIDER_CLASS)
|
||||||
|
@ -2,6 +2,7 @@ import re
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
from exceptiongroup import BaseExceptionGroup
|
||||||
|
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
@ -36,7 +37,7 @@ from nonebot.consts import (
|
|||||||
UNKNOWN_PARAM = "Unknown parameter"
|
UNKNOWN_PARAM = "Unknown parameter"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_depend(app: App):
|
async def test_depend(app: App):
|
||||||
from plugins.param.param_depend import (
|
from plugins.param.param_depend import (
|
||||||
ClassDependency,
|
ClassDependency,
|
||||||
@ -90,36 +91,47 @@ async def test_depend(app: App):
|
|||||||
|
|
||||||
assert runned == [1, 1, 1]
|
assert runned == [1, 1, 1]
|
||||||
|
|
||||||
|
runned.clear()
|
||||||
|
|
||||||
async with app.test_dependent(
|
async with app.test_dependent(
|
||||||
annotated_class_depend, allow_types=[DependParam]
|
annotated_class_depend, allow_types=[DependParam]
|
||||||
) as ctx:
|
) as ctx:
|
||||||
ctx.should_return(ClassDependency(x=1, y=2))
|
ctx.should_return(ClassDependency(x=1, y=2))
|
||||||
|
|
||||||
with pytest.raises(TypeMisMatch): # noqa: PT012
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: # noqa: PT012
|
||||||
async with app.test_dependent(
|
async with app.test_dependent(
|
||||||
sub_type_mismatch, allow_types=[DependParam, BotParam]
|
sub_type_mismatch, allow_types=[DependParam, BotParam]
|
||||||
) as ctx:
|
) as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
ctx.pass_params(bot=bot)
|
ctx.pass_params(bot=bot)
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
|
async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
|
||||||
ctx.should_return(1)
|
ctx.should_return(1)
|
||||||
|
|
||||||
with pytest.raises(TypeMisMatch):
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
|
||||||
async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
|
async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
|
async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
|
||||||
ctx.should_return(1)
|
ctx.should_return(1)
|
||||||
|
|
||||||
with pytest.raises(TypeMisMatch):
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
|
||||||
async with app.test_dependent(
|
async with app.test_dependent(
|
||||||
validate_field_fail, allow_types=[DependParam]
|
validate_field_fail, allow_types=[DependParam]
|
||||||
) as ctx:
|
) as ctx:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
|
@pytest.mark.anyio
|
||||||
async def test_bot(app: App):
|
async def test_bot(app: App):
|
||||||
from plugins.param.param_bot import (
|
from plugins.param.param_bot import (
|
||||||
FooBot,
|
FooBot,
|
||||||
@ -157,11 +169,14 @@ async def test_bot(app: App):
|
|||||||
ctx.pass_params(bot=bot)
|
ctx.pass_params(bot=bot)
|
||||||
ctx.should_return(bot)
|
ctx.should_return(bot)
|
||||||
|
|
||||||
with pytest.raises(TypeMisMatch): # noqa: PT012
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: # noqa: PT012
|
||||||
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
|
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
ctx.pass_params(bot=bot)
|
ctx.pass_params(bot=bot)
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx:
|
async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx:
|
||||||
bot = ctx.create_bot(base=FooBot)
|
bot = ctx.create_bot(base=FooBot)
|
||||||
ctx.pass_params(bot=bot)
|
ctx.pass_params(bot=bot)
|
||||||
@ -181,7 +196,7 @@ async def test_bot(app: App):
|
|||||||
app.test_dependent(not_bot, allow_types=[BotParam])
|
app.test_dependent(not_bot, allow_types=[BotParam])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_event(app: App):
|
async def test_event(app: App):
|
||||||
from plugins.param.param_event import (
|
from plugins.param.param_event import (
|
||||||
FooEvent,
|
FooEvent,
|
||||||
@ -223,10 +238,13 @@ async def test_event(app: App):
|
|||||||
ctx.pass_params(event=fake_fooevent)
|
ctx.pass_params(event=fake_fooevent)
|
||||||
ctx.should_return(fake_fooevent)
|
ctx.should_return(fake_fooevent)
|
||||||
|
|
||||||
with pytest.raises(TypeMisMatch):
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
|
||||||
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
|
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
|
||||||
ctx.pass_params(event=fake_event)
|
ctx.pass_params(event=fake_event)
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx:
|
async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx:
|
||||||
ctx.pass_params(event=fake_fooevent)
|
ctx.pass_params(event=fake_fooevent)
|
||||||
ctx.should_return(fake_fooevent)
|
ctx.should_return(fake_fooevent)
|
||||||
@ -267,7 +285,7 @@ async def test_event(app: App):
|
|||||||
ctx.should_return(fake_event.is_tome())
|
ctx.should_return(fake_event.is_tome())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_state(app: App):
|
async def test_state(app: App):
|
||||||
from plugins.param.param_state import (
|
from plugins.param.param_state import (
|
||||||
state,
|
state,
|
||||||
@ -418,7 +436,7 @@ async def test_state(app: App):
|
|||||||
ctx.should_return(fake_state[KEYWORD_KEY])
|
ctx.should_return(fake_state[KEYWORD_KEY])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher(app: App):
|
async def test_matcher(app: App):
|
||||||
from plugins.param.param_matcher import (
|
from plugins.param.param_matcher import (
|
||||||
FooMatcher,
|
FooMatcher,
|
||||||
@ -457,10 +475,13 @@ async def test_matcher(app: App):
|
|||||||
ctx.pass_params(matcher=foo_matcher)
|
ctx.pass_params(matcher=foo_matcher)
|
||||||
ctx.should_return(foo_matcher)
|
ctx.should_return(foo_matcher)
|
||||||
|
|
||||||
with pytest.raises(TypeMisMatch):
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
|
||||||
async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx:
|
async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx:
|
||||||
ctx.pass_params(matcher=fake_matcher)
|
ctx.pass_params(matcher=fake_matcher)
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
async with app.test_dependent(union_matcher, allow_types=[MatcherParam]) as ctx:
|
async with app.test_dependent(union_matcher, allow_types=[MatcherParam]) as ctx:
|
||||||
ctx.pass_params(matcher=foo_matcher)
|
ctx.pass_params(matcher=foo_matcher)
|
||||||
ctx.should_return(foo_matcher)
|
ctx.should_return(foo_matcher)
|
||||||
@ -496,7 +517,7 @@ async def test_matcher(app: App):
|
|||||||
ctx.should_return(event_next)
|
ctx.should_return(event_next)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_arg(app: App):
|
async def test_arg(app: App):
|
||||||
from plugins.param.param_arg import (
|
from plugins.param.param_arg import (
|
||||||
arg,
|
arg,
|
||||||
@ -548,7 +569,7 @@ async def test_arg(app: App):
|
|||||||
ctx.should_return(message.extract_plain_text())
|
ctx.should_return(message.extract_plain_text())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_exception(app: App):
|
async def test_exception(app: App):
|
||||||
from plugins.param.param_exception import exc, legacy_exc
|
from plugins.param.param_exception import exc, legacy_exc
|
||||||
|
|
||||||
@ -562,7 +583,7 @@ async def test_exception(app: App):
|
|||||||
ctx.should_return(exception)
|
ctx.should_return(exception)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_default(app: App):
|
async def test_default(app: App):
|
||||||
from plugins.param.param_default import default
|
from plugins.param.param_default import default
|
||||||
|
|
||||||
@ -570,8 +591,7 @@ async def test_default(app: App):
|
|||||||
ctx.should_return(1)
|
ctx.should_return(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_priority():
|
||||||
async def test_priority():
|
|
||||||
from plugins.param.priority import complex_priority
|
from plugins.param.priority import complex_priority
|
||||||
|
|
||||||
dependent = Dependent[None].parse(
|
dependent = Dependent[None].parse(
|
||||||
|
@ -22,7 +22,7 @@ from nonebot.permission import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_permission(app: App):
|
async def test_permission(app: App):
|
||||||
async def falsy():
|
async def falsy():
|
||||||
return False
|
return False
|
||||||
@ -54,7 +54,7 @@ async def test_permission(app: App):
|
|||||||
assert await Permission(truthy, skipped)(bot, event) is True
|
assert await Permission(truthy, skipped)(bot, event) is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(("type", "expected"), [("message", True), ("notice", False)])
|
@pytest.mark.parametrize(("type", "expected"), [("message", True), ("notice", False)])
|
||||||
async def test_message(type: str, expected: bool):
|
async def test_message(type: str, expected: bool):
|
||||||
dependent = next(iter(MESSAGE.checkers))
|
dependent = next(iter(MESSAGE.checkers))
|
||||||
@ -66,7 +66,7 @@ async def test_message(type: str, expected: bool):
|
|||||||
assert await dependent(event=event) == expected
|
assert await dependent(event=event) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("notice", True)])
|
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("notice", True)])
|
||||||
async def test_notice(type: str, expected: bool):
|
async def test_notice(type: str, expected: bool):
|
||||||
dependent = next(iter(NOTICE.checkers))
|
dependent = next(iter(NOTICE.checkers))
|
||||||
@ -78,7 +78,7 @@ async def test_notice(type: str, expected: bool):
|
|||||||
assert await dependent(event=event) == expected
|
assert await dependent(event=event) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("request", True)])
|
@pytest.mark.parametrize(("type", "expected"), [("message", False), ("request", True)])
|
||||||
async def test_request(type: str, expected: bool):
|
async def test_request(type: str, expected: bool):
|
||||||
dependent = next(iter(REQUEST.checkers))
|
dependent = next(iter(REQUEST.checkers))
|
||||||
@ -90,7 +90,7 @@ async def test_request(type: str, expected: bool):
|
|||||||
assert await dependent(event=event) == expected
|
assert await dependent(event=event) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("type", "expected"), [("message", False), ("meta_event", True)]
|
("type", "expected"), [("message", False), ("meta_event", True)]
|
||||||
)
|
)
|
||||||
@ -104,7 +104,7 @@ async def test_metaevent(type: str, expected: bool):
|
|||||||
assert await dependent(event=event) == expected
|
assert await dependent(event=event) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("type", "user_id", "expected"),
|
("type", "user_id", "expected"),
|
||||||
[
|
[
|
||||||
@ -128,7 +128,7 @@ async def test_superuser(app: App, type: str, user_id: str, expected: bool):
|
|||||||
assert await dependent(bot=bot, event=event) == expected
|
assert await dependent(bot=bot, event=event) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("session_ids", "session_id", "expected"),
|
("session_ids", "session_id", "expected"),
|
||||||
[
|
[
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import pytest
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from nonebot.plugin import PluginManager, _managers
|
from nonebot.plugin import PluginManager, _managers
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_plugin():
|
||||||
async def test_get_plugin():
|
|
||||||
# check simple plugin
|
# check simple plugin
|
||||||
plugin = nonebot.get_plugin("export")
|
plugin = nonebot.get_plugin("export")
|
||||||
assert plugin
|
assert plugin
|
||||||
@ -22,8 +20,7 @@ async def test_get_plugin():
|
|||||||
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
|
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_plugin_by_module_name():
|
||||||
async def test_get_plugin_by_module_name():
|
|
||||||
# check get plugin by exact module name
|
# check get plugin by exact module name
|
||||||
plugin = nonebot.get_plugin_by_module_name("plugins.nested")
|
plugin = nonebot.get_plugin_by_module_name("plugins.nested")
|
||||||
assert plugin
|
assert plugin
|
||||||
@ -48,8 +45,7 @@ async def test_get_plugin_by_module_name():
|
|||||||
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
|
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_available_plugin():
|
||||||
async def test_get_available_plugin():
|
|
||||||
old_managers = _managers.copy()
|
old_managers = _managers.copy()
|
||||||
_managers.clear()
|
_managers.clear()
|
||||||
try:
|
try:
|
||||||
@ -63,8 +59,7 @@ async def test_get_available_plugin():
|
|||||||
_managers.extend(old_managers)
|
_managers.extend(old_managers)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_get_plugin_config():
|
||||||
async def test_get_plugin_config():
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
plugin_config: int
|
plugin_config: int
|
||||||
|
|
||||||
|
@ -1,15 +1,44 @@
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from functools import wraps
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
from typing import TypeVar, Callable
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from nonebot.plugin import Plugin, PluginManager, _managers, inherit_supported_adapters
|
from nonebot.plugin import (
|
||||||
|
Plugin,
|
||||||
|
PluginManager,
|
||||||
|
_plugins,
|
||||||
|
_managers,
|
||||||
|
inherit_supported_adapters,
|
||||||
|
)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def _recover(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
async def test_load_plugin():
|
|
||||||
|
@wraps(func)
|
||||||
|
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
origin_managers = _managers.copy()
|
||||||
|
origin_plugins = _plugins.copy()
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
_managers.clear()
|
||||||
|
_managers.extend(origin_managers)
|
||||||
|
_plugins.clear()
|
||||||
|
_plugins.update(origin_plugins)
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@_recover
|
||||||
|
def test_load_plugin():
|
||||||
# check regular
|
# check regular
|
||||||
assert nonebot.load_plugin("dynamic.simple")
|
assert nonebot.load_plugin("dynamic.simple")
|
||||||
|
|
||||||
@ -20,8 +49,7 @@ async def test_load_plugin():
|
|||||||
assert nonebot.load_plugin("some_plugin_not_exist") is None
|
assert nonebot.load_plugin("some_plugin_not_exist") is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[Plugin]):
|
||||||
async def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[Plugin]):
|
|
||||||
loaded_plugins = {
|
loaded_plugins = {
|
||||||
plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
|
plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
|
||||||
}
|
}
|
||||||
@ -44,8 +72,7 @@ async def test_load_plugins(load_plugin: set[Plugin], load_builtin_plugin: set[P
|
|||||||
PluginManager(search_path=["plugins"]).load_all_plugins()
|
PluginManager(search_path=["plugins"]).load_all_plugins()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_load_nested_plugin():
|
||||||
async def test_load_nested_plugin():
|
|
||||||
parent_plugin = nonebot.get_plugin("nested")
|
parent_plugin = nonebot.get_plugin("nested")
|
||||||
sub_plugin = nonebot.get_plugin("nested:nested_subplugin")
|
sub_plugin = nonebot.get_plugin("nested:nested_subplugin")
|
||||||
sub_plugin2 = nonebot.get_plugin("nested:nested_subplugin2")
|
sub_plugin2 = nonebot.get_plugin("nested:nested_subplugin2")
|
||||||
@ -57,16 +84,16 @@ async def test_load_nested_plugin():
|
|||||||
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
|
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_load_json():
|
def test_load_json():
|
||||||
nonebot.load_from_json("./plugins.json")
|
nonebot.load_from_json("./plugins.json")
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
nonebot.load_from_json("./plugins.invalid.json")
|
nonebot.load_from_json("./plugins.invalid.json")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_load_toml():
|
def test_load_toml():
|
||||||
nonebot.load_from_toml("./plugins.toml")
|
nonebot.load_from_toml("./plugins.toml")
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Cannot find"):
|
with pytest.raises(ValueError, match="Cannot find"):
|
||||||
@ -76,19 +103,20 @@ async def test_load_toml():
|
|||||||
nonebot.load_from_toml("./plugins.invalid.toml")
|
nonebot.load_from_toml("./plugins.invalid.toml")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_bad_plugin():
|
def test_bad_plugin():
|
||||||
nonebot.load_plugins("bad_plugins")
|
nonebot.load_plugins("bad_plugins")
|
||||||
|
|
||||||
assert nonebot.get_plugin("bad_plugin") is None
|
assert nonebot.get_plugin("bad_plugin") is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_require_loaded(monkeypatch: pytest.MonkeyPatch):
|
def test_require_loaded(monkeypatch: pytest.MonkeyPatch):
|
||||||
def _patched_find(name: str):
|
def _patched_find(name: str):
|
||||||
pytest.fail("require existing plugin should not call find_manager_by_name")
|
pytest.fail("require existing plugin should not call find_manager_by_name")
|
||||||
|
|
||||||
monkeypatch.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find)
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr("nonebot.plugin.load._find_manager_by_name", _patched_find)
|
||||||
|
|
||||||
# require use module name
|
# require use module name
|
||||||
nonebot.require("plugins.export")
|
nonebot.require("plugins.export")
|
||||||
@ -97,19 +125,20 @@ async def test_require_loaded(monkeypatch: pytest.MonkeyPatch):
|
|||||||
nonebot.require("nested:nested_subplugin")
|
nonebot.require("nested:nested_subplugin")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch):
|
def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch):
|
||||||
m = PluginManager(["dynamic.require_not_loaded"], ["dynamic/require_not_loaded/"])
|
pm = PluginManager(["dynamic.require_not_loaded"], ["dynamic/require_not_loaded/"])
|
||||||
_managers.append(m)
|
_managers.append(pm)
|
||||||
num_managers = len(_managers)
|
num_managers = len(_managers)
|
||||||
|
|
||||||
origin_load = PluginManager.load_plugin
|
origin_load = PluginManager.load_plugin
|
||||||
|
|
||||||
def _patched_load(self: PluginManager, name: str):
|
def _patched_load(self: PluginManager, name: str):
|
||||||
assert self is m
|
assert self is pm
|
||||||
return origin_load(self, name)
|
return origin_load(self, name)
|
||||||
|
|
||||||
monkeypatch.setattr(PluginManager, "load_plugin", _patched_load)
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr(PluginManager, "load_plugin", _patched_load)
|
||||||
|
|
||||||
# require standalone plugin
|
# require standalone plugin
|
||||||
nonebot.require("dynamic.require_not_loaded")
|
nonebot.require("dynamic.require_not_loaded")
|
||||||
@ -120,8 +149,8 @@ async def test_require_not_loaded(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert len(_managers) == num_managers
|
assert len(_managers) == num_managers
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_require_not_declared():
|
def test_require_not_declared():
|
||||||
num_managers = len(_managers)
|
num_managers = len(_managers)
|
||||||
|
|
||||||
nonebot.require("dynamic.require_not_declared")
|
nonebot.require("dynamic.require_not_declared")
|
||||||
@ -130,14 +159,13 @@ async def test_require_not_declared():
|
|||||||
assert _managers[-1].plugins == {"dynamic.require_not_declared"}
|
assert _managers[-1].plugins == {"dynamic.require_not_declared"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@_recover
|
||||||
async def test_require_not_found():
|
def test_require_not_found():
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
nonebot.require("some_plugin_not_exist")
|
nonebot.require("some_plugin_not_exist")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_plugin_metadata():
|
||||||
async def test_plugin_metadata():
|
|
||||||
from plugins.metadata import Config, FakeAdapter
|
from plugins.metadata import Config, FakeAdapter
|
||||||
|
|
||||||
plugin = nonebot.get_plugin("metadata")
|
plugin = nonebot.get_plugin("metadata")
|
||||||
@ -157,8 +185,7 @@ async def test_plugin_metadata():
|
|||||||
assert plugin.metadata.get_supported_adapters() == {FakeAdapter}
|
assert plugin.metadata.get_supported_adapters() == {FakeAdapter}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_inherit_supported_adapters_not_found():
|
||||||
async def test_inherit_supported_adapters_not_found():
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
inherit_supported_adapters("some_plugin_not_exist")
|
inherit_supported_adapters("some_plugin_not_exist")
|
||||||
|
|
||||||
@ -166,7 +193,6 @@ async def test_inherit_supported_adapters_not_found():
|
|||||||
inherit_supported_adapters("export")
|
inherit_supported_adapters("export")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("inherit_plugins", "expected"),
|
("inherit_plugins", "expected"),
|
||||||
[
|
[
|
||||||
@ -233,7 +259,7 @@ async def test_inherit_supported_adapters_not_found():
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_inherit_supported_adapters_combine(
|
def test_inherit_supported_adapters_combine(
|
||||||
inherit_plugins: tuple[str], expected: set[str]
|
inherit_plugins: tuple[str], expected: set[str]
|
||||||
):
|
):
|
||||||
assert inherit_supported_adapters(*inherit_plugins) == expected
|
assert inherit_supported_adapters(*inherit_plugins) == expected
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from nonebot.plugin import PluginManager, _managers
|
from nonebot.plugin import PluginManager, _managers
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_load_plugin_name():
|
||||||
async def test_load_plugin_name():
|
|
||||||
m = PluginManager(plugins=["dynamic.manager"])
|
m = PluginManager(plugins=["dynamic.manager"])
|
||||||
|
try:
|
||||||
_managers.append(m)
|
_managers.append(m)
|
||||||
|
|
||||||
# load by plugin id
|
# load by plugin id
|
||||||
@ -15,3 +13,5 @@ async def test_load_plugin_name():
|
|||||||
assert module1
|
assert module1
|
||||||
assert module2
|
assert module2
|
||||||
assert module1 is module2
|
assert module1 is module2
|
||||||
|
finally:
|
||||||
|
_managers.remove(m)
|
||||||
|
@ -18,7 +18,6 @@ from nonebot.rule import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("matcher_name", "pre_rule_factory", "has_permission"),
|
("matcher_name", "pre_rule_factory", "has_permission"),
|
||||||
[
|
[
|
||||||
@ -102,7 +101,7 @@ from nonebot.rule import (
|
|||||||
pytest.param("matcher_group_on_type", lambda e: IsTypeRule(e), True),
|
pytest.param("matcher_group_on_type", lambda e: IsTypeRule(e), True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_on(
|
def test_on(
|
||||||
matcher_name: str,
|
matcher_name: str,
|
||||||
pre_rule_factory: Optional[Callable[[type[Event]], T_RuleChecker]],
|
pre_rule_factory: Optional[Callable[[type[Event]], T_RuleChecker]],
|
||||||
has_permission: bool,
|
has_permission: bool,
|
||||||
@ -150,8 +149,7 @@ async def test_on(
|
|||||||
assert matcher.module_name == "plugins.plugin.matchers"
|
assert matcher.module_name == "plugins.plugin.matchers"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_runtime_on():
|
||||||
async def test_runtime_on():
|
|
||||||
import plugins.plugin.matchers as module
|
import plugins.plugin.matchers as module
|
||||||
from plugins.plugin.matchers import matcher_on_factory
|
from plugins.plugin.matchers import matcher_on_factory
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ from nonebot.rule import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_rule(app: App):
|
async def test_rule(app: App):
|
||||||
async def falsy():
|
async def falsy():
|
||||||
return False
|
return False
|
||||||
@ -81,7 +81,7 @@ async def test_rule(app: App):
|
|||||||
assert await Rule(truthy, skipped)(bot, event, {}) is False
|
assert await Rule(truthy, skipped)(bot, event, {}) is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_trie(app: App):
|
async def test_trie(app: App):
|
||||||
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
|
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ async def test_trie(app: App):
|
|||||||
del TrieRule.prefix["/fake-prefix"]
|
del TrieRule.prefix["/fake-prefix"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("msg", "ignorecase", "type", "text", "expected"),
|
("msg", "ignorecase", "type", "text", "expected"),
|
||||||
[
|
[
|
||||||
@ -186,7 +186,7 @@ async def test_startswith(
|
|||||||
assert await dependent(event=event, state=state) == expected
|
assert await dependent(event=event, state=state) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("msg", "ignorecase", "type", "text", "expected"),
|
("msg", "ignorecase", "type", "text", "expected"),
|
||||||
[
|
[
|
||||||
@ -226,7 +226,7 @@ async def test_endswith(
|
|||||||
assert await dependent(event=event, state=state) == expected
|
assert await dependent(event=event, state=state) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("msg", "ignorecase", "type", "text", "expected"),
|
("msg", "ignorecase", "type", "text", "expected"),
|
||||||
[
|
[
|
||||||
@ -266,7 +266,7 @@ async def test_fullmatch(
|
|||||||
assert await dependent(event=event, state=state) == expected
|
assert await dependent(event=event, state=state) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("kws", "type", "text", "expected"),
|
("kws", "type", "text", "expected"),
|
||||||
[
|
[
|
||||||
@ -298,7 +298,7 @@ async def test_keyword(
|
|||||||
assert await dependent(event=event, state=state) == expected
|
assert await dependent(event=event, state=state) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("cmds", "force_whitespace", "cmd", "whitespace", "arg_text", "expected"),
|
("cmds", "force_whitespace", "cmd", "whitespace", "arg_text", "expected"),
|
||||||
[
|
[
|
||||||
@ -344,7 +344,7 @@ async def test_command(
|
|||||||
assert await dependent(state=state) == expected
|
assert await dependent(state=state) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_shell_command():
|
async def test_shell_command():
|
||||||
state: T_State
|
state: T_State
|
||||||
CMD = ("test",)
|
CMD = ("test",)
|
||||||
@ -451,7 +451,7 @@ async def test_shell_command():
|
|||||||
assert state[SHELL_ARGS].status != 0
|
assert state[SHELL_ARGS].status != 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("pattern", "type", "text", "expected", "matched"),
|
("pattern", "type", "text", "expected", "matched"),
|
||||||
[
|
[
|
||||||
@ -494,7 +494,7 @@ async def test_regex(
|
|||||||
assert result.span() == matched.span()
|
assert result.span() == matched.span()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize("expected", [True, False])
|
@pytest.mark.parametrize("expected", [True, False])
|
||||||
async def test_to_me(expected: bool):
|
async def test_to_me(expected: bool):
|
||||||
test_to_me = to_me()
|
test_to_me = to_me()
|
||||||
@ -507,7 +507,7 @@ async def test_to_me(expected: bool):
|
|||||||
assert await dependent(event=event) == expected
|
assert await dependent(event=event) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_is_type():
|
async def test_is_type():
|
||||||
Event1 = make_fake_event()
|
Event1 = make_fake_event()
|
||||||
Event2 = make_fake_event()
|
Event2 = make_fake_event()
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from utils import make_fake_event
|
from utils import make_fake_event
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.anyio
|
||||||
async def test_matcher_mutex():
|
async def test_matcher_mutex():
|
||||||
from nonebot.plugins.single_session import matcher_mutex, _running_matcher
|
from nonebot.plugins.single_session import matcher_mutex, _running_matcher
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ NoneBot2 是一个现代、跨平台、可扩展的 Python 聊天机器人框架
|
|||||||
|
|
||||||
### 异步优先
|
### 异步优先
|
||||||
|
|
||||||
NoneBot 基于 Python [asyncio](https://docs.python.org/zh-cn/3/library/asyncio.html) 编写,并在异步机制的基础上进行了一定程度的同步函数兼容。
|
NoneBot 基于 Python [asyncio](https://docs.python.org/zh-cn/3/library/asyncio.html) / [trio](https://trio.readthedocs.io/en/stable/) 编写,并在异步机制的基础上进行了一定程度的同步函数兼容。
|
||||||
|
|
||||||
### 完整的类型注解
|
### 完整的类型注解
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user