add test cases

This commit is contained in:
yanyongyu 2021-12-20 00:28:02 +08:00
parent ca045b2f73
commit c2c3d5ef4b
17 changed files with 432 additions and 55 deletions

View File

@ -5,6 +5,9 @@ from ._bot import Bot
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import ( from nonebot.drivers import (
Driver, Driver,
Request,
Response,
WebSocket,
ForwardDriver, ForwardDriver,
ReverseDriver, ReverseDriver,
HTTPServerSetup, HTTPServerSetup,
@ -44,6 +47,16 @@ class Adapter(abc.ABC):
raise TypeError("Current driver does not support websocket server") raise TypeError("Current driver does not support websocket server")
self.driver.setup_websocket_server(setup) self.driver.setup_websocket_server(setup)
async def request(self, setup: Request) -> Response:
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support http client")
return await self.driver.request(setup)
async def websocket(self, setup: Request) -> WebSocket:
if not isinstance(self.driver, ForwardDriver):
raise TypeError("Current driver does not support websocket client")
return await self.driver.websocket(setup)
@abc.abstractmethod @abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any: async def _call_api(self, api: str, **data) -> Any:
""" """

View File

@ -8,7 +8,6 @@ from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.exception import MockApiException from nonebot.exception import MockApiException
from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook from nonebot.typing import T_CalledAPIHook, T_CallingAPIHook
from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
if TYPE_CHECKING: if TYPE_CHECKING:
from ._event import Event from ._event import Event

View File

@ -151,6 +151,9 @@ class Dependent(Generic[R]):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
for param in self.parameterless:
await param._solve(**params)
for field in self.params: for field in self.params:
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, Param), "Params must be subclasses of Param" assert isinstance(field_info, Param), "Params must be subclasses of Param"
@ -168,7 +171,4 @@ class Dependent(Generic[R]):
else: else:
values[field.name] = value values[field.name] = value
for param in self.parameterless:
await param._solve(**params)
return values return values

View File

@ -10,10 +10,14 @@ import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from ._model import URL as URL
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from ._model import Request as Request
from nonebot.config import Env, Config from nonebot.config import Env, Config
from ._model import URL, Request, Response, WebSocket, HTTPVersion from ._model import Response as Response
from ._model import WebSocket as WebSocket
from ._model import HTTPVersion as HTTPVersion
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING: if TYPE_CHECKING:
@ -204,11 +208,11 @@ class ForwardDriver(Driver):
""" """
@abc.abstractmethod @abc.abstractmethod
async def request(self, setup: "Request") -> Any: async def request(self, setup: Request) -> Response:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def websocket(self, setup: "Request") -> Any: async def websocket(self, setup: Request) -> WebSocket:
raise NotImplementedError raise NotImplementedError

View File

@ -193,6 +193,7 @@ class Matcher(metaclass=MatcherMeta):
params.BotParam, params.BotParam,
params.EventParam, params.EventParam,
params.StateParam, params.StateParam,
params.ArgParam,
params.MatcherParam, params.MatcherParam,
params.DefaultParam, params.DefaultParam,
] ]
@ -443,10 +444,10 @@ class Matcher(metaclass=MatcherMeta):
async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]: async def _receive(event: Event, matcher: "Matcher") -> Union[None, NoReturn]:
if matcher.get_receive(id): if matcher.get_receive(id):
return return
if matcher.get_target() == RECEIVE_KEY.format(id=id): if matcher.get_target() == RECEIVE_KEY.format(id=id or ""):
matcher.set_receive(id, event) matcher.set_receive(id, event)
return return
matcher.set_target(RECEIVE_KEY.format(id=id)) matcher.set_target(RECEIVE_KEY.format(id=id or ""))
raise RejectedException raise RejectedException
parameterless = [params.Depends(_receive), *(parameterless or [])] parameterless = [params.Depends(_receive), *(parameterless or [])]
@ -472,7 +473,6 @@ class Matcher(metaclass=MatcherMeta):
cls, cls,
key: str, key: str,
prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None,
args_parser: Optional[T_ArgsParser] = None,
parameterless: Optional[List[Any]] = None, parameterless: Optional[List[Any]] = None,
) -> Callable[[T_Handler], T_Handler]: ) -> Callable[[T_Handler], T_Handler]:
""" """
@ -495,6 +495,8 @@ class Matcher(metaclass=MatcherMeta):
matcher.set_arg(key, event) matcher.set_arg(key, event)
return return
matcher.set_target(ARG_KEY.format(key=key)) matcher.set_target(ARG_KEY.format(key=key))
if prompt is not None:
await matcher.send(prompt)
raise RejectedException raise RejectedException
_parameterless = [ _parameterless = [
@ -517,7 +519,9 @@ class Matcher(metaclass=MatcherMeta):
@classmethod @classmethod
async def send( async def send(
cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs cls,
message: Union[str, Message, MessageSegment, MessageTemplate],
**kwargs: Any,
) -> Any: ) -> Any:
""" """
:说明: :说明:

View File

@ -58,19 +58,21 @@ EVENT_PCS_PARAMS = [
] ]
RUN_PREPCS_PARAMS = [ RUN_PREPCS_PARAMS = [
params.DependParam, params.DependParam,
params.MatcherParam,
params.BotParam, params.BotParam,
params.EventParam, params.EventParam,
params.StateParam, params.StateParam,
params.ArgParam,
params.MatcherParam,
params.DefaultParam, params.DefaultParam,
] ]
RUN_POSTPCS_PARAMS = [ RUN_POSTPCS_PARAMS = [
params.DependParam, params.DependParam,
params.MatcherParam,
params.ExceptionParam, params.ExceptionParam,
params.BotParam, params.BotParam,
params.EventParam, params.EventParam,
params.StateParam, params.StateParam,
params.ArgParam,
params.MatcherParam,
params.DefaultParam, params.DefaultParam,
] ]

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import inspect import inspect
from typing_extensions import Literal
from typing import Any, Dict, List, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
@ -200,7 +201,7 @@ async def _event_message(event: Event) -> Message:
return event.get_message() return event.get_message()
def EventMessage() -> Message: def EventMessage() -> Any:
return Depends(_event_message) return Depends(_event_message)
@ -260,7 +261,7 @@ def _command_arg(state=State()) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY] return state[PREFIX_KEY][CMD_ARG_KEY]
def CommandArg() -> Message: def CommandArg() -> Any:
return Depends(_command_arg, use_cache=False) return Depends(_command_arg, use_cache=False)
@ -332,6 +333,44 @@ def LastReceived(default: Any = None) -> Any:
return Depends(_last_received, use_cache=False) return Depends(_last_received, use_cache=False)
class ArgInner:
def __init__(
self, key: Optional[str], type: Literal["event", "message", "str"]
) -> None:
self.key = key
self.type = type
def Arg(key: Optional[str] = None) -> Any:
return ArgInner(key, "message")
def ArgEvent(key: Optional[str] = None) -> Any:
return ArgInner(key, "event")
def ArgStr(key: Optional[str] = None) -> Any:
return ArgInner(key, "str")
class ArgParam(Param):
@classmethod
def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["ArgParam"]:
if isinstance(param.default, ArgInner):
return cls(Required, key=param.default.key or name, type=param.default.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
event = matcher.get_arg(self.extra["key"])
if self.extra["type"] == "event":
return event
elif self.extra["type"] == "message":
return event.get_message()
else:
return matcher.get_arg_str(self.extra["key"])
class ExceptionParam(Param): class ExceptionParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(

View File

@ -1,24 +0,0 @@
from nonebot import on_message
from nonebot.adapters import Event
from nonebot.params import Depends
test_depends = on_message()
runned = []
def dependency(event: Event):
# test cache
runned.append(event)
return event
@test_depends.handle()
async def depends(x: Event = Depends(dependency)):
# test dependency
return x
@test_depends.handle()
async def depends_cache(y: Event = Depends(dependency, use_cache=True)):
return y

View File

@ -0,0 +1,7 @@
from pathlib import Path
from nonebot import load_plugins
_sub_plugins = set()
_sub_plugins |= load_plugins(str(Path(__file__).parent))

View File

@ -0,0 +1,5 @@
from nonebot.adapters import Bot
async def get_bot(b: Bot):
return b

View File

@ -0,0 +1,30 @@
from nonebot import on_message
from nonebot.adapters import Event
from nonebot.params import Depends
test_depends = on_message()
runned = []
def dependency():
runned.append(1)
return 1
def parameterless():
assert len(runned) == 0
runned.append(1)
# test parameterless
@test_depends.handle(parameterless=[Depends(parameterless)])
async def depends(x: int = Depends(dependency)):
# test dependency
return x
@test_depends.handle()
async def depends_cache(y: int = Depends(dependency, use_cache=True)):
# test cache
return y

View File

@ -0,0 +1,22 @@
from nonebot.adapters import Event, Message
from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText
async def event(e: Event) -> Event:
return e
async def event_type(t: str = EventType()) -> str:
return t
async def event_message(msg: Message = EventMessage()) -> Message:
return msg
async def event_plain_text(text: str = EventPlainText()) -> str:
return text
async def event_to_me(to_me: bool = EventToMe()) -> bool:
return to_me

View File

@ -0,0 +1,15 @@
from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.params import Received, LastReceived
async def matcher(m: Matcher) -> Matcher:
return m
async def receive(e: Event = Received("test")) -> Event:
return e
async def last_receive(e: Event = LastReceived()) -> Event:
return e

View File

@ -0,0 +1,55 @@
from typing import List, Tuple
from nonebot.typing import T_State
from nonebot.adapters import Message
from nonebot.params import (
State,
Command,
RegexDict,
CommandArg,
RawCommand,
RegexGroup,
RegexMatched,
ShellCommandArgs,
ShellCommandArgv,
)
async def state(x: T_State = State()) -> T_State:
return x
async def command(cmd: Tuple[str, ...] = Command()) -> Tuple[str, ...]:
return cmd
async def raw_command(raw_cmd: str = RawCommand()) -> str:
return raw_cmd
async def command_arg(cmd_arg: Message = CommandArg()) -> Message:
return cmd_arg
async def shell_command_args(
shell_command_args: dict = ShellCommandArgs(),
) -> dict:
return shell_command_args
async def shell_command_argv(
shell_command_argv: List[str] = ShellCommandArgv(),
) -> List[str]:
return shell_command_argv
async def regex_dict(regex_dict: dict = RegexDict()) -> dict:
return regex_dict
async def regex_group(regex_group: Tuple = RegexGroup()) -> Tuple:
return regex_group
async def regex_matched(regex_matched: str = RegexMatched()) -> str:
return regex_matched

View File

@ -79,8 +79,11 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear):
async def test_load_plugin(load_plugin: Set["Plugin"]): async def test_load_plugin(load_plugin: Set["Plugin"]):
import nonebot import nonebot
assert nonebot.get_loaded_plugins() == load_plugin loaded_plugins = set(
plugin = nonebot.get_plugin("depends") plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
)
assert loaded_plugins == load_plugin
plugin = nonebot.get_plugin("param_depend")
assert plugin assert plugin
assert plugin.module_name == "plugins.depends" assert plugin.module_name == "plugins.param.param_depend"
assert "plugins.depends" in sys.modules assert "plugins.param.param_depend" in sys.modules

View File

@ -1,23 +1,19 @@
import pytest import pytest
from nonebug import App from nonebug import App
from utils import load_plugin, make_fake_event from utils import load_plugin, make_fake_event, make_fake_message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_depends(app: App, load_plugin): async def test_depend(app: App, load_plugin):
from nonebot.params import EventParam, DependParam from nonebot.params import DependParam
from plugins.depends import runned, depends, test_depends from plugins.param.param_depend import runned, depends, test_depends
async with app.test_dependent( async with app.test_dependent(depends, allow_types=[DependParam]) as ctx:
depends, allow_types=[EventParam, DependParam] ctx.should_return(1)
) as ctx:
event = make_fake_event()()
ctx.pass_params(event=event)
ctx.should_return(event)
assert len(runned) == 1 and runned[0] == event assert len(runned) == 1 and runned[0] == 1
runned.clear() runned.clear()
@ -26,4 +22,181 @@ async def test_depends(app: App, load_plugin):
event_next = make_fake_event()() event_next = make_fake_event()()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
assert len(runned) == 1 and runned[0] == event_next assert len(runned) == 2 and runned[0] == runned[1] == 1
@pytest.mark.asyncio
async def test_bot(app: App, load_plugin):
from nonebot.params import BotParam
from plugins.param.param_bot import get_bot
async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
ctx.should_return(bot)
@pytest.mark.asyncio
async def test_event(app: App, load_plugin):
from nonebot.params import EventParam, DependParam
from plugins.param.param_event import (
event,
event_type,
event_to_me,
event_message,
event_plain_text,
)
fake_message = make_fake_message()("text")
fake_event = make_fake_event(_message=fake_message)()
async with app.test_dependent(event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)
async with app.test_dependent(
event_type, allow_types=[EventParam, DependParam]
) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event.get_type())
async with app.test_dependent(
event_message, allow_types=[EventParam, DependParam]
) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event.get_message())
async with app.test_dependent(
event_plain_text, allow_types=[EventParam, DependParam]
) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event.get_plaintext())
async with app.test_dependent(
event_to_me, allow_types=[EventParam, DependParam]
) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event.is_tome())
@pytest.mark.asyncio
async def test_state(app: App, load_plugin):
from nonebot.params import StateParam, DependParam
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
REGEX_DICT,
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
REGEX_MATCHED,
)
from plugins.param.param_state import (
state,
command,
regex_dict,
command_arg,
raw_command,
regex_group,
regex_matched,
shell_command_args,
shell_command_argv,
)
fake_message = make_fake_message()("text")
fake_state = {
PREFIX_KEY: {CMD_KEY: ("cmd",), RAW_CMD_KEY: "/cmd", CMD_ARG_KEY: fake_message},
SHELL_ARGV: ["-h"],
SHELL_ARGS: {"help": True},
REGEX_MATCHED: "[cq:test,arg=value]",
REGEX_GROUP: ("test", "arg=value"),
REGEX_DICT: {"type": "test", "arg": "value"},
}
async with app.test_dependent(state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)
async with app.test_dependent(
command, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][CMD_KEY])
async with app.test_dependent(
raw_command, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][RAW_CMD_KEY])
async with app.test_dependent(
command_arg, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][CMD_ARG_KEY])
async with app.test_dependent(
shell_command_argv, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[SHELL_ARGV])
async with app.test_dependent(
shell_command_args, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[SHELL_ARGS])
async with app.test_dependent(
regex_matched, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[REGEX_MATCHED])
async with app.test_dependent(
regex_group, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[REGEX_GROUP])
async with app.test_dependent(
regex_dict, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[REGEX_DICT])
@pytest.mark.asyncio
async def test_matcher(app: App, load_plugin):
from nonebot.matcher import Matcher
from nonebot.params import DependParam, MatcherParam
from plugins.param.param_matcher import matcher, receive, last_receive
fake_matcher = Matcher()
async with app.test_dependent(matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)
event = make_fake_event()()
fake_matcher.set_receive("test", event)
event_next = make_fake_event()()
fake_matcher.set_receive(None, event_next)
async with app.test_dependent(
receive, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(event)
async with app.test_dependent(
last_receive, allow_types=[MatcherParam, DependParam]
) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(event_next)

View File

@ -9,6 +9,36 @@ if TYPE_CHECKING:
from nonebot.adapters import Event, Message from nonebot.adapters import Event, Message
def make_fake_message() -> Type["Message"]:
from nonebot.adapters import Message, MessageSegment
class FakeMessageSegment(MessageSegment):
@classmethod
def get_message_class(cls):
return FakeMessage
def __str__(self) -> str:
return self.data["text"]
@classmethod
def text(cls, text: str):
return cls("text", {"text": text})
def is_text(self) -> bool:
return True
class FakeMessage(Message):
@classmethod
def get_segment_class(cls):
return FakeMessageSegment
@staticmethod
def _construct(msg: str):
yield FakeMessageSegment.text(msg)
return FakeMessage
def make_fake_event( def make_fake_event(
_type: str = "message", _type: str = "message",
_name: str = "test", _name: str = "test",