👷 Test: 移除 httpbin 并整理测试 (#2110)

This commit is contained in:
Ju4tCode 2023-06-19 17:48:59 +08:00 committed by GitHub
parent 27a3d1f0bb
commit 080b876d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 592 additions and 309 deletions

12
poetry.lock generated
View File

@ -1215,7 +1215,7 @@ dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegu
name = "markupsafe"
version = "2.1.2"
description = "Safely add untrusted strings to HTML/XML markup."
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "MarkupSafe-2.1.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:665a36ae6f8f20a4676b53224e33d456a6f5a72657d9c83c2aa00765072f31f7"},
@ -2239,13 +2239,13 @@ files = [
[[package]]
name = "werkzeug"
version = "2.3.4"
version = "2.3.6"
description = "The comprehensive WSGI web application library."
optional = true
optional = false
python-versions = ">=3.8"
files = [
{file = "Werkzeug-2.3.4-py3-none-any.whl", hash = "sha256:48e5e61472fee0ddee27ebad085614ebedb7af41e88f687aaf881afb723a162f"},
{file = "Werkzeug-2.3.4.tar.gz", hash = "sha256:1d5a58e0377d1fe39d061a5de4469e414e78ccb1e1e59c0f5ad6fa1c36c52b76"},
{file = "Werkzeug-2.3.6-py3-none-any.whl", hash = "sha256:935539fa1413afbb9195b24880778422ed620c0fc09670945185cce4d91a8890"},
{file = "Werkzeug-2.3.6.tar.gz", hash = "sha256:98c774df2f91b05550078891dee5f0eb0cb797a522c757a2452b9cee5b202330"},
]
[package.dependencies]
@ -2395,4 +2395,4 @@ websockets = ["websockets"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
content-hash = "be276204946fdf3338fb06cc5391aafc2f2fc89fc5c0a91bf09590f1eee40aa4"
content-hash = "cf0a75f5173e6eb0fa3cd9d8901d4149e104094d08bce7e298d6f8c92954803d"

View File

@ -51,6 +51,7 @@ pre-commit = "^3.0.0"
[tool.poetry.group.test.dependencies]
nonebug = "^0.3.0"
werkzeug = "^2.3.6"
pytest-cov = "^4.0.0"
pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.21.0"
@ -68,7 +69,7 @@ fastapi = ["fastapi", "uvicorn"]
all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_mode = "strict"
addopts = "--cov=nonebot --cov-append --cov-report=term-missing"
filterwarnings = [
"error",

View File

@ -1,11 +1,15 @@
import os
import threading
from pathlib import Path
from typing import TYPE_CHECKING, Set
from typing import TYPE_CHECKING, Set, Generator
import pytest
from nonebug import NONEBOT_INIT_KWARGS
from werkzeug.serving import BaseWSGIServer, make_server
import nonebot
from nonebot.drivers import URL
from fake_server import request_handler
os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
os.environ["CONFIG_OVERRIDE"] = "new"
@ -28,3 +32,20 @@ def load_plugin(nonebug_init: None) -> Set["Plugin"]:
def load_builtin_plugin(nonebug_init: None) -> Set["Plugin"]:
# preload builtin plugins
return nonebot.load_builtin_plugins("echo", "single_session")
@pytest.fixture(scope="session", autouse=True)
def server() -> Generator[BaseWSGIServer, None, None]:
server = make_server("127.0.0.1", 0, app=request_handler)
thread = threading.Thread(target=server.serve_forever)
thread.start()
try:
yield server
finally:
server.shutdown()
thread.join()
@pytest.fixture(scope="session")
def server_url(server: BaseWSGIServer) -> URL:
return URL(f"http://{server.host}:{server.port}")

69
tests/fake_server.py Normal file
View File

@ -0,0 +1,69 @@
import json
import base64
from typing import Dict, List, Union, TypeVar
from werkzeug import Request, Response
from werkzeug.datastructures import MultiDict
K = TypeVar("K")
V = TypeVar("V")
def json_safe(string, content_type="application/octet-stream") -> str:
try:
string = string.decode("utf-8")
json.dumps(string)
return string
except (ValueError, TypeError):
return b"".join(
[
b"data:",
content_type.encode("utf-8"),
b";base64,",
base64.b64encode(string),
]
).decode("utf-8")
def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]:
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}
@Request.application
def request_handler(request: Request) -> Response:
try:
_json = json.loads(request.data.decode("utf-8"))
except (ValueError, TypeError):
_json = None
return Response(
json.dumps(
{
"url": request.url,
"method": request.method,
"origin": request.headers.get("X-Forwarded-For", request.remote_addr),
"headers": flattern(
MultiDict((k, v) for k, v in request.headers.items())
),
"args": flattern(request.args),
"form": flattern(request.form),
"data": json_safe(request.data),
"json": _json,
"files": flattern(
MultiDict(
(
k,
json_safe(
v.read(),
request.files[k].content_type
or "application/octet-stream",
),
)
for k, v in request.files.items()
)
),
}
),
status=200,
content_type="application/json",
)

View File

@ -2,6 +2,7 @@ from nonebot.matcher import Matcher
from nonebot.permission import USER, Permission
default_permission = Permission()
new_permission = Permission()
test_permission_updater = Matcher.new(permission=default_permission)
@ -14,4 +15,4 @@ test_custom_updater = Matcher.new(permission=default_permission)
@test_custom_updater.permission_updater
async def _() -> Permission:
return default_permission
return new_permission

View File

@ -1,135 +1,136 @@
import pytest
from pydantic import ValidationError, parse_obj_as
from utils import make_fake_message
from nonebot.adapters import Message
from utils import FakeMessage, FakeMessageSegment
def test_segment_add():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert MessageSegment.text("text") + MessageSegment.text("text") == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
)
assert MessageSegment.text("text") + "text" == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
)
assert (
MessageSegment.text("text") + Message([MessageSegment.text("text")])
) == Message([MessageSegment.text("text"), MessageSegment.text("text")])
assert "text" + MessageSegment.text("text") == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
)
def test_segment_validate():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert parse_obj_as(
MessageSegment,
{"type": "text", "data": {"text": "text"}, "extra": "should be ignored"},
) == MessageSegment.text("text")
with pytest.raises(ValidationError):
parse_obj_as(MessageSegment, "some str")
with pytest.raises(ValidationError):
parse_obj_as(MessageSegment, {"data": {}})
def test_segment_join():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
seg = MessageSegment.text("test")
iterable = [
MessageSegment.text("first"),
Message([MessageSegment.text("second"), MessageSegment.text("third")]),
]
assert seg.join(iterable) == Message(
[
MessageSegment.text("first"),
MessageSegment.text("test"),
MessageSegment.text("second"),
MessageSegment.text("third"),
]
)
def test_segment():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert len(MessageSegment.text("text")) == 4
assert MessageSegment.text("text") != MessageSegment.text("other")
assert MessageSegment.text("text").get("data") == {"text": "text"}
assert list(MessageSegment.text("text").keys()) == ["type", "data"]
assert list(MessageSegment.text("text").values()) == ["text", {"text": "text"}]
assert list(MessageSegment.text("text").items()) == [
def test_segment_data():
assert len(FakeMessageSegment.text("text")) == 4
assert FakeMessageSegment.text("text").get("data") == {"text": "text"}
assert list(FakeMessageSegment.text("text").keys()) == ["type", "data"]
assert list(FakeMessageSegment.text("text").values()) == ["text", {"text": "text"}]
assert list(FakeMessageSegment.text("text").items()) == [
("type", "text"),
("data", {"text": "text"}),
]
origin = MessageSegment.text("text")
def test_segment_equal():
assert FakeMessageSegment("text", {"text": "text"}) == FakeMessageSegment(
"text", {"text": "text"}
)
assert FakeMessageSegment("text", {"text": "text"}) != FakeMessageSegment(
"text", {"text": "other"}
)
assert FakeMessageSegment("text", {"text": "text"}) != FakeMessageSegment(
"other", {"text": "text"}
)
def test_segment_add():
assert FakeMessageSegment.text("text") + FakeMessageSegment.text(
"text"
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert FakeMessageSegment.text("text") + "text" == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
assert (
FakeMessageSegment.text("text") + FakeMessage([FakeMessageSegment.text("text")])
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert "text" + FakeMessageSegment.text("text") == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
def test_segment_validate():
assert parse_obj_as(
FakeMessageSegment,
{"type": "text", "data": {"text": "text"}, "extra": "should be ignored"},
) == FakeMessageSegment.text("text")
with pytest.raises(ValidationError):
parse_obj_as(FakeMessageSegment, "some str")
with pytest.raises(ValidationError):
parse_obj_as(FakeMessageSegment, {"data": {}})
def test_segment_join():
seg = FakeMessageSegment.text("test")
iterable = [
FakeMessageSegment.text("first"),
FakeMessage(
[FakeMessageSegment.text("second"), FakeMessageSegment.text("third")]
),
]
assert seg.join(iterable) == FakeMessage(
[
FakeMessageSegment.text("first"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("second"),
FakeMessageSegment.text("third"),
]
)
def test_segment_copy():
origin = FakeMessageSegment.text("text")
copy = origin.copy()
assert origin is not copy
assert origin == copy
def test_message_add():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert (
Message([MessageSegment.text("text")]) + MessageSegment.text("text")
) == Message([MessageSegment.text("text"), MessageSegment.text("text")])
FakeMessage([FakeMessageSegment.text("text")]) + FakeMessageSegment.text("text")
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert Message([MessageSegment.text("text")]) + "text" == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
assert FakeMessage([FakeMessageSegment.text("text")]) + "text" == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
assert (
Message([MessageSegment.text("text")]) + Message([MessageSegment.text("text")])
) == Message([MessageSegment.text("text"), MessageSegment.text("text")])
FakeMessage([FakeMessageSegment.text("text")])
+ FakeMessage([FakeMessageSegment.text("text")])
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert "text" + Message([MessageSegment.text("text")]) == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
assert "text" + FakeMessage([FakeMessageSegment.text("text")]) == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
msg = Message([MessageSegment.text("text")])
msg += MessageSegment.text("text")
assert msg == Message([MessageSegment.text("text"), MessageSegment.text("text")])
msg = FakeMessage([FakeMessageSegment.text("text")])
msg += FakeMessageSegment.text("text")
assert msg == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
def test_message_getitem():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.image("test2"),
FakeMessageSegment.image("test3"),
FakeMessageSegment.text("test4"),
]
)
assert message[0] == MessageSegment.text("test")
assert message[0] == FakeMessageSegment.text("test")
assert message[:2] == Message(
[MessageSegment.text("test"), MessageSegment.image("test2")]
assert message[:2] == FakeMessage(
[FakeMessageSegment.text("test"), FakeMessageSegment.image("test2")]
)
assert message["image"] == Message(
[MessageSegment.image("test2"), MessageSegment.image("test3")]
assert message["image"] == FakeMessage(
[FakeMessageSegment.image("test2"), FakeMessageSegment.image("test3")]
)
assert message["image", 0] == MessageSegment.image("test2")
assert message["image", 0] == FakeMessageSegment.image("test2")
assert message["image", 0:2] == message["image"]
assert message.index(message[0]) == 0
@ -137,153 +138,137 @@ def test_message_getitem():
assert message.get("image") == message["image"]
assert message.get("image", 114514) == message["image"]
assert message.get("image", 1) == Message([message["image", 0]])
assert message.get("image", 1) == FakeMessage([message["image", 0]])
assert message.count("image") == 2
def test_message_validate():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
Message_ = make_fake_message()
assert parse_obj_as(Message, Message([])) == Message([])
assert parse_obj_as(FakeMessage, FakeMessage([])) == FakeMessage([])
with pytest.raises(ValidationError):
parse_obj_as(Message, Message_([]))
parse_obj_as(type("FakeMessage2", (Message,), {}), FakeMessage([]))
assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")])
assert parse_obj_as(Message, {"type": "text", "data": {"text": "text"}}) == Message(
[MessageSegment.text("text")]
assert parse_obj_as(FakeMessage, "text") == FakeMessage(
[FakeMessageSegment.text("text")]
)
assert parse_obj_as(
Message,
[MessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}],
) == Message([MessageSegment.text("text"), MessageSegment.text("text")])
FakeMessage, {"type": "text", "data": {"text": "text"}}
) == FakeMessage([FakeMessageSegment.text("text")])
assert parse_obj_as(
FakeMessage,
[FakeMessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}],
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
with pytest.raises(ValidationError):
parse_obj_as(Message, object())
parse_obj_as(FakeMessage, object())
def test_message_contains():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.image("test2"),
FakeMessageSegment.image("test3"),
FakeMessageSegment.text("test4"),
]
)
assert message.has(MessageSegment.text("test")) is True
assert MessageSegment.text("test") in message
assert message.has(FakeMessageSegment.text("test")) is True
assert FakeMessageSegment.text("test") in message
assert message.has("image") is True
assert "image" in message
assert message.has(MessageSegment.text("foo")) is False
assert MessageSegment.text("foo") not in message
assert message.has(FakeMessageSegment.text("foo")) is False
assert FakeMessageSegment.text("foo") not in message
assert message.has("foo") is False
assert "foo" not in message
def test_message_only():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.text("test2"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("test2"),
]
)
assert message.only("text") is True
assert message.only(MessageSegment.text("test")) is False
assert message.only(FakeMessageSegment.text("test")) is False
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.image("test2"),
FakeMessageSegment.image("test3"),
FakeMessageSegment.text("test4"),
]
)
assert message.only("text") is False
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.text("test"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("test"),
]
)
assert message.only(MessageSegment.text("test")) is True
assert message.only(FakeMessageSegment.text("test")) is True
def test_message_join():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
msg = Message([MessageSegment.text("test")])
msg = FakeMessage([FakeMessageSegment.text("test")])
iterable = [
MessageSegment.text("first"),
Message([MessageSegment.text("second"), MessageSegment.text("third")]),
FakeMessageSegment.text("first"),
FakeMessage(
[FakeMessageSegment.text("second"), FakeMessageSegment.text("third")]
),
]
assert msg.join(iterable) == Message(
assert msg.join(iterable) == FakeMessage(
[
MessageSegment.text("first"),
MessageSegment.text("test"),
MessageSegment.text("second"),
MessageSegment.text("third"),
FakeMessageSegment.text("first"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("second"),
FakeMessageSegment.text("third"),
]
)
def test_message_include():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.image("test2"),
FakeMessageSegment.image("test3"),
FakeMessageSegment.text("test4"),
]
)
assert message.include("text") == Message(
assert message.include("text") == FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("test4"),
]
)
def test_message_exclude():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
message = FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.image("test2"),
FakeMessageSegment.image("test3"),
FakeMessageSegment.text("test4"),
]
)
assert message.exclude("image") == Message(
assert message.exclude("image") == FakeMessage(
[
MessageSegment.text("test"),
MessageSegment.text("test4"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("test4"),
]
)

View File

@ -1,5 +1,5 @@
from nonebot.adapters import MessageTemplate
from utils import escape_text, make_fake_message
from utils import FakeMessage, FakeMessageSegment, escape_text
def test_template_basis():
@ -9,8 +9,7 @@ def test_template_basis():
def test_template_message():
Message = make_fake_message()
template = Message.template("{a:custom}{b:text}{c:image}/{d}")
template = FakeMessage.template("{a:custom}{b:text}{c:image}/{d}")
@template.add_format_spec
def custom(input: str) -> str:
@ -37,29 +36,24 @@ def test_template_message():
def test_rich_template_message():
Message = make_fake_message()
MS = Message.get_segment_class()
pic1, pic2, pic3 = (
MS.image("file:///pic1.jpg"),
MS.image("file:///pic2.jpg"),
MS.image("file:///pic3.jpg"),
FakeMessageSegment.image("file:///pic1.jpg"),
FakeMessageSegment.image("file:///pic2.jpg"),
FakeMessageSegment.image("file:///pic3.jpg"),
)
template = Message.template("{}{}" + pic2 + "{}")
template = FakeMessage.template("{}{}" + pic2 + "{}")
result = template.format(pic1, "[fake:image]", pic3)
assert result["image"] == Message([pic1, pic2, pic3])
assert result["image"] == FakeMessage([pic1, pic2, pic3])
assert str(result) == (
"[fake:image]" + escape_text("[fake:image]") + "[fake:image]" + "[fake:image]"
)
def test_message_injection():
Message = make_fake_message()
template = Message.template("{name}Is Bad")
template = FakeMessage.template("{name}Is Bad")
message = template.format(name="[fake:image]")
assert message.extract_plain_text() == escape_text("[fake:image]Is Bad")

View File

@ -79,13 +79,37 @@ async def test_lifespan():
],
indirect=True,
)
async def test_reverse_driver(app: App, driver: Driver):
async def test_http_server(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
async def _handle_http(request: Request) -> Response:
assert request.content in (b"test", "test")
return Response(200, content="test")
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
response = await client.post("/http_test", data="test")
assert response.status_code == 200
assert response.text == "test"
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
indirect=True,
)
async def test_websocket_server(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
async def _handle_ws(ws: WebSocket) -> None:
await ws.accept()
data = await ws.receive()
@ -107,17 +131,11 @@ async def test_reverse_driver(app: App, driver: Driver):
with pytest.raises(WebSocketClosed):
await ws.receive()
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
response = await client.post("/http_test", data="test")
assert response.status_code == 200
assert response.text == "test"
async with client.websocket_connect("/ws_test") as ws:
await ws.send_text("ping")
@ -145,12 +163,13 @@ async def test_reverse_driver(app: App, driver: Driver):
],
indirect=True,
)
async def test_http_driver(driver: Driver):
async def test_http_client(driver: Driver, server_url: URL):
driver = cast(ForwardDriver, driver)
# simple post with query, headers, cookies and content
request = Request(
"POST",
"https://httpbin.org/post",
server_url,
params={"param": "test"},
headers={"X-Test": "test"},
cookies={"session": "test"},
@ -159,32 +178,39 @@ async def test_http_driver(driver: Driver):
response = await driver.request(request)
assert response.status_code == 200 and response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"param": "test"}
assert data["headers"].get("X-Test") == "test"
assert data["headers"].get("Cookie") == "session=test"
assert data["data"] == "test"
request = Request("POST", "https://httpbin.org/post", data={"form": "test"})
# post with data body
request = Request("POST", server_url, data={"form": "test"})
response = await driver.request(request)
assert response.status_code == 200 and response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["form"] == {"form": "test"}
request = Request("POST", "https://httpbin.org/post", json={"json": "test"})
# post with json body
request = Request("POST", server_url, json={"json": "test"})
response = await driver.request(request)
assert response.status_code == 200 and response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["json"] == {"json": "test"}
# post with files and form data
request = Request(
"POST",
"https://httpbin.org/post",
server_url,
data={"form": "test"},
files={"test": ("test.txt", b"test")},
)
response = await driver.request(request)
assert response.status_code == 200 and response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["form"] == {"form": "test"}
assert data["files"] == {"test": "test"}
@ -236,7 +262,6 @@ async def test_bot_connect_hook(app: App, driver: Driver):
@driver.on_bot_connect
async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
nonlocal conn_should_be_called
conn_should_be_called = True
if foo is not bot:
pytest.fail("on_bot_connect hook called with wrong bot")
@ -245,12 +270,13 @@ async def test_bot_connect_hook(app: App, driver: Driver):
if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value")
conn_should_be_called = True
@driver.on_bot_disconnect
async def disconn_hook(
foo: Bot, dep: int = Depends(dependency), default: int = 1
):
nonlocal disconn_should_be_called
disconn_should_be_called = True
if foo is not bot:
pytest.fail("on_bot_disconnect hook called with wrong bot")
@ -259,6 +285,8 @@ async def test_bot_connect_hook(app: App, driver: Driver):
if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value")
disconn_should_be_called = True
if conn_hook not in {hook.call for hook in conn_hooks}:
pytest.fail("on_bot_connect hook not registered")
if disconn_hook not in {hook.call for hook in disconn_hooks}:
@ -268,6 +296,7 @@ async def test_bot_connect_hook(app: App, driver: Driver):
bot = ctx.create_bot()
await asyncio.sleep(1)
if not conn_should_be_called:
pytest.fail("on_bot_connect hook not called")
if not disconn_should_be_called:

