♻️ rewrite adapter abc class

This commit is contained in:
yanyongyu 2021-12-06 22:19:05 +08:00
parent 180aaadda9
commit d80c02ae46
7 changed files with 172 additions and 437 deletions

View File

@ -22,6 +22,7 @@ except Exception:
from ._bot import Bot as Bot from ._bot import Bot as Bot
from ._event import Event as Event from ._event import Event as Event
from ._adapter import Adapter as Adapter
from ._message import Message as Message from ._message import Message as Message
from ._message import MessageSegment as MessageSegment from ._message import MessageSegment as MessageSegment
from ._template import MessageTemplate as MessageTemplate from ._template import MessageTemplate as MessageTemplate

View File

@ -0,0 +1,59 @@
import abc
from typing import Any, Dict
from ._bot import Bot
from nonebot.config import Config
from nonebot.drivers import (
Driver,
ForwardDriver,
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
)
class Adapter(abc.ABC):
def __init__(self, driver: Driver, **kwargs: Any):
self.driver = driver
self.bots: Dict[str, Bot] = {}
@classmethod
@abc.abstractmethod
def get_name(cls) -> str:
raise NotImplementedError
@property
def config(self) -> Config:
return self.driver.config
def bot_connect(self, bot: Bot):
self.driver._bot_connect(bot)
self.bots[bot.self_id] = bot
def bot_disconnect(self, bot: Bot):
self.driver._bot_disconnect(bot)
self.bots.pop(bot.self_id, None)
def setup_http_server(self, setup: HTTPServerSetup):
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support http server")
self.driver.setup_http_server(setup)
def setup_websocket_server(self, setup: WebSocketServerSetup):
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support websocket server")
self.driver.setup_websocket_server(setup)
@abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数实现该方法以调用 api
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError

View File

