👷 Test: 移除 httpbin 并整理测试 (#2110)

This commit is contained in:
Ju4tCode 2023-06-19 17:48:59 +08:00 committed by GitHub
parent 27a3d1f0bb
commit 080b876d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 592 additions and 309 deletions

12
poetry.lock generated
View File

@ -1215,7 +1215,7 @@ dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegu
name = "markupsafe" name = "markupsafe"
version = "2.1.2" version = "2.1.2"
description = "Safely add untrusted strings to HTML/XML markup." description = "Safely add untrusted strings to HTML/XML markup."
optional = true optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "MarkupSafe-2.1.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:665a36ae6f8f20a4676b53224e33d456a6f5a72657d9c83c2aa00765072f31f7"}, {file = "MarkupSafe-2.1.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:665a36ae6f8f20a4676b53224e33d456a6f5a72657d9c83c2aa00765072f31f7"},
@ -2239,13 +2239,13 @@ files = [
[[package]] [[package]]
name = "werkzeug" name = "werkzeug"
version = "2.3.4" version = "2.3.6"
description = "The comprehensive WSGI web application library." description = "The comprehensive WSGI web application library."
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "Werkzeug-2.3.4-py3-none-any.whl", hash = "sha256:48e5e61472fee0ddee27ebad085614ebedb7af41e88f687aaf881afb723a162f"}, {file = "Werkzeug-2.3.6-py3-none-any.whl", hash = "sha256:935539fa1413afbb9195b24880778422ed620c0fc09670945185cce4d91a8890"},
{file = "Werkzeug-2.3.4.tar.gz", hash = "sha256:1d5a58e0377d1fe39d061a5de4469e414e78ccb1e1e59c0f5ad6fa1c36c52b76"}, {file = "Werkzeug-2.3.6.tar.gz", hash = "sha256:98c774df2f91b05550078891dee5f0eb0cb797a522c757a2452b9cee5b202330"},
] ]
[package.dependencies] [package.dependencies]
@ -2395,4 +2395,4 @@ websockets = ["websockets"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "be276204946fdf3338fb06cc5391aafc2f2fc89fc5c0a91bf09590f1eee40aa4" content-hash = "cf0a75f5173e6eb0fa3cd9d8901d4149e104094d08bce7e298d6f8c92954803d"

View File

@ -51,6 +51,7 @@ pre-commit = "^3.0.0"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
nonebug = "^0.3.0" nonebug = "^0.3.0"
werkzeug = "^2.3.6"
pytest-cov = "^4.0.0" pytest-cov = "^4.0.0"
pytest-xdist = "^3.0.2" pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.21.0" pytest-asyncio = "^0.21.0"
@ -68,7 +69,7 @@ fastapi = ["fastapi", "uvicorn"]
all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"] all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = "strict"
addopts = "--cov=nonebot --cov-append --cov-report=term-missing" addopts = "--cov=nonebot --cov-append --cov-report=term-missing"
filterwarnings = [ filterwarnings = [
"error", "error",

View File

@ -1,11 +1,15 @@
import os import os
import threading
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Set from typing import TYPE_CHECKING, Set, Generator
import pytest import pytest
from nonebug import NONEBOT_INIT_KWARGS from nonebug import NONEBOT_INIT_KWARGS
from werkzeug.serving import BaseWSGIServer, make_server
import nonebot import nonebot
from nonebot.drivers import URL
from fake_server import request_handler
os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}' os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
os.environ["CONFIG_OVERRIDE"] = "new" os.environ["CONFIG_OVERRIDE"] = "new"
@ -28,3 +32,20 @@ def load_plugin(nonebug_init: None) -> Set["Plugin"]:
def load_builtin_plugin(nonebug_init: None) -> Set["Plugin"]: def load_builtin_plugin(nonebug_init: None) -> Set["Plugin"]:
# preload builtin plugins # preload builtin plugins
return nonebot.load_builtin_plugins("echo", "single_session") return nonebot.load_builtin_plugins("echo", "single_session")
@pytest.fixture(scope="session", autouse=True)
def server() -> Generator[BaseWSGIServer, None, None]:
server = make_server("127.0.0.1", 0, app=request_handler)
thread = threading.Thread(target=server.serve_forever)
thread.start()
try:
yield server
finally:
server.shutdown()
thread.join()
@pytest.fixture(scope="session")
def server_url(server: BaseWSGIServer) -> URL:
return URL(f"http://{server.host}:{server.port}")

69
tests/fake_server.py Normal file
View File

@ -0,0 +1,69 @@
import json
import base64
from typing import Dict, List, Union, TypeVar
from werkzeug import Request, Response
from werkzeug.datastructures import MultiDict
K = TypeVar("K")
V = TypeVar("V")
def json_safe(string, content_type="application/octet-stream") -> str:
try:
string = string.decode("utf-8")
json.dumps(string)
return string
except (ValueError, TypeError):
return b"".join(
[
b"data:",
content_type.encode("utf-8"),
b";base64,",
base64.b64encode(string),
]
).decode("utf-8")
def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]:
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}
@Request.application
def request_handler(request: Request) -> Response:
try:
_json = json.loads(request.data.decode("utf-8"))
except (ValueError, TypeError):
_json = None
return Response(
json.dumps(
{
"url": request.url,
"method": request.method,
"origin": request.headers.get("X-Forwarded-For", request.remote_addr),
"headers": flattern(
MultiDict((k, v) for k, v in request.headers.items())
),
"args": flattern(request.args),
"form": flattern(request.form),
"data": json_safe(request.data),
"json": _json,
"files": flattern(
MultiDict(
(
k,
json_safe(
v.read(),
request.files[k].content_type
or "application/octet-stream",
),
)
for k, v in request.files.items()
)
),
}
),
status=200,
content_type="application/json",
)

View File

@ -2,6 +2,7 @@ from nonebot.matcher import Matcher
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
default_permission = Permission() default_permission = Permission()
new_permission = Permission()
test_permission_updater = Matcher.new(permission=default_permission) test_permission_updater = Matcher.new(permission=default_permission)
@ -14,4 +15,4 @@ test_custom_updater = Matcher.new(permission=default_permission)
@test_custom_updater.permission_updater @test_custom_updater.permission_updater
async def _() -> Permission: async def _() -> Permission:
return default_permission return new_permission