View File

@ -31,17 +31,29 @@ async def test_init():
@pytest.mark.asyncio
async def test_get(app: App, monkeypatch: pytest.MonkeyPatch):
async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setattr(nonebot, "_driver", None)
with pytest.raises(ValueError):
get_driver()
@pytest.mark.asyncio
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert get_asgi() == driver.asgi
@pytest.mark.asyncio
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert get_app() == driver.server_app
@pytest.mark.asyncio
async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
async with app.test_api() as ctx:
adapter = ctx.create_adapter()
adapter_name = adapter.get_name()
@ -54,6 +66,9 @@ async def test_get(app: App, monkeypatch: pytest.MonkeyPatch):
with pytest.raises(ValueError):
get_adapter("not exist")
@pytest.mark.asyncio
async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
runned = False
def mock_run(*args, **kwargs):
@ -61,14 +76,24 @@ async def test_get(app: App, monkeypatch: pytest.MonkeyPatch):
runned = True
assert args == ("arg",) and kwargs == {"kwarg": "kwarg"}
monkeypatch.setattr(driver, "run", mock_run)
nonebot.run("arg", kwarg="kwarg")
driver = get_driver()
with monkeypatch.context() as m:
m.setattr(driver, "run", mock_run)
nonebot.run("arg", kwarg="kwarg")
assert runned
@pytest.mark.asyncio
async def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
with pytest.raises(ValueError):
get_bot()
monkeypatch.setattr(driver, "_bots", {"test": "test"})
assert get_bot() == "test"
assert get_bot("test") == "test"
assert get_bots() == {"test": "test"}
with monkeypatch.context() as m:
m.setattr(driver, "_bots", {"test": "test"})
assert get_bot() == "test"
assert get_bot("test") == "test"
assert get_bots() == {"test": "test"}

