nonebot2/tests/test_driver.py

541 lines
16 KiB
Python
Raw Normal View History

import json
import asyncio
from http.cookies import SimpleCookie
from typing import Any, Set, Optional
2021-12-22 16:53:55 +08:00
import pytest
from nonebug import App
2021-12-22 16:53:55 +08:00
from utils import FakeAdapter
from nonebot.adapters import Bot
from nonebot.params import Depends
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers import (
URL,
Driver,
Request,
Response,
ASGIMixin,
WebSocket,
HTTPClientMixin,
HTTPServerSetup,
WebSocketClientMixin,
WebSocketServerSetup,
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver", [pytest.param("nonebot.drivers.none:Driver", id="none")], indirect=True
)
async def test_lifespan(driver: Driver):
adapter = FakeAdapter(driver)
start_log = []
ready_log = []
shutdown_log = []
@driver.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)
@driver.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)
@adapter.on_ready
def _ready1():
assert start_log == [1, 2]
assert ready_log == []
ready_log.append(1)
@adapter.on_ready
def _ready2():
assert ready_log == [1]
ready_log.append(2)
@driver.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)
@driver.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)
async with driver._lifespan:
assert start_log == [1, 2]
assert ready_log == [1, 2]
assert shutdown_log == [1, 2]
2021-12-22 16:53:55 +08:00
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
2021-12-22 16:53:55 +08:00
indirect=True,
)
async def test_http_server(app: App, driver: Driver):
assert isinstance(driver, ASGIMixin)
async def _handle_http(request: Request) -> Response:
assert request.content in (b"test", "test")
return Response(200, content="test")
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
response = await client.post("/http_test", data="test")
assert response.status_code == 200
assert response.text == "test"
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
indirect=True,
)
async def test_websocket_server(app: App, driver: Driver):
assert isinstance(driver, ASGIMixin)
async def _handle_ws(ws: WebSocket) -> None:
await ws.accept()
data = await ws.receive()
assert data == "ping"
await ws.send("pong")
data = await ws.receive()
assert data == b"ping"
await ws.send(b"pong")
data = await ws.receive_text()
assert data == "ping"
await ws.send("pong")
data = await ws.receive_bytes()
assert data == b"ping"
await ws.send(b"pong")
with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive()
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
2021-12-23 01:34:20 +08:00
async with client.websocket_connect("/ws_test") as ws:
await ws.send_text("ping")
assert await ws.receive_text() == "pong"
await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"
await ws.send_text("ping")
assert await ws.receive_text() == "pong"
await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"
await ws.close(code=1000)
await asyncio.sleep(1)
2023-07-17 15:01:21 +08:00
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
indirect=True,
)
async def test_cross_context(app: App, driver: Driver):
assert isinstance(driver, ASGIMixin)
2023-07-17 15:01:21 +08:00
ws: Optional[WebSocket] = None
ws_ready = asyncio.Event()
ws_should_close = asyncio.Event()
async def background_task():
try:
await ws_ready.wait()
assert ws is not None
await ws.send("ping")
data = await ws.receive()
assert data == "pong"
finally:
ws_should_close.set()
task = asyncio.create_task(background_task())
async def _handle_ws(websocket: WebSocket) -> None:
nonlocal ws
await websocket.accept()
ws = websocket
ws_ready.set()
await ws_should_close.wait()
await websocket.close()
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
async with client.websocket_connect("/ws_test") as websocket:
try:
data = await websocket.receive_text()
assert data == "ping"
await websocket.send_text("pong")
except Exception as e:
if not e.args or "websocket.close" not in str(e.args[0]):
raise
await task
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_http_client(driver: Driver, server_url: URL):
assert isinstance(driver, HTTPClientMixin)
# simple post with query, headers, cookies and content
request = Request(
"POST",
server_url,
params={"param": "test"},
headers={"X-Test": "test"},
cookies={"session": "test"},
content="test",
)
response = await driver.request(request)
assert server_url.host is not None
request_raw_url = Request(
"POST",
(
server_url.scheme.encode("ascii"),
server_url.host.encode("ascii"),
server_url.port,
server_url.path.encode("ascii"),
),
params={"param": "test"},
headers={"X-Test": "test"},
cookies={"session": "test"},
content="test",
)
assert (
request.url == request_raw_url.url
), "request.url should be equal to request_raw_url.url"
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"param": "test"}
assert data["headers"].get("X-Test") == "test"
assert data["headers"].get("Cookie") == "session=test"
assert data["data"] == "test"
# post with data body
request = Request("POST", server_url, data={"form": "test"})
response = await driver.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["form"] == {"form": "test"}
# post with json body
request = Request("POST", server_url, json={"json": "test"})
response = await driver.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["json"] == {"json": "test"}
# post with files and form data
request = Request(
"POST",
server_url,
data={"form": "test"},
files=[
("test1", b"test"),
("test2", ("test.txt", b"test")),
("test3", ("test.txt", b"test", "text/plain")),
],
)
response = await driver.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["form"] == {"form": "test"}
assert data["files"] == {
"test1": "test",
"test2": "test",
"test3": "test",
}, "file parsing error"
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_http_client_session(driver: Driver, server_url: URL):
assert isinstance(driver, HTTPClientMixin)
session = driver.get_session(
params={"session": "test"},
headers={"X-Session": "test"},
cookies={"session": "test"},
)
request = Request("GET", server_url)
with pytest.raises(RuntimeError):
await session.request(request)
async with session as session:
# simple post with query, headers, cookies and content
request = Request(
"POST",
server_url,
params={"param": "test"},
headers={"X-Test": "test"},
cookies={"cookie": "test"},
content="test",
)
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test", "param": "test"}
assert data["headers"].get("X-Session") == "test"
assert data["headers"].get("X-Test") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {
"session": "test",
"cookie": "test",
}
assert data["data"] == "test"
# post with data body
request = Request("POST", server_url, data={"form": "test"})
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test"}
assert data["headers"].get("X-Session") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {"session": "test"}
assert data["form"] == {"form": "test"}
# post with json body
request = Request("POST", server_url, json={"json": "test"})
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test"}
assert data["headers"].get("X-Session") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {"session": "test"}
assert data["json"] == {"json": "test"}
# post with files and form data
request = Request(
"POST",
server_url,
data={"form": "test"},
files=[
("test1", b"test"),
("test2", ("test.txt", b"test")),
("test3", ("test.txt", b"test", "text/plain")),
],
)
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test"}
assert data["headers"].get("X-Session") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {"session": "test"}
assert data["form"] == {"form": "test"}
assert data["files"] == {
"test1": "test",
"test2": "test",
"test3": "test",
}, "file parsing error"
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.websockets:Driver", id="websockets"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_websocket_client(driver: Driver, server_url: URL):
assert isinstance(driver, WebSocketClientMixin)
request = Request("GET", server_url.with_scheme("ws"))
async with driver.websocket(request) as ws:
await ws.send("test")
assert await ws.receive() == "test"
await ws.send(b"test")
assert await ws.receive() == b"test"
await ws.send_text("test")
assert await ws.receive_text() == "test"
await ws.send_bytes(b"test")
assert await ws.receive_bytes() == b"test"
await ws.send("quit")
with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive()
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("driver", "driver_type"),
[
pytest.param(
"nonebot.drivers.fastapi:Driver+nonebot.drivers.aiohttp:Mixin",
"fastapi+aiohttp",
id="fastapi+aiohttp",
),
pytest.param(
"~httpx:Driver+~websockets",
"none+httpx+websockets",
id="httpx+websockets",
),
],
indirect=["driver"],
)
async def test_combine_driver(driver: Driver, driver_type: str):
assert driver.type == driver_type
@pytest.mark.asyncio
async def test_bot_connect_hook(app: App, driver: Driver):
with pytest.MonkeyPatch.context() as m:
conn_hooks: Set[Dependent[Any]] = set()
disconn_hooks: Set[Dependent[Any]] = set()
m.setattr(Driver, "_bot_connection_hook", conn_hooks)
m.setattr(Driver, "_bot_disconnection_hook", disconn_hooks)
conn_should_be_called = False
disconn_should_be_called = False
dependency_should_be_run = False
dependency_should_be_cleaned = False
async def dependency():
nonlocal dependency_should_be_run, dependency_should_be_cleaned
dependency_should_be_run = True
try:
yield 1
finally:
dependency_should_be_cleaned = True
@driver.on_bot_connect
async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
nonlocal conn_should_be_called
if foo is not bot:
pytest.fail("on_bot_connect hook called with wrong bot")
if dep != 1:
pytest.fail("on_bot_connect hook called with wrong dependency")
if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value")
conn_should_be_called = True
@driver.on_bot_disconnect
async def disconn_hook(
foo: Bot, dep: int = Depends(dependency), default: int = 1
):
nonlocal disconn_should_be_called
if foo is not bot:
pytest.fail("on_bot_disconnect hook called with wrong bot")
if dep != 1:
pytest.fail("on_bot_connect hook called with wrong dependency")
if default != 1:
pytest.fail("on_bot_connect hook called with wrong default value")
disconn_should_be_called = True
if conn_hook not in {hook.call for hook in conn_hooks}: # type: ignore
pytest.fail("on_bot_connect hook not registered")
if disconn_hook not in {hook.call for hook in disconn_hooks}: # type: ignore
pytest.fail("on_bot_disconnect hook not registered")
async with app.test_api() as ctx:
bot = ctx.create_bot()
await asyncio.sleep(1)
if not conn_should_be_called:
pytest.fail("on_bot_connect hook not called")
if not disconn_should_be_called:
pytest.fail("on_bot_disconnect hook not called")
if not dependency_should_be_run:
pytest.fail("dependency not run")
if not dependency_should_be_cleaned:
pytest.fail("dependency not cleaned")