⬆️ upgrade dependencies

This commit is contained in:
yanyongyu 2021-12-22 16:53:55 +08:00
parent 9b2fa46921
commit fecdb5367a
11 changed files with 417 additions and 508 deletions

View File

@ -80,7 +80,7 @@ class Driver(abc.ABC):
"""
return self._clients
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
"""
:说明:
@ -105,7 +105,7 @@ class Driver(abc.ABC):
@property
@abc.abstractmethod
def type(self):
def type(self) -> str:
"""驱动类型名称"""
raise NotImplementedError
@ -204,10 +204,11 @@ class Driver(abc.ABC):
asyncio.create_task(_run_hook(bot))
class ForwardDriver(Driver):
"""
Forward Driver 基类将客户端框架封装以满足适配器使用
"""
class ForwardMixin(abc.ABC):
@property
@abc.abstractmethod
def type(self) -> str:
raise NotImplementedError
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
@ -218,6 +219,12 @@ class ForwardDriver(Driver):
raise NotImplementedError
class ForwardDriver(Driver, ForwardMixin):
"""
Forward Driver 基类将客户端框架封装以满足适配器使用
"""
class ReverseDriver(Driver):
"""
Reverse Driver 基类将后端框架封装以满足适配器使用
@ -244,6 +251,19 @@ class ReverseDriver(Driver):
raise NotImplementedError
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
class CombinedDriver(driver, *mixins): # type: ignore
@property
def type(self) -> str:
return (
driver.type.__get__(self)
+ "+"
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
)
return CombinedDriver
@dataclass
class HTTPServerSetup:
path: URL # path should not be absolute, check it by URL.is_absolute() == False

View File

@ -0,0 +1,158 @@
import signal
import asyncio
import threading
from typing import Set, Callable, Awaitable
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.config import Env, Config
from nonebot.drivers import ForwardDriver
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
class BlockDriver(ForwardDriver):
"""
AIOHTTP 驱动框架
"""
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
@property
@overrides(ForwardDriver)
def type(self) -> str:
"""驱动名称: ``block_driver``"""
return "block_driver"
@property
@overrides(ForwardDriver)
def logger(self):
"""block driver 使用的 logger"""
return logger
@overrides(ForwardDriver)
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
"""
:说明:
注册一个启动时执行的函数
:参数:
* ``func: Callable[[], Awaitable[None]]``
"""
self.startup_funcs.add(func)
return func
@overrides(ForwardDriver)
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
"""
:说明:
注册一个停止时执行的函数
:参数:
* ``func: Callable[[], Awaitable[None]]``
"""
self.shutdown_funcs.add(func)
return func
@overrides(ForwardDriver)
def run(self, *args, **kwargs):
"""启动 block driver"""
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve())
async def serve(self):
self.install_signal_handlers()
await self.startup()
if self.should_exit.is_set():
return
await self.main_loop()
await self.shutdown()
async def startup(self):
# run startup
cors = [startup() for startup in self.startup_funcs]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
logger.info("Application startup completed.")
async def main_loop(self):
await self.should_exit.wait()
async def shutdown(self):
logger.info("Shutting down")
logger.info("Waiting for application shutdown.")
# run shutdown
cors = [shutdown() for shutdown in self.shutdown_funcs]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("Application shutdown complete.")
loop = asyncio.get_event_loop()
loop.stop()
def install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
loop = asyncio.get_event_loop()
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
def handle_exit(self, sig, frame):
if self.should_exit.is_set():
self.force_exit = True
else:
self.should_exit.set()

View File

@ -5,428 +5,79 @@ AIOHTTP 驱动适配
本驱动仅支持客户端连接
"""
import signal
import asyncio
import threading
from dataclasses import dataclass
from typing import Set, List, Union, Callable, Optional, Awaitable, cast
import aiohttp
from yarl import URL
from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from nonebot.drivers import Request, Response
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (
HTTPRequest,
ForwardDriver,
WebSocketSetup,
HTTPPollingSetup,
)
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
class Driver(ForwardDriver):
"""
AIOHTTP 驱动框架
"""
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.connections: List[asyncio.Task] = []
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
class AiohttpMixin(ForwardMixin):
@property
@overrides(ForwardDriver)
@overrides(ForwardMixin)
def type(self) -> str:
"""驱动名称: ``aiohttp``"""
return "aiohttp"
@property
@overrides(ForwardDriver)
def logger(self):
"""aiohttp driver 使用的 logger"""
return logger
@overrides(ForwardDriver)
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
"""
:说明:
注册一个启动时执行的函数
:参数:
* ``func: Callable[[], Awaitable[None]]``
"""
self.startup_funcs.add(func)
return func
@overrides(ForwardDriver)
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
"""
:说明:
注册一个停止时执行的函数
:参数:
* ``func: Callable[[], Awaitable[None]]``
"""
self.shutdown_funcs.add(func)
return func
@overrides(ForwardDriver)
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
"""
:说明:
注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
@overrides(ForwardDriver)
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
@overrides(ForwardDriver)
def run(self, *args, **kwargs):
"""启动 aiohttp driver"""
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve())
async def serve(self):
self.install_signal_handlers()
await self.startup()
if self.should_exit.is_set():
return
await self.main_loop()
await self.shutdown()
async def startup(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
logger.info("Application startup completed.")
# run startup
cors = [startup() for startup in self.startup_funcs]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
async def main_loop(self):
await self.should_exit.wait()
async def shutdown(self):
# run shutdown
cors = [shutdown() for shutdown in self.shutdown_funcs]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)
for task in self.connections:
if not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
loop = asyncio.get_event_loop()
loop.stop()
def install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
loop = asyncio.get_event_loop()
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
def handle_exit(self, sig, frame):
if self.should_exit.is_set():
self.force_exit = True
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
self.should_exit.set()
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = URL(setup.url)
if not url.is_absolute() or not url.host:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
host = f"{url.host}:{url.port}" if url.port else url.host
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.raw_query_string.encode("latin-1"),
{**setup.headers, "host": host},
timeout = aiohttp.ClientTimeout(setup.timeout)
async with aiohttp.ClientSession(version=version) as session:
async with session.request(
setup.method,
setup.body,
)
setup.url,
data=setup.content,
headers=setup.headers,
timeout=timeout,
) as response:
res = Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
return res
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
setup_: Optional[HTTPPollingSetup] = None
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>"
session = aiohttp.ClientSession(version=version)
ws = await session.ws_connect(
setup.url,
method=setup.method,
timeout=setup.timeout or 10,
headers=setup.headers,
)
try:
async with aiohttp.ClientSession() as session:
while not self.should_exit.is_set():
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
setup_ = cast(HTTPPollingSetup, setup_)
if not bot:
request = await _build_request(setup_)
if not request:
return
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
elif callable(setup):
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
bot.request = request
request = cast(HTTPRequest, request)
headers = request.headers
timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion
if request.http_version == "1.0":
version = aiohttp.HttpVersion10
elif request.http_version == "1.1":
version = aiohttp.HttpVersion11
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Unsupported HTTP Version "
f"{escape_tag(request.http_version)}</bg #f8bbd0></r>"
)
return
logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
)
try:
async with session.request(
request.method,
setup_.url,
data=request.body,
headers=headers,
timeout=timeout,
version=version,
) as response:
response.raise_for_status()
data = await response.read()
asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
async with aiohttp.ClientSession() as session:
while True:
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
url = URL(setup_.url)
if not url.is_absolute() or not url.host:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.reconnect_interval)
continue
host = f"{url.host}:{url.port}" if url.port else url.host
headers = {**setup_.headers, "host": host}
logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
)
try:
async with session.ws_connect(
url, headers=headers, timeout=30.0
) as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1",
url.scheme,
url.path,
url.raw_query_string.encode("latin-1"),
headers,
ws,
)
BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request)
self._bot_connect(bot)
while not self.should_exit.is_set():
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task(
bot.handle_message(msg.data.encode())
)
elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task(bot.handle_message(msg.data))
elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg #f8bbd0></r>"
)
break
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except (
aiohttp.ClientResponseError,
aiohttp.ClientConnectionError,
) as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {escape_tag(str(url))}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
bot = None
if not setup_.reconnect:
logger.info(
f"WebSocket reconnect disabled for bot {setup_.self_id}"
)
break
await asyncio.sleep(setup_.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>"
)
websocket = WebSocket(request=setup, session=session, websocket=ws)
return websocket
@dataclass
class WebSocket(BaseWebSocket):
websocket: aiohttp.ClientWebSocketResponse = None # type: ignore
def __init__(
self,
*,
request: Request,
session: aiohttp.ClientSession,
websocket: aiohttp.ClientWebSocketResponse,
):
super().__init__(request=request)
self.session = session
self.websocket = websocket
@property
@overrides(BaseWebSocket)
@ -440,6 +91,7 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
await self.session.close()
@overrides(BaseWebSocket)
async def receive(self) -> str:
@ -456,3 +108,6 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)
Driver = combine_driver(BlockDriver, AiohttpMixin)

View File

@ -11,31 +11,28 @@ FastAPI 驱动适配
"""
import logging
from functools import partial
from dataclasses import dataclass
from typing import Any, List, Union, Callable, Optional, Awaitable
from typing import List, Callable, Optional
import httpx
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from fastapi import FastAPI, Request, status
from starlette.websockets import WebSocket, WebSocketState
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.drivers.aiohttp import AiohttpMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import Response as BaseResponse
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
HTTPVersion,
ForwardDriver,
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
@ -246,7 +243,7 @@ class Driver(ReverseDriver):
self,
request: Request,
setup: HTTPServerSetup,
):
) -> Response:
http_request = BaseRequest(
request.method,
str(request.url),
@ -265,7 +262,7 @@ class Driver(ReverseDriver):
str(websocket.url),
headers=websocket.headers.items(),
cookies=websocket.cookies,
version=websocket.scope["http_version"],
version=websocket.scope.get("http_version", "1.1"),
)
ws = FastAPIWebSocket(
request=request,
@ -275,90 +272,6 @@ class Driver(ReverseDriver):
await setup.handle_func(ws)
class FullDriver(ForwardDriver, Driver):
"""
完整的 FastAPI 驱动框架包含正向 Client 支持和反向 Server 支持
:使用方法:
.. code-block:: dotenv
DRIVER=nonebot.drivers.fastapi:FullDriver
"""
@property
@overrides(Driver)
def type(self) -> str:
"""驱动名称: ``fastapi_full``"""
return "fastapi_full"
@overrides(ForwardDriver)
async def request(self, setup: "BaseRequest") -> Any:
async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content,
headers=tuple(setup.headers.items()),
timeout=30.0,
)
return BaseResponse(
response.status_code,
headers=response.headers,
content=response.content,
request=setup,
)
@overrides(ForwardDriver)
async def websocket(self, setup: "BaseRequest") -> Any:
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
return WebSocketsWS(request=setup, websocket=ws)
class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
async def receive(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send(data)
class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
@ -398,3 +311,7 @@ class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
AiohttpDriver = combine_driver(Driver, AiohttpMixin)

45
nonebot/drivers/httpx.py Normal file
View File

@ -0,0 +1,45 @@
import httpx
from nonebot.typing import overrides
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import (
Request,
Response,
WebSocket,
HTTPVersion,
ForwardMixin,
combine_driver,
)
class HttpxMixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
return "httpx"
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content,
headers=tuple(setup.headers.items()),
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers,
content=response.content,
request=setup,
)
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> WebSocket:
return await super(HttpxMixin, self).websocket(setup)
Driver = combine_driver(BlockDriver, HttpxMixin)

