From 080b876d93c2f87345e24b0ca9b5ef33eb4adcc2 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Mon, 19 Jun 2023 17:48:59 +0800 Subject: [PATCH] =?UTF-8?q?:construction=5Fworker:=20Test:=20=E7=A7=BB?= =?UTF-8?q?=E9=99=A4=20httpbin=20=E5=B9=B6=E6=95=B4=E7=90=86=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=20(#2110)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- poetry.lock | 12 +- pyproject.toml | 3 +- tests/conftest.py | 23 +- tests/fake_server.py | 69 ++++ tests/plugins/matcher/matcher_permission.py | 3 +- tests/test_adapters/test_message.py | 333 ++++++++++---------- tests/test_adapters/test_template.py | 22 +- tests/test_driver.py | 57 +++- tests/test_init.py | 39 ++- tests/test_matcher/test_matcher.py | 108 +++++-- tests/test_param.py | 8 +- tests/test_rule.py | 27 +- tests/test_utils.py | 116 ++++++- tests/utils.py | 81 +++-- 14 files changed, 592 insertions(+), 309 deletions(-) create mode 100644 tests/fake_server.py diff --git a/poetry.lock b/poetry.lock index 7d2bf6dd..0fd7441b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1215,7 +1215,7 @@ dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegu name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:665a36ae6f8f20a4676b53224e33d456a6f5a72657d9c83c2aa00765072f31f7"}, @@ -2239,13 +2239,13 @@ files = [ [[package]] name = "werkzeug" -version = "2.3.4" +version = "2.3.6" description = "The comprehensive WSGI web application library." -optional = true +optional = false python-versions = ">=3.8" files = [ - {file = "Werkzeug-2.3.4-py3-none-any.whl", hash = "sha256:48e5e61472fee0ddee27ebad085614ebedb7af41e88f687aaf881afb723a162f"}, - {file = "Werkzeug-2.3.4.tar.gz", hash = "sha256:1d5a58e0377d1fe39d061a5de4469e414e78ccb1e1e59c0f5ad6fa1c36c52b76"}, + {file = "Werkzeug-2.3.6-py3-none-any.whl", hash = "sha256:935539fa1413afbb9195b24880778422ed620c0fc09670945185cce4d91a8890"}, + {file = "Werkzeug-2.3.6.tar.gz", hash = "sha256:98c774df2f91b05550078891dee5f0eb0cb797a522c757a2452b9cee5b202330"}, ] [package.dependencies] @@ -2395,4 +2395,4 @@ websockets = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "be276204946fdf3338fb06cc5391aafc2f2fc89fc5c0a91bf09590f1eee40aa4" +content-hash = "cf0a75f5173e6eb0fa3cd9d8901d4149e104094d08bce7e298d6f8c92954803d" diff --git a/pyproject.toml b/pyproject.toml index c678466b..ad2fc24d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ pre-commit = "^3.0.0" [tool.poetry.group.test.dependencies] nonebug = "^0.3.0" +werkzeug = "^2.3.6" pytest-cov = "^4.0.0" pytest-xdist = "^3.0.2" pytest-asyncio = "^0.21.0" @@ -68,7 +69,7 @@ fastapi = ["fastapi", "uvicorn"] all = ["fastapi", "quart", "aiohttp", "httpx", "websockets", "uvicorn"] [tool.pytest.ini_options] -asyncio_mode = "auto" +asyncio_mode = "strict" addopts = "--cov=nonebot --cov-append --cov-report=term-missing" filterwarnings = [ "error", diff --git a/tests/conftest.py b/tests/conftest.py index a25efd85..56e2b5f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,15 @@ import os +import threading from pathlib import Path -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING, Set, Generator import pytest from nonebug import NONEBOT_INIT_KWARGS +from werkzeug.serving import BaseWSGIServer, make_server import nonebot +from nonebot.drivers import URL +from fake_server import request_handler os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}' os.environ["CONFIG_OVERRIDE"] = "new" @@ -28,3 +32,20 @@ def load_plugin(nonebug_init: None) -> Set["Plugin"]: def load_builtin_plugin(nonebug_init: None) -> Set["Plugin"]: # preload builtin plugins return nonebot.load_builtin_plugins("echo", "single_session") + + +@pytest.fixture(scope="session", autouse=True) +def server() -> Generator[BaseWSGIServer, None, None]: + server = make_server("127.0.0.1", 0, app=request_handler) + thread = threading.Thread(target=server.serve_forever) + thread.start() + try: + yield server + finally: + server.shutdown() + thread.join() + + +@pytest.fixture(scope="session") +def server_url(server: BaseWSGIServer) -> URL: + return URL(f"http://{server.host}:{server.port}") diff --git a/tests/fake_server.py b/tests/fake_server.py new file mode 100644 index 00000000..3f19cf2c --- /dev/null +++ b/tests/fake_server.py @@ -0,0 +1,69 @@ +import json +import base64 +from typing import Dict, List, Union, TypeVar + +from werkzeug import Request, Response +from werkzeug.datastructures import MultiDict + +K = TypeVar("K") +V = TypeVar("V") + + +def json_safe(string, content_type="application/octet-stream") -> str: + try: + string = string.decode("utf-8") + json.dumps(string) + return string + except (ValueError, TypeError): + return b"".join( + [ + b"data:", + content_type.encode("utf-8"), + b";base64,", + base64.b64encode(string), + ] + ).decode("utf-8") + + +def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]: + return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()} + + +@Request.application +def request_handler(request: Request) -> Response: + try: + _json = json.loads(request.data.decode("utf-8")) + except (ValueError, TypeError): + _json = None + + return Response( + json.dumps( + { + "url": request.url, + "method": request.method, + "origin": request.headers.get("X-Forwarded-For", request.remote_addr), + "headers": flattern( + MultiDict((k, v) for k, v in request.headers.items()) + ), + "args": flattern(request.args), + "form": flattern(request.form), + "data": json_safe(request.data), + "json": _json, + "files": flattern( + MultiDict( + ( + k, + json_safe( + v.read(), + request.files[k].content_type + or "application/octet-stream", + ), + ) + for k, v in request.files.items() + ) + ), + } + ), + status=200, + content_type="application/json", + ) diff --git a/tests/plugins/matcher/matcher_permission.py b/tests/plugins/matcher/matcher_permission.py index d4042502..6fb26d04 100644 --- a/tests/plugins/matcher/matcher_permission.py +++ b/tests/plugins/matcher/matcher_permission.py @@ -2,6 +2,7 @@ from nonebot.matcher import Matcher from nonebot.permission import USER, Permission default_permission = Permission() +new_permission = Permission() test_permission_updater = Matcher.new(permission=default_permission) @@ -14,4 +15,4 @@ test_custom_updater = Matcher.new(permission=default_permission) @test_custom_updater.permission_updater async def _() -> Permission: - return default_permission + return new_permission diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index 6e6bc377..2748a7c0 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -1,135 +1,136 @@ import pytest from pydantic import ValidationError, parse_obj_as -from utils import make_fake_message +from nonebot.adapters import Message +from utils import FakeMessage, FakeMessageSegment -def test_segment_add(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - assert MessageSegment.text("text") + MessageSegment.text("text") == Message( - [MessageSegment.text("text"), MessageSegment.text("text")] - ) - - assert MessageSegment.text("text") + "text" == Message( - [MessageSegment.text("text"), MessageSegment.text("text")] - ) - - assert ( - MessageSegment.text("text") + Message([MessageSegment.text("text")]) - ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) - - assert "text" + MessageSegment.text("text") == Message( - [MessageSegment.text("text"), MessageSegment.text("text")] - ) - - -def test_segment_validate(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - assert parse_obj_as( - MessageSegment, - {"type": "text", "data": {"text": "text"}, "extra": "should be ignored"}, - ) == MessageSegment.text("text") - - with pytest.raises(ValidationError): - parse_obj_as(MessageSegment, "some str") - - with pytest.raises(ValidationError): - parse_obj_as(MessageSegment, {"data": {}}) - - -def test_segment_join(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - seg = MessageSegment.text("test") - iterable = [ - MessageSegment.text("first"), - Message([MessageSegment.text("second"), MessageSegment.text("third")]), - ] - - assert seg.join(iterable) == Message( - [ - MessageSegment.text("first"), - MessageSegment.text("test"), - MessageSegment.text("second"), - MessageSegment.text("third"), - ] - ) - - -def test_segment(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - assert len(MessageSegment.text("text")) == 4 - assert MessageSegment.text("text") != MessageSegment.text("other") - assert MessageSegment.text("text").get("data") == {"text": "text"} - assert list(MessageSegment.text("text").keys()) == ["type", "data"] - assert list(MessageSegment.text("text").values()) == ["text", {"text": "text"}] - assert list(MessageSegment.text("text").items()) == [ +def test_segment_data(): + assert len(FakeMessageSegment.text("text")) == 4 + assert FakeMessageSegment.text("text").get("data") == {"text": "text"} + assert list(FakeMessageSegment.text("text").keys()) == ["type", "data"] + assert list(FakeMessageSegment.text("text").values()) == ["text", {"text": "text"}] + assert list(FakeMessageSegment.text("text").items()) == [ ("type", "text"), ("data", {"text": "text"}), ] - origin = MessageSegment.text("text") + +def test_segment_equal(): + assert FakeMessageSegment("text", {"text": "text"}) == FakeMessageSegment( + "text", {"text": "text"} + ) + assert FakeMessageSegment("text", {"text": "text"}) != FakeMessageSegment( + "text", {"text": "other"} + ) + assert FakeMessageSegment("text", {"text": "text"}) != FakeMessageSegment( + "other", {"text": "text"} + ) + + +def test_segment_add(): + assert FakeMessageSegment.text("text") + FakeMessageSegment.text( + "text" + ) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]) + + assert FakeMessageSegment.text("text") + "text" == FakeMessage( + [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")] + ) + + assert ( + FakeMessageSegment.text("text") + FakeMessage([FakeMessageSegment.text("text")]) + ) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]) + + assert "text" + FakeMessageSegment.text("text") == FakeMessage( + [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")] + ) + + +def test_segment_validate(): + assert parse_obj_as( + FakeMessageSegment, + {"type": "text", "data": {"text": "text"}, "extra": "should be ignored"}, + ) == FakeMessageSegment.text("text") + + with pytest.raises(ValidationError): + parse_obj_as(FakeMessageSegment, "some str") + + with pytest.raises(ValidationError): + parse_obj_as(FakeMessageSegment, {"data": {}}) + + +def test_segment_join(): + seg = FakeMessageSegment.text("test") + iterable = [ + FakeMessageSegment.text("first"), + FakeMessage( + [FakeMessageSegment.text("second"), FakeMessageSegment.text("third")] + ), + ] + + assert seg.join(iterable) == FakeMessage( + [ + FakeMessageSegment.text("first"), + FakeMessageSegment.text("test"), + FakeMessageSegment.text("second"), + FakeMessageSegment.text("third"), + ] + ) + + +def test_segment_copy(): + origin = FakeMessageSegment.text("text") copy = origin.copy() assert origin is not copy assert origin == copy def test_message_add(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - assert ( - Message([MessageSegment.text("text")]) + MessageSegment.text("text") - ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + FakeMessage([FakeMessageSegment.text("text")]) + FakeMessageSegment.text("text") + ) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]) - assert Message([MessageSegment.text("text")]) + "text" == Message( - [MessageSegment.text("text"), MessageSegment.text("text")] + assert FakeMessage([FakeMessageSegment.text("text")]) + "text" == FakeMessage( + [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")] ) assert ( - Message([MessageSegment.text("text")]) + Message([MessageSegment.text("text")]) - ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + FakeMessage([FakeMessageSegment.text("text")]) + + FakeMessage([FakeMessageSegment.text("text")]) + ) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]) - assert "text" + Message([MessageSegment.text("text")]) == Message( - [MessageSegment.text("text"), MessageSegment.text("text")] + assert "text" + FakeMessage([FakeMessageSegment.text("text")]) == FakeMessage( + [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")] ) - msg = Message([MessageSegment.text("text")]) - msg += MessageSegment.text("text") - assert msg == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + msg = FakeMessage([FakeMessageSegment.text("text")]) + msg += FakeMessageSegment.text("text") + assert msg == FakeMessage( + [FakeMessageSegment.text("text"), FakeMessageSegment.text("text")] + ) def test_message_getitem(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.image("test2"), - MessageSegment.image("test3"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.image("test2"), + FakeMessageSegment.image("test3"), + FakeMessageSegment.text("test4"), ] ) - assert message[0] == MessageSegment.text("test") + assert message[0] == FakeMessageSegment.text("test") - assert message[:2] == Message( - [MessageSegment.text("test"), MessageSegment.image("test2")] + assert message[:2] == FakeMessage( + [FakeMessageSegment.text("test"), FakeMessageSegment.image("test2")] ) - assert message["image"] == Message( - [MessageSegment.image("test2"), MessageSegment.image("test3")] + assert message["image"] == FakeMessage( + [FakeMessageSegment.image("test2"), FakeMessageSegment.image("test3")] ) - assert message["image", 0] == MessageSegment.image("test2") + assert message["image", 0] == FakeMessageSegment.image("test2") assert message["image", 0:2] == message["image"] assert message.index(message[0]) == 0 @@ -137,153 +138,137 @@ def test_message_getitem(): assert message.get("image") == message["image"] assert message.get("image", 114514) == message["image"] - assert message.get("image", 1) == Message([message["image", 0]]) + assert message.get("image", 1) == FakeMessage([message["image", 0]]) assert message.count("image") == 2 def test_message_validate(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - Message_ = make_fake_message() - - assert parse_obj_as(Message, Message([])) == Message([]) + assert parse_obj_as(FakeMessage, FakeMessage([])) == FakeMessage([]) with pytest.raises(ValidationError): - parse_obj_as(Message, Message_([])) + parse_obj_as(type("FakeMessage2", (Message,), {}), FakeMessage([])) - assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")]) - - assert parse_obj_as(Message, {"type": "text", "data": {"text": "text"}}) == Message( - [MessageSegment.text("text")] + assert parse_obj_as(FakeMessage, "text") == FakeMessage( + [FakeMessageSegment.text("text")] ) assert parse_obj_as( - Message, - [MessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}], - ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + FakeMessage, {"type": "text", "data": {"text": "text"}} + ) == FakeMessage([FakeMessageSegment.text("text")]) + + assert parse_obj_as( + FakeMessage, + [FakeMessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}], + ) == FakeMessage([FakeMessageSegment.text("text"), FakeMessageSegment.text("text")]) with pytest.raises(ValidationError): - parse_obj_as(Message, object()) + parse_obj_as(FakeMessage, object()) def test_message_contains(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.image("test2"), - MessageSegment.image("test3"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.image("test2"), + FakeMessageSegment.image("test3"), + FakeMessageSegment.text("test4"), ] ) - assert message.has(MessageSegment.text("test")) is True - assert MessageSegment.text("test") in message + assert message.has(FakeMessageSegment.text("test")) is True + assert FakeMessageSegment.text("test") in message assert message.has("image") is True assert "image" in message - assert message.has(MessageSegment.text("foo")) is False - assert MessageSegment.text("foo") not in message + assert message.has(FakeMessageSegment.text("foo")) is False + assert FakeMessageSegment.text("foo") not in message assert message.has("foo") is False assert "foo" not in message def test_message_only(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.text("test2"), + FakeMessageSegment.text("test"), + FakeMessageSegment.text("test2"), ] ) assert message.only("text") is True - assert message.only(MessageSegment.text("test")) is False + assert message.only(FakeMessageSegment.text("test")) is False - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.image("test2"), - MessageSegment.image("test3"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.image("test2"), + FakeMessageSegment.image("test3"), + FakeMessageSegment.text("test4"), ] ) assert message.only("text") is False - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.text("test"), + FakeMessageSegment.text("test"), + FakeMessageSegment.text("test"), ] ) - assert message.only(MessageSegment.text("test")) is True + assert message.only(FakeMessageSegment.text("test")) is True def test_message_join(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - msg = Message([MessageSegment.text("test")]) + msg = FakeMessage([FakeMessageSegment.text("test")]) iterable = [ - MessageSegment.text("first"), - Message([MessageSegment.text("second"), MessageSegment.text("third")]), + FakeMessageSegment.text("first"), + FakeMessage( + [FakeMessageSegment.text("second"), FakeMessageSegment.text("third")] + ), ] - assert msg.join(iterable) == Message( + assert msg.join(iterable) == FakeMessage( [ - MessageSegment.text("first"), - MessageSegment.text("test"), - MessageSegment.text("second"), - MessageSegment.text("third"), + FakeMessageSegment.text("first"), + FakeMessageSegment.text("test"), + FakeMessageSegment.text("second"), + FakeMessageSegment.text("third"), ] ) def test_message_include(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.image("test2"), - MessageSegment.image("test3"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.image("test2"), + FakeMessageSegment.image("test3"), + FakeMessageSegment.text("test4"), ] ) - assert message.include("text") == Message( + assert message.include("text") == FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.text("test4"), ] ) def test_message_exclude(): - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - - message = Message( + message = FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.image("test2"), - MessageSegment.image("test3"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.image("test2"), + FakeMessageSegment.image("test3"), + FakeMessageSegment.text("test4"), ] ) - assert message.exclude("image") == Message( + assert message.exclude("image") == FakeMessage( [ - MessageSegment.text("test"), - MessageSegment.text("test4"), + FakeMessageSegment.text("test"), + FakeMessageSegment.text("test4"), ] ) diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 865a03b2..710c556a 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,5 +1,5 @@ from nonebot.adapters import MessageTemplate -from utils import escape_text, make_fake_message +from utils import FakeMessage, FakeMessageSegment, escape_text def test_template_basis(): @@ -9,8 +9,7 @@ def test_template_basis(): def test_template_message(): - Message = make_fake_message() - template = Message.template("{a:custom}{b:text}{c:image}/{d}") + template = FakeMessage.template("{a:custom}{b:text}{c:image}/{d}") @template.add_format_spec def custom(input: str) -> str: @@ -37,29 +36,24 @@ def test_template_message(): def test_rich_template_message(): - Message = make_fake_message() - MS = Message.get_segment_class() - pic1, pic2, pic3 = ( - MS.image("file:///pic1.jpg"), - MS.image("file:///pic2.jpg"), - MS.image("file:///pic3.jpg"), + FakeMessageSegment.image("file:///pic1.jpg"), + FakeMessageSegment.image("file:///pic2.jpg"), + FakeMessageSegment.image("file:///pic3.jpg"), ) - template = Message.template("{}{}" + pic2 + "{}") + template = FakeMessage.template("{}{}" + pic2 + "{}") result = template.format(pic1, "[fake:image]", pic3) - assert result["image"] == Message([pic1, pic2, pic3]) + assert result["image"] == FakeMessage([pic1, pic2, pic3]) assert str(result) == ( "[fake:image]" + escape_text("[fake:image]") + "[fake:image]" + "[fake:image]" ) def test_message_injection(): - Message = make_fake_message() - - template = Message.template("{name}Is Bad") + template = FakeMessage.template("{name}Is Bad") message = template.format(name="[fake:image]") assert message.extract_plain_text() == escape_text("[fake:image]Is Bad") diff --git a/tests/test_driver.py b/tests/test_driver.py index 48b503a0..d09d3bbf 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -79,13 +79,37 @@ async def test_lifespan(): ], indirect=True, ) -async def test_reverse_driver(app: App, driver: Driver): +async def test_http_server(app: App, driver: Driver): driver = cast(ReverseDriver, driver) async def _handle_http(request: Request) -> Response: assert request.content in (b"test", "test") return Response(200, content="test") + http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http) + driver.setup_http_server(http_setup) + + async with app.test_server(driver.asgi) as ctx: + client = ctx.get_client() + response = await client.post("/http_test", data="test") + assert response.status_code == 200 + assert response.text == "test" + + await asyncio.sleep(1) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"), + pytest.param("nonebot.drivers.quart:Driver", id="quart"), + ], + indirect=True, +) +async def test_websocket_server(app: App, driver: Driver): + driver = cast(ReverseDriver, driver) + async def _handle_ws(ws: WebSocket) -> None: await ws.accept() data = await ws.receive() @@ -107,17 +131,11 @@ async def test_reverse_driver(app: App, driver: Driver): with pytest.raises(WebSocketClosed): await ws.receive() - http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http) - driver.setup_http_server(http_setup) - ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) driver.setup_websocket_server(ws_setup) async with app.test_server(driver.asgi) as ctx: client = ctx.get_client() - response = await client.post("/http_test", data="test") - assert response.status_code == 200 - assert response.text == "test" async with client.websocket_connect("/ws_test") as ws: await ws.send_text("ping") @@ -145,12 +163,13 @@ async def test_reverse_driver(app: App, driver: Driver): ], indirect=True, ) -async def test_http_driver(driver: Driver): +async def test_http_client(driver: Driver, server_url: URL): driver = cast(ForwardDriver, driver) + # simple post with query, headers, cookies and content request = Request( "POST", - "https://httpbin.org/post", + server_url, params={"param": "test"}, headers={"X-Test": "test"}, cookies={"session": "test"}, @@ -159,32 +178,39 @@ async def test_http_driver(driver: Driver): response = await driver.request(request) assert response.status_code == 200 and response.content data = json.loads(response.content) + assert data["method"] == "POST" assert data["args"] == {"param": "test"} assert data["headers"].get("X-Test") == "test" assert data["headers"].get("Cookie") == "session=test" assert data["data"] == "test" - request = Request("POST", "https://httpbin.org/post", data={"form": "test"}) + # post with data body + request = Request("POST", server_url, data={"form": "test"}) response = await driver.request(request) assert response.status_code == 200 and response.content data = json.loads(response.content) + assert data["method"] == "POST" assert data["form"] == {"form": "test"} - request = Request("POST", "https://httpbin.org/post", json={"json": "test"}) + # post with json body + request = Request("POST", server_url, json={"json": "test"}) response = await driver.request(request) assert response.status_code == 200 and response.content data = json.loads(response.content) + assert data["method"] == "POST" assert data["json"] == {"json": "test"} + # post with files and form data request = Request( "POST", - "https://httpbin.org/post", + server_url, data={"form": "test"}, files={"test": ("test.txt", b"test")}, ) response = await driver.request(request) assert response.status_code == 200 and response.content data = json.loads(response.content) + assert data["method"] == "POST" assert data["form"] == {"form": "test"} assert data["files"] == {"test": "test"} @@ -236,7 +262,6 @@ async def test_bot_connect_hook(app: App, driver: Driver): @driver.on_bot_connect async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1): nonlocal conn_should_be_called - conn_should_be_called = True if foo is not bot: pytest.fail("on_bot_connect hook called with wrong bot") @@ -245,12 +270,13 @@ async def test_bot_connect_hook(app: App, driver: Driver): if default != 1: pytest.fail("on_bot_connect hook called with wrong default value") + conn_should_be_called = True + @driver.on_bot_disconnect async def disconn_hook( foo: Bot, dep: int = Depends(dependency), default: int = 1 ): nonlocal disconn_should_be_called - disconn_should_be_called = True if foo is not bot: pytest.fail("on_bot_disconnect hook called with wrong bot") @@ -259,6 +285,8 @@ async def test_bot_connect_hook(app: App, driver: Driver): if default != 1: pytest.fail("on_bot_connect hook called with wrong default value") + disconn_should_be_called = True + if conn_hook not in {hook.call for hook in conn_hooks}: pytest.fail("on_bot_connect hook not registered") if disconn_hook not in {hook.call for hook in disconn_hooks}: @@ -268,6 +296,7 @@ async def test_bot_connect_hook(app: App, driver: Driver): bot = ctx.create_bot() await asyncio.sleep(1) + if not conn_should_be_called: pytest.fail("on_bot_connect hook not called") if not disconn_should_be_called: diff --git a/tests/test_init.py b/tests/test_init.py index 8624749e..2257e415 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -31,17 +31,29 @@ async def test_init(): @pytest.mark.asyncio -async def test_get(app: App, monkeypatch: pytest.MonkeyPatch): +async def test_get_driver(app: App, monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setattr(nonebot, "_driver", None) with pytest.raises(ValueError): get_driver() + +@pytest.mark.asyncio +async def test_get_asgi(app: App, monkeypatch: pytest.MonkeyPatch): driver = get_driver() assert isinstance(driver, ReverseDriver) assert get_asgi() == driver.asgi + + +@pytest.mark.asyncio +async def test_get_app(app: App, monkeypatch: pytest.MonkeyPatch): + driver = get_driver() + assert isinstance(driver, ReverseDriver) assert get_app() == driver.server_app + +@pytest.mark.asyncio +async def test_get_adapter(app: App, monkeypatch: pytest.MonkeyPatch): async with app.test_api() as ctx: adapter = ctx.create_adapter() adapter_name = adapter.get_name() @@ -54,6 +66,9 @@ async def test_get(app: App, monkeypatch: pytest.MonkeyPatch): with pytest.raises(ValueError): get_adapter("not exist") + +@pytest.mark.asyncio +async def test_run(app: App, monkeypatch: pytest.MonkeyPatch): runned = False def mock_run(*args, **kwargs): @@ -61,14 +76,24 @@ async def test_get(app: App, monkeypatch: pytest.MonkeyPatch): runned = True assert args == ("arg",) and kwargs == {"kwarg": "kwarg"} - monkeypatch.setattr(driver, "run", mock_run) - nonebot.run("arg", kwarg="kwarg") + driver = get_driver() + + with monkeypatch.context() as m: + m.setattr(driver, "run", mock_run) + nonebot.run("arg", kwarg="kwarg") + assert runned + +@pytest.mark.asyncio +async def test_get_bot(app: App, monkeypatch: pytest.MonkeyPatch): + driver = get_driver() + with pytest.raises(ValueError): get_bot() - monkeypatch.setattr(driver, "_bots", {"test": "test"}) - assert get_bot() == "test" - assert get_bot("test") == "test" - assert get_bots() == {"test": "test"} + with monkeypatch.context() as m: + m.setattr(driver, "_bots", {"test": "test"}) + assert get_bot() == "test" + assert get_bot("test") == "test" + assert get_bots() == {"test": "test"} diff --git a/tests/test_matcher/test_matcher.py b/tests/test_matcher/test_matcher.py index 9433e74b..72f8dc7a 100644 --- a/tests/test_matcher/test_matcher.py +++ b/tests/test_matcher/test_matcher.py @@ -3,25 +3,16 @@ from nonebug import App from nonebot.permission import User from nonebot.matcher import Matcher, matchers +from utils import FakeMessage, make_fake_event from nonebot.message import check_and_run_matcher -from utils import make_fake_event, make_fake_message @pytest.mark.asyncio -async def test_matcher(app: App): - from plugins.matcher.matcher_process import ( - test_got, - test_handle, - test_preset, - test_combine, - test_receive, - test_overload, - ) +async def test_matcher_handle(app: App): + from plugins.matcher.matcher_process import test_handle - message = make_fake_message()("text") + message = FakeMessage("text") event = make_fake_event(_message=message)() - message_next = make_fake_message()("text_next") - event_next = make_fake_event(_message=message_next)() assert len(test_handle.handlers) == 1 async with app.test_matcher(test_handle) as ctx: @@ -30,6 +21,16 @@ async def test_matcher(app: App): ctx.should_call_send(event, "send", "result", at_sender=True) ctx.should_finished() + +@pytest.mark.asyncio +async def test_matcher_got(app: App): + from plugins.matcher.matcher_process import test_got + + message = FakeMessage("text") + event = make_fake_event(_message=message)() + message_next = FakeMessage("text_next") + event_next = make_fake_event(_message=message_next)() + assert len(test_got.handlers) == 1 async with app.test_matcher(test_got) as ctx: bot = ctx.create_bot() @@ -42,6 +43,14 @@ async def test_matcher(app: App): ctx.should_rejected() ctx.receive_event(bot, event_next) + +@pytest.mark.asyncio +async def test_matcher_receive(app: App): + from plugins.matcher.matcher_process import test_receive + + message = FakeMessage("text") + event = make_fake_event(_message=message)() + assert len(test_receive.handlers) == 1 async with app.test_matcher(test_receive) as ctx: bot = ctx.create_bot() @@ -51,7 +60,17 @@ async def test_matcher(app: App): ctx.should_call_send(event, "pause", "result", at_sender=True) ctx.should_paused() - assert len(test_receive.handlers) == 1 + +@pytest.mark.asyncio +async def test_matcher_(app: App): + from plugins.matcher.matcher_process import test_combine + + message = FakeMessage("text") + event = make_fake_event(_message=message)() + message_next = FakeMessage("text_next") + event_next = make_fake_event(_message=message_next)() + + assert len(test_combine.handlers) == 1 async with app.test_matcher(test_combine) as ctx: bot = ctx.create_bot() ctx.receive_event(bot, event) @@ -64,6 +83,16 @@ async def test_matcher(app: App): ctx.should_rejected() ctx.receive_event(bot, event_next) + +@pytest.mark.asyncio +async def test_matcher_preset(app: App): + from plugins.matcher.matcher_process import test_preset + + message = FakeMessage("text") + event = make_fake_event(_message=message)() + message_next = FakeMessage("text_next") + event_next = make_fake_event(_message=message_next)() + assert len(test_preset.handlers) == 2 async with app.test_matcher(test_preset) as ctx: bot = ctx.create_bot() @@ -72,6 +101,14 @@ async def test_matcher(app: App): ctx.should_rejected() ctx.receive_event(bot, event_next) + +@pytest.mark.asyncio +async def test_matcher_overload(app: App): + from plugins.matcher.matcher_process import test_overload + + message = FakeMessage("text") + event = make_fake_event(_message=message)() + assert len(test_overload.handlers) == 2 async with app.test_matcher(test_overload) as ctx: bot = ctx.create_bot() @@ -115,12 +152,10 @@ async def test_type_updater(app: App): @pytest.mark.asyncio -async def test_permission_updater(app: App): +async def test_default_permission_updater(app: App): from plugins.matcher.matcher_permission import ( default_permission, - test_custom_updater, test_permission_updater, - test_user_permission_updater, ) event = make_fake_event(_session_id="test")() @@ -136,6 +171,15 @@ async def test_permission_updater(app: App): assert checker.users == ("test",) assert checker.perm is default_permission + +@pytest.mark.asyncio +async def test_user_permission_updater(app: App): + from plugins.matcher.matcher_permission import ( + default_permission, + test_user_permission_updater, + ) + + event = make_fake_event(_session_id="test")() user_permission = list(test_user_permission_updater.permission.checkers)[0].call assert isinstance(user_permission, User) assert user_permission.perm is default_permission @@ -149,12 +193,22 @@ async def test_permission_updater(app: App): assert checker.users == ("test",) assert checker.perm is default_permission + +@pytest.mark.asyncio +async def test_custom_permission_updater(app: App): + from plugins.matcher.matcher_permission import ( + new_permission, + default_permission, + test_custom_updater, + ) + + event = make_fake_event(_session_id="test")() assert test_custom_updater.permission is default_permission async with app.test_api() as ctx: bot = ctx.create_bot() matcher = test_custom_updater() new_perm = await matcher.update_permission(bot, event) - assert new_perm is default_permission + assert new_perm is new_permission @pytest.mark.asyncio @@ -189,12 +243,8 @@ async def test_run(app: App): @pytest.mark.asyncio -async def test_expire(app: App): - from plugins.matcher.matcher_expire import ( - test_temp_matcher, - test_datetime_matcher, - test_timedelta_matcher, - ) +async def test_temp(app: App): + from plugins.matcher.matcher_expire import test_temp_matcher event = make_fake_event(_type="test")() async with app.test_api() as ctx: @@ -203,6 +253,11 @@ async def test_expire(app: App): await check_and_run_matcher(test_temp_matcher, bot, event, {}) assert test_temp_matcher not in matchers[test_temp_matcher.priority] + +@pytest.mark.asyncio +async def test_datetime_expire(app: App): + from plugins.matcher.matcher_expire import test_datetime_matcher + event = make_fake_event()() async with app.test_api() as ctx: bot = ctx.create_bot() @@ -210,6 +265,11 @@ async def test_expire(app: App): await check_and_run_matcher(test_datetime_matcher, bot, event, {}) assert test_datetime_matcher not in matchers[test_datetime_matcher.priority] + +@pytest.mark.asyncio +async def test_timedelta_expire(app: App): + from plugins.matcher.matcher_expire import test_timedelta_matcher + event = make_fake_event()() async with app.test_api() as ctx: bot = ctx.create_bot() diff --git a/tests/test_param.py b/tests/test_param.py index 6fd930dc..8abf2dde 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -6,7 +6,7 @@ from nonebug import App from nonebot.matcher import Matcher from nonebot.dependencies import Dependent from nonebot.exception import TypeMisMatch -from utils import make_fake_event, make_fake_message +from utils import FakeMessage, make_fake_event from nonebot.params import ( ArgParam, BotParam, @@ -157,7 +157,7 @@ async def test_event(app: App): generic_event_none, ) - fake_message = make_fake_message()("text") + fake_message = FakeMessage("text") fake_event = make_fake_event(_message=fake_message)() fake_fooevent = make_fake_event(_base=FooEvent)() @@ -247,7 +247,7 @@ async def test_state(app: App): shell_command_argv, ) - fake_message = make_fake_message()("text") + fake_message = FakeMessage("text") fake_matched = re.match(r"\[cq:(?P.*?),(?P.*?)\]", "[cq:test,arg=value]") fake_state = { PREFIX_KEY: { @@ -453,7 +453,7 @@ async def test_arg(app: App): from plugins.param.param_arg import arg, arg_str, arg_plain_text matcher = Matcher() - message = make_fake_message()("text") + message = FakeMessage("text") matcher.set_arg("key", message) async with app.test_dependent(arg, allow_types=[ArgParam]) as ctx: diff --git a/tests/test_rule.py b/tests/test_rule.py index 0622215a..aa85725c 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -6,8 +6,8 @@ import pytest from nonebug import App from nonebot.typing import T_State -from utils import make_fake_event, make_fake_message from nonebot.exception import ParserExit, SkippedException +from utils import FakeMessage, FakeMessageSegment, make_fake_event from nonebot.consts import ( CMD_KEY, PREFIX_KEY, @@ -85,24 +85,21 @@ async def test_rule(app: App): async def test_trie(app: App): TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",))) - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - async with app.test_api() as ctx: bot = ctx.create_bot() - message = Message("/fake-prefix some args") + message = FakeMessage("/fake-prefix some args") event = make_fake_event(_message=message)() state = {} TrieRule.get_value(bot, event, state) assert state[PREFIX_KEY] == CMD_RESULT( command=("fake-prefix",), raw_command="/fake-prefix", - command_arg=Message("some args"), + command_arg=FakeMessage("some args"), command_start="/", command_whitespace=" ", ) - message = MessageSegment.text("/fake-prefix ") + MessageSegment.image( + message = FakeMessageSegment.text("/fake-prefix ") + FakeMessageSegment.image( "fake url" ) event = make_fake_event(_message=message)() @@ -111,7 +108,7 @@ async def test_trie(app: App): assert state[PREFIX_KEY] == CMD_RESULT( command=("fake-prefix",), raw_command="/fake-prefix", - command_arg=Message(MessageSegment.image("fake url")), + command_arg=FakeMessage(FakeMessageSegment.image("fake url")), command_start="/", command_whitespace=" ", ) @@ -152,7 +149,7 @@ async def test_startswith( assert checker.msg == msg assert checker.ignorecase == ignorecase - message = text if text is None else make_fake_message()(text) + message = text if text is None else FakeMessage(text) event = make_fake_event(_type=type, _message=message)() for prefix in msg: state = {STARTSWITH_KEY: prefix} @@ -192,7 +189,7 @@ async def test_endswith( assert checker.msg == msg assert checker.ignorecase == ignorecase - message = text if text is None else make_fake_message()(text) + message = text if text is None else FakeMessage(text) event = make_fake_event(_type=type, _message=message)() for suffix in msg: state = {ENDSWITH_KEY: suffix} @@ -232,7 +229,7 @@ async def test_fullmatch( assert checker.msg == msg assert checker.ignorecase == ignorecase - message = text if text is None else make_fake_message()(text) + message = text if text is None else FakeMessage(text) event = make_fake_event(_type=type, _message=message)() for full in msg: state = {FULLMATCH_KEY: full} @@ -264,7 +261,7 @@ async def test_keyword( assert isinstance(checker, KeywordsRule) assert checker.keywords == kws - message = text if text is None else make_fake_message()(text) + message = text if text is None else FakeMessage(text) event = make_fake_event(_type=type, _message=message)() for kw in kws: state = {KEYWORD_KEY: kw} @@ -310,7 +307,7 @@ async def test_command( assert isinstance(checker, CommandRule) assert checker.cmds == cmds - arg = arg_text if arg_text is None else make_fake_message()(arg_text) + arg = arg_text if arg_text is None else FakeMessage(arg_text) state = { PREFIX_KEY: {CMD_KEY: cmd, CMD_WHITESPACE_KEY: whitespace, CMD_ARG_KEY: arg} } @@ -321,7 +318,7 @@ async def test_command( async def test_shell_command(): state: T_State CMD = ("test",) - Message = make_fake_message() + Message = FakeMessage MessageSegment = Message.get_segment_class() test_not_cmd = shell_command(CMD) @@ -455,7 +452,7 @@ async def test_regex( assert isinstance(checker, RegexRule) assert checker.regex == pattern - message = text if text is None else make_fake_message()(text) + message = text if text is None else FakeMessage(text) event = make_fake_event(_type=type, _message=message)() state = {} assert await dependent(event=event, state=state) == expected diff --git a/tests/test_utils.py b/tests/test_utils.py index 95b88f1f..607dbdfa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,16 +1,122 @@ import json +from typing import Dict, List, Union, TypeVar -from utils import make_fake_message -from nonebot.utils import DataclassEncoder +from utils import FakeMessage, FakeMessageSegment +from nonebot.utils import ( + DataclassEncoder, + escape_tag, + is_gen_callable, + is_async_gen_callable, + is_coroutine_callable, + generic_check_issubclass, +) + + +def test_loguru_escape_tag(): + assert escape_tag("red") == r"\red\" + assert escape_tag("white") == r"\white\" + assert escape_tag("white") == "\\white\\" + assert escape_tag("white") == r"\white\" + assert escape_tag("white") == "\\white\\" + + +def test_generic_check_issubclass(): + assert generic_check_issubclass(int, (int, float)) + assert not generic_check_issubclass(str, (int, float)) + assert generic_check_issubclass(Union[int, float, None], (int, float)) + assert generic_check_issubclass(List[int], list) + assert generic_check_issubclass(Dict[str, int], dict) + assert generic_check_issubclass(TypeVar("T", int, float), (int, float)) + assert generic_check_issubclass(TypeVar("T", bound=int), (int, float)) + + +def test_is_coroutine_callable(): + async def test1(): + ... + + def test2(): + ... + + class TestClass1: + async def __call__(self): + ... + + class TestClass2: + def __call__(self): + ... + + assert is_coroutine_callable(test1) + assert not is_coroutine_callable(test2) + assert not is_coroutine_callable(TestClass1) + assert is_coroutine_callable(TestClass1()) + assert not is_coroutine_callable(TestClass2) + + +def test_is_gen_callable(): + def test1(): + yield + + async def test2(): + yield + + def test3(): + ... + + class TestClass1: + def __call__(self): + yield + + class TestClass2: + async def __call__(self): + yield + + class TestClass3: + def __call__(self): + ... + + assert is_gen_callable(test1) + assert not is_gen_callable(test2) + assert not is_gen_callable(test3) + assert is_gen_callable(TestClass1()) + assert not is_gen_callable(TestClass2()) + assert not is_gen_callable(TestClass3()) + + +def test_is_async_gen_callable(): + async def test1(): + yield + + def test2(): + yield + + async def test3(): + ... + + class TestClass1: + async def __call__(self): + yield + + class TestClass2: + def __call__(self): + yield + + class TestClass3: + async def __call__(self): + ... + + assert is_async_gen_callable(test1) + assert not is_async_gen_callable(test2) + assert not is_async_gen_callable(test3) + assert is_async_gen_callable(TestClass1()) + assert not is_async_gen_callable(TestClass2()) + assert not is_async_gen_callable(TestClass3()) def test_dataclass_encoder(): simple = json.dumps("123", cls=DataclassEncoder) assert simple == '"123"' - Message = make_fake_message() - MessageSegment = Message.get_segment_class() - ms = MessageSegment.nested(Message(MessageSegment.text("text"))) + ms = FakeMessageSegment.nested(FakeMessage(FakeMessageSegment.text("text"))) s = json.dumps(ms, cls=DataclassEncoder) assert ( s diff --git a/tests/utils.py b/tests/utils.py index 5367014a..be084ac4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,6 @@ from typing import Type, Union, Mapping, Iterable, Optional -from pydantic import create_model +from pydantic import Extra, create_model from nonebot.adapters import Event, Message, MessageSegment @@ -12,51 +12,49 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str: return s -def make_fake_message(): - class FakeMessageSegment(MessageSegment["FakeMessage"]): - @classmethod - def get_message_class(cls): - return FakeMessage +class FakeMessageSegment(MessageSegment["FakeMessage"]): + @classmethod + def get_message_class(cls): + return FakeMessage - def __str__(self) -> str: - return self.data["text"] if self.type == "text" else f"[fake:{self.type}]" + def __str__(self) -> str: + return self.data["text"] if self.type == "text" else f"[fake:{self.type}]" - @classmethod - def text(cls, text: str): - return cls("text", {"text": text}) + @classmethod + def text(cls, text: str): + return cls("text", {"text": text}) - @staticmethod - def image(url: str): - return FakeMessageSegment("image", {"url": url}) + @staticmethod + def image(url: str): + return FakeMessageSegment("image", {"url": url}) - @staticmethod - def nested(content: "FakeMessage"): - return FakeMessageSegment("node", {"content": content}) + @staticmethod + def nested(content: "FakeMessage"): + return FakeMessageSegment("node", {"content": content}) - def is_text(self) -> bool: - return self.type == "text" + def is_text(self) -> bool: + return self.type == "text" - class FakeMessage(Message[FakeMessageSegment]): - @classmethod - def get_segment_class(cls): - return FakeMessageSegment - @staticmethod - def _construct(msg: Union[str, Iterable[Mapping]]): - if isinstance(msg, str): - yield FakeMessageSegment.text(msg) - else: - for seg in msg: - yield FakeMessageSegment(**seg) - return +class FakeMessage(Message[FakeMessageSegment]): + @classmethod + def get_segment_class(cls): + return FakeMessageSegment - def __add__( - self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]] - ): - other = escape_text(other) if isinstance(other, str) else other - return super().__add__(other) + @staticmethod + def _construct(msg: Union[str, Iterable[Mapping]]): + if isinstance(msg, str): + yield FakeMessageSegment.text(msg) + else: + for seg in msg: + yield FakeMessageSegment(**seg) + return - return FakeMessage + def __add__( + self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]] + ): + other = escape_text(other) if isinstance(other, str) else other + return super().__add__(other) def make_fake_event( @@ -70,9 +68,9 @@ def make_fake_event( _to_me: bool = True, **fields, ) -> Type[Event]: - _Fake = create_model("_Fake", __base__=_base or Event, **fields) + Base = _base or Event - class FakeEvent(_Fake): + class FakeEvent(Base, extra=Extra.forbid): def get_type(self) -> str: return _type @@ -100,7 +98,4 @@ def make_fake_event( def is_tome(self) -> bool: return _to_me - class Config: - extra = "forbid" - - return FakeEvent + return create_model("FakeEvent", __base__=FakeEvent, **fields)