mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
⬆️ upgrade dependencies
This commit is contained in:
parent
9b2fa46921
commit
fecdb5367a
@ -80,7 +80,7 @@ class Driver(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
return self._clients
|
return self._clients
|
||||||
|
|
||||||
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
|
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ class Driver(abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def type(self):
|
def type(self) -> str:
|
||||||
"""驱动类型名称"""
|
"""驱动类型名称"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -204,10 +204,11 @@ class Driver(abc.ABC):
|
|||||||
asyncio.create_task(_run_hook(bot))
|
asyncio.create_task(_run_hook(bot))
|
||||||
|
|
||||||
|
|
||||||
class ForwardDriver(Driver):
|
class ForwardMixin(abc.ABC):
|
||||||
"""
|
@property
|
||||||
Forward Driver 基类。将客户端框架封装,以满足适配器使用。
|
@abc.abstractmethod
|
||||||
"""
|
def type(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
@ -218,6 +219,12 @@ class ForwardDriver(Driver):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardDriver(Driver, ForwardMixin):
|
||||||
|
"""
|
||||||
|
Forward Driver 基类。将客户端框架封装,以满足适配器使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ReverseDriver(Driver):
|
class ReverseDriver(Driver):
|
||||||
"""
|
"""
|
||||||
Reverse Driver 基类。将后端框架封装,以满足适配器使用。
|
Reverse Driver 基类。将后端框架封装,以满足适配器使用。
|
||||||
@ -244,6 +251,19 @@ class ReverseDriver(Driver):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||||
|
class CombinedDriver(driver, *mixins): # type: ignore
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
return (
|
||||||
|
driver.type.__get__(self)
|
||||||
|
+ "+"
|
||||||
|
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
|
||||||
|
)
|
||||||
|
|
||||||
|
return CombinedDriver
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HTTPServerSetup:
|
class HTTPServerSetup:
|
||||||
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
||||||
|
158
nonebot/drivers/_block_driver.py
Normal file
158
nonebot/drivers/_block_driver.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import signal
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from typing import Set, Callable, Awaitable
|
||||||
|
|
||||||
|
from nonebot.log import logger
|
||||||
|
from nonebot.typing import overrides
|
||||||
|
from nonebot.config import Env, Config
|
||||||
|
from nonebot.drivers import ForwardDriver
|
||||||
|
|
||||||
|
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||||
|
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||||
|
HANDLED_SIGNALS = (
|
||||||
|
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||||
|
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockDriver(ForwardDriver):
|
||||||
|
"""
|
||||||
|
AIOHTTP 驱动框架
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: Env, config: Config):
|
||||||
|
super().__init__(env, config)
|
||||||
|
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||||
|
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
|
||||||
|
self.should_exit: asyncio.Event = asyncio.Event()
|
||||||
|
self.force_exit: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
@overrides(ForwardDriver)
|
||||||
|
def type(self) -> str:
|
||||||
|
"""驱动名称: ``block_driver``"""
|
||||||
|
return "block_driver"
|
||||||
|
|
||||||
|
@property
|
||||||
|
@overrides(ForwardDriver)
|
||||||
|
def logger(self):
|
||||||
|
"""block driver 使用的 logger"""
|
||||||
|
return logger
|
||||||
|
|
||||||
|
@overrides(ForwardDriver)
|
||||||
|
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
注册一个启动时执行的函数
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``func: Callable[[], Awaitable[None]]``
|
||||||
|
"""
|
||||||
|
self.startup_funcs.add(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
@overrides(ForwardDriver)
|
||||||
|
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
注册一个停止时执行的函数
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``func: Callable[[], Awaitable[None]]``
|
||||||
|
"""
|
||||||
|
self.shutdown_funcs.add(func)
|
||||||
|
return func
|
||||||
|
|
||||||
|
@overrides(ForwardDriver)
|
||||||
|
def run(self, *args, **kwargs):
|
||||||
|
"""启动 block driver"""
|
||||||
|
super().run(*args, **kwargs)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(self.serve())
|
||||||
|
|
||||||
|
async def serve(self):
|
||||||
|
self.install_signal_handlers()
|
||||||
|
await self.startup()
|
||||||
|
if self.should_exit.is_set():
|
||||||
|
return
|
||||||
|
await self.main_loop()
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
async def startup(self):
|
||||||
|
# run startup
|
||||||
|
cors = [startup() for startup in self.startup_funcs]
|
||||||
|
if cors:
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*cors)
|
||||||
|
except Exception as e:
|
||||||
|
logger.opt(colors=True, exception=e).error(
|
||||||
|
"<r><bg #f8bbd0>Error when running startup function. "
|
||||||
|
"Ignored!</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Application startup completed.")
|
||||||
|
|
||||||
|
async def main_loop(self):
|
||||||
|
await self.should_exit.wait()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
logger.info("Shutting down")
|
||||||
|
|
||||||
|
logger.info("Waiting for application shutdown.")
|
||||||
|
# run shutdown
|
||||||
|
cors = [shutdown() for shutdown in self.shutdown_funcs]
|
||||||
|
if cors:
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*cors)
|
||||||
|
except Exception as e:
|
||||||
|
logger.opt(colors=True, exception=e).error(
|
||||||
|
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||||
|
"Ignored!</bg #f8bbd0></r>"
|
||||||
|
)
|
||||||
|
|
||||||
|
for task in asyncio.all_tasks():
|
||||||
|
if task is not asyncio.current_task() and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
|
if tasks and not self.force_exit:
|
||||||
|
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
|
||||||
|
while tasks and not self.force_exit:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
logger.info("Application shutdown complete.")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.stop()
|
||||||
|
|
||||||
|
def install_signal_handlers(self) -> None:
|
||||||
|
if threading.current_thread() is not threading.main_thread():
|
||||||
|
# Signals can only be listened to from the main thread.
|
||||||
|
return
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for sig in HANDLED_SIGNALS:
|
||||||
|
loop.add_signal_handler(sig, self.handle_exit, sig, None)
|
||||||
|
except NotImplementedError:
|
||||||
|
# Windows
|
||||||
|
for sig in HANDLED_SIGNALS:
|
||||||
|
signal.signal(sig, self.handle_exit)
|
||||||
|
|
||||||
|
def handle_exit(self, sig, frame):
|
||||||
|
if self.should_exit.is_set():
|
||||||
|
self.force_exit = True
|
||||||
|
else:
|
||||||
|
self.should_exit.set()
|
@ -5,428 +5,79 @@ AIOHTTP 驱动适配
|
|||||||
本驱动仅支持客户端连接
|
本驱动仅支持客户端连接
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import signal
|
|
||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Set, List, Union, Callable, Optional, Awaitable, cast
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from yarl import URL
|
|
||||||
|
|
||||||
from nonebot.log import logger
|
|
||||||
from nonebot.adapters import Bot
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.drivers import Request, Response
|
||||||
from nonebot.config import Env, Config
|
from nonebot.drivers._block_driver import BlockDriver
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
|
||||||
HTTPRequest,
|
|
||||||
ForwardDriver,
|
|
||||||
WebSocketSetup,
|
|
||||||
HTTPPollingSetup,
|
|
||||||
)
|
|
||||||
|
|
||||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
|
||||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
|
||||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
|
|
||||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
|
||||||
HANDLED_SIGNALS = (
|
|
||||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
|
||||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Driver(ForwardDriver):
|
class AiohttpMixin(ForwardMixin):
|
||||||
"""
|
|
||||||
AIOHTTP 驱动框架
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: Env, config: Config):
|
|
||||||
super().__init__(env, config)
|
|
||||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
|
||||||
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
|
|
||||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
|
||||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
|
||||||
self.connections: List[asyncio.Task] = []
|
|
||||||
self.should_exit: asyncio.Event = asyncio.Event()
|
|
||||||
self.force_exit: bool = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardDriver)
|
@overrides(ForwardMixin)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
"""驱动名称: ``aiohttp``"""
|
|
||||||
return "aiohttp"
|
return "aiohttp"
|
||||||
|
|
||||||
@property
|
@overrides(ForwardMixin)
|
||||||
@overrides(ForwardDriver)
|
async def request(self, setup: Request) -> Response:
|
||||||
def logger(self):
|
if setup.version == HTTPVersion.H10:
|
||||||
"""aiohttp driver 使用的 logger"""
|
|
||||||
return logger
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个启动时执行的函数
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``func: Callable[[], Awaitable[None]]``
|
|
||||||
"""
|
|
||||||
self.startup_funcs.add(func)
|
|
||||||
return func
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个停止时执行的函数
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``func: Callable[[], Awaitable[None]]``
|
|
||||||
"""
|
|
||||||
self.shutdown_funcs.add(func)
|
|
||||||
return func
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
|
||||||
"""
|
|
||||||
self.http_pollings.append(setup)
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
|
||||||
"""
|
|
||||||
:说明:
|
|
||||||
|
|
||||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
|
||||||
|
|
||||||
:参数:
|
|
||||||
|
|
||||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
|
||||||
"""
|
|
||||||
self.websockets.append(setup)
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
def run(self, *args, **kwargs):
|
|
||||||
"""启动 aiohttp driver"""
|
|
||||||
super().run(*args, **kwargs)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(self.serve())
|
|
||||||
|
|
||||||
async def serve(self):
|
|
||||||
self.install_signal_handlers()
|
|
||||||
await self.startup()
|
|
||||||
if self.should_exit.is_set():
|
|
||||||
return
|
|
||||||
await self.main_loop()
|
|
||||||
await self.shutdown()
|
|
||||||
|
|
||||||
async def startup(self):
|
|
||||||
for setup in self.http_pollings:
|
|
||||||
self.connections.append(asyncio.create_task(self._http_loop(setup)))
|
|
||||||
for setup in self.websockets:
|
|
||||||
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
|
|
||||||
|
|
||||||
logger.info("Application startup completed.")
|
|
||||||
|
|
||||||
# run startup
|
|
||||||
cors = [startup() for startup in self.startup_funcs]
|
|
||||||
if cors:
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*cors)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running startup function. "
|
|
||||||
"Ignored!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def main_loop(self):
|
|
||||||
await self.should_exit.wait()
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
# run shutdown
|
|
||||||
cors = [shutdown() for shutdown in self.shutdown_funcs]
|
|
||||||
if cors:
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*cors)
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
|
||||||
"Ignored!</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
|
|
||||||
for task in self.connections:
|
|
||||||
if not task.done():
|
|
||||||
task.cancel()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
||||||
if tasks and not self.force_exit:
|
|
||||||
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
|
|
||||||
while tasks and not self.force_exit:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.stop()
|
|
||||||
|
|
||||||
def install_signal_handlers(self) -> None:
|
|
||||||
if threading.current_thread() is not threading.main_thread():
|
|
||||||
# Signals can only be listened to from the main thread.
|
|
||||||
return
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
try:
|
|
||||||
for sig in HANDLED_SIGNALS:
|
|
||||||
loop.add_signal_handler(sig, self.handle_exit, sig, None)
|
|
||||||
except NotImplementedError:
|
|
||||||
# Windows
|
|
||||||
for sig in HANDLED_SIGNALS:
|
|
||||||
signal.signal(sig, self.handle_exit)
|
|
||||||
|
|
||||||
def handle_exit(self, sig, frame):
|
|
||||||
if self.should_exit.is_set():
|
|
||||||
self.force_exit = True
|
|
||||||
else:
|
|
||||||
self.should_exit.set()
|
|
||||||
|
|
||||||
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
|
|
||||||
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
|
||||||
url = URL(setup.url)
|
|
||||||
if not url.is_absolute() or not url.host:
|
|
||||||
logger.opt(colors=True).error(
|
|
||||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
|
||||||
return HTTPRequest(
|
|
||||||
setup.http_version,
|
|
||||||
url.scheme,
|
|
||||||
url.path,
|
|
||||||
url.raw_query_string.encode("latin-1"),
|
|
||||||
{**setup.headers, "host": host},
|
|
||||||
setup.method,
|
|
||||||
setup.body,
|
|
||||||
)
|
|
||||||
|
|
||||||
bot: Optional[Bot] = None
|
|
||||||
request: Optional[HTTPRequest] = None
|
|
||||||
setup_: Optional[HTTPPollingSetup] = None
|
|
||||||
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
|
|
||||||
f"Bot {escape_tag(setup.self_id)}</y>"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
while not self.should_exit.is_set():
|
|
||||||
|
|
||||||
try:
|
|
||||||
if callable(setup):
|
|
||||||
setup_ = await setup()
|
|
||||||
else:
|
|
||||||
setup_ = setup
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error while parsing setup "
|
|
||||||
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
continue
|
|
||||||
|
|
||||||
setup_ = cast(HTTPPollingSetup, setup_)
|
|
||||||
|
|
||||||
if not bot:
|
|
||||||
request = await _build_request(setup_)
|
|
||||||
if not request:
|
|
||||||
return
|
|
||||||
|
|
||||||
BotClass = self._adapters[setup.adapter]
|
|
||||||
bot = BotClass(setup.self_id, request)
|
|
||||||
self._bot_connect(bot)
|
|
||||||
elif callable(setup):
|
|
||||||
request = await _build_request(setup_)
|
|
||||||
if not request:
|
|
||||||
await asyncio.sleep(setup_.poll_interval)
|
|
||||||
continue
|
|
||||||
bot.request = request
|
|
||||||
|
|
||||||
request = cast(HTTPRequest, request)
|
|
||||||
|
|
||||||
headers = request.headers
|
|
||||||
timeout = aiohttp.ClientTimeout(30)
|
|
||||||
version: aiohttp.HttpVersion
|
|
||||||
if request.http_version == "1.0":
|
|
||||||
version = aiohttp.HttpVersion10
|
version = aiohttp.HttpVersion10
|
||||||
elif request.http_version == "1.1":
|
elif setup.version == HTTPVersion.H11:
|
||||||
version = aiohttp.HttpVersion11
|
version = aiohttp.HttpVersion11
|
||||||
else:
|
else:
|
||||||
logger.opt(colors=True).error(
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||||
"<r><bg #f8bbd0>Unsupported HTTP Version "
|
|
||||||
f"{escape_tag(request.http_version)}</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(
|
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
|
async with aiohttp.ClientSession(version=version) as session:
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with session.request(
|
async with session.request(
|
||||||
request.method,
|
setup.method,
|
||||||
setup_.url,
|
setup.url,
|
||||||
data=request.body,
|
data=setup.content,
|
||||||
headers=headers,
|
headers=setup.headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
version=version,
|
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status()
|
res = Response(
|
||||||
data = await response.read()
|
response.status,
|
||||||
asyncio.create_task(bot.handle_message(data))
|
headers=response.headers.copy(),
|
||||||
except aiohttp.ClientResponseError as e:
|
content=await response.read(),
|
||||||
logger.opt(colors=True, exception=e).error(
|
request=setup,
|
||||||
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
await asyncio.sleep(setup_.poll_interval)
|
@overrides(ForwardMixin)
|
||||||
|
async def websocket(self, setup: Request) -> "WebSocket":
|
||||||
except asyncio.CancelledError:
|
if setup.version == HTTPVersion.H10:
|
||||||
pass
|
version = aiohttp.HttpVersion10
|
||||||
except Exception as e:
|
elif setup.version == HTTPVersion.H11:
|
||||||
logger.opt(colors=True, exception=e).error(
|
version = aiohttp.HttpVersion11
|
||||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
|
||||||
"while http polling</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if bot:
|
|
||||||
self._bot_disconnect(bot)
|
|
||||||
|
|
||||||
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
|
|
||||||
bot: Optional[Bot] = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
while True:
|
|
||||||
|
|
||||||
try:
|
|
||||||
if callable(setup):
|
|
||||||
setup_ = await setup()
|
|
||||||
else:
|
else:
|
||||||
setup_ = setup
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Error while parsing setup "
|
|
||||||
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
continue
|
|
||||||
|
|
||||||
url = URL(setup_.url)
|
session = aiohttp.ClientSession(version=version)
|
||||||
if not url.is_absolute() or not url.host:
|
ws = await session.ws_connect(
|
||||||
logger.opt(colors=True).error(
|
setup.url,
|
||||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
method=setup.method,
|
||||||
)
|
timeout=setup.timeout or 10,
|
||||||
await asyncio.sleep(setup_.reconnect_interval)
|
headers=setup.headers,
|
||||||
continue
|
|
||||||
|
|
||||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
|
||||||
headers = {**setup_.headers, "host": host}
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
async with session.ws_connect(
|
|
||||||
url, headers=headers, timeout=30.0
|
|
||||||
) as ws:
|
|
||||||
logger.opt(colors=True).info(
|
|
||||||
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
|
|
||||||
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
|
|
||||||
)
|
|
||||||
request = WebSocket(
|
|
||||||
"1.1",
|
|
||||||
url.scheme,
|
|
||||||
url.path,
|
|
||||||
url.raw_query_string.encode("latin-1"),
|
|
||||||
headers,
|
|
||||||
ws,
|
|
||||||
)
|
|
||||||
|
|
||||||
BotClass = self._adapters[setup_.adapter]
|
|
||||||
bot = BotClass(setup_.self_id, request)
|
|
||||||
self._bot_connect(bot)
|
|
||||||
while not self.should_exit.is_set():
|
|
||||||
msg = await ws.receive()
|
|
||||||
if msg.type == aiohttp.WSMsgType.text:
|
|
||||||
asyncio.create_task(
|
|
||||||
bot.handle_message(msg.data.encode())
|
|
||||||
)
|
|
||||||
elif msg.type == aiohttp.WSMsgType.binary:
|
|
||||||
asyncio.create_task(bot.handle_message(msg.data))
|
|
||||||
elif msg.type == aiohttp.WSMsgType.error:
|
|
||||||
logger.opt(colors=True).error(
|
|
||||||
"<r><bg #f8bbd0>Error while handling websocket frame. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.opt(colors=True).error(
|
|
||||||
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except (
|
|
||||||
aiohttp.ClientResponseError,
|
|
||||||
aiohttp.ClientConnectionError,
|
|
||||||
) as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
f"<r><bg #f8bbd0>Error while connecting to {escape_tag(str(url))}. "
|
|
||||||
"Try to reconnect...</bg #f8bbd0></r>"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if bot:
|
|
||||||
self._bot_disconnect(bot)
|
|
||||||
bot = None
|
|
||||||
|
|
||||||
if not setup_.reconnect:
|
|
||||||
logger.info(
|
|
||||||
f"WebSocket reconnect disabled for bot {setup_.self_id}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
await asyncio.sleep(setup_.reconnect_interval)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.opt(colors=True, exception=e).error(
|
|
||||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
|
||||||
"while websocket loop</bg #f8bbd0></r>"
|
|
||||||
)
|
)
|
||||||
|
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
||||||
|
return websocket
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
websocket: aiohttp.ClientWebSocketResponse = None # type: ignore
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
websocket: aiohttp.ClientWebSocketResponse,
|
||||||
|
):
|
||||||
|
super().__init__(request=request)
|
||||||
|
self.session = session
|
||||||
|
self.websocket = websocket
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@ -440,6 +91,7 @@ class WebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def close(self, code: int = 1000):
|
async def close(self, code: int = 1000):
|
||||||
await self.websocket.close(code=code)
|
await self.websocket.close(code=code)
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> str:
|
||||||
@ -456,3 +108,6 @@ class WebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send_bytes(self, data: bytes) -> None:
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
await self.websocket.send_bytes(data)
|
await self.websocket.send_bytes(data)
|
||||||
|
|
||||||
|
|
||||||
|
Driver = combine_driver(BlockDriver, AiohttpMixin)
|
||||||
|
@ -11,31 +11,28 @@ FastAPI 驱动适配
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from typing import List, Callable, Optional
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, List, Union, Callable, Optional, Awaitable
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from fastapi import FastAPI, Request, status
|
from fastapi import FastAPI, Request, status
|
||||||
from starlette.websockets import WebSocket, WebSocketState
|
from starlette.websockets import WebSocket, WebSocketState
|
||||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
|
||||||
|
|
||||||
from nonebot.config import Env
|
from nonebot.config import Env
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
|
from nonebot.drivers.httpx import HttpxMixin
|
||||||
|
from nonebot.drivers.aiohttp import AiohttpMixin
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import Request as BaseRequest
|
from nonebot.drivers import Request as BaseRequest
|
||||||
from nonebot.drivers import Response as BaseResponse
|
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
|
from nonebot.drivers.websockets import WebSocketsMixin
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
HTTPVersion,
|
|
||||||
ForwardDriver,
|
|
||||||
ReverseDriver,
|
ReverseDriver,
|
||||||
HTTPServerSetup,
|
HTTPServerSetup,
|
||||||
WebSocketServerSetup,
|
WebSocketServerSetup,
|
||||||
|
combine_driver,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -246,7 +243,7 @@ class Driver(ReverseDriver):
|
|||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
setup: HTTPServerSetup,
|
setup: HTTPServerSetup,
|
||||||
):
|
) -> Response:
|
||||||
http_request = BaseRequest(
|
http_request = BaseRequest(
|
||||||
request.method,
|
request.method,
|
||||||
str(request.url),
|
str(request.url),
|
||||||
@ -265,7 +262,7 @@ class Driver(ReverseDriver):
|
|||||||
str(websocket.url),
|
str(websocket.url),
|
||||||
headers=websocket.headers.items(),
|
headers=websocket.headers.items(),
|
||||||
cookies=websocket.cookies,
|
cookies=websocket.cookies,
|
||||||
version=websocket.scope["http_version"],
|
version=websocket.scope.get("http_version", "1.1"),
|
||||||
)
|
)
|
||||||
ws = FastAPIWebSocket(
|
ws = FastAPIWebSocket(
|
||||||
request=request,
|
request=request,
|
||||||
@ -275,90 +272,6 @@ class Driver(ReverseDriver):
|
|||||||
await setup.handle_func(ws)
|
await setup.handle_func(ws)
|
||||||
|
|
||||||
|
|
||||||
class FullDriver(ForwardDriver, Driver):
|
|
||||||
"""
|
|
||||||
完整的 FastAPI 驱动框架,包含正向 Client 支持和反向 Server 支持。
|
|
||||||
|
|
||||||
:使用方法:
|
|
||||||
|
|
||||||
.. code-block:: dotenv
|
|
||||||
|
|
||||||
DRIVER=nonebot.drivers.fastapi:FullDriver
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@overrides(Driver)
|
|
||||||
def type(self) -> str:
|
|
||||||
"""驱动名称: ``fastapi_full``"""
|
|
||||||
return "fastapi_full"
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
async def request(self, setup: "BaseRequest") -> Any:
|
|
||||||
async with httpx.AsyncClient(
|
|
||||||
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
|
||||||
) as client:
|
|
||||||
response = await client.request(
|
|
||||||
setup.method,
|
|
||||||
str(setup.url),
|
|
||||||
content=setup.content,
|
|
||||||
headers=tuple(setup.headers.items()),
|
|
||||||
timeout=30.0,
|
|
||||||
)
|
|
||||||
return BaseResponse(
|
|
||||||
response.status_code,
|
|
||||||
headers=response.headers,
|
|
||||||
content=response.content,
|
|
||||||
request=setup,
|
|
||||||
)
|
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
|
||||||
async def websocket(self, setup: "BaseRequest") -> Any:
|
|
||||||
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
|
|
||||||
return WebSocketsWS(request=setup, websocket=ws)
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsWS(BaseWebSocket):
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
|
|
||||||
super().__init__(request=request)
|
|
||||||
self.websocket = websocket
|
|
||||||
|
|
||||||
@property
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
def closed(self) -> bool:
|
|
||||||
return self.websocket.closed
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def accept(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def close(self, code: int = 1000, reason: str = ""):
|
|
||||||
await self.websocket.close(code, reason)
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def receive(self) -> str:
|
|
||||||
msg = await self.websocket.recv()
|
|
||||||
if isinstance(msg, bytes):
|
|
||||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
|
||||||
return msg
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def receive_bytes(self) -> bytes:
|
|
||||||
msg = await self.websocket.recv()
|
|
||||||
if isinstance(msg, str):
|
|
||||||
raise TypeError("WebSocket received unexpected frame type: str")
|
|
||||||
return msg
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def send(self, data: str) -> None:
|
|
||||||
await self.websocket.send(data)
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
|
||||||
async def send_bytes(self, data: bytes) -> None:
|
|
||||||
await self.websocket.send(data)
|
|
||||||
|
|
||||||
|
|
||||||
class FastAPIWebSocket(BaseWebSocket):
|
class FastAPIWebSocket(BaseWebSocket):
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
|
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
|
||||||
@ -398,3 +311,7 @@ class FastAPIWebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send_bytes(self, data: bytes) -> None:
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||||
|
|
||||||
|
|
||||||
|
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||||
|
AiohttpDriver = combine_driver(Driver, AiohttpMixin)
|
||||||
|
45
nonebot/drivers/httpx.py
Normal file
45
nonebot/drivers/httpx.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import httpx
|
||||||
|
|
||||||
|
from nonebot.typing import overrides
|
||||||
|
from nonebot.drivers._block_driver import BlockDriver
|
||||||
|
from nonebot.drivers import (
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
WebSocket,
|
||||||
|
HTTPVersion,
|
||||||
|
ForwardMixin,
|
||||||
|
combine_driver,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpxMixin(ForwardMixin):
|
||||||
|
@property
|
||||||
|
@overrides(ForwardMixin)
|
||||||
|
def type(self) -> str:
|
||||||
|
return "httpx"
|
||||||
|
|
||||||
|
@overrides(ForwardMixin)
|
||||||
|
async def request(self, setup: Request) -> Response:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
||||||
|
) as client:
|
||||||
|
response = await client.request(
|
||||||
|
setup.method,
|
||||||
|
str(setup.url),
|
||||||
|
content=setup.content,
|
||||||
|
headers=tuple(setup.headers.items()),
|
||||||
|
timeout=setup.timeout,
|
||||||
|
)
|
||||||
|
return Response(
|
||||||
|
response.status_code,
|
||||||
|
headers=response.headers,
|
||||||
|
content=response.content,
|
||||||
|
request=setup,
|
||||||
|
)
|
||||||
|
|
||||||
|
@overrides(ForwardMixin)
|
||||||
|
async def websocket(self, setup: Request) -> WebSocket:
|
||||||
|
return await super(HttpxMixin, self).websocket(setup)
|
||||||
|
|
||||||
|
|
||||||
|
Driver = combine_driver(BlockDriver, HttpxMixin)
|
@ -18,10 +18,17 @@ from nonebot.config import Env
|
|||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
|
from nonebot.drivers.httpx import HttpxMixin
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import Request as BaseRequest
|
from nonebot.drivers import Request as BaseRequest
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
from nonebot.drivers.websockets import WebSocketsMixin
|
||||||
|
from nonebot.drivers import (
|
||||||
|
ReverseDriver,
|
||||||
|
HTTPServerSetup,
|
||||||
|
WebSocketServerSetup,
|
||||||
|
combine_driver,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from quart import request as _request
|
from quart import request as _request
|
||||||
@ -281,3 +288,6 @@ class WebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send_bytes(self, data: bytes):
|
async def send_bytes(self, data: bytes):
|
||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
|
|
||||||
|
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||||
|
78
nonebot/drivers/websockets.py
Normal file
78
nonebot/drivers/websockets.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||||
|
|
||||||
|
from nonebot.typing import overrides
|
||||||
|
from nonebot.log import LoguruHandler
|
||||||
|
from nonebot.drivers import Request, Response
|
||||||
|
from nonebot.drivers._block_driver import BlockDriver
|
||||||
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
|
from nonebot.drivers import ForwardMixin, combine_driver
|
||||||
|
|
||||||
|
logger = logging.Logger("websockets.client", "INFO")
|
||||||
|
logger.addHandler(LoguruHandler())
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketsMixin(ForwardMixin):
|
||||||
|
@property
|
||||||
|
@overrides(ForwardMixin)
|
||||||
|
def type(self) -> str:
|
||||||
|
return "websockets"
|
||||||
|
|
||||||
|
@overrides(ForwardMixin)
|
||||||
|
async def request(self, setup: Request) -> Response:
|
||||||
|
return await super(WebSocketsMixin, self).request(setup)
|
||||||
|
|
||||||
|
@overrides(ForwardMixin)
|
||||||
|
async def websocket(self, setup: Request) -> "WebSocket":
|
||||||
|
ws = await Connect(
|
||||||
|
str(setup.url),
|
||||||
|
extra_headers=setup.headers.items(),
|
||||||
|
open_timeout=setup.timeout,
|
||||||
|
)
|
||||||
|
return WebSocket(request=setup, websocket=ws)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocket(BaseWebSocket):
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
|
||||||
|
super().__init__(request=request)
|
||||||
|
self.websocket = websocket
|
||||||
|
|
||||||
|
@property
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self.websocket.closed
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def accept(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def close(self, code: int = 1000, reason: str = ""):
|
||||||
|
await self.websocket.close(code, reason)
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def receive(self) -> str:
|
||||||
|
msg = await self.websocket.recv()
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def receive_bytes(self) -> bytes:
|
||||||
|
msg = await self.websocket.recv()
|
||||||
|
if isinstance(msg, str):
|
||||||
|
raise TypeError("WebSocket received unexpected frame type: str")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def send(self, data: str) -> None:
|
||||||
|
await self.websocket.send(data)
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
|
await self.websocket.send(data)
|
||||||
|
|
||||||
|
|
||||||
|
Driver = combine_driver(BlockDriver, WebSocketsMixin)
|
@ -12,9 +12,8 @@ include = ["nonebot_plugin_docs/dist/**/*"]
|
|||||||
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.7"
|
python = "^3.7.3"
|
||||||
aiofiles = "^0.7.0"
|
nonebot2 = "^2.0.0-beta.1"
|
||||||
nonebot2 = "^2.0.0-alpha.1"
|
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
|
|
||||||
|
18
poetry.lock
generated
18
poetry.lock
generated
@ -93,6 +93,18 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"]
|
tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-asgi-testclient"
|
||||||
|
version = "1.4.8"
|
||||||
|
description = "Async client for testing ASGI web applications"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
multidict = ">=4.0,<6.0"
|
||||||
|
requests = ">=2.21,<3.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-timeout"
|
name = "async-timeout"
|
||||||
version = "4.0.2"
|
version = "4.0.2"
|
||||||
@ -534,6 +546,7 @@ python-versions = "^3.7.3"
|
|||||||
develop = false
|
develop = false
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
async-asgi-testclient = "^1.4.8"
|
||||||
nonebot2 = "^2.0.0-beta.1"
|
nonebot2 = "^2.0.0-beta.1"
|
||||||
pytest = "^6.2.5"
|
pytest = "^6.2.5"
|
||||||
pytest-asyncio = "^0.16.0"
|
pytest-asyncio = "^0.16.0"
|
||||||
@ -543,7 +556,7 @@ pytest-order = "^1.0.0"
|
|||||||
type = "git"
|
type = "git"
|
||||||
url = "https://github.com/nonebot/nonebug.git"
|
url = "https://github.com/nonebot/nonebug.git"
|
||||||
reference = "master"
|
reference = "master"
|
||||||
resolved_reference = "0a1132e9dc1803517ded0d485bfbe8c47a1d8585"
|
resolved_reference = "4af5bd99c3eb0f63f4619620461b16de6c96b227"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "packaging"
|
name = "packaging"
|
||||||
@ -1281,6 +1294,9 @@ asgiref = [
|
|||||||
{file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"},
|
{file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"},
|
||||||
{file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"},
|
{file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"},
|
||||||
]
|
]
|
||||||
|
async-asgi-testclient = [
|
||||||
|
{file = "async-asgi-testclient-1.4.8.tar.gz", hash = "sha256:52d666ea75971c8a825befd34a5684414578f3c5bfa5a90e7eb7de924f447aae"},
|
||||||
|
]
|
||||||
async-timeout = [
|
async-timeout = [
|
||||||
{file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
|
{file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
|
||||||
{file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"},
|
{file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"},
|
||||||
|
@ -23,6 +23,7 @@ include = ["nonebot/py.typed"]
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.7.3"
|
python = "^3.7.3"
|
||||||
|
yarl = "^1.7.2"
|
||||||
loguru = "^0.5.1"
|
loguru = "^0.5.1"
|
||||||
pygtrie = "^2.4.1"
|
pygtrie = "^2.4.1"
|
||||||
tomlkit = "^0.7.0"
|
tomlkit = "^0.7.0"
|
||||||
@ -34,7 +35,6 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
|
|||||||
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
||||||
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
||||||
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
||||||
yarl = "^1.7.2"
|
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
sphinx = "^4.1.1"
|
sphinx = "^4.1.1"
|
||||||
|
11
tests/test_driver.py
Normal file
11
tests/test_driver.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"nonebug_init",
|
||||||
|
[{"driver": "nonebot.drivers.fastapi"}],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
async def test_driver(nonebug_init):
|
||||||
|
...
|
Loading…
Reference in New Issue
Block a user