View File

@ -18,10 +18,17 @@ from nonebot.config import Env
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
try:
from quart import request as _request
@ -281,3 +288,6 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)

View File

@ -0,0 +1,78 @@
import logging
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, combine_driver
logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
class WebSocketsMixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
return "websockets"
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
return await super(WebSocketsMixin, self).request(setup)
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
ws = await Connect(
str(setup.url),
extra_headers=setup.headers.items(),
open_timeout=setup.timeout,
)
return WebSocket(request=setup, websocket=ws)
class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
async def receive(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send(data)
Driver = combine_driver(BlockDriver, WebSocketsMixin)

View File

@ -12,9 +12,8 @@ include = ["nonebot_plugin_docs/dist/**/*"]
[tool.poetry.dependencies]
python = "^3.7"
aiofiles = "^0.7.0"
nonebot2 = "^2.0.0-alpha.1"
python = "^3.7.3"
nonebot2 = "^2.0.0-beta.1"
[tool.poetry.dev-dependencies]

18
poetry.lock generated
View File

@ -93,6 +93,18 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
[package.extras]
tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"]
[[package]]
name = "async-asgi-testclient"
version = "1.4.8"
description = "Async client for testing ASGI web applications"
category = "dev"
optional = false
python-versions = "*"
[package.dependencies]
multidict = ">=4.0,<6.0"
requests = ">=2.21,<3.0"
[[package]]
name = "async-timeout"
version = "4.0.2"
@ -534,6 +546,7 @@ python-versions = "^3.7.3"
develop = false
[package.dependencies]
async-asgi-testclient = "^1.4.8"
nonebot2 = "^2.0.0-beta.1"
pytest = "^6.2.5"
pytest-asyncio = "^0.16.0"
@ -543,7 +556,7 @@ pytest-order = "^1.0.0"
type = "git"
url = "https://github.com/nonebot/nonebug.git"
reference = "master"
resolved_reference = "0a1132e9dc1803517ded0d485bfbe8c47a1d8585"
resolved_reference = "4af5bd99c3eb0f63f4619620461b16de6c96b227"
[[package]]
name = "packaging"
@ -1281,6 +1294,9 @@ asgiref = [
{file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"},
{file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"},
]
async-asgi-testclient = [
{file = "async-asgi-testclient-1.4.8.tar.gz", hash = "sha256:52d666ea75971c8a825befd34a5684414578f3c5bfa5a90e7eb7de924f447aae"},
]
async-timeout = [
{file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
{file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"},

View File

@ -23,6 +23,7 @@ include = ["nonebot/py.typed"]
[tool.poetry.dependencies]
python = "^3.7.3"
yarl = "^1.7.2"
loguru = "^0.5.1"
pygtrie = "^2.4.1"
tomlkit = "^0.7.0"
@ -34,7 +35,6 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
uvicorn = { version = "^0.15.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
yarl = "^1.7.2"
[tool.poetry.dev-dependencies]
sphinx = "^4.1.1"

11
tests/test_driver.py Normal file
View File

@ -0,0 +1,11 @@
import pytest
@pytest.mark.asyncio
@pytest.mark.parametrize(
"nonebug_init",
[{"driver": "nonebot.drivers.fastapi"}],
indirect=True,
)
async def test_driver(nonebug_init):
...