View File

@ -3,25 +3,16 @@ from nonebug import App
from nonebot.permission import User
from nonebot.matcher import Matcher, matchers
from utils import FakeMessage, make_fake_event
from nonebot.message import check_and_run_matcher
from utils import make_fake_event, make_fake_message
@pytest.mark.asyncio
async def test_matcher(app: App):
from plugins.matcher.matcher_process import (
test_got,
test_handle,
test_preset,
test_combine,
test_receive,
test_overload,
)
async def test_matcher_handle(app: App):
from plugins.matcher.matcher_process import test_handle
message = make_fake_message()("text")
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = make_fake_message()("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_handle.handlers) == 1
async with app.test_matcher(test_handle) as ctx:
@ -30,6 +21,16 @@ async def test_matcher(app: App):
ctx.should_call_send(event, "send", "result", at_sender=True)
ctx.should_finished()
@pytest.mark.asyncio
async def test_matcher_got(app: App):
from plugins.matcher.matcher_process import test_got
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = FakeMessage("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_got.handlers) == 1
async with app.test_matcher(test_got) as ctx:
bot = ctx.create_bot()
@ -42,6 +43,14 @@ async def test_matcher(app: App):
ctx.should_rejected()
ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
async def test_matcher_receive(app: App):
from plugins.matcher.matcher_process import test_receive
message = FakeMessage("text")
event = make_fake_event(_message=message)()
assert len(test_receive.handlers) == 1
async with app.test_matcher(test_receive) as ctx:
bot = ctx.create_bot()
@ -51,7 +60,17 @@ async def test_matcher(app: App):
ctx.should_call_send(event, "pause", "result", at_sender=True)
ctx.should_paused()
assert len(test_receive.handlers) == 1
@pytest.mark.asyncio
async def test_matcher_(app: App):
from plugins.matcher.matcher_process import test_combine
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = FakeMessage("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_combine.handlers) == 1
async with app.test_matcher(test_combine) as ctx:
bot = ctx.create_bot()
ctx.receive_event(bot, event)
@ -64,6 +83,16 @@ async def test_matcher(app: App):
ctx.should_rejected()
ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
async def test_matcher_preset(app: App):
from plugins.matcher.matcher_process import test_preset
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = FakeMessage("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_preset.handlers) == 2
async with app.test_matcher(test_preset) as ctx:
bot = ctx.create_bot()
@ -72,6 +101,14 @@ async def test_matcher(app: App):
ctx.should_rejected()
ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
async def test_matcher_overload(app: App):
from plugins.matcher.matcher_process import test_overload
message = FakeMessage("text")
event = make_fake_event(_message=message)()
assert len(test_overload.handlers) == 2
async with app.test_matcher(test_overload) as ctx:
bot = ctx.create_bot()
@ -115,12 +152,10 @@ async def test_type_updater(app: App):
@pytest.mark.asyncio
async def test_permission_updater(app: App):
async def test_default_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
default_permission,
test_custom_updater,
test_permission_updater,
test_user_permission_updater,
)
event = make_fake_event(_session_id="test")()
@ -136,6 +171,15 @@ async def test_permission_updater(app: App):
assert checker.users == ("test",)
assert checker.perm is default_permission
@pytest.mark.asyncio
async def test_user_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
default_permission,
test_user_permission_updater,
)
event = make_fake_event(_session_id="test")()
user_permission = list(test_user_permission_updater.permission.checkers)[0].call
assert isinstance(user_permission, User)
assert user_permission.perm is default_permission
@ -149,12 +193,22 @@ async def test_permission_updater(app: App):
assert checker.users == ("test",)
assert checker.perm is default_permission
@pytest.mark.asyncio
async def test_custom_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
new_permission,
default_permission,
test_custom_updater,
)
event = make_fake_event(_session_id="test")()
assert test_custom_updater.permission is default_permission
async with app.test_api() as ctx:
bot = ctx.create_bot()
matcher = test_custom_updater()
new_perm = await matcher.update_permission(bot, event)
assert new_perm is default_permission
assert new_perm is new_permission
@pytest.mark.asyncio
@ -189,12 +243,8 @@ async def test_run(app: App):
@pytest.mark.asyncio
async def test_expire(app: App):
from plugins.matcher.matcher_expire import (
test_temp_matcher,
test_datetime_matcher,
test_timedelta_matcher,
)
async def test_temp(app: App):
from plugins.matcher.matcher_expire import test_temp_matcher
event = make_fake_event(_type="test")()
async with app.test_api() as ctx:
@ -203,6 +253,11 @@ async def test_expire(app: App):
await check_and_run_matcher(test_temp_matcher, bot, event, {})
assert test_temp_matcher not in matchers[test_temp_matcher.priority]
@pytest.mark.asyncio
async def test_datetime_expire(app: App):
from plugins.matcher.matcher_expire import test_datetime_matcher
event = make_fake_event()()
async with app.test_api() as ctx:
bot = ctx.create_bot()
@ -210,6 +265,11 @@ async def test_expire(app: App):
await check_and_run_matcher(test_datetime_matcher, bot, event, {})
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
@pytest.mark.asyncio
async def test_timedelta_expire(app: App):
from plugins.matcher.matcher_expire import test_timedelta_matcher
event = make_fake_event()()
async with app.test_api() as ctx:
bot = ctx.create_bot()

View File

@ -6,7 +6,7 @@ from nonebug import App
from nonebot.matcher import Matcher
from nonebot.dependencies import Dependent
from nonebot.exception import TypeMisMatch
from utils import make_fake_event, make_fake_message
from utils import FakeMessage, make_fake_event
from nonebot.params import (
ArgParam,
BotParam,
@ -157,7 +157,7 @@ async def test_event(app: App):
generic_event_none,
)
fake_message = make_fake_message()("text")
fake_message = FakeMessage("text")
fake_event = make_fake_event(_message=fake_message)()
fake_fooevent = make_fake_event(_base=FooEvent)()
@ -247,7 +247,7 @@ async def test_state(app: App):
shell_command_argv,
)
fake_message = make_fake_message()("text")
fake_message = FakeMessage("text")
fake_matched = re.match(r"\[cq:(?P<type>.*?),(?P<arg>.*?)\]", "[cq:test,arg=value]")
fake_state = {
PREFIX_KEY: {
@ -453,7 +453,7 @@ async def test_arg(app: App):
from plugins.param.param_arg import arg, arg_str, arg_plain_text
matcher = Matcher()
message = make_fake_message()("text")
message = FakeMessage("text")
matcher.set_arg("key", message)
async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx:

View File

@ -6,8 +6,8 @@ import pytest
from nonebug import App
from nonebot.typing import T_State
from utils import make_fake_event, make_fake_message
from nonebot.exception import ParserExit, SkippedException
from utils import FakeMessage, FakeMessageSegment, make_fake_event
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
@ -85,24 +85,21 @@ async def test_rule(app: App):
async def test_trie(app: App):
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
async with app.test_api() as ctx:
bot = ctx.create_bot()
message = Message("/fake-prefix some args")
message = FakeMessage("/fake-prefix some args")
event = make_fake_event(_message=message)()
state = {}
TrieRule.get_value(bot, event, state)
assert state[PREFIX_KEY] == CMD_RESULT(
command=("fake-prefix",),
raw_command="/fake-prefix",
command_arg=Message("some args"),
command_arg=FakeMessage("some args"),
command_start="/",
command_whitespace=" ",
)
message = MessageSegment.text("/fake-prefix ") + MessageSegment.image(
message = FakeMessageSegment.text("/fake-prefix ") + FakeMessageSegment.image(
"fake url"
)
event = make_fake_event(_message=message)()
@ -111,7 +108,7 @@ async def test_trie(app: App):
assert state[PREFIX_KEY] == CMD_RESULT(
command=("fake-prefix",),
raw_command="/fake-prefix",
command_arg=Message(MessageSegment.image("fake url")),
command_arg=FakeMessage(FakeMessageSegment.image("fake url")),
command_start="/",
command_whitespace=" ",
)
@ -152,7 +149,7 @@ async def test_startswith(
assert checker.msg == msg
assert checker.ignorecase == ignorecase
message = text if text is None else make_fake_message()(text)
message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)()
for prefix in msg:
state = {STARTSWITH_KEY: prefix}
@ -192,7 +189,7 @@ async def test_endswith(
assert checker.msg == msg
assert checker.ignorecase == ignorecase
message = text if text is None else make_fake_message()(text)
message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)()
for suffix in msg:
state = {ENDSWITH_KEY: suffix}
@ -232,7 +229,7 @@ async def test_fullmatch(
assert checker.msg == msg
assert checker.ignorecase == ignorecase
message = text if text is None else make_fake_message()(text)
message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)()
for full in msg:
state = {FULLMATCH_KEY: full}
@ -264,7 +261,7 @@ async def test_keyword(
assert isinstance(checker, KeywordsRule)
assert checker.keywords == kws
message = text if text is None else make_fake_message()(text)
message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)()
for kw in kws:
state = {KEYWORD_KEY: kw}
@ -310,7 +307,7 @@ async def test_command(
assert isinstance(checker, CommandRule)
assert checker.cmds == cmds
arg = arg_text if arg_text is None else make_fake_message()(arg_text)
arg = arg_text if arg_text is None else FakeMessage(arg_text)
state = {
PREFIX_KEY: {CMD_KEY: cmd, CMD_WHITESPACE_KEY: whitespace, CMD_ARG_KEY: arg}
}
@ -321,7 +318,7 @@ async def test_command(
async def test_shell_command():
state: T_State
CMD = ("test",)
Message = make_fake_message()
Message = FakeMessage
MessageSegment = Message.get_segment_class()
test_not_cmd = shell_command(CMD)
@ -455,7 +452,7 @@ async def test_regex(
assert isinstance(checker, RegexRule)
assert checker.regex == pattern
message = text if text is None else make_fake_message()(text)
message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)()
state = {}
assert await dependent(event=event, state=state) == expected

View File

@ -1,16 +1,122 @@
import json
from typing import Dict, List, Union, TypeVar
from utils import make_fake_message
from nonebot.utils import DataclassEncoder
from utils import FakeMessage, FakeMessageSegment
from nonebot.utils import (
DataclassEncoder,
escape_tag,
is_gen_callable,
is_async_gen_callable,
is_coroutine_callable,
generic_check_issubclass,
)
def test_loguru_escape_tag():
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"
assert escape_tag("<fg\n#fff>white</fg\n#fff>") == "\\<fg\n#fff>white\\</fg\n#fff>"
assert escape_tag("<bg #fff>white</bg #fff>") == r"\<bg #fff>white\</bg #fff>"
assert escape_tag("<bg\n#fff>white</bg\n#fff>") == "\\<bg\n#fff>white\\</bg\n#fff>"
def test_generic_check_issubclass():
assert generic_check_issubclass(int, (int, float))
assert not generic_check_issubclass(str, (int, float))
assert generic_check_issubclass(Union[int, float, None], (int, float))
assert generic_check_issubclass(List[int], list)
assert generic_check_issubclass(Dict[str, int], dict)
assert generic_check_issubclass(TypeVar("T", int, float), (int, float))
assert generic_check_issubclass(TypeVar("T", bound=int), (int, float))
def test_is_coroutine_callable():
async def test1():
...
def test2():
...
class TestClass1:
async def __call__(self):
...
class TestClass2:
def __call__(self):
...
assert is_coroutine_callable(test1)
assert not is_coroutine_callable(test2)
assert not is_coroutine_callable(TestClass1)
assert is_coroutine_callable(TestClass1())
assert not is_coroutine_callable(TestClass2)
def test_is_gen_callable():
def test1():
yield
async def test2():
yield
def test3():
...
class TestClass1:
def __call__(self):
yield
class TestClass2:
async def __call__(self):
yield
class TestClass3:
def __call__(self):
...
assert is_gen_callable(test1)
assert not is_gen_callable(test2)
assert not is_gen_callable(test3)
assert is_gen_callable(TestClass1())
assert not is_gen_callable(TestClass2())
assert not is_gen_callable(TestClass3())
def test_is_async_gen_callable():
async def test1():
yield
def test2():
yield
async def test3():
...
class TestClass1:
async def __call__(self):
yield
class TestClass2:
def __call__(self):
yield
class TestClass3:
async def __call__(self):
...
assert is_async_gen_callable(test1)
assert not is_async_gen_callable(test2)
assert not is_async_gen_callable(test3)
assert is_async_gen_callable(TestClass1())
assert not is_async_gen_callable(TestClass2())
assert not is_async_gen_callable(TestClass3())
def test_dataclass_encoder():
simple = json.dumps("123", cls=DataclassEncoder)
assert simple == '"123"'
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
ms = MessageSegment.nested(Message(MessageSegment.text("text")))
ms = FakeMessageSegment.nested(FakeMessage(FakeMessageSegment.text("text")))
s = json.dumps(ms, cls=DataclassEncoder)
assert (
s

View File

@ -1,6 +1,6 @@
from typing import Type, Union, Mapping, Iterable, Optional
from pydantic import create_model
from pydantic import Extra, create_model
from nonebot.adapters import Event, Message, MessageSegment
@ -12,51 +12,49 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str:
return s
def make_fake_message():
class FakeMessageSegment(MessageSegment["FakeMessage"]):
@classmethod
def get_message_class(cls):
return FakeMessage
class FakeMessageSegment(MessageSegment["FakeMessage"]):
@classmethod
def get_message_class(cls):
return FakeMessage
def __str__(self) -> str:
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
def __str__(self) -> str:
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
@classmethod
def text(cls, text: str):
return cls("text", {"text": text})
@classmethod
def text(cls, text: str):
return cls("text", {"text": text})
@staticmethod
def image(url: str):
return FakeMessageSegment("image", {"url": url})
@staticmethod
def image(url: str):
return FakeMessageSegment("image", {"url": url})
@staticmethod
def nested(content: "FakeMessage"):
return FakeMessageSegment("node", {"content": content})
@staticmethod
def nested(content: "FakeMessage"):
return FakeMessageSegment("node", {"content": content})
def is_text(self) -> bool:
return self.type == "text"
def is_text(self) -> bool:
return self.type == "text"
class FakeMessage(Message[FakeMessageSegment]):
@classmethod
def get_segment_class(cls):
return FakeMessageSegment
@staticmethod
def _construct(msg: Union[str, Iterable[Mapping]]):
if isinstance(msg, str):
yield FakeMessageSegment.text(msg)
else:
for seg in msg:
yield FakeMessageSegment(**seg)
return
class FakeMessage(Message[FakeMessageSegment]):
@classmethod
def get_segment_class(cls):
return FakeMessageSegment
def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other)
@staticmethod
def _construct(msg: Union[str, Iterable[Mapping]]):
if isinstance(msg, str):
yield FakeMessageSegment.text(msg)
else:
for seg in msg:
yield FakeMessageSegment(**seg)
return
return FakeMessage
def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other)
def make_fake_event(
@ -70,9 +68,9 @@ def make_fake_event(
_to_me: bool = True,
**fields,
) -> Type[Event]:
_Fake = create_model("_Fake", __base__=_base or Event, **fields)
Base = _base or Event
class FakeEvent(_Fake):
class FakeEvent(Base, extra=Extra.forbid):
def get_type(self) -> str:
return _type
@ -100,7 +98,4 @@ def make_fake_event(
def is_tome(self) -> bool:
return _to_me
class Config:
extra = "forbid"
return FakeEvent
return create_model("FakeEvent", __base__=FakeEvent, **fields)