mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
♻️ rewrite adapter abc class
This commit is contained in:
parent
180aaadda9
commit
d80c02ae46
@ -22,6 +22,7 @@ except Exception:
|
||||
|
||||
from ._bot import Bot as Bot
|
||||
from ._event import Event as Event
|
||||
from ._adapter import Adapter as Adapter
|
||||
from ._message import Message as Message
|
||||
from ._message import MessageSegment as MessageSegment
|
||||
from ._template import MessageTemplate as MessageTemplate
|
||||
|
59
nonebot/adapters/_adapter.py
Normal file
59
nonebot/adapters/_adapter.py
Normal file
@ -0,0 +1,59 @@
|
||||
import abc
|
||||
from typing import Any, Dict
|
||||
|
||||
from ._bot import Bot
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import (
|
||||
Driver,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
)
|
||||
|
||||
|
||||
class Adapter(abc.ABC):
|
||||
def __init__(self, driver: Driver, **kwargs: Any):
|
||||
self.driver = driver
|
||||
self.bots: Dict[str, Bot] = {}
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_name(cls) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self.driver.config
|
||||
|
||||
def bot_connect(self, bot: Bot):
|
||||
self.driver._bot_connect(bot)
|
||||
self.bots[bot.self_id] = bot
|
||||
|
||||
def bot_disconnect(self, bot: Bot):
|
||||
self.driver._bot_disconnect(bot)
|
||||
self.bots.pop(bot.self_id, None)
|
||||
|
||||
def setup_http_server(self, setup: HTTPServerSetup):
|
||||
if not isinstance(self.driver, ReverseDriver):
|
||||
raise TypeError("Current driver does not support http server")
|
||||
self.driver.setup_http_server(setup)
|
||||
|
||||
def setup_websocket_server(self, setup: WebSocketServerSetup):
|
||||
if not isinstance(self.driver, ReverseDriver):
|
||||
raise TypeError("Current driver does not support websocket server")
|
||||
self.driver.setup_websocket_server(setup)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data``: API 数据
|
||||
"""
|
||||
raise NotImplementedError
|
@ -12,6 +12,7 @@ from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._event import Event
|
||||
from ._adapter import Adapter
|
||||
from ._message import Message, MessageSegment
|
||||
|
||||
|
||||
@ -25,10 +26,6 @@ class Bot(abc.ABC):
|
||||
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
||||
"""
|
||||
|
||||
driver: Driver
|
||||
"""Driver 对象"""
|
||||
config: Config
|
||||
"""Config 配置对象"""
|
||||
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
||||
"""
|
||||
:类型: ``Set[T_CallingAPIHook]``
|
||||
@ -40,36 +37,27 @@ class Bot(abc.ABC):
|
||||
:说明: call_api 后执行的函数
|
||||
"""
|
||||
|
||||
def __init__(self, self_id: str, request: HTTPConnection):
|
||||
def __init__(self, adapter: "Adapter", self_id: str):
|
||||
"""
|
||||
:参数:
|
||||
|
||||
* ``self_id: str``: 机器人 ID
|
||||
* ``request: HTTPConnection``: request 连接对象
|
||||
"""
|
||||
self.adapter = adapter
|
||||
self.self_id: str = self_id
|
||||
"""机器人 ID"""
|
||||
self.request: HTTPConnection = request
|
||||
"""连接信息"""
|
||||
|
||||
def __getattr__(self, name: str) -> _ApiCall:
|
||||
return partial(self.call_api, name)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
"""Adapter 类型"""
|
||||
raise NotImplementedError
|
||||
return self.adapter.get_name()
|
||||
|
||||
@classmethod
|
||||
def register(cls, driver: Driver, config: Config, **kwargs):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``register`` 方法会在 ``driver.register_adapter`` 时被调用,用于初始化相关配置
|
||||
"""
|
||||
cls.driver = driver
|
||||
cls.config = config
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
return self.adapter.config
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
@ -106,20 +94,6 @@ class Bot(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _call_api(self, api: str, **data) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``api: str``: API 名称
|
||||
* ``**data``: API 数据
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def call_api(self, api: str, **data: Any) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
@ -162,7 +136,7 @@ class Bot(abc.ABC):
|
||||
|
||||
if not skip_calling_api:
|
||||
try:
|
||||
result = await self._call_api(api, **data)
|
||||
result = await self.adapter._call_api(api, **data)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
|
@ -26,7 +26,7 @@ from nonebot.config import Env, Config
|
||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters import Bot, Adapter
|
||||
|
||||
|
||||
class Driver(abc.ABC):
|
||||
@ -34,9 +34,9 @@ class Driver(abc.ABC):
|
||||
Driver 基类。
|
||||
"""
|
||||
|
||||
_adapters: Dict[str, Type["Bot"]] = {}
|
||||
_adapters: Dict[str, "Adapter"] = {}
|
||||
"""
|
||||
:类型: ``Dict[str, Type[Bot]]``
|
||||
:类型: ``Dict[str, Adapter]``
|
||||
:说明: 已注册的适配器列表
|
||||
"""
|
||||
_bot_connection_hook: Set[T_BotConnectionHook] = set()
|
||||
@ -85,7 +85,7 @@ class Driver(abc.ABC):
|
||||
"""
|
||||
return self._clients
|
||||
|
||||
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
|
||||
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -97,13 +97,13 @@ class Driver(abc.ABC):
|
||||
* ``adapter: Type[Bot]``: 适配器 Class
|
||||
* ``**kwargs``: 其他传递给适配器的参数
|
||||
"""
|
||||
name = adapter.get_name()
|
||||
if name in self._adapters:
|
||||
logger.opt(colors=True).debug(
|
||||
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
|
||||
)
|
||||
return
|
||||
self._adapters[name] = adapter
|
||||
adapter.register(self, self.config, **kwargs)
|
||||
self._adapters[name] = adapter(self, **kwargs)
|
||||
logger.opt(colors=True).debug(
|
||||
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
|
||||
)
|
||||
@ -213,34 +213,11 @@ class ForwardDriver(Driver):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_http_polling(
|
||||
self,
|
||||
setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
|
||||
) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
async def request(self, setup: "HTTPRequest") -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_websocket(
|
||||
self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
|
||||
) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
async def websocket(self, setup: "HTTPConnection") -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -261,7 +238,16 @@ class ReverseDriver(Driver):
|
||||
"""驱动 ASGI 对象"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: repack dataclass
|
||||
@dataclass
|
||||
class HTTPConnection(abc.ABC):
|
||||
http_version: str
|
||||
@ -401,36 +387,13 @@ class WebSocket(HTTPConnection, abc.ABC):
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPPollingSetup:
|
||||
adapter: str
|
||||
"""协议适配器名称"""
|
||||
self_id: str
|
||||
"""机器人 ID"""
|
||||
url: str
|
||||
"""URL"""
|
||||
class HTTPServerSetup:
|
||||
path: str
|
||||
method: str
|
||||
"""HTTP method"""
|
||||
body: bytes
|
||||
"""HTTP body"""
|
||||
headers: Dict[str, str]
|
||||
"""HTTP headers"""
|
||||
http_version: str
|
||||
"""HTTP version"""
|
||||
poll_interval: float
|
||||
"""HTTP 轮询间隔"""
|
||||
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketSetup:
|
||||
adapter: str
|
||||
"""协议适配器名称"""
|
||||
self_id: str
|
||||
"""机器人 ID"""
|
||||
url: str
|
||||
"""URL"""
|
||||
headers: Dict[str, str] = field(default_factory=dict)
|
||||
"""HTTP headers"""
|
||||
reconnect: bool = True
|
||||
"""WebSocket 是否重连"""
|
||||
reconnect_interval: float = 3.0
|
||||
"""WebSocket 重连间隔"""
|
||||
class WebSocketServerSetup:
|
||||
path: str
|
||||
handle_func: Callable[[WebSocket], Awaitable[Any]]
|
||||
|
@ -12,38 +12,35 @@ FastAPI 驱动适配
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
|
||||
from typing import Any, List, Union, Callable, Optional, Awaitable
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from pydantic import BaseSettings
|
||||
from fastapi.responses import Response
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from starlette.websockets import WebSocketState
|
||||
from fastapi import Depends, FastAPI, Request, status
|
||||
from starlette.websockets import WebSocket as FastAPIWebSocket
|
||||
from starlette.websockets import WebSocketState, WebSocketDisconnect
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
|
||||
from nonebot.config import Env
|
||||
from nonebot.log import logger
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers import WebSocket
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import (
|
||||
HTTPRequest,
|
||||
HTTPResponse,
|
||||
ForwardDriver,
|
||||
ReverseDriver,
|
||||
WebSocketSetup,
|
||||
HTTPPollingSetup,
|
||||
HTTPConnection,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
)
|
||||
|
||||
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
|
||||
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
|
||||
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""
|
||||
@ -136,16 +133,7 @@ class Config(BaseSettings):
|
||||
|
||||
|
||||
class Driver(ReverseDriver):
|
||||
"""
|
||||
FastAPI 驱动框架。包含反向 Server 功能。
|
||||
|
||||
:上报地址:
|
||||
|
||||
* ``/{adapter name}/``: HTTP POST 上报
|
||||
* ``/{adapter name}/http/``: HTTP POST 上报
|
||||
* ``/{adapter name}/ws``: WebSocket 上报
|
||||
* ``/{adapter name}/ws/``: WebSocket 上报
|
||||
"""
|
||||
"""FastAPI 驱动框架。包含反向 Server 功能。"""
|
||||
|
||||
def __init__(self, env: Env, config: NoneBotConfig):
|
||||
super(Driver, self).__init__(env, config)
|
||||
@ -159,11 +147,6 @@ class Driver(ReverseDriver):
|
||||
redoc_url=self.fastapi_config.fastapi_redoc_url,
|
||||
)
|
||||
|
||||
self._server_app.post("/{adapter}/")(self._handle_http)
|
||||
self._server_app.post("/{adapter}/http")(self._handle_http)
|
||||
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
|
||||
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
def type(self) -> str:
|
||||
@ -188,6 +171,30 @@ class Driver(ReverseDriver):
|
||||
"""fastapi 使用的 logger"""
|
||||
return logging.getLogger("fastapi")
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def setup_http_server(self, setup: HTTPServerSetup):
|
||||
def _get_handle_func():
|
||||
return setup.handle_func
|
||||
|
||||
self._server_app.add_api_route(
|
||||
setup.path,
|
||||
partial(self._handle_http, handle_func=Depends(_get_handle_func)),
|
||||
methods=[setup.method],
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
|
||||
def _get_handle_func():
|
||||
return setup.handle_func
|
||||
|
||||
self._server_app.add_api_websocket_route(
|
||||
setup.path,
|
||||
partial(
|
||||
self._handle_ws,
|
||||
handle_func=Depends(_get_handle_func),
|
||||
),
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
|
||||
@ -241,19 +248,11 @@ class Driver(ReverseDriver):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _handle_http(self, adapter: str, request: Request):
|
||||
data = await request.body()
|
||||
|
||||
if adapter not in self._adapters:
|
||||
logger.warning(
|
||||
f"Unknown adapter {adapter}. Please register the adapter before use."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
|
||||
)
|
||||
|
||||
# 创建 Bot 对象
|
||||
BotClass = self._adapters[adapter]
|
||||
async def _handle_http(
|
||||
self,
|
||||
request: Request,
|
||||
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
|
||||
):
|
||||
http_request = HTTPRequest(
|
||||
request.scope["http_version"],
|
||||
request.url.scheme,
|
||||
@ -261,28 +260,17 @@ class Driver(ReverseDriver):
|
||||
request.scope["query_string"],
|
||||
dict(request.headers),
|
||||
request.method,
|
||||
data,
|
||||
await request.body(),
|
||||
)
|
||||
x_self_id, response = await BotClass.check_permission(self, http_request)
|
||||
|
||||
if not x_self_id:
|
||||
raise HTTPException(
|
||||
response and response.status or 401,
|
||||
response and response.body and response.body.decode("utf-8"),
|
||||
)
|
||||
response = await handle_func(http_request)
|
||||
return Response(response.body, response.status, response.headers)
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.warning(
|
||||
"There's already a reverse websocket connection,"
|
||||
"so the event may be handled twice."
|
||||
)
|
||||
|
||||
bot = BotClass(x_self_id, http_request)
|
||||
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
return Response(response and response.body, response and response.status or 200)
|
||||
|
||||
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
|
||||
async def _handle_ws(
|
||||
self,
|
||||
websocket: FastAPIWebSocket,
|
||||
handle_func: Callable[[WebSocket], Awaitable[Any]],
|
||||
):
|
||||
ws = WebSocket(
|
||||
websocket.scope.get("http_version", "1.1"),
|
||||
websocket.url.scheme,
|
||||
@ -292,55 +280,7 @@ class Driver(ReverseDriver):
|
||||
websocket,
|
||||
)
|
||||
|
||||
if adapter not in self._adapters:
|
||||
logger.warning(
|
||||
f"Unknown adapter {adapter}. Please register the adapter before use."
|
||||
)
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
# Create Bot Object
|
||||
BotClass = self._adapters[adapter]
|
||||
self_id, _ = await BotClass.check_permission(self, ws)
|
||||
|
||||
if not self_id:
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
if self_id in self._clients:
|
||||
logger.opt(colors=True).warning(
|
||||
"There's already a websocket connection, "
|
||||
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
|
||||
)
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
bot = BotClass(self_id, ws)
|
||||
|
||||
await ws.accept()
|
||||
logger.opt(colors=True).info(
|
||||
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
|
||||
f"Bot {escape_tag(self_id)}</y> Accepted!"
|
||||
)
|
||||
|
||||
self._bot_connect(bot)
|
||||
|
||||
try:
|
||||
while not ws.closed:
|
||||
try:
|
||||
data = await ws.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.error("WebSocket disconnected by peer.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error(
|
||||
"Error when receiving data from websocket."
|
||||
)
|
||||
break
|
||||
|
||||
asyncio.create_task(bot.handle_message(data.encode()))
|
||||
finally:
|
||||
self._bot_disconnect(bot)
|
||||
await handle_func(ws)
|
||||
|
||||
|
||||
class FullDriver(ForwardDriver, Driver):
|
||||
@ -354,17 +294,6 @@ class FullDriver(ForwardDriver, Driver):
|
||||
DRIVER=nonebot.drivers.fastapi:FullDriver
|
||||
"""
|
||||
|
||||
def __init__(self, env: Env, config: NoneBotConfig):
|
||||
super(FullDriver, self).__init__(env, config)
|
||||
|
||||
self.http_pollings: List[HTTPPOLLING_SETUP] = []
|
||||
self.websockets: List[WEBSOCKET_SETUP] = []
|
||||
self.shutdown: asyncio.Event = asyncio.Event()
|
||||
self.connections: List[asyncio.Task] = []
|
||||
|
||||
self.on_startup(self._run_forward)
|
||||
self.on_shutdown(self._shutdown_forward)
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
def type(self) -> str:
|
||||
@ -372,217 +301,25 @@ class FullDriver(ForwardDriver, Driver):
|
||||
return "fastapi_full"
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
|
||||
"""
|
||||
self.http_pollings.append(setup)
|
||||
async def request(self, setup: "HTTPRequest") -> Any:
|
||||
async with httpx.AsyncClient(
|
||||
http2=setup.http_version == "2", follow_redirects=True
|
||||
) as client:
|
||||
response = await client.request(
|
||||
setup.method,
|
||||
setup.url,
|
||||
content=setup.body,
|
||||
headers=setup.headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return HTTPResponse(
|
||||
response.status_code, response.content, response.headers
|
||||
)
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
|
||||
|
||||
:参数:
|
||||
|
||||
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
|
||||
"""
|
||||
self.websockets.append(setup)
|
||||
|
||||
def _run_forward(self):
|
||||
for setup in self.http_pollings:
|
||||
self.connections.append(asyncio.create_task(self._http_loop(setup)))
|
||||
for setup in self.websockets:
|
||||
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
|
||||
|
||||
def _shutdown_forward(self):
|
||||
self.shutdown.set()
|
||||
for task in self.connections:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
async def _prepare_setup(
|
||||
self, setup: Union[S, Callable[[], Awaitable[S]]]
|
||||
) -> Optional[S]:
|
||||
try:
|
||||
if callable(setup):
|
||||
return await setup()
|
||||
else:
|
||||
return setup
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Error while parsing setup "
|
||||
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
|
||||
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
return HTTPRequest(
|
||||
setup.http_version,
|
||||
url.scheme,
|
||||
url.path,
|
||||
url.query,
|
||||
setup.headers,
|
||||
setup.method,
|
||||
setup.body,
|
||||
)
|
||||
|
||||
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
|
||||
|
||||
http2: bool = False
|
||||
bot: Optional[Bot] = None
|
||||
request: Optional[HTTPRequest] = None
|
||||
client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
# FIXME: seperate const values from setup (self_id, adapter)
|
||||
# logger.opt(colors=True).info(
|
||||
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
|
||||
# f"Bot {escape_tag(_setup.self_id)}</y>"
|
||||
# )
|
||||
|
||||
try:
|
||||
while not self.shutdown.is_set():
|
||||
|
||||
setup = await self._prepare_setup(_setup)
|
||||
if not setup:
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
request = self._build_http_request(setup)
|
||||
if not request:
|
||||
await asyncio.sleep(setup.poll_interval)
|
||||
continue
|
||||
|
||||
if not client:
|
||||
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
|
||||
elif http2 != (setup.http_version == "2"):
|
||||
await client.aclose()
|
||||
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
|
||||
http2 = setup.http_version == "2"
|
||||
|
||||
if not bot:
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
else:
|
||||
bot.request = request
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
|
||||
)
|
||||
try:
|
||||
response = await client.request(
|
||||
request.method,
|
||||
setup.url,
|
||||
content=request.body,
|
||||
headers=request.headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.read()
|
||||
asyncio.create_task(bot.handle_message(data))
|
||||
except httpx.HTTPError as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
|
||||
await asyncio.sleep(setup.poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while http polling</bg #f8bbd0></r>"
|
||||
)
|
||||
finally:
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
if client:
|
||||
await client.aclose()
|
||||
|
||||
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
|
||||
bot: Optional[Bot] = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
||||
setup = await self._prepare_setup(_setup)
|
||||
if not setup:
|
||||
await asyncio.sleep(3)
|
||||
continue
|
||||
|
||||
url = httpx.URL(setup.url)
|
||||
if not url.netloc:
|
||||
logger.opt(colors=True).error(
|
||||
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
|
||||
)
|
||||
try:
|
||||
connection = Connect(setup.url, extra_headers=setup.headers)
|
||||
async with connection as ws:
|
||||
logger.opt(colors=True).info(
|
||||
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
|
||||
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
|
||||
)
|
||||
request = WebSocket(
|
||||
"1.1", url.scheme, url.path, url.query, setup.headers, ws
|
||||
)
|
||||
|
||||
BotClass = self._adapters[setup.adapter]
|
||||
bot = BotClass(setup.self_id, request)
|
||||
self._bot_connect(bot)
|
||||
while not self.shutdown.is_set():
|
||||
# use try except instead of "request.closed" because of queued message
|
||||
try:
|
||||
msg = await request.receive_bytes()
|
||||
asyncio.create_task(bot.handle_message(msg))
|
||||
except ConnectionClosed:
|
||||
logger.opt(colors=True).error(
|
||||
"<r><bg #f8bbd0>WebSocket connection closed. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f"<r><bg #f8bbd0>Error while connecting to {url}. "
|
||||
"Try to reconnect...</bg #f8bbd0></r>"
|
||||
)
|
||||
finally:
|
||||
if bot:
|
||||
self._bot_disconnect(bot)
|
||||
bot = None
|
||||
|
||||
if not setup.reconnect:
|
||||
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
|
||||
break
|
||||
await asyncio.sleep(setup.reconnect_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
"<r><bg #f8bbd0>Unexpected exception occurred "
|
||||
"while websocket loop</bg #f8bbd0></r>"
|
||||
)
|
||||
async def websocket(self, setup: "HTTPConnection") -> Any:
|
||||
ws = await Connect(setup.url, extra_headers=setup.headers)
|
||||
return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -111,9 +111,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
|
||||
return func
|
||||
|
||||
|
||||
# FIXME: run handler with try/except skipped exception
|
||||
|
||||
|
||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
|
||||
try:
|
||||
return await coro
|
||||
|
@ -14,13 +14,12 @@ from contextlib import AsyncExitStack
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Type,
|
||||
Tuple,
|
||||
Union,
|
||||
Callable,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Coroutine,
|
||||
)
|
||||
|
||||
from nonebot import params
|
||||
@ -30,6 +29,13 @@ from nonebot.exception import SkippedException
|
||||
from nonebot.typing import T_PermissionChecker
|
||||
|
||||
|
||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
||||
try:
|
||||
return await coro
|
||||
except SkippedException:
|
||||
return False
|
||||
|
||||
|
||||
class Permission:
|
||||
"""
|
||||
:说明:
|
||||
@ -100,20 +106,18 @@ class Permission:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
checker(
|
||||
bot=bot,
|
||||
event=event,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache,
|
||||
_run_coro_with_catch(
|
||||
checker(
|
||||
bot=bot,
|
||||
event=event,
|
||||
_stack=stack,
|
||||
_dependency_cache=dependency_cache,
|
||||
)
|
||||
)
|
||||
for checker in self.checkers
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
return next(
|
||||
filter(lambda x: bool(x) and not isinstance(x, SkippedException), results),
|
||||
False,
|
||||
)
|
||||
return any(results)
|
||||
|
||||
def __and__(self, other) -> NoReturn:
|
||||
raise RuntimeError("And operation between Permissions is not allowed.")
|
||||
|
Loading…
Reference in New Issue
Block a user