From c2c3d5ef4b6a7bbbdf0289cafd92366422ed9182 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 20 Dec 2021 00:28:02 +0800 Subject: [PATCH] :white_check_mark: add test cases --- nonebot/adapters/_adapter.py | 13 ++ nonebot/adapters/_bot.py | 1 - nonebot/dependencies/__init__.py | 6 +- nonebot/drivers/__init__.py | 10 +- nonebot/matcher.py | 12 +- nonebot/message.py | 6 +- nonebot/params.py | 43 +++++- tests/plugins/depends.py | 24 ---- tests/plugins/param/__init__.py | 7 + tests/plugins/param/param_bot.py | 5 + tests/plugins/param/param_depend.py | 30 ++++ tests/plugins/param/param_event.py | 22 +++ tests/plugins/param/param_matcher.py | 15 ++ tests/plugins/param/param_state.py | 55 ++++++++ tests/test_init.py | 11 +- tests/test_param.py | 197 +++++++++++++++++++++++++-- tests/utils.py | 30 ++++ 17 files changed, 432 insertions(+), 55 deletions(-) delete mode 100644 tests/plugins/depends.py create mode 100644 tests/plugins/param/__init__.py create mode 100644 tests/plugins/param/param_bot.py create mode 100644 tests/plugins/param/param_depend.py create mode 100644 tests/plugins/param/param_event.py create mode 100644 tests/plugins/param/param_matcher.py create mode 100644 tests/plugins/param/param_state.py diff --git a/nonebot/adapters/_adapter.py b/nonebot/adapters/_adapter.py index f22e872c..b373dced 100644 --- a/nonebot/adapters/_adapter.py +++ b/nonebot/adapters/_adapter.py @@ -5,6 +5,9 @@ from ._bot import Bot from nonebot.config import Config from nonebot.drivers import ( Driver, + Request, + Response, + WebSocket, ForwardDriver, ReverseDriver, HTTPServerSetup, @@ -44,6 +47,16 @@ class Adapter(abc.ABC): raise TypeError("Current driver does not support websocket server") self.driver.setup_websocket_server(setup) + async def request(self, setup: Request) -> Response: + if not isinstance(self.driver, ForwardDriver): + raise TypeError("Current driver does not support http client") + return await self.driver.request(setup) + + async def websocket(self, setup: Request) -> WebSocket: + if not isinstance(self.driver, ForwardDriver): + raise TypeError("Current driver does not support websocket client") + return await self.driver.websocket(setup) + @abc.abstractmethod async def _call_api(self, api: str, **data) -> Any: """ diff --git a/nonebot/adapters/_bot.py b/nonebot/adapters/_bot.py index ab6ec419..cfaae3a8 100644 --- a/nonebot/adapters/_bot.py +++ b/nonebot/adapters/_bot.py @@ -8,7 +8,6 @@ from nonebot.log import logger from nonebot.config import Config from nonebot.exception import MockApiException from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook -from nonebot.drivers import Driver, HTTPResponse, HTTPConnection if TYPE_CHECKING: from ._event import Event diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index da5f9273..05786d4b 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -151,6 +151,9 @@ class Dependent(Generic[R]): ) -> Dict[str, Any]: values: Dict[str, Any] = {} + for param in self.parameterless: + await param._solve(**params) + for field in self.params: field_info = field.field_info assert isinstance(field_info, Param), "Params must be subclasses of Param" @@ -168,7 +171,4 @@ class Dependent(Generic[R]): else: values[field.name] = value - for param in self.parameterless: - await param._solve(**params) - return values diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 7511c745..daca96b4 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -10,10 +10,14 @@ import asyncio from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable +from ._model import URL as URL from nonebot.log import logger from nonebot.utils import escape_tag +from ._model import Request as Request from nonebot.config import Env, Config -from ._model import URL, Request, Response, WebSocket, HTTPVersion +from ._model import Response as Response +from ._model import WebSocket as WebSocket +from ._model import HTTPVersion as HTTPVersion from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook if TYPE_CHECKING: @@ -204,11 +208,11 @@ class ForwardDriver(Driver): """ @abc.abstractmethod - async def request(self, setup: "Request") -> Any: + async def request(self, setup: Request) -> Response: raise NotImplementedError @abc.abstractmethod - async def websocket(self, setup: "Request") -> Any: + async def websocket(self, setup: Request) -> WebSocket: raise NotImplementedError diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 382825bd..acdec706 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -193,6 +193,7 @@ class Matcher(metaclass=MatcherMeta): params.BotParam, params.EventParam, params.StateParam, + params.ArgParam, params.MatcherParam, params.DefaultParam, ] @@ -443,10 +444,10 @@ class Matcher(metaclass=MatcherMeta): async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]: if matcher.get_receive(id): return - if matcher.get_target() == RECEIVE_KEY.format(id=id): + if matcher.get_target() == RECEIVE_KEY.format(id=id or ""): matcher.set_receive(id, event) return - matcher.set_target(RECEIVE_KEY.format(id=id)) + matcher.set_target(RECEIVE_KEY.format(id=id or "")) raise RejectedException parameterless = [params.Depends(_receive), *(parameterless or [])] @@ -472,7 +473,6 @@ class Matcher(metaclass=MatcherMeta): cls, key: str, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, - args_parser: Optional[T_ArgsParser] = None, parameterless: Optional[List[Any]] = None, ) -> Callable[[T_Handler], T_Handler]: """ @@ -495,6 +495,8 @@ class Matcher(metaclass=MatcherMeta): matcher.set_arg(key, event) return matcher.set_target(ARG_KEY.format(key=key)) + if prompt is not None: + await matcher.send(prompt) raise RejectedException _parameterless = [ @@ -517,7 +519,9 @@ class Matcher(metaclass=MatcherMeta): @classmethod async def send( - cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs + cls, + message: Union[str, Message, MessageSegment, MessageTemplate], + **kwargs: Any, ) -> Any: """ :说明: diff --git a/nonebot/message.py b/nonebot/message.py index a58c0222..72386b7f 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -58,19 +58,21 @@ EVENT_PCS_PARAMS = [ ] RUN_PREPCS_PARAMS = [ params.DependParam, - params.MatcherParam, params.BotParam, params.EventParam, params.StateParam, + params.ArgParam, + params.MatcherParam, params.DefaultParam, ] RUN_POSTPCS_PARAMS = [ params.DependParam, - params.MatcherParam, params.ExceptionParam, params.BotParam, params.EventParam, params.StateParam, + params.ArgParam, + params.MatcherParam, params.DefaultParam, ] diff --git a/nonebot/params.py b/nonebot/params.py index 822fd2a8..a4a6c5f5 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -1,5 +1,6 @@ import asyncio import inspect +from typing_extensions import Literal from typing import Any, Dict, List, Tuple, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager @@ -200,7 +201,7 @@ async def _event_message(event: Event) -> Message: return event.get_message() -def EventMessage() -> Message: +def EventMessage() -> Any: return Depends(_event_message) @@ -260,7 +261,7 @@ def _command_arg(state=State()) -> Message: return state[PREFIX_KEY][CMD_ARG_KEY] -def CommandArg() -> Message: +def CommandArg() -> Any: return Depends(_command_arg, use_cache=False) @@ -332,6 +333,44 @@ def LastReceived(default: Any = None) -> Any: return Depends(_last_received, use_cache=False) +class ArgInner: + def __init__( + self, key: Optional[str], type: Literal["event", "message", "str"] + ) -> None: + self.key = key + self.type = type + + +def Arg(key: Optional[str] = None) -> Any: + return ArgInner(key, "message") + + +def ArgEvent(key: Optional[str] = None) -> Any: + return ArgInner(key, "event") + + +def ArgStr(key: Optional[str] = None) -> Any: + return ArgInner(key, "str") + + +class ArgParam(Param): + @classmethod + def _check_param( + cls, dependent: Dependent, name: str, param: inspect.Parameter + ) -> Optional["ArgParam"]: + if isinstance(param.default, ArgInner): + return cls(Required, key=param.default.key or name, type=param.default.type) + + async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: + event = matcher.get_arg(self.extra["key"]) + if self.extra["type"] == "event": + return event + elif self.extra["type"] == "message": + return event.get_message() + else: + return matcher.get_arg_str(self.extra["key"]) + + class ExceptionParam(Param): @classmethod def _check_param( diff --git a/tests/plugins/depends.py b/tests/plugins/depends.py deleted file mode 100644 index 51226d20..00000000 --- a/tests/plugins/depends.py +++ /dev/null @@ -1,24 +0,0 @@ -from nonebot import on_message -from nonebot.adapters import Event -from nonebot.params import Depends - -test_depends = on_message() - -runned = [] - - -def dependency(event: Event): - # test cache - runned.append(event) - return event - - -@test_depends.handle() -async def depends(x: Event = Depends(dependency)): - # test dependency - return x - - -@test_depends.handle() -async def depends_cache(y: Event = Depends(dependency, use_cache=True)): - return y diff --git a/tests/plugins/param/__init__.py b/tests/plugins/param/__init__.py new file mode 100644 index 00000000..f084d7f1 --- /dev/null +++ b/tests/plugins/param/__init__.py @@ -0,0 +1,7 @@ +from pathlib import Path + +from nonebot import load_plugins + +_sub_plugins = set() + +_sub_plugins |= load_plugins(str(Path(__file__).parent)) diff --git a/tests/plugins/param/param_bot.py b/tests/plugins/param/param_bot.py new file mode 100644 index 00000000..a6befdda --- /dev/null +++ b/tests/plugins/param/param_bot.py @@ -0,0 +1,5 @@ +from nonebot.adapters import Bot + + +async def get_bot(b: Bot): + return b diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py new file mode 100644 index 00000000..fc13c415 --- /dev/null +++ b/tests/plugins/param/param_depend.py @@ -0,0 +1,30 @@ +from nonebot import on_message +from nonebot.adapters import Event +from nonebot.params import Depends + +test_depends = on_message() + +runned = [] + + +def dependency(): + runned.append(1) + return 1 + + +def parameterless(): + assert len(runned) == 0 + runned.append(1) + + +# test parameterless +@test_depends.handle(parameterless=[Depends(parameterless)]) +async def depends(x: int = Depends(dependency)): + # test dependency + return x + + +@test_depends.handle() +async def depends_cache(y: int = Depends(dependency, use_cache=True)): + # test cache + return y diff --git a/tests/plugins/param/param_event.py b/tests/plugins/param/param_event.py new file mode 100644 index 00000000..3cc04570 --- /dev/null +++ b/tests/plugins/param/param_event.py @@ -0,0 +1,22 @@ +from nonebot.adapters import Event, Message +from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText + + +async def event(e: Event) -> Event: + return e + + +async def event_type(t: str = EventType()) -> str: + return t + + +async def event_message(msg: Message = EventMessage()) -> Message: + return msg + + +async def event_plain_text(text: str = EventPlainText()) -> str: + return text + + +async def event_to_me(to_me: bool = EventToMe()) -> bool: + return to_me diff --git a/tests/plugins/param/param_matcher.py b/tests/plugins/param/param_matcher.py new file mode 100644 index 00000000..ad8d5bd8 --- /dev/null +++ b/tests/plugins/param/param_matcher.py @@ -0,0 +1,15 @@ +from nonebot.adapters import Event +from nonebot.matcher import Matcher +from nonebot.params import Received, LastReceived + + +async def matcher(m: Matcher) -> Matcher: + return m + + +async def receive(e: Event = Received("test")) -> Event: + return e + + +async def last_receive(e: Event = LastReceived()) -> Event: + return e diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py new file mode 100644 index 00000000..beec94b8 --- /dev/null +++ b/tests/plugins/param/param_state.py @@ -0,0 +1,55 @@ +from typing import List, Tuple + +from nonebot.typing import T_State +from nonebot.adapters import Message +from nonebot.params import ( + State, + Command, + RegexDict, + CommandArg, + RawCommand, + RegexGroup, + RegexMatched, + ShellCommandArgs, + ShellCommandArgv, +) + + +async def state(x: T_State = State()) -> T_State: + return x + + +async def command(cmd: Tuple[str, ...] = Command()) -> Tuple[str, ...]: + return cmd + + +async def raw_command(raw_cmd: str = RawCommand()) -> str: + return raw_cmd + + +async def command_arg(cmd_arg: Message = CommandArg()) -> Message: + return cmd_arg + + +async def shell_command_args( + shell_command_args: dict = ShellCommandArgs(), +) -> dict: + return shell_command_args + + +async def shell_command_argv( + shell_command_argv: List[str] = ShellCommandArgv(), +) -> List[str]: + return shell_command_argv + + +async def regex_dict(regex_dict: dict = RegexDict()) -> dict: + return regex_dict + + +async def regex_group(regex_group: Tuple = RegexGroup()) -> Tuple: + return regex_group + + +async def regex_matched(regex_matched: str = RegexMatched()) -> str: + return regex_matched diff --git a/tests/test_init.py b/tests/test_init.py index ef76676c..3ab17710 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -79,8 +79,11 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear): async def test_load_plugin(load_plugin: Set["Plugin"]): import nonebot - assert nonebot.get_loaded_plugins() == load_plugin - plugin = nonebot.get_plugin("depends") + loaded_plugins = set( + plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin + ) + assert loaded_plugins == load_plugin + plugin = nonebot.get_plugin("param_depend") assert plugin - assert plugin.module_name == "plugins.depends" - assert "plugins.depends" in sys.modules + assert plugin.module_name == "plugins.param.param_depend" + assert "plugins.param.param_depend" in sys.modules diff --git a/tests/test_param.py b/tests/test_param.py index 35a67daa..45f8d128 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -1,23 +1,19 @@ import pytest from nonebug import App -from utils import load_plugin, make_fake_event +from utils import load_plugin, make_fake_event, make_fake_message @pytest.mark.asyncio -async def test_depends(app: App, load_plugin): - from nonebot.params import EventParam, DependParam +async def test_depend(app: App, load_plugin): + from nonebot.params import DependParam - from plugins.depends import runned, depends, test_depends + from plugins.param.param_depend import runned, depends, test_depends - async with app.test_dependent( - depends, allow_types=[EventParam, DependParam] - ) as ctx: - event = make_fake_event()() - ctx.pass_params(event=event) - ctx.should_return(event) + async with app.test_dependent(depends, allow_types=[DependParam]) as ctx: + ctx.should_return(1) - assert len(runned) == 1 and runned[0] == event + assert len(runned) == 1 and runned[0] == 1 runned.clear() @@ -26,4 +22,181 @@ async def test_depends(app: App, load_plugin): event_next = make_fake_event()() ctx.receive_event(bot, event_next) - assert len(runned) == 1 and runned[0] == event_next + assert len(runned) == 2 and runned[0] == runned[1] == 1 + + +@pytest.mark.asyncio +async def test_bot(app: App, load_plugin): + from nonebot.params import BotParam + + from plugins.param.param_bot import get_bot + + async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + ctx.should_return(bot) + + +@pytest.mark.asyncio +async def test_event(app: App, load_plugin): + from nonebot.params import EventParam, DependParam + + from plugins.param.param_event import ( + event, + event_type, + event_to_me, + event_message, + event_plain_text, + ) + + fake_message = make_fake_message()("text") + fake_event = make_fake_event(_message=fake_message)() + + async with app.test_dependent(event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event) + + async with app.test_dependent( + event_type, allow_types=[EventParam, DependParam] + ) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event.get_type()) + + async with app.test_dependent( + event_message, allow_types=[EventParam, DependParam] + ) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event.get_message()) + + async with app.test_dependent( + event_plain_text, allow_types=[EventParam, DependParam] + ) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event.get_plaintext()) + + async with app.test_dependent( + event_to_me, allow_types=[EventParam, DependParam] + ) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event.is_tome()) + + +@pytest.mark.asyncio +async def test_state(app: App, load_plugin): + from nonebot.params import StateParam, DependParam + from nonebot.consts import ( + CMD_KEY, + PREFIX_KEY, + REGEX_DICT, + SHELL_ARGS, + SHELL_ARGV, + CMD_ARG_KEY, + RAW_CMD_KEY, + REGEX_GROUP, + REGEX_MATCHED, + ) + + from plugins.param.param_state import ( + state, + command, + regex_dict, + command_arg, + raw_command, + regex_group, + regex_matched, + shell_command_args, + shell_command_argv, + ) + + fake_message = make_fake_message()("text") + fake_state = { + PREFIX_KEY: {CMD_KEY: ("cmd",), RAW_CMD_KEY: "/cmd", CMD_ARG_KEY: fake_message}, + SHELL_ARGV: ["-h"], + SHELL_ARGS: {"help": True}, + REGEX_MATCHED: "[cq:test,arg=value]", + REGEX_GROUP: ("test", "arg=value"), + REGEX_DICT: {"type": "test", "arg": "value"}, + } + + async with app.test_dependent(state, allow_types=[StateParam]) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state) + + async with app.test_dependent( + command, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[PREFIX_KEY][CMD_KEY]) + + async with app.test_dependent( + raw_command, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[PREFIX_KEY][RAW_CMD_KEY]) + + async with app.test_dependent( + command_arg, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[PREFIX_KEY][CMD_ARG_KEY]) + + async with app.test_dependent( + shell_command_argv, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[SHELL_ARGV]) + + async with app.test_dependent( + shell_command_args, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[SHELL_ARGS]) + + async with app.test_dependent( + regex_matched, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[REGEX_MATCHED]) + + async with app.test_dependent( + regex_group, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[REGEX_GROUP]) + + async with app.test_dependent( + regex_dict, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[REGEX_DICT]) + + +@pytest.mark.asyncio +async def test_matcher(app: App, load_plugin): + from nonebot.matcher import Matcher + from nonebot.params import DependParam, MatcherParam + + from plugins.param.param_matcher import matcher, receive, last_receive + + fake_matcher = Matcher() + + async with app.test_dependent(matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(fake_matcher) + + event = make_fake_event()() + fake_matcher.set_receive("test", event) + event_next = make_fake_event()() + fake_matcher.set_receive(None, event_next) + + async with app.test_dependent( + receive, allow_types=[MatcherParam, DependParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(event) + + async with app.test_dependent( + last_receive, allow_types=[MatcherParam, DependParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(event_next) diff --git a/tests/utils.py b/tests/utils.py index 2e90ca7e..d516665f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,36 @@ if TYPE_CHECKING: from nonebot.adapters import Event, Message +def make_fake_message() -> Type["Message"]: + from nonebot.adapters import Message, MessageSegment + + class FakeMessageSegment(MessageSegment): + @classmethod + def get_message_class(cls): + return FakeMessage + + def __str__(self) -> str: + return self.data["text"] + + @classmethod + def text(cls, text: str): + return cls("text", {"text": text}) + + def is_text(self) -> bool: + return True + + class FakeMessage(Message): + @classmethod + def get_segment_class(cls): + return FakeMessageSegment + + @staticmethod + def _construct(msg: str): + yield FakeMessageSegment.text(msg) + + return FakeMessage + + def make_fake_event( _type: str = "message", _name: str = "test",