View File

@ -1,135 +1,136 @@
import pytest import pytest
from pydantic import ValidationError, parse_obj_as from pydantic import ValidationError, parse_obj_as
from utils import make_fake_message from nonebot.adapters import Message
from utils import FakeMessage, FakeMessageSegment
def test_segment_add(): def test_segment_data():
Message = make_fake_message() assert len(FakeMessageSegment.text("text")) == 4
MessageSegment = Message.get_segment_class() assert FakeMessageSegment.text("text").get("data") == {"text": "text"}
assert list(FakeMessageSegment.text("text").keys()) == ["type", "data"]
assert MessageSegment.text("text") + MessageSegment.text("text") == Message( assert list(FakeMessageSegment.text("text").values()) == ["text", {"text": "text"}]
[MessageSegment.text("text"), MessageSegment.text("text")] assert list(FakeMessageSegment.text("text").items()) == [
)
assert MessageSegment.text("text") + "text" == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
)
assert (
MessageSegment.text("text") + Message([MessageSegment.text("text")])
) == Message([MessageSegment.text("text"), MessageSegment.text("text")])
assert "text" + MessageSegment.text("text") == Message(
[MessageSegment.text("text"), MessageSegment.text("text")]
)
def test_segment_validate():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert parse_obj_as(
MessageSegment,
{"type": "text", "data": {"text": "text"}, "extra": "should be ignored"},
) == MessageSegment.text("text")
with pytest.raises(ValidationError):
parse_obj_as(MessageSegment, "some str")
with pytest.raises(ValidationError):
parse_obj_as(MessageSegment, {"data": {}})
def test_segment_join():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
seg = MessageSegment.text("test")
iterable = [
MessageSegment.text("first"),
Message([MessageSegment.text("second"), MessageSegment.text("third")]),
]
assert seg.join(iterable) == Message(
[
MessageSegment.text("first"),
MessageSegment.text("test"),
MessageSegment.text("second"),
MessageSegment.text("third"),
]
)
def test_segment():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert len(MessageSegment.text("text")) == 4
assert MessageSegment.text("text") != MessageSegment.text("other")
assert MessageSegment.text("text").get("data") == {"text": "text"}
assert list(MessageSegment.text("text").keys()) == ["type", "data"]
assert list(MessageSegment.text("text").values()) == ["text", {"text": "text"}]
assert list(MessageSegment.text("text").items()) == [
("type", "text"), ("type", "text"),
("data", {"text": "text"}), ("data", {"text": "text"}),
] ]
origin = MessageSegment.text("text")
def test_segment_equal():
assert FakeMessageSegment("text", {"text": "text"}) == FakeMessageSegment(
"text", {"text": "text"}
)
assert FakeMessageSegment("text", {"text": "text"}) != FakeMessageSegment(
"text", {"text": "other"}
)
assert FakeMessageSegment("text", {"text": "text"}) != FakeMessageSegment(
"other", {"text": "text"}
)
def test_segment_add():
assert FakeMessageSegment.text("text") + FakeMessageSegment.text(
"text"
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert FakeMessageSegment.text("text") + "text" == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
assert (
FakeMessageSegment.text("text") + FakeMessage([FakeMessageSegment.text("text")])
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert "text" + FakeMessageSegment.text("text") == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
def test_segment_validate():
assert parse_obj_as(
FakeMessageSegment,
{"type": "text", "data": {"text": "text"}, "extra": "should be ignored"},
) == FakeMessageSegment.text("text")
with pytest.raises(ValidationError):
parse_obj_as(FakeMessageSegment, "some str")
with pytest.raises(ValidationError):
parse_obj_as(FakeMessageSegment, {"data": {}})
def test_segment_join():
seg = FakeMessageSegment.text("test")
iterable = [
FakeMessageSegment.text("first"),
FakeMessage(
[FakeMessageSegment.text("second"), FakeMessageSegment.text("third")]
),
]
assert seg.join(iterable) == FakeMessage(
[
FakeMessageSegment.text("first"),
FakeMessageSegment.text("test"),
FakeMessageSegment.text("second"),
FakeMessageSegment.text("third"),
]
)
def test_segment_copy():
origin = FakeMessageSegment.text("text")
copy = origin.copy() copy = origin.copy()
assert origin is not copy assert origin is not copy
assert origin == copy assert origin == copy
def test_message_add(): def test_message_add():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
assert ( assert (
Message([MessageSegment.text("text")]) + MessageSegment.text("text") FakeMessage([FakeMessageSegment.text("text")]) + FakeMessageSegment.text("text")
) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) ) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert Message([MessageSegment.text("text")]) + "text" == Message( assert FakeMessage([FakeMessageSegment.text("text")]) + "text" == FakeMessage(
[MessageSegment.text("text"), MessageSegment.text("text")] [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
) )
assert ( assert (
Message([MessageSegment.text("text")]) + Message([MessageSegment.text("text")]) FakeMessage([FakeMessageSegment.text("text")])
) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + FakeMessage([FakeMessageSegment.text("text")])
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
assert "text" + Message([MessageSegment.text("text")]) == Message( assert "text" + FakeMessage([FakeMessageSegment.text("text")]) == FakeMessage(
[MessageSegment.text("text"), MessageSegment.text("text")] [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
) )
msg = Message([MessageSegment.text("text")]) msg = FakeMessage([FakeMessageSegment.text("text")])
msg += MessageSegment.text("text") msg += FakeMessageSegment.text("text")
assert msg == Message([MessageSegment.text("text"), MessageSegment.text("text")]) assert msg == FakeMessage(
[FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]
)
def test_message_getitem(): def test_message_getitem():
Message = make_fake_message() message = FakeMessage(
MessageSegment = Message.get_segment_class()
message = Message(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.image("test2"), FakeMessageSegment.image("test2"),
MessageSegment.image("test3"), FakeMessageSegment.image("test3"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )
assert message[0] == MessageSegment.text("test") assert message[0] == FakeMessageSegment.text("test")
assert message[:2] == Message( assert message[:2] == FakeMessage(
[MessageSegment.text("test"), MessageSegment.image("test2")] [FakeMessageSegment.text("test"), FakeMessageSegment.image("test2")]
) )
assert message["image"] == Message( assert message["image"] == FakeMessage(
[MessageSegment.image("test2"), MessageSegment.image("test3")] [FakeMessageSegment.image("test2"), FakeMessageSegment.image("test3")]
) )
assert message["image", 0] == MessageSegment.image("test2") assert message["image", 0] == FakeMessageSegment.image("test2")
assert message["image", 0:2] == message["image"] assert message["image", 0:2] == message["image"]
assert message.index(message[0]) == 0 assert message.index(message[0]) == 0
@ -137,153 +138,137 @@ def test_message_getitem():
assert message.get("image") == message["image"] assert message.get("image") == message["image"]
assert message.get("image", 114514) == message["image"] assert message.get("image", 114514) == message["image"]
assert message.get("image", 1) == Message([message["image", 0]]) assert message.get("image", 1) == FakeMessage([message["image", 0]])
assert message.count("image") == 2 assert message.count("image") == 2
def test_message_validate(): def test_message_validate():
Message = make_fake_message() assert parse_obj_as(FakeMessage, FakeMessage([])) == FakeMessage([])
MessageSegment = Message.get_segment_class()
Message_ = make_fake_message()
assert parse_obj_as(Message, Message([])) == Message([])
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
parse_obj_as(Message, Message_([])) parse_obj_as(type("FakeMessage2", (Message,), {}), FakeMessage([]))
assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")]) assert parse_obj_as(FakeMessage, "text") == FakeMessage(
[FakeMessageSegment.text("text")]
assert parse_obj_as(Message, {"type": "text", "data": {"text": "text"}}) == Message(
[MessageSegment.text("text")]
) )
assert parse_obj_as( assert parse_obj_as(
Message, FakeMessage, {"type": "text", "data": {"text": "text"}}
[MessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}], ) == FakeMessage([FakeMessageSegment.text("text")])
) == Message([MessageSegment.text("text"), MessageSegment.text("text")])
assert parse_obj_as(
FakeMessage,
[FakeMessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}],
) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")])
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
parse_obj_as(Message, object()) parse_obj_as(FakeMessage, object())
def test_message_contains(): def test_message_contains():
Message = make_fake_message() message = FakeMessage(
MessageSegment = Message.get_segment_class()
message = Message(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.image("test2"), FakeMessageSegment.image("test2"),
MessageSegment.image("test3"), FakeMessageSegment.image("test3"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )
assert message.has(MessageSegment.text("test")) is True assert message.has(FakeMessageSegment.text("test")) is True
assert MessageSegment.text("test") in message assert FakeMessageSegment.text("test") in message
assert message.has("image") is True assert message.has("image") is True
assert "image" in message assert "image" in message
assert message.has(MessageSegment.text("foo")) is False assert message.has(FakeMessageSegment.text("foo")) is False
assert MessageSegment.text("foo") not in message assert FakeMessageSegment.text("foo") not in message
assert message.has("foo") is False assert message.has("foo") is False
assert "foo" not in message assert "foo" not in message
def test_message_only(): def test_message_only():
Message = make_fake_message() message = FakeMessage(
MessageSegment = Message.get_segment_class()
message = Message(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.text("test2"), FakeMessageSegment.text("test2"),
] ]
) )
assert message.only("text") is True assert message.only("text") is True
assert message.only(MessageSegment.text("test")) is False assert message.only(FakeMessageSegment.text("test")) is False
message = Message( message = FakeMessage(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.image("test2"), FakeMessageSegment.image("test2"),
MessageSegment.image("test3"), FakeMessageSegment.image("test3"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )
assert message.only("text") is False assert message.only("text") is False
message = Message( message = FakeMessage(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.text("test"), FakeMessageSegment.text("test"),
] ]
) )
assert message.only(MessageSegment.text("test")) is True assert message.only(FakeMessageSegment.text("test")) is True
def test_message_join(): def test_message_join():
Message = make_fake_message() msg = FakeMessage([FakeMessageSegment.text("test")])
MessageSegment = Message.get_segment_class()
msg = Message([MessageSegment.text("test")])
iterable = [ iterable = [
MessageSegment.text("first"), FakeMessageSegment.text("first"),
Message([MessageSegment.text("second"), MessageSegment.text("third")]), FakeMessage(
[FakeMessageSegment.text("second"), FakeMessageSegment.text("third")]
),
] ]
assert msg.join(iterable) == Message( assert msg.join(iterable) == FakeMessage(
[ [
MessageSegment.text("first"), FakeMessageSegment.text("first"),
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.text("second"), FakeMessageSegment.text("second"),
MessageSegment.text("third"), FakeMessageSegment.text("third"),
] ]
) )
def test_message_include(): def test_message_include():
Message = make_fake_message() message = FakeMessage(
MessageSegment = Message.get_segment_class()
message = Message(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.image("test2"), FakeMessageSegment.image("test2"),
MessageSegment.image("test3"), FakeMessageSegment.image("test3"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )
assert message.include("text") == Message( assert message.include("text") == FakeMessage(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )
def test_message_exclude(): def test_message_exclude():
Message = make_fake_message() message = FakeMessage(
MessageSegment = Message.get_segment_class()
message = Message(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.image("test2"), FakeMessageSegment.image("test2"),
MessageSegment.image("test3"), FakeMessageSegment.image("test3"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )
assert message.exclude("image") == Message( assert message.exclude("image") == FakeMessage(
[ [
MessageSegment.text("test"), FakeMessageSegment.text("test"),
MessageSegment.text("test4"), FakeMessageSegment.text("test4"),
] ]
) )

View File

@ -1,5 +1,5 @@
from nonebot.adapters import MessageTemplate from nonebot.adapters import MessageTemplate
from utils import escape_text, make_fake_message from utils import FakeMessage, FakeMessageSegment, escape_text
def test_template_basis(): def test_template_basis():
@ -9,8 +9,7 @@ def test_template_basis():
def test_template_message(): def test_template_message():
Message = make_fake_message() template = FakeMessage.template("{a:custom}{b:text}{c:image}/{d}")
template = Message.template("{a:custom}{b:text}{c:image}/{d}")
@template.add_format_spec @template.add_format_spec
def custom(input: str) -> str: def custom(input: str) -> str:
@ -37,29 +36,24 @@ def test_template_message():
def test_rich_template_message(): def test_rich_template_message():
Message = make_fake_message()
MS = Message.get_segment_class()
pic1, pic2, pic3 = ( pic1, pic2, pic3 = (
MS.image("file:///pic1.jpg"), FakeMessageSegment.image("file:///pic1.jpg"),
MS.image("file:///pic2.jpg"), FakeMessageSegment.image("file:///pic2.jpg"),
MS.image("file:///pic3.jpg"), FakeMessageSegment.image("file:///pic3.jpg"),
) )
template = Message.template("{}{}" + pic2 + "{}") template = FakeMessage.template("{}{}" + pic2 + "{}")
result = template.format(pic1, "[fake:image]", pic3) result = template.format(pic1, "[fake:image]", pic3)
assert result["image"] == Message([pic1, pic2, pic3]) assert result["image"] == FakeMessage([pic1, pic2, pic3])
assert str(result) == ( assert str(result) == (
"[fake:image]" + escape_text("[fake:image]") + "[fake:image]" + "[fake:image]" "[fake:image]" + escape_text("[fake:image]") + "[fake:image]" + "[fake:image]"
) )
def test_message_injection(): def test_message_injection():
Message = make_fake_message() template = FakeMessage.template("{name}Is Bad")
template = Message.template("{name}Is Bad")
message = template.format(name="[fake:image]") message = template.format(name="[fake:image]")
assert message.extract_plain_text() == escape_text("[fake:image]Is Bad") assert message.extract_plain_text() == escape_text("[fake:image]Is Bad")

View File

@ -79,13 +79,37 @@ async def test_lifespan():
], ],
indirect=True, indirect=True,
) )
async def test_reverse_driver(app: App, driver: Driver): async def test_http_server(app: App, driver: Driver):
driver = cast(ReverseDriver, driver) driver = cast(ReverseDriver, driver)
async def _handle_http(request: Request) -> Response: async def _handle_http(request: Request) -> Response:
assert request.content in (b"test", "test") assert request.content in (b"test", "test")
return Response(200, content="test") return Response(200, content="test")
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
response = await client.post("/http_test", data="test")
assert response.status_code == 200
assert response.text == "test"
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
indirect=True,
)
async def test_websocket_server(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
async def _handle_ws(ws: WebSocket) -> None: async def _handle_ws(ws: WebSocket) -> None:
await ws.accept() await ws.accept()
data = await ws.receive() data = await ws.receive()
@ -107,17 +131,11 @@ async def test_reverse_driver(app: App, driver: Driver):
with pytest.raises(WebSocketClosed): with pytest.raises(WebSocketClosed):
await ws.receive() await ws.receive()
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup) driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx: async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client() client = ctx.get_client()
response = await client.post("/http_test", data="test")
assert response.status_code == 200
assert response.text == "test"
async with client.websocket_connect("/ws_test") as ws: async with client.websocket_connect("/ws_test") as ws:
await ws.send_text("ping") await ws.send_text("ping")
@ -145,12 +163,13 @@ async def test_reverse_driver(app: App, driver: Driver):
], ],
indirect=True, indirect=True,
) )
async def test_http_driver(driver: Driver): async def test_http_client(driver: Driver, server_url: URL):
driver = cast(ForwardDriver, driver) driver = cast(ForwardDriver, driver)
# simple post with query, headers, cookies and content
request = Request( request = Request(
"POST", "POST",
"https://httpbin.org/post", server_url,
params={"param": "test"}, params={"param": "test"},
headers={"X-Test": "test"}, headers={"X-Test": "test"},
cookies={"session": "test"}, cookies={"session": "test"},
@ -159,32 +178,39 @@ async def test_http_driver(driver: Driver):
response = await driver.request(request) response = await driver.request(request)
assert response.status_code == 200 and response.content assert response.status_code == 200 and response.content
data = json.loads(response.content) data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"param": "test"} assert data["args"] == {"param": "test"}
assert data["headers"].get("X-Test") == "test" assert data["headers"].get("X-Test") == "test"
assert data["headers"].get("Cookie") == "session=test" assert data["headers"].get("Cookie") == "session=test"
assert data["data"] == "test" assert data["data"] == "test"
request = Request("POST", "https://httpbin.org/post", data={"form": "test"}) # post with data body
request = Request("POST", server_url, data={"form": "test"})
response = await driver.request(request) response = await driver.request(request)
assert response.status_code == 200 and response.content assert response.status_code == 200 and response.content
data = json.loads(response.content) data = json.loads(response.content)
assert data["method"] == "POST"
assert data["form"] == {"form": "test"} assert data["form"] == {"form": "test"}
request = Request("POST", "https://httpbin.org/post", json={"json": "test"}) # post with json body
request = Request("POST", server_url, json={"json": "test"})
response = await driver.request(request) response = await driver.request(request)
assert response.status_code == 200 and response.content assert response.status_code == 200 and response.content
data = json.loads(response.content) data = json.loads(response.content)
assert data["method"] == "POST"
assert data["json"] == {"json": "test"} assert data["json"] == {"json": "test"}
# post with files and form data
request = Request( request = Request(
"POST", "POST",
"https://httpbin.org/post", server_url,
data={"form": "test"}, data={"form": "test"},
files={"test": ("test.txt", b"test")}, files={"test": ("test.txt", b"test")},
) )
response = await driver.request(request) response = await driver.request(request)
assert response.status_code == 200 and response.content assert response.status_code == 200 and response.content
data = json.loads(response.content) data = json.loads(response.content)
assert data["method"] == "POST"
assert data["form"] == {"form": "test"} assert data["form"] == {"form": "test"}
assert data["files"] == {"test": "test"} assert data["files"] == {"test": "test"}
@ -236,7 +262,6 @@ async def test_bot_connect_hook(app: App, driver: Driver):
@driver.on_bot_connect @driver.on_bot_connect
async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1): async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
nonlocal conn_should_be_called nonlocal conn_should_be_called
conn_should_be_called = True
if foo is not bot: if foo is not bot:
pytest.fail("on_bot_connect hook called with wrong bot") pytest.fail("on_bot_connect hook called with wrong bot")
@ -245,12 +270,13 @@ async def test_bot_connect_hook(app: App, driver: Driver):
if default != 1: if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value") pytest.fail("on_bot_connect hook called with wrong default value")
conn_should_be_called = True
@driver.on_bot_disconnect @driver.on_bot_disconnect
async def disconn_hook( async def disconn_hook(
foo: Bot, dep: int = Depends(dependency), default: int = 1 foo: Bot, dep: int = Depends(dependency), default: int = 1
): ):
nonlocal disconn_should_be_called nonlocal disconn_should_be_called
disconn_should_be_called = True
if foo is not bot: if foo is not bot:
pytest.fail("on_bot_disconnect hook called with wrong bot") pytest.fail("on_bot_disconnect hook called with wrong bot")
@ -259,6 +285,8 @@ async def test_bot_connect_hook(app: App, driver: Driver):
if default != 1: if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value") pytest.fail("on_bot_connect hook called with wrong default value")
disconn_should_be_called = True
if conn_hook not in {hook.call for hook in conn_hooks}: if conn_hook not in {hook.call for hook in conn_hooks}:
pytest.fail("on_bot_connect hook not registered") pytest.fail("on_bot_connect hook not registered")
if disconn_hook not in {hook.call for hook in disconn_hooks}: if disconn_hook not in {hook.call for hook in disconn_hooks}:
@ -268,6 +296,7 @@ async def test_bot_connect_hook(app: App, driver: Driver):
bot = ctx.create_bot() bot = ctx.create_bot()
await asyncio.sleep(1) await asyncio.sleep(1)
if not conn_should_be_called: if not conn_should_be_called:
pytest.fail("on_bot_connect hook not called") pytest.fail("on_bot_connect hook not called")
if not disconn_should_be_called: if not disconn_should_be_called:

