mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
⬆️ upgrade dependencies
This commit is contained in:
parent
9b2fa46921
commit
fecdb5367a
@ -80,7 +80,7 @@ class Driver(abc.ABC):
|
||||
"""
|
||||
return self._clients
|
||||
|
||||
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
|
||||
def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -105,7 +105,7 @@ class Driver(abc.ABC):
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self):
|
||||
def type(self) -> str:
|
||||
"""驱动类型名称"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -204,10 +204,11 @@ class Driver(abc.ABC):
|
||||
asyncio.create_task(_run_hook(bot))
|
||||
|
||||
|
||||
class ForwardDriver(Driver):
|
||||
"""
|
||||
Forward Driver 基类。将客户端框架封装,以满足适配器使用。
|
||||
"""
|
||||
class ForwardMixin(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(self, setup: Request) -> Response:
|
||||
@ -218,6 +219,12 @@ class ForwardDriver(Driver):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ForwardDriver(Driver, ForwardMixin):
|
||||
"""
|
||||
Forward Driver 基类。将客户端框架封装,以满足适配器使用。
|
||||
"""
|
||||
|
||||
|
||||
class ReverseDriver(Driver):
|
||||
"""
|
||||
Reverse Driver 基类。将后端框架封装,以满足适配器使用。
|
||||
@ -244,6 +251,19 @@ class ReverseDriver(Driver):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||
class CombinedDriver(driver, *mixins): # type: ignore
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return (
|
||||
driver.type.__get__(self)
|
||||
+ "+"
|
||||
+ "+".join(map(lambda x: x.type.__get__(self), mixins))
|
||||
)
|
||||
|
||||
return CombinedDriver
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPServerSetup:
|
||||
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
||||
|
158
nonebot/drivers/_block_driver.py
Normal file
158
nonebot/drivers/_block_driver.py
Normal file
@ -0,0 +1,158 @@
|
||||
import signal
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Set, Callable, Awaitable
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import ForwardDriver
|
||||
|
||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
|
||||
|
||||
class BlockDriver(ForwardDriver):
|
||||
"""
|
||||
AIOHTTP 驱动框架
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
|
||||
self.should_exit: asyncio.Event = asyncio.Event()
|
||||
self.force_exit: bool = False
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
def type(self) -> str:
|
||||
"""驱动名称: ``block_driver``"""
|
||||
return "block_driver"
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
def logger(self):
|
||||
"""block driver 使用的 logger"""
|
||||
return logger
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个启动时执行的函数
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[[], Awaitable[None]]``
|
||||
"""
|
||||
self.startup_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个停止时执行的函数
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[[], Awaitable[None]]``
|
||||
"""
|
||||
self.shutdown_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 block driver"""
|
||||
super().run(*args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.serve())
|
||||
|
||||
async def serve(self):
|
||||
self.install_signal_handlers()
|
||||
await self.startup()
|
||||
if self.should_exit.is_set():
|
||||
return
|
||||
await self.main_loop()
|
||||
await self.shutdown()
|
||||
|
||||
async def startup(self):
|
||||
# run startup
|
||||
cors = [startup() for startup in self.startup_funcs]
|
||||
if cors:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running startup function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
logger.info("Application startup completed.")
|
||||
|
||||
async def main_loop(self):
|
||||
await self.should_exit.wait()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Shutting down")
|
||||
|
||||
logger.info("Waiting for application shutdown.")
|
||||
# run shutdown
|
||||
cors = [shutdown() for shutdown in self.shutdown_funcs]
|
||||
if cors:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
for task in asyncio.all_tasks():
|
||||
if task is not asyncio.current_task() and not task.done():
|
||||
task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
if tasks and not self.force_exit:
|
||||
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
|
||||
while tasks and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.info("Application shutdown complete.")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.stop()
|
||||
|
||||
def install_signal_handlers(self) -> None:
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# Signals can only be listened to from the main thread.
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
for sig in HANDLED_SIGNALS:
|
||||
loop.add_signal_handler(sig, self.handle_exit, sig, None)
|
||||
except NotImplementedError:
|
||||
# Windows
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.handle_exit)
|
||||
|
||||
def handle_exit(self, sig, frame):
|
||||
if self.should_exit.is_set():
|
||||
self.force_exit = True
|
||||
else:
|
||||
self.should_exit.set()
|
@ -5,428 +5,79 @@ AIOHTTP 驱动适配
|
||||
本驱动仅支持客户端连接
|
||||
"""
|
||||
|
||||
import signal
|
||||
import asyncio
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Set, List, Union, Callable, Optional, Awaitable, cast
|
||||
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import (
|
||||
HTTPRequest,
|
||||
ForwardDriver,
|
||||
WebSocketSetup,
|
||||
HTTPPollingSetup,
|
||||
)
|
||||
|
||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
|
||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
||||
HANDLED_SIGNALS = (
|
||||
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
|
||||
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
|
||||
)
|
||||
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
|
||||
|
||||
|
||||
class Driver(ForwardDriver):
|
||||
"""
|
||||
AIOHTTP 驱动框架
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
|
||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
||||
self.connections: List[asyncio.Task] = []
|
||||
self.should_exit: asyncio.Event = asyncio.Event()
|
||||
self.force_exit: bool = False
|
||||
|
||||
class AiohttpMixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
"""驱动名称: ``aiohttp``"""
|
||||
return "aiohttp"
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
def logger(self):
|
||||
"""aiohttp driver 使用的 logger"""
|
||||
return logger
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个启动时执行的函数
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[[], Awaitable[None]]``
|
||||
"""
|
||||
self.startup_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个停止时执行的函数
|
||||
|
||||
:参数:
|
||||
|
||||
* ``func: Callable[[], Awaitable[None]]``
|
||||
"""
|
||||
self.shutdown_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
self.http_pollings.append(setup)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
self.websockets.append(setup)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 aiohttp driver"""
|
||||
super().run(*args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.serve())
|
||||
|
||||
async def serve(self):
|
||||
self.install_signal_handlers()
|
||||
await self.startup()
|
||||
if self.should_exit.is_set():
|
||||
return
|
||||
await self.main_loop()
|
||||
await self.shutdown()
|
||||
|
||||
async def startup(self):
|
||||
for setup in self.http_pollings:
|
||||
self.connections.append(asyncio.create_task(self._http_loop(setup)))
|
||||
for setup in self.websockets:
|
||||
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
|
||||
|
||||
logger.info("Application startup completed.")
|
||||
|
||||
# run startup
|
||||
cors = [startup() for startup in self.startup_funcs]
|
||||
if cors:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running startup function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
async def main_loop(self):
|
||||
await self.should_exit.wait()
|
||||
|
||||
async def shutdown(self):
|
||||
# run shutdown
|
||||
cors = [shutdown() for shutdown in self.shutdown_funcs]
|
||||
if cors:
|
||||
try:
|
||||
await asyncio.gather(*cors)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error when running shutdown function. "
|
||||
"Ignored!</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
for task in self.connections:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
if tasks and not self.force_exit:
|
||||
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
|
||||
while tasks and not self.force_exit:
|
||||
await asyncio.sleep(0.1)
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.stop()
|
||||
|
||||
def install_signal_handlers(self) -> None:
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
# Signals can only be listened to from the main thread.
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
for sig in HANDLED_SIGNALS:
|
||||
loop.add_signal_handler(sig, self.handle_exit, sig, None)
|
||||
except NotImplementedError:
|
||||
# Windows
|
||||
for sig in HANDLED_SIGNALS:
|
||||
signal.signal(sig, self.handle_exit)
|
||||
|
||||
def handle_exit(self, sig, frame):
|
||||
if self.should_exit.is_set():
|
||||
self.force_exit = True
|
||||
@overrides(ForwardMixin)
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if setup.version == HTTPVersion.H10:
|
||||
version = aiohttp.HttpVersion10
|
||||
elif setup.version == HTTPVersion.H11:
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
self.should_exit.set()
|
||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||
|
||||
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
|
||||
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
||||
url = URL(setup.url)
|
||||
if not url.is_absolute() or not url.host:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
||||
return HTTPRequest(
|
||||
setup.http_version,
|
||||
url.scheme,
|
||||
url.path,
|
||||
url.raw_query_string.encode("latin-1"),
|
||||
{**setup.headers, "host": host},
|
||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||
async with aiohttp.ClientSession(version=version) as session:
|
||||
async with session.request(
|
||||
setup.method,
|
||||
setup.body,
|
||||
)
|
||||
setup.url,
|
||||
data=setup.content,
|
||||
headers=setup.headers,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
res = Response(
|
||||
response.status,
|
||||
headers=response.headers.copy(),
|
||||
content=await response.read(),
|
||||
request=setup,
|
||||
)
|
||||
return res
|
||||
|
||||
bot: Optional[Bot] = None
|
||||
request: Optional[HTTPRequest] = None
|
||||
setup_: Optional[HTTPPollingSetup] = None
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> "WebSocket":
|
||||
if setup.version == HTTPVersion.H10:
|
||||
version = aiohttp.HttpVersion10
|
||||
elif setup.version == HTTPVersion.H11:
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
|
||||
f"Bot {escape_tag(setup.self_id)}</y>"
|
||||
session = aiohttp.ClientSession(version=version)
|
||||
ws = await session.ws_connect(
|
||||
setup.url,
|
||||
method=setup.method,
|
||||
timeout=setup.timeout or 10,
|
||||
headers=setup.headers,
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while not self.should_exit.is_set():
|
||||
|
||||
try:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error while parsing setup "
|
||||
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
|
||||
setup_ = cast(HTTPPollingSetup, setup_)
|
||||
|
||||
if not bot:
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
return
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
elif callable(setup):
|
||||
request = await _build_request(setup_)
|
||||
if not request:
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
continue
|
||||
bot.request = request
|
||||
|
||||
request = cast(HTTPRequest, request)
|
||||
|
||||
headers = request.headers
|
||||
timeout = aiohttp.ClientTimeout(30)
|
||||
version: aiohttp.HttpVersion
|
||||
if request.http_version == "1.0":
|
||||
version = aiohttp.HttpVersion10
|
||||
elif request.http_version == "1.1":
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>Unsupported HTTP Version "
|
||||
f"{escape_tag(request.http_version)}</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.request(
|
||||
request.method,
|
||||
setup_.url,
|
||||
data=request.body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
version=version,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.read()
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
except aiohttp.ClientResponseError as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
await asyncio.sleep(setup_.poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while http polling</bg #f8bbd0></r>"
|
||||
)
|
||||
finally:
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
|
||||
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
|
||||
bot: Optional[Bot] = None
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
|
||||
try:
|
||||
if callable(setup):
|
||||
setup_ = await setup()
|
||||
else:
|
||||
setup_ = setup
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error while parsing setup "
|
||||
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
|
||||
url = URL(setup_.url)
|
||||
if not url.is_absolute() or not url.host:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
||||
)
|
||||
await asyncio.sleep(setup_.reconnect_interval)
|
||||
continue
|
||||
|
||||
host = f"{url.host}:{url.port}" if url.port else url.host
|
||||
headers = {**setup_.headers, "host": host}
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
|
||||
)
|
||||
try:
|
||||
async with session.ws_connect(
|
||||
url, headers=headers, timeout=30.0
|
||||
) as ws:
|
||||
logger.opt(colors=True).info(
|
||||
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
|
||||
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
|
||||
)
|
||||
request = WebSocket(
|
||||
"1.1",
|
||||
url.scheme,
|
||||
url.path,
|
||||
url.raw_query_string.encode("latin-1"),
|
||||
headers,
|
||||
ws,
|
||||
)
|
||||
|
||||
BotClass = self._adapters[setup_.adapter]
|
||||
bot = BotClass(setup_.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
while not self.should_exit.is_set():
|
||||
msg = await ws.receive()
|
||||
if msg.type == aiohttp.WSMsgType.text:
|
||||
asyncio.create_task(
|
||||
bot.handle_message(msg.data.encode())
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.binary:
|
||||
asyncio.create_task(bot.handle_message(msg.data))
|
||||
elif msg.type == aiohttp.WSMsgType.error:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>Error while handling websocket frame. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
break
|
||||
except (
|
||||
aiohttp.ClientResponseError,
|
||||
aiohttp.ClientConnectionError,
|
||||
) as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error while connecting to {escape_tag(str(url))}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
finally:
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
bot = None
|
||||
|
||||
if not setup_.reconnect:
|
||||
logger.info(
|
||||
f"WebSocket reconnect disabled for bot {setup_.self_id}"
|
||||
)
|
||||
break
|
||||
await asyncio.sleep(setup_.reconnect_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while websocket loop</bg #f8bbd0></r>"
|
||||
)
|
||||
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
||||
return websocket
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocket(BaseWebSocket):
|
||||
websocket: aiohttp.ClientWebSocketResponse = None # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request: Request,
|
||||
session: aiohttp.ClientSession,
|
||||
websocket: aiohttp.ClientWebSocketResponse,
|
||||
):
|
||||
super().__init__(request=request)
|
||||
self.session = session
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
@ -440,6 +91,7 @@ class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def close(self, code: int = 1000):
|
||||
await self.websocket.close(code=code)
|
||||
await self.session.close()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> str:
|
||||
@ -456,3 +108,6 @@ class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send_bytes(data)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, AiohttpMixin)
|
||||
|
@ -11,31 +11,28 @@ FastAPI 驱动适配
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Union, Callable, Optional, Awaitable
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from pydantic import BaseSettings
|
||||
from fastapi.responses import Response
|
||||
from fastapi import FastAPI, Request, status
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers.httpx import HttpxMixin
|
||||
from nonebot.drivers.aiohttp import AiohttpMixin
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import Response as BaseResponse
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.websockets import WebSocketsMixin
|
||||
from nonebot.drivers import (
|
||||
HTTPVersion,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
|
||||
@ -246,7 +243,7 @@ class Driver(ReverseDriver):
|
||||
self,
|
||||
request: Request,
|
||||
setup: HTTPServerSetup,
|
||||
):
|
||||
) -> Response:
|
||||
http_request = BaseRequest(
|
||||
request.method,
|
||||
str(request.url),
|
||||
@ -265,7 +262,7 @@ class Driver(ReverseDriver):
|
||||
str(websocket.url),
|
||||
headers=websocket.headers.items(),
|
||||
cookies=websocket.cookies,
|
||||
version=websocket.scope["http_version"],
|
||||
version=websocket.scope.get("http_version", "1.1"),
|
||||
)
|
||||
ws = FastAPIWebSocket(
|
||||
request=request,
|
||||
@ -275,90 +272,6 @@ class Driver(ReverseDriver):
|
||||
await setup.handle_func(ws)
|
||||
|
||||
|
||||
class FullDriver(ForwardDriver, Driver):
|
||||
"""
|
||||
完整的 FastAPI 驱动框架,包含正向 Client 支持和反向 Server 支持。
|
||||
|
||||
:使用方法:
|
||||
|
||||
.. code-block:: dotenv
|
||||
|
||||
DRIVER=nonebot.drivers.fastapi:FullDriver
|
||||
"""
|
||||
|
||||
@property
|
||||
@overrides(Driver)
|
||||
def type(self) -> str:
|
||||
"""驱动名称: ``fastapi_full``"""
|
||||
return "fastapi_full"
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
async def request(self, setup: "BaseRequest") -> Any:
|
||||
async with httpx.AsyncClient(
|
||||
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
||||
) as client:
|
||||
response = await client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
content=setup.content,
|
||||
headers=tuple(setup.headers.items()),
|
||||
timeout=30.0,
|
||||
)
|
||||
return BaseResponse(
|
||||
response.status_code,
|
||||
headers=response.headers,
|
||||
content=response.content,
|
||||
request=setup,
|
||||
)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
async def websocket(self, setup: "BaseRequest") -> Any:
|
||||
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
|
||||
return WebSocketsWS(request=setup, websocket=ws)
|
||||
|
||||
|
||||
class WebSocketsWS(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
|
||||
super().__init__(request=request)
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def closed(self) -> bool:
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def accept(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def close(self, code: int = 1000, reason: str = ""):
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> str:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, bytes):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, str):
|
||||
raise TypeError("WebSocket received unexpected frame type: str")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
class FastAPIWebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
|
||||
@ -398,3 +311,7 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
|
||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||
AiohttpDriver = combine_driver(Driver, AiohttpMixin)
|
||||
|
45
nonebot/drivers/httpx.py
Normal file
45
nonebot/drivers/httpx.py
Normal file
@ -0,0 +1,45 @@
|
||||
import httpx
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import (
|
||||
Request,
|
||||
Response,
|
||||
WebSocket,
|
||||
HTTPVersion,
|
||||
ForwardMixin,
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
|
||||
class HttpxMixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
return "httpx"
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def request(self, setup: Request) -> Response:
|
||||
async with httpx.AsyncClient(
|
||||
http2=setup.version == HTTPVersion.H2, follow_redirects=True
|
||||
) as client:
|
||||
response = await client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
content=setup.content,
|
||||
headers=tuple(setup.headers.items()),
|
||||
timeout=setup.timeout,
|
||||
)
|
||||
return Response(
|
||||
response.status_code,
|
||||
headers=response.headers,
|
||||
content=response.content,
|
||||
request=setup,
|
||||
)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> WebSocket:
|
||||
return await super(HttpxMixin, self).websocket(setup)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, HttpxMixin)
|
@ -18,10 +18,17 @@ from nonebot.config import Env
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers.httpx import HttpxMixin
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
from nonebot.drivers.websockets import WebSocketsMixin
|
||||
from nonebot.drivers import (
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
try:
|
||||
from quart import request as _request
|
||||
@ -281,3 +288,6 @@ class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes):
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||
|
78
nonebot/drivers/websockets.py
Normal file
78
nonebot/drivers/websockets.py
Normal file
@ -0,0 +1,78 @@
|
||||
import logging
|
||||
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.log import LoguruHandler
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ForwardMixin, combine_driver
|
||||
|
||||
logger = logging.Logger("websockets.client", "INFO")
|
||||
logger.addHandler(LoguruHandler())
|
||||
|
||||
|
||||
class WebSocketsMixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
return "websockets"
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def request(self, setup: Request) -> Response:
|
||||
return await super(WebSocketsMixin, self).request(setup)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> "WebSocket":
|
||||
ws = await Connect(
|
||||
str(setup.url),
|
||||
extra_headers=setup.headers.items(),
|
||||
open_timeout=setup.timeout,
|
||||
)
|
||||
return WebSocket(request=setup, websocket=ws)
|
||||
|
||||
|
||||
class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
|
||||
super().__init__(request=request)
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
def closed(self) -> bool:
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def accept(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def close(self, code: int = 1000, reason: str = ""):
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> str:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, bytes):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, str):
|
||||
raise TypeError("WebSocket received unexpected frame type: str")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, WebSocketsMixin)
|
@ -12,9 +12,8 @@ include = ["nonebot_plugin_docs/dist/**/*"]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
aiofiles = "^0.7.0"
|
||||
nonebot2 = "^2.0.0-alpha.1"
|
||||
python = "^3.7.3"
|
||||
nonebot2 = "^2.0.0-beta.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
|
18
poetry.lock
generated
18
poetry.lock
generated
@ -93,6 +93,18 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"]
|
||||
|
||||
[[package]]
|
||||
name = "async-asgi-testclient"
|
||||
version = "1.4.8"
|
||||
description = "Async client for testing ASGI web applications"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.dependencies]
|
||||
multidict = ">=4.0,<6.0"
|
||||
requests = ">=2.21,<3.0"
|
||||
|
||||
[[package]]
|
||||
name = "async-timeout"
|
||||
version = "4.0.2"
|
||||
@ -534,6 +546,7 @@ python-versions = "^3.7.3"
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
async-asgi-testclient = "^1.4.8"
|
||||
nonebot2 = "^2.0.0-beta.1"
|
||||
pytest = "^6.2.5"
|
||||
pytest-asyncio = "^0.16.0"
|
||||
@ -543,7 +556,7 @@ pytest-order = "^1.0.0"
|
||||
type = "git"
|
||||
url = "https://github.com/nonebot/nonebug.git"
|
||||
reference = "master"
|
||||
resolved_reference = "0a1132e9dc1803517ded0d485bfbe8c47a1d8585"
|
||||
resolved_reference = "4af5bd99c3eb0f63f4619620461b16de6c96b227"
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
@ -1281,6 +1294,9 @@ asgiref = [
|
||||
{file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"},
|
||||
{file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"},
|
||||
]
|
||||
async-asgi-testclient = [
|
||||
{file = "async-asgi-testclient-1.4.8.tar.gz", hash = "sha256:52d666ea75971c8a825befd34a5684414578f3c5bfa5a90e7eb7de924f447aae"},
|
||||
]
|
||||
async-timeout = [
|
||||
{file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
|
||||
{file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"},
|
||||
|
@ -23,6 +23,7 @@ include = ["nonebot/py.typed"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7.3"
|
||||
yarl = "^1.7.2"
|
||||
loguru = "^0.5.1"
|
||||
pygtrie = "^2.4.1"
|
||||
tomlkit = "^0.7.0"
|
||||
@ -34,7 +35,6 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
|
||||
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
||||
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
||||
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
||||
yarl = "^1.7.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
sphinx = "^4.1.1"
|
||||
|
11
tests/test_driver.py
Normal file
11
tests/test_driver.py
Normal file
@ -0,0 +1,11 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"nonebug_init",
|
||||
[{"driver": "nonebot.drivers.fastapi"}],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_driver(nonebug_init):
|
||||
...
|
Loading…
Reference in New Issue
Block a user