@ -12,6 +12,7 @@ from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
if TYPE_CHECKING: if TYPE_CHECKING:
from ._event import Event from ._event import Event
from ._adapter import Adapter
from ._message import Message, MessageSegment from ._message import Message, MessageSegment
@ -25,10 +26,6 @@ class Bot(abc.ABC):
Bot 基类用于处理上报消息并提供 API 调用接口 Bot 基类用于处理上报消息并提供 API 调用接口
""" """
driver: Driver
"""Driver 对象"""
config: Config
"""Config 配置对象"""
_calling_api_hook: Set[T_CallingAPIHook] = set() _calling_api_hook: Set[T_CallingAPIHook] = set()
""" """
:类型: ``Set[T_CallingAPIHook]`` :类型: ``Set[T_CallingAPIHook]``
@ -40,36 +37,27 @@ class Bot(abc.ABC):
:说明: call_api 后执行的函数 :说明: call_api 后执行的函数
""" """
def __init__(self, self_id: str, request: HTTPConnection): def __init__(self, adapter: "Adapter", self_id: str):
""" """
:参数: :参数:
* ``self_id: str``: 机器人 ID * ``self_id: str``: 机器人 ID
* ``request: HTTPConnection``: request 连接对象 * ``request: HTTPConnection``: request 连接对象
""" """
self.adapter = adapter
self.self_id: str = self_id self.self_id: str = self_id
"""机器人 ID""" """机器人 ID"""
self.request: HTTPConnection = request
"""连接信息"""
def __getattr__(self, name: str) -> _ApiCall: def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name) return partial(self.call_api, name)
@property @property
@abc.abstractmethod
def type(self) -> str: def type(self) -> str:
"""Adapter 类型""" return self.adapter.get_name()
raise NotImplementedError
@classmethod @property
def register(cls, driver: Driver, config: Config, **kwargs): def config(self) -> Config:
""" return self.adapter.config
:说明:
``register`` 方法会在 ``driver.register_adapter`` 时被调用用于初始化相关配置
"""
cls.driver = driver
cls.config = config
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
@ -106,20 +94,6 @@ class Bot(abc.ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数实现该方法以调用 api
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError
async def call_api(self, api: str, **data: Any) -> Any: async def call_api(self, api: str, **data: Any) -> Any:
""" """
:说明: :说明:
@ -162,7 +136,7 @@ class Bot(abc.ABC):
if not skip_calling_api: if not skip_calling_api:
try: try:
result = await self._call_api(api, **data) result = await self.adapter._call_api(api, **data)
except Exception as e: except Exception as e:
exception = e exception = e

View File

@ -26,7 +26,7 @@ from nonebot.config import Env, Config
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot from nonebot.adapters import Bot, Adapter
class Driver(abc.ABC): class Driver(abc.ABC):
@ -34,9 +34,9 @@ class Driver(abc.ABC):
Driver 基类 Driver 基类
""" """
_adapters: Dict[str, Type["Bot"]] = {} _adapters: Dict[str, "Adapter"] = {}
""" """
:类型: ``Dict[str, Type[Bot]]`` :类型: ``Dict[str, Adapter]``
:说明: 已注册的适配器列表 :说明: 已注册的适配器列表
""" """
_bot_connection_hook: Set[T_BotConnectionHook] = set() _bot_connection_hook: Set[T_BotConnectionHook] = set()
@ -85,7 +85,7 @@ class Driver(abc.ABC):
""" """
return self._clients return self._clients
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs): def register_adapter(self, adapter: Type["Adapter"], **kwargs):
""" """
:说明: :说明:
@ -97,13 +97,13 @@ class Driver(abc.ABC):
* ``adapter: Type[Bot]``: 适配器 Class * ``adapter: Type[Bot]``: 适配器 Class
* ``**kwargs``: 其他传递给适配器的参数 * ``**kwargs``: 其他传递给适配器的参数
""" """
name = adapter.get_name()
if name in self._adapters: if name in self._adapters:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists' f'Adapter "<y>{escape_tag(name)}</y>" already exists'
) )
return return
self._adapters[name] = adapter self._adapters[name] = adapter(self, **kwargs)
adapter.register(self, self.config, **kwargs)
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"' f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
) )
@ -213,34 +213,11 @@ class ForwardDriver(Driver):
""" """
@abc.abstractmethod @abc.abstractmethod
def setup_http_polling( async def request(self, setup: "HTTPRequest") -> Any:
self,
setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
) -> None:
"""
:说明:
注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def setup_websocket( async def websocket(self, setup: "HTTPConnection") -> Any:
self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
) -> None:
"""
:说明:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
raise NotImplementedError raise NotImplementedError
@ -261,7 +238,16 @@ class ReverseDriver(Driver):
"""驱动 ASGI 对象""" """驱动 ASGI 对象"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
raise NotImplementedError
# TODO: repack dataclass
@dataclass @dataclass
class HTTPConnection(abc.ABC): class HTTPConnection(abc.ABC):
http_version: str http_version: str
@ -401,36 +387,13 @@ class WebSocket(HTTPConnection, abc.ABC):
@dataclass @dataclass
class HTTPPollingSetup: class HTTPServerSetup:
adapter: str path: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
method: str method: str
"""HTTP method""" handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
body: bytes
"""HTTP body"""
headers: Dict[str, str]
"""HTTP headers"""
http_version: str
"""HTTP version"""
poll_interval: float
"""HTTP 轮询间隔"""
@dataclass @dataclass
class WebSocketSetup: class WebSocketServerSetup:
adapter: str path: str
"""协议适配器名称""" handle_func: Callable[[WebSocket], Awaitable[Any]]
self_id: str
"""机器人 ID"""
url: str
"""URL"""
headers: Dict[str, str] = field(default_factory=dict)
"""HTTP headers"""
reconnect: bool = True
"""WebSocket 是否重连"""
reconnect_interval: float = 3.0
"""WebSocket 重连间隔"""

View File

@ -12,38 +12,35 @@ FastAPI 驱动适配
import asyncio import asyncio
import logging import logging
from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast from typing import Any, List, Union, Callable, Optional, Awaitable
import httpx import httpx
import uvicorn import uvicorn
from pydantic import BaseSettings from pydantic import BaseSettings
from fastapi.responses import Response from fastapi.responses import Response
from websockets.exceptions import ConnectionClosed from starlette.websockets import WebSocketState
from fastapi import FastAPI, Request, HTTPException, status from fastapi import Depends, FastAPI, Request, status
from starlette.websockets import WebSocket as FastAPIWebSocket from starlette.websockets import WebSocket as FastAPIWebSocket
from starlette.websockets import WebSocketState, WebSocketDisconnect
from websockets.legacy.client import Connect, WebSocketClientProtocol from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env from nonebot.config import Env
from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.drivers import WebSocket
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ( from nonebot.drivers import (
HTTPRequest, HTTPRequest,
HTTPResponse,
ForwardDriver, ForwardDriver,
ReverseDriver, ReverseDriver,
WebSocketSetup, HTTPConnection,
HTTPPollingSetup, HTTPServerSetup,
WebSocketServerSetup,
) )
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
class Config(BaseSettings): class Config(BaseSettings):
""" """
@ -136,16 +133,7 @@ class Config(BaseSettings):
class Driver(ReverseDriver): class Driver(ReverseDriver):
""" """FastAPI 驱动框架。包含反向 Server 功能。"""
FastAPI 驱动框架包含反向 Server 功能
:上报地址:
* ``/{adapter name}/``: HTTP POST 上报
* ``/{adapter name}/http/``: HTTP POST 上报
* ``/{adapter name}/ws``: WebSocket 上报
* ``/{adapter name}/ws/``: WebSocket 上报
"""
def __init__(self, env: Env, config: NoneBotConfig): def __init__(self, env: Env, config: NoneBotConfig):
super(Driver, self).__init__(env, config) super(Driver, self).__init__(env, config)
@ -159,11 +147,6 @@ class Driver(ReverseDriver):
redoc_url=self.fastapi_config.fastapi_redoc_url, redoc_url=self.fastapi_config.fastapi_redoc_url,
) )
self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http)
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
@property @property
@overrides(ReverseDriver) @overrides(ReverseDriver)
def type(self) -> str: def type(self) -> str:
@ -188,6 +171,30 @@ class Driver(ReverseDriver):
"""fastapi 使用的 logger""" """fastapi 使用的 logger"""
return logging.getLogger("fastapi") return logging.getLogger("fastapi")
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
def _get_handle_func():
return setup.handle_func
self._server_app.add_api_route(
setup.path,
partial(self._handle_http, handle_func=Depends(_get_handle_func)),
methods=[setup.method],
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
def _get_handle_func():
return setup.handle_func
self._server_app.add_api_websocket_route(
setup.path,
partial(
self._handle_ws,
handle_func=Depends(_get_handle_func),
),
)
@overrides(ReverseDriver) @overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable: def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_""" """参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
@ -241,19 +248,11 @@ class Driver(ReverseDriver):
**kwargs, **kwargs,
) )
async def _handle_http(self, adapter: str, request: Request): async def _handle_http(
data = await request.body() self,
request: Request,
if adapter not in self._adapters: handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
logger.warning( ):
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
)
# 创建 Bot 对象
BotClass = self._adapters[adapter]
http_request = HTTPRequest( http_request = HTTPRequest(
request.scope["http_version"], request.scope["http_version"],
request.url.scheme, request.url.scheme,
@ -261,28 +260,17 @@ class Driver(ReverseDriver):
request.scope["query_string"], request.scope["query_string"],
dict(request.headers), dict(request.headers),
request.method, request.method,
data, await request.body(),
)
x_self_id, response = await BotClass.check_permission(self, http_request)
if not x_self_id:
raise HTTPException(
response and response.status or 401,
response and response.body and response.body.decode("utf-8"),
) )
if x_self_id in self._clients: response = await handle_func(http_request)
logger.warning( return Response(response.body, response.status, response.headers)
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(x_self_id, http_request) async def _handle_ws(
self,
asyncio.create_task(bot.handle_message(data)) websocket: FastAPIWebSocket,
return Response(response and response.body, response and response.status or 200) handle_func: Callable[[WebSocket], Awaitable[Any]],
):
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
ws = WebSocket( ws = WebSocket(
websocket.scope.get("http_version", "1.1"), websocket.scope.get("http_version", "1.1"),
websocket.url.scheme, websocket.url.scheme,
@ -292,55 +280,7 @@ class Driver(ReverseDriver):
websocket, websocket,
) )
if adapter not in self._adapters: await handle_func(ws)
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Create Bot Object
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, ws)
if not self_id:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if self_id in self._clients:
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
try:
while not ws.closed:
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
class FullDriver(ForwardDriver, Driver): class FullDriver(ForwardDriver, Driver):
@ -354,17 +294,6 @@ class FullDriver(ForwardDriver, Driver):
DRIVER=nonebot.drivers.fastapi:FullDriver DRIVER=nonebot.drivers.fastapi:FullDriver
""" """
def __init__(self, env: Env, config: NoneBotConfig):
super(FullDriver, self).__init__(env, config)
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
@property @property
@overrides(ForwardDriver) @overrides(ForwardDriver)
def type(self) -> str: def type(self) -> str:
@ -372,217 +301,25 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full" return "fastapi_full"
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None: async def request(self, setup: "HTTPRequest") -> Any:
""" async with httpx.AsyncClient(
:说明: http2=setup.http_version == "2", follow_redirects=True
) as client:
注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
@overrides(ForwardDriver)
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
async def _prepare_setup(
self, setup: Union[S, Callable[[], Awaitable[S]]]
) -> Optional[S]:
try:
if callable(setup):
return await setup()
else:
return setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
return
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
setup.headers,
setup.method,
setup.body,
)
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
http2: bool = False
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
client: Optional[httpx.AsyncClient] = None
# FIXME: seperate const values from setup (self_id, adapter)
# logger.opt(colors=True).info(
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
# f"Bot {escape_tag(_setup.self_id)}</y>"
# )
try:
while not self.shutdown.is_set():
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
request = self._build_http_request(setup)
if not request:
await asyncio.sleep(setup.poll_interval)
continue
if not client:
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
elif http2 != (setup.http_version == "2"):
await client.aclose()
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
http2 = setup.http_version == "2"
if not bot:
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
else:
bot.request = request
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
)
try:
response = await client.request( response = await client.request(
request.method, setup.method,
setup.url, setup.url,
content=request.body, content=setup.body,
headers=request.headers, headers=setup.headers,
timeout=30.0, timeout=30.0,
) )
response.raise_for_status() return HTTPResponse(
data = response.read() response.status_code, response.content, response.headers
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
) )
await asyncio.sleep(setup.poll_interval) @overrides(ForwardDriver)
async def websocket(self, setup: "HTTPConnection") -> Any:
except asyncio.CancelledError: ws = await Connect(setup.url, extra_headers=setup.headers)
pass return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
if client:
await client.aclose()
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
while True:
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, setup.headers, ws
)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
try:
msg = await request.receive_bytes()
asyncio.create_task(bot.handle_message(msg))
except ConnectionClosed:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed. "
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
bot = None
if not setup.reconnect:
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
break
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>"
)
@dataclass @dataclass