View File

@ -31,17 +31,29 @@ async def test_init():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get(app: App, monkeypatch: pytest.MonkeyPatch): async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setattr(nonebot, "_driver", None) m.setattr(nonebot, "_driver", None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_driver() get_driver()
@pytest.mark.asyncio
async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver() driver = get_driver()
assert isinstance(driver, ReverseDriver) assert isinstance(driver, ReverseDriver)
assert get_asgi() == driver.asgi assert get_asgi() == driver.asgi
@pytest.mark.asyncio
async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
assert isinstance(driver, ReverseDriver)
assert get_app() == driver.server_app assert get_app() == driver.server_app
@pytest.mark.asyncio
async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch):
async with app.test_api() as ctx: async with app.test_api() as ctx:
adapter = ctx.create_adapter() adapter = ctx.create_adapter()
adapter_name = adapter.get_name() adapter_name = adapter.get_name()
@ -54,6 +66,9 @@ async def test_get(app: App, monkeypatch: pytest.MonkeyPatch):
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_adapter("not exist") get_adapter("not exist")
@pytest.mark.asyncio
async def test_run(app: App, monkeypatch: pytest.MonkeyPatch):
runned = False runned = False
def mock_run(*args, **kwargs): def mock_run(*args, **kwargs):
@ -61,14 +76,24 @@ async def test_get(app: App, monkeypatch: pytest.MonkeyPatch):
runned = True runned = True
assert args == ("arg",) and kwargs == {"kwarg": "kwarg"} assert args == ("arg",) and kwargs == {"kwarg": "kwarg"}
monkeypatch.setattr(driver, "run", mock_run) driver = get_driver()
nonebot.run("arg", kwarg="kwarg")
with monkeypatch.context() as m:
m.setattr(driver, "run", mock_run)
nonebot.run("arg", kwarg="kwarg")
assert runned assert runned
@pytest.mark.asyncio
async def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch):
driver = get_driver()
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_bot() get_bot()
monkeypatch.setattr(driver, "_bots", {"test": "test"}) with monkeypatch.context() as m:
assert get_bot() == "test" m.setattr(driver, "_bots", {"test": "test"})
assert get_bot("test") == "test" assert get_bot() == "test"
assert get_bots() == {"test": "test"} assert get_bot("test") == "test"
assert get_bots() == {"test": "test"}