View File

@ -111,9 +111,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func return func
# FIXME: run handler with try/except skipped exception
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any: async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
try: try:
return await coro return await coro

View File

@ -14,13 +14,12 @@ from contextlib import AsyncExitStack
from typing import ( from typing import (
Any, Any,
Dict, Dict,
List,
Type,
Tuple, Tuple,
Union, Union,
Callable, Callable,
NoReturn, NoReturn,
Optional, Optional,
Coroutine,
) )
from nonebot import params from nonebot import params
@ -30,6 +29,13 @@ from nonebot.exception import SkippedException
from nonebot.typing import T_PermissionChecker from nonebot.typing import T_PermissionChecker
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
try:
return await coro
except SkippedException:
return False
class Permission: class Permission:
""" """
:说明: :说明:
@ -100,20 +106,18 @@ class Permission:
return True return True
results = await asyncio.gather( results = await asyncio.gather(
*( *(
_run_coro_with_catch(
checker( checker(
bot=bot, bot=bot,
event=event, event=event,
_stack=stack, _stack=stack,
_dependency_cache=dependency_cache, _dependency_cache=dependency_cache,
) )
)
for checker in self.checkers for checker in self.checkers
), ),
return_exceptions=True,
)
return next(
filter(lambda x: bool(x) and not isinstance(x, SkippedException), results),
False,
) )
return any(results)
def __and__(self, other) -> NoReturn: def __and__(self, other) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.") raise RuntimeError("And operation between Permissions is not allowed.")