View File

@ -3,25 +3,16 @@ from nonebug import App
from nonebot.permission import User from nonebot.permission import User
from nonebot.matcher import Matcher, matchers from nonebot.matcher import Matcher, matchers
from utils import FakeMessage, make_fake_event
from nonebot.message import check_and_run_matcher from nonebot.message import check_and_run_matcher
from utils import make_fake_event, make_fake_message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_matcher(app: App): async def test_matcher_handle(app: App):
from plugins.matcher.matcher_process import ( from plugins.matcher.matcher_process import test_handle
test_got,
test_handle,
test_preset,
test_combine,
test_receive,
test_overload,
)
message = make_fake_message()("text") message = FakeMessage("text")
event = make_fake_event(_message=message)() event = make_fake_event(_message=message)()
message_next = make_fake_message()("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_handle.handlers) == 1 assert len(test_handle.handlers) == 1
async with app.test_matcher(test_handle) as ctx: async with app.test_matcher(test_handle) as ctx:
@ -30,6 +21,16 @@ async def test_matcher(app: App):
ctx.should_call_send(event, "send", "result", at_sender=True) ctx.should_call_send(event, "send", "result", at_sender=True)
ctx.should_finished() ctx.should_finished()
@pytest.mark.asyncio
async def test_matcher_got(app: App):
from plugins.matcher.matcher_process import test_got
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = FakeMessage("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_got.handlers) == 1 assert len(test_got.handlers) == 1
async with app.test_matcher(test_got) as ctx: async with app.test_matcher(test_got) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
@ -42,6 +43,14 @@ async def test_matcher(app: App):
ctx.should_rejected() ctx.should_rejected()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
async def test_matcher_receive(app: App):
from plugins.matcher.matcher_process import test_receive
message = FakeMessage("text")
event = make_fake_event(_message=message)()
assert len(test_receive.handlers) == 1 assert len(test_receive.handlers) == 1
async with app.test_matcher(test_receive) as ctx: async with app.test_matcher(test_receive) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
@ -51,7 +60,17 @@ async def test_matcher(app: App):
ctx.should_call_send(event, "pause", "result", at_sender=True) ctx.should_call_send(event, "pause", "result", at_sender=True)
ctx.should_paused() ctx.should_paused()
assert len(test_receive.handlers) == 1
@pytest.mark.asyncio
async def test_matcher_(app: App):
from plugins.matcher.matcher_process import test_combine
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = FakeMessage("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_combine.handlers) == 1
async with app.test_matcher(test_combine) as ctx: async with app.test_matcher(test_combine) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.receive_event(bot, event) ctx.receive_event(bot, event)
@ -64,6 +83,16 @@ async def test_matcher(app: App):
ctx.should_rejected() ctx.should_rejected()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
async def test_matcher_preset(app: App):
from plugins.matcher.matcher_process import test_preset
message = FakeMessage("text")
event = make_fake_event(_message=message)()
message_next = FakeMessage("text_next")
event_next = make_fake_event(_message=message_next)()
assert len(test_preset.handlers) == 2 assert len(test_preset.handlers) == 2
async with app.test_matcher(test_preset) as ctx: async with app.test_matcher(test_preset) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
@ -72,6 +101,14 @@ async def test_matcher(app: App):
ctx.should_rejected() ctx.should_rejected()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
@pytest.mark.asyncio
async def test_matcher_overload(app: App):
from plugins.matcher.matcher_process import test_overload
message = FakeMessage("text")
event = make_fake_event(_message=message)()
assert len(test_overload.handlers) == 2 assert len(test_overload.handlers) == 2
async with app.test_matcher(test_overload) as ctx: async with app.test_matcher(test_overload) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
@ -115,12 +152,10 @@ async def test_type_updater(app: App):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_permission_updater(app: App): async def test_default_permission_updater(app: App):
from plugins.matcher.matcher_permission import ( from plugins.matcher.matcher_permission import (
default_permission, default_permission,
test_custom_updater,
test_permission_updater, test_permission_updater,
test_user_permission_updater,
) )
event = make_fake_event(_session_id="test")() event = make_fake_event(_session_id="test")()
@ -136,6 +171,15 @@ async def test_permission_updater(app: App):
assert checker.users == ("test",) assert checker.users == ("test",)
assert checker.perm is default_permission assert checker.perm is default_permission
@pytest.mark.asyncio
async def test_user_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
default_permission,
test_user_permission_updater,
)
event = make_fake_event(_session_id="test")()
user_permission = list(test_user_permission_updater.permission.checkers)[0].call user_permission = list(test_user_permission_updater.permission.checkers)[0].call
assert isinstance(user_permission, User) assert isinstance(user_permission, User)
assert user_permission.perm is default_permission assert user_permission.perm is default_permission
@ -149,12 +193,22 @@ async def test_permission_updater(app: App):
assert checker.users == ("test",) assert checker.users == ("test",)
assert checker.perm is default_permission assert checker.perm is default_permission
@pytest.mark.asyncio
async def test_custom_permission_updater(app: App):
from plugins.matcher.matcher_permission import (
new_permission,
default_permission,
test_custom_updater,
)
event = make_fake_event(_session_id="test")()
assert test_custom_updater.permission is default_permission assert test_custom_updater.permission is default_permission
async with app.test_api() as ctx: async with app.test_api() as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
matcher = test_custom_updater() matcher = test_custom_updater()
new_perm = await matcher.update_permission(bot, event) new_perm = await matcher.update_permission(bot, event)
assert new_perm is default_permission assert new_perm is new_permission
@pytest.mark.asyncio @pytest.mark.asyncio
@ -189,12 +243,8 @@ async def test_run(app: App):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_expire(app: App): async def test_temp(app: App):
from plugins.matcher.matcher_expire import ( from plugins.matcher.matcher_expire import test_temp_matcher
test_temp_matcher,
test_datetime_matcher,
test_timedelta_matcher,
)
event = make_fake_event(_type="test")() event = make_fake_event(_type="test")()
async with app.test_api() as ctx: async with app.test_api() as ctx:
@ -203,6 +253,11 @@ async def test_expire(app: App):
await check_and_run_matcher(test_temp_matcher, bot, event, {}) await check_and_run_matcher(test_temp_matcher, bot, event, {})
assert test_temp_matcher not in matchers[test_temp_matcher.priority] assert test_temp_matcher not in matchers[test_temp_matcher.priority]
@pytest.mark.asyncio
async def test_datetime_expire(app: App):
from plugins.matcher.matcher_expire import test_datetime_matcher
event = make_fake_event()() event = make_fake_event()()
async with app.test_api() as ctx: async with app.test_api() as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
@ -210,6 +265,11 @@ async def test_expire(app: App):
await check_and_run_matcher(test_datetime_matcher, bot, event, {}) await check_and_run_matcher(test_datetime_matcher, bot, event, {})
assert test_datetime_matcher not in matchers[test_datetime_matcher.priority] assert test_datetime_matcher not in matchers[test_datetime_matcher.priority]
@pytest.mark.asyncio
async def test_timedelta_expire(app: App):
from plugins.matcher.matcher_expire import test_timedelta_matcher
event = make_fake_event()() event = make_fake_event()()
async with app.test_api() as ctx: async with app.test_api() as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()

View File

@ -6,7 +6,7 @@ from nonebug import App
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from utils import make_fake_event, make_fake_message from utils import FakeMessage, make_fake_event
from nonebot.params import ( from nonebot.params import (
ArgParam, ArgParam,
BotParam, BotParam,
@ -157,7 +157,7 @@ async def test_event(app: App):
generic_event_none, generic_event_none,
) )
fake_message = make_fake_message()("text") fake_message = FakeMessage("text")
fake_event = make_fake_event(_message=fake_message)() fake_event = make_fake_event(_message=fake_message)()
fake_fooevent = make_fake_event(_base=FooEvent)() fake_fooevent = make_fake_event(_base=FooEvent)()
@ -247,7 +247,7 @@ async def test_state(app: App):
shell_command_argv, shell_command_argv,
) )
fake_message = make_fake_message()("text") fake_message = FakeMessage("text")
fake_matched = re.match(r"\[cq:(?P<type>.*?),(?P<arg>.*?)\]", "[cq:test,arg=value]") fake_matched = re.match(r"\[cq:(?P<type>.*?),(?P<arg>.*?)\]", "[cq:test,arg=value]")
fake_state = { fake_state = {
PREFIX_KEY: { PREFIX_KEY: {
@ -453,7 +453,7 @@ async def test_arg(app: App):
from plugins.param.param_arg import arg, arg_str, arg_plain_text from plugins.param.param_arg import arg, arg_str, arg_plain_text
matcher = Matcher() matcher = Matcher()
message = make_fake_message()("text") message = FakeMessage("text")
matcher.set_arg("key", message) matcher.set_arg("key", message)
async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx:

View File

@ -6,8 +6,8 @@ import pytest
from nonebug import App from nonebug import App
from nonebot.typing import T_State from nonebot.typing import T_State
from utils import make_fake_event, make_fake_message
from nonebot.exception import ParserExit, SkippedException from nonebot.exception import ParserExit, SkippedException
from utils import FakeMessage, FakeMessageSegment, make_fake_event
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
PREFIX_KEY, PREFIX_KEY,
@ -85,24 +85,21 @@ async def test_rule(app: App):
async def test_trie(app: App): async def test_trie(app: App):
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",))) TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
async with app.test_api() as ctx: async with app.test_api() as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
message = Message("/fake-prefix some args") message = FakeMessage("/fake-prefix some args")
event = make_fake_event(_message=message)() event = make_fake_event(_message=message)()
state = {} state = {}
TrieRule.get_value(bot, event, state) TrieRule.get_value(bot, event, state)
assert state[PREFIX_KEY] == CMD_RESULT( assert state[PREFIX_KEY] == CMD_RESULT(
command=("fake-prefix",), command=("fake-prefix",),
raw_command="/fake-prefix", raw_command="/fake-prefix",
command_arg=Message("some args"), command_arg=FakeMessage("some args"),
command_start="/", command_start="/",
command_whitespace=" ", command_whitespace=" ",
) )
message = MessageSegment.text("/fake-prefix ") + MessageSegment.image( message = FakeMessageSegment.text("/fake-prefix ") + FakeMessageSegment.image(
"fake url" "fake url"
) )
event = make_fake_event(_message=message)() event = make_fake_event(_message=message)()
@ -111,7 +108,7 @@ async def test_trie(app: App):
assert state[PREFIX_KEY] == CMD_RESULT( assert state[PREFIX_KEY] == CMD_RESULT(
command=("fake-prefix",), command=("fake-prefix",),
raw_command="/fake-prefix", raw_command="/fake-prefix",
command_arg=Message(MessageSegment.image("fake url")), command_arg=FakeMessage(FakeMessageSegment.image("fake url")),
command_start="/", command_start="/",
command_whitespace=" ", command_whitespace=" ",
) )
@ -152,7 +149,7 @@ async def test_startswith(
assert checker.msg == msg assert checker.msg == msg
assert checker.ignorecase == ignorecase assert checker.ignorecase == ignorecase
message = text if text is None else make_fake_message()(text) message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
for prefix in msg: for prefix in msg:
state = {STARTSWITH_KEY: prefix} state = {STARTSWITH_KEY: prefix}
@ -192,7 +189,7 @@ async def test_endswith(
assert checker.msg == msg assert checker.msg == msg
assert checker.ignorecase == ignorecase assert checker.ignorecase == ignorecase
message = text if text is None else make_fake_message()(text) message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
for suffix in msg: for suffix in msg:
state = {ENDSWITH_KEY: suffix} state = {ENDSWITH_KEY: suffix}
@ -232,7 +229,7 @@ async def test_fullmatch(
assert checker.msg == msg assert checker.msg == msg
assert checker.ignorecase == ignorecase assert checker.ignorecase == ignorecase
message = text if text is None else make_fake_message()(text) message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
for full in msg: for full in msg:
state = {FULLMATCH_KEY: full} state = {FULLMATCH_KEY: full}
@ -264,7 +261,7 @@ async def test_keyword(
assert isinstance(checker, KeywordsRule) assert isinstance(checker, KeywordsRule)
assert checker.keywords == kws assert checker.keywords == kws
message = text if text is None else make_fake_message()(text) message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
for kw in kws: for kw in kws:
state = {KEYWORD_KEY: kw} state = {KEYWORD_KEY: kw}
@ -310,7 +307,7 @@ async def test_command(
assert isinstance(checker, CommandRule) assert isinstance(checker, CommandRule)
assert checker.cmds == cmds assert checker.cmds == cmds
arg = arg_text if arg_text is None else make_fake_message()(arg_text) arg = arg_text if arg_text is None else FakeMessage(arg_text)
state = { state = {
PREFIX_KEY: {CMD_KEY: cmd, CMD_WHITESPACE_KEY: whitespace, CMD_ARG_KEY: arg} PREFIX_KEY: {CMD_KEY: cmd, CMD_WHITESPACE_KEY: whitespace, CMD_ARG_KEY: arg}
} }
@ -321,7 +318,7 @@ async def test_command(
async def test_shell_command(): async def test_shell_command():
state: T_State state: T_State
CMD = ("test",) CMD = ("test",)
Message = make_fake_message() Message = FakeMessage
MessageSegment = Message.get_segment_class() MessageSegment = Message.get_segment_class()
test_not_cmd = shell_command(CMD) test_not_cmd = shell_command(CMD)
@ -455,7 +452,7 @@ async def test_regex(
assert isinstance(checker, RegexRule) assert isinstance(checker, RegexRule)
assert checker.regex == pattern assert checker.regex == pattern
message = text if text is None else make_fake_message()(text) message = text if text is None else FakeMessage(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
state = {} state = {}
assert await dependent(event=event, state=state) == expected assert await dependent(event=event, state=state) == expected

View File

@ -1,16 +1,122 @@
import json import json
from typing import Dict, List, Union, TypeVar
from utils import make_fake_message from utils import FakeMessage, FakeMessageSegment
from nonebot.utils import DataclassEncoder from nonebot.utils import (
DataclassEncoder,
escape_tag,
is_gen_callable,
is_async_gen_callable,
is_coroutine_callable,
generic_check_issubclass,
)
def test_loguru_escape_tag():
assert escape_tag("<red>red</red>") == r"\<red>red\</red>"
assert escape_tag("<fg #fff>white</fg #fff>") == r"\<fg #fff>white\</fg #fff>"
assert escape_tag("<fg\n#fff>white</fg\n#fff>") == "\\<fg\n#fff>white\\</fg\n#fff>"
assert escape_tag("<bg #fff>white</bg #fff>") == r"\<bg #fff>white\</bg #fff>"
assert escape_tag("<bg\n#fff>white</bg\n#fff>") == "\\<bg\n#fff>white\\</bg\n#fff>"
def test_generic_check_issubclass():
assert generic_check_issubclass(int, (int, float))
assert not generic_check_issubclass(str, (int, float))
assert generic_check_issubclass(Union[int, float, None], (int, float))
assert generic_check_issubclass(List[int], list)
assert generic_check_issubclass(Dict[str, int], dict)
assert generic_check_issubclass(TypeVar("T", int, float), (int, float))
assert generic_check_issubclass(TypeVar("T", bound=int), (int, float))
def test_is_coroutine_callable():
async def test1():
...
def test2():
...
class TestClass1:
async def __call__(self):
...
class TestClass2:
def __call__(self):
...
assert is_coroutine_callable(test1)
assert not is_coroutine_callable(test2)
assert not is_coroutine_callable(TestClass1)
assert is_coroutine_callable(TestClass1())
assert not is_coroutine_callable(TestClass2)
def test_is_gen_callable():
def test1():
yield
async def test2():
yield
def test3():
...
class TestClass1:
def __call__(self):
yield
class TestClass2:
async def __call__(self):
yield
class TestClass3:
def __call__(self):
...
assert is_gen_callable(test1)
assert not is_gen_callable(test2)
assert not is_gen_callable(test3)
assert is_gen_callable(TestClass1())
assert not is_gen_callable(TestClass2())
assert not is_gen_callable(TestClass3())
def test_is_async_gen_callable():
async def test1():
yield
def test2():
yield
async def test3():
...
class TestClass1:
async def __call__(self):
yield
class TestClass2:
def __call__(self):
yield
class TestClass3:
async def __call__(self):
...
assert is_async_gen_callable(test1)
assert not is_async_gen_callable(test2)
assert not is_async_gen_callable(test3)
assert is_async_gen_callable(TestClass1())
assert not is_async_gen_callable(TestClass2())
assert not is_async_gen_callable(TestClass3())
def test_dataclass_encoder(): def test_dataclass_encoder():
simple = json.dumps("123", cls=DataclassEncoder) simple = json.dumps("123", cls=DataclassEncoder)
assert simple == '"123"' assert simple == '"123"'
Message = make_fake_message() ms = FakeMessageSegment.nested(FakeMessage(FakeMessageSegment.text("text")))
MessageSegment = Message.get_segment_class()
ms = MessageSegment.nested(Message(MessageSegment.text("text")))
s = json.dumps(ms, cls=DataclassEncoder) s = json.dumps(ms, cls=DataclassEncoder)
assert ( assert (
s s

View File

@ -1,6 +1,6 @@
from typing import Type, Union, Mapping, Iterable, Optional from typing import Type, Union, Mapping, Iterable, Optional
from pydantic import create_model from pydantic import Extra, create_model
from nonebot.adapters import Event, Message, MessageSegment from nonebot.adapters import Event, Message, MessageSegment
@ -12,51 +12,49 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str:
return s return s
def make_fake_message(): class FakeMessageSegment(MessageSegment["FakeMessage"]):
class FakeMessageSegment(MessageSegment["FakeMessage"]): @classmethod
@classmethod def get_message_class(cls):
def get_message_class(cls): return FakeMessage
return FakeMessage
def __str__(self) -> str: def __str__(self) -> str:
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]" return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
@classmethod @classmethod
def text(cls, text: str): def text(cls, text: str):
return cls("text", {"text": text}) return cls("text", {"text": text})
@staticmethod @staticmethod
def image(url: str): def image(url: str):
return FakeMessageSegment("image", {"url": url}) return FakeMessageSegment("image", {"url": url})
@staticmethod @staticmethod
def nested(content: "FakeMessage"): def nested(content: "FakeMessage"):
return FakeMessageSegment("node", {"content": content}) return FakeMessageSegment("node", {"content": content})
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"
class FakeMessage(Message[FakeMessageSegment]):
@classmethod
def get_segment_class(cls):
return FakeMessageSegment
@staticmethod class FakeMessage(Message[FakeMessageSegment]):
def _construct(msg: Union[str, Iterable[Mapping]]): @classmethod
if isinstance(msg, str): def get_segment_class(cls):
yield FakeMessageSegment.text(msg) return FakeMessageSegment
else:
for seg in msg:
yield FakeMessageSegment(**seg)
return
def __add__( @staticmethod
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]] def _construct(msg: Union[str, Iterable[Mapping]]):
): if isinstance(msg, str):
other = escape_text(other) if isinstance(other, str) else other yield FakeMessageSegment.text(msg)
return super().__add__(other) else:
for seg in msg:
yield FakeMessageSegment(**seg)
return
return FakeMessage def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other)
def make_fake_event( def make_fake_event(
@ -70,9 +68,9 @@ def make_fake_event(
_to_me: bool = True, _to_me: bool = True,
**fields, **fields,
) -> Type[Event]: ) -> Type[Event]:
_Fake = create_model("_Fake", __base__=_base or Event, **fields) Base = _base or Event
class FakeEvent(_Fake): class FakeEvent(Base, extra=Extra.forbid):
def get_type(self) -> str: def get_type(self) -> str:
return _type return _type
@ -100,7 +98,4 @@ def make_fake_event(
def is_tome(self) -> bool: def is_tome(self) -> bool:
return _to_me return _to_me
class Config: return create_model("FakeEvent", __base__=FakeEvent, **fields)
extra = "forbid"
return FakeEvent