diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py
index a3080971..2bae9974 100644
--- a/nonebot/adapters/__init__.py
+++ b/nonebot/adapters/__init__.py
@@ -22,6 +22,7 @@ except Exception:
from ._bot import Bot as Bot
from ._event import Event as Event
+from ._adapter import Adapter as Adapter
from ._message import Message as Message
from ._message import MessageSegment as MessageSegment
from ._template import MessageTemplate as MessageTemplate
diff --git a/nonebot/adapters/_adapter.py b/nonebot/adapters/_adapter.py
new file mode 100644
index 00000000..92038d78
--- /dev/null
+++ b/nonebot/adapters/_adapter.py
@@ -0,0 +1,59 @@
+import abc
+from typing import Any, Dict
+
+from ._bot import Bot
+from nonebot.config import Config
+from nonebot.drivers import (
+ Driver,
+ ForwardDriver,
+ ReverseDriver,
+ HTTPServerSetup,
+ WebSocketServerSetup,
+)
+
+
+class Adapter(abc.ABC):
+ def __init__(self, driver: Driver, **kwargs: Any):
+ self.driver = driver
+ self.bots: Dict[str, Bot] = {}
+
+ @classmethod
+ @abc.abstractmethod
+ def get_name(cls) -> str:
+ raise NotImplementedError
+
+ @property
+ def config(self) -> Config:
+ return self.driver.config
+
+ def bot_connect(self, bot: Bot):
+ self.driver._bot_connect(bot)
+ self.bots[bot.self_id] = bot
+
+ def bot_disconnect(self, bot: Bot):
+ self.driver._bot_disconnect(bot)
+ self.bots.pop(bot.self_id, None)
+
+ def setup_http_server(self, setup: HTTPServerSetup):
+ if not isinstance(self.driver, ReverseDriver):
+ raise TypeError("Current driver does not support http server")
+ self.driver.setup_http_server(setup)
+
+ def setup_websocket_server(self, setup: WebSocketServerSetup):
+ if not isinstance(self.driver, ReverseDriver):
+ raise TypeError("Current driver does not support websocket server")
+ self.driver.setup_websocket_server(setup)
+
+ @abc.abstractmethod
+ async def _call_api(self, api: str, **data) -> Any:
+ """
+ :说明:
+
+ ``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
+
+ :参数:
+
+ * ``api: str``: API 名称
+ * ``**data``: API 数据
+ """
+ raise NotImplementedError
diff --git a/nonebot/adapters/_bot.py b/nonebot/adapters/_bot.py
index 5e3aec7f..5aba33ef 100644
--- a/nonebot/adapters/_bot.py
+++ b/nonebot/adapters/_bot.py
@@ -12,6 +12,7 @@ from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
if TYPE_CHECKING:
from ._event import Event
+ from ._adapter import Adapter
from ._message import Message, MessageSegment
@@ -25,10 +26,6 @@ class Bot(abc.ABC):
Bot 基类。用于处理上报消息,并提供 API 调用接口。
"""
- driver: Driver
- """Driver 对象"""
- config: Config
- """Config 配置对象"""
_calling_api_hook: Set[T_CallingAPIHook] = set()
"""
:类型: ``Set[T_CallingAPIHook]``
@@ -40,36 +37,27 @@ class Bot(abc.ABC):
:说明: call_api 后执行的函数
"""
- def __init__(self, self_id: str, request: HTTPConnection):
+ def __init__(self, adapter: "Adapter", self_id: str):
"""
:参数:
* ``self_id: str``: 机器人 ID
* ``request: HTTPConnection``: request 连接对象
"""
+ self.adapter = adapter
self.self_id: str = self_id
"""机器人 ID"""
- self.request: HTTPConnection = request
- """连接信息"""
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
- @abc.abstractmethod
def type(self) -> str:
- """Adapter 类型"""
- raise NotImplementedError
+ return self.adapter.get_name()
- @classmethod
- def register(cls, driver: Driver, config: Config, **kwargs):
- """
- :说明:
-
- ``register`` 方法会在 ``driver.register_adapter`` 时被调用,用于初始化相关配置
- """
- cls.driver = driver
- cls.config = config
+ @property
+ def config(self) -> Config:
+ return self.adapter.config
@classmethod
@abc.abstractmethod
@@ -106,20 +94,6 @@ class Bot(abc.ABC):
"""
raise NotImplementedError
- @abc.abstractmethod
- async def _call_api(self, api: str, **data) -> Any:
- """
- :说明:
-
- ``adapter`` 实际调用 api 的逻辑实现函数,实现该方法以调用 api。
-
- :参数:
-
- * ``api: str``: API 名称
- * ``**data``: API 数据
- """
- raise NotImplementedError
-
async def call_api(self, api: str, **data: Any) -> Any:
"""
:说明:
@@ -162,7 +136,7 @@ class Bot(abc.ABC):
if not skip_calling_api:
try:
- result = await self._call_api(api, **data)
+ result = await self.adapter._call_api(api, **data)
except Exception as e:
exception = e
diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py
index f5e87fca..1785e1b6 100644
--- a/nonebot/drivers/__init__.py
+++ b/nonebot/drivers/__init__.py
@@ -26,7 +26,7 @@ from nonebot.config import Env, Config
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
- from nonebot.adapters import Bot
+ from nonebot.adapters import Bot, Adapter
class Driver(abc.ABC):
@@ -34,9 +34,9 @@ class Driver(abc.ABC):
Driver 基类。
"""
- _adapters: Dict[str, Type["Bot"]] = {}
+ _adapters: Dict[str, "Adapter"] = {}
"""
- :类型: ``Dict[str, Type[Bot]]``
+ :类型: ``Dict[str, Adapter]``
:说明: 已注册的适配器列表
"""
_bot_connection_hook: Set[T_BotConnectionHook] = set()
@@ -85,7 +85,7 @@ class Driver(abc.ABC):
"""
return self._clients
- def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
+ def register_adapter(self, adapter: Type["Adapter"], **kwargs):
"""
:说明:
@@ -97,13 +97,13 @@ class Driver(abc.ABC):
* ``adapter: Type[Bot]``: 适配器 Class
* ``**kwargs``: 其他传递给适配器的参数
"""
+ name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "{escape_tag(name)}" already exists'
)
return
- self._adapters[name] = adapter
- adapter.register(self, self.config, **kwargs)
+ self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "{escape_tag(name)}"'
)
@@ -213,34 +213,11 @@ class ForwardDriver(Driver):
"""
@abc.abstractmethod
- def setup_http_polling(
- self,
- setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
- ) -> None:
- """
- :说明:
-
- 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
-
- :参数:
-
- * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
- """
+ async def request(self, setup: "HTTPRequest") -> Any:
raise NotImplementedError
@abc.abstractmethod
- def setup_websocket(
- self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
- ) -> None:
- """
- :说明:
-
- 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
-
- :参数:
-
- * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
- """
+ async def websocket(self, setup: "HTTPConnection") -> Any:
raise NotImplementedError
@@ -261,7 +238,16 @@ class ReverseDriver(Driver):
"""驱动 ASGI 对象"""
raise NotImplementedError
+ @abc.abstractmethod
+ def setup_http_server(self, setup: "HTTPServerSetup") -> None:
+ raise NotImplementedError
+ @abc.abstractmethod
+ def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
+ raise NotImplementedError
+
+
+# TODO: repack dataclass
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
@@ -401,36 +387,13 @@ class WebSocket(HTTPConnection, abc.ABC):
@dataclass
-class HTTPPollingSetup:
- adapter: str
- """协议适配器名称"""
- self_id: str
- """机器人 ID"""
- url: str
- """URL"""
+class HTTPServerSetup:
+ path: str
method: str
- """HTTP method"""
- body: bytes
- """HTTP body"""
- headers: Dict[str, str]
- """HTTP headers"""
- http_version: str
- """HTTP version"""
- poll_interval: float
- """HTTP 轮询间隔"""
+ handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
@dataclass
-class WebSocketSetup:
- adapter: str
- """协议适配器名称"""
- self_id: str
- """机器人 ID"""
- url: str
- """URL"""
- headers: Dict[str, str] = field(default_factory=dict)
- """HTTP headers"""
- reconnect: bool = True
- """WebSocket 是否重连"""
- reconnect_interval: float = 3.0
- """WebSocket 重连间隔"""
+class WebSocketServerSetup:
+ path: str
+ handle_func: Callable[[WebSocket], Awaitable[Any]]
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index ac23081b..9fea8d95 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -12,38 +12,35 @@ FastAPI 驱动适配
import asyncio
import logging
+from functools import partial
from dataclasses import dataclass
-from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
+from typing import Any, List, Union, Callable, Optional, Awaitable
import httpx
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
-from websockets.exceptions import ConnectionClosed
-from fastapi import FastAPI, Request, HTTPException, status
+from starlette.websockets import WebSocketState
+from fastapi import Depends, FastAPI, Request, status
from starlette.websockets import WebSocket as FastAPIWebSocket
-from starlette.websockets import WebSocketState, WebSocketDisconnect
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env
-from nonebot.log import logger
-from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.utils import escape_tag
+from nonebot.drivers import WebSocket
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (
HTTPRequest,
+ HTTPResponse,
ForwardDriver,
ReverseDriver,
- WebSocketSetup,
- HTTPPollingSetup,
+ HTTPConnection,
+ HTTPServerSetup,
+ WebSocketServerSetup,
)
-S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
-HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
-WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
-
class Config(BaseSettings):
"""
@@ -136,16 +133,7 @@ class Config(BaseSettings):
class Driver(ReverseDriver):
- """
- FastAPI 驱动框架。包含反向 Server 功能。
-
- :上报地址:
-
- * ``/{adapter name}/``: HTTP POST 上报
- * ``/{adapter name}/http/``: HTTP POST 上报
- * ``/{adapter name}/ws``: WebSocket 上报
- * ``/{adapter name}/ws/``: WebSocket 上报
- """
+ """FastAPI 驱动框架。包含反向 Server 功能。"""
def __init__(self, env: Env, config: NoneBotConfig):
super(Driver, self).__init__(env, config)
@@ -159,11 +147,6 @@ class Driver(ReverseDriver):
redoc_url=self.fastapi_config.fastapi_redoc_url,
)
- self._server_app.post("/{adapter}/")(self._handle_http)
- self._server_app.post("/{adapter}/http")(self._handle_http)
- self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
- self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
-
@property
@overrides(ReverseDriver)
def type(self) -> str:
@@ -188,6 +171,30 @@ class Driver(ReverseDriver):
"""fastapi 使用的 logger"""
return logging.getLogger("fastapi")
+ @overrides(ReverseDriver)
+ def setup_http_server(self, setup: HTTPServerSetup):
+ def _get_handle_func():
+ return setup.handle_func
+
+ self._server_app.add_api_route(
+ setup.path,
+ partial(self._handle_http, handle_func=Depends(_get_handle_func)),
+ methods=[setup.method],
+ )
+
+ @overrides(ReverseDriver)
+ def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
+ def _get_handle_func():
+ return setup.handle_func
+
+ self._server_app.add_api_websocket_route(
+ setup.path,
+ partial(
+ self._handle_ws,
+ handle_func=Depends(_get_handle_func),
+ ),
+ )
+
@overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events `_"""
@@ -241,19 +248,11 @@ class Driver(ReverseDriver):
**kwargs,
)
- async def _handle_http(self, adapter: str, request: Request):
- data = await request.body()
-
- if adapter not in self._adapters:
- logger.warning(
- f"Unknown adapter {adapter}. Please register the adapter before use."
- )
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
- )
-
- # 创建 Bot 对象
- BotClass = self._adapters[adapter]
+ async def _handle_http(
+ self,
+ request: Request,
+ handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
+ ):
http_request = HTTPRequest(
request.scope["http_version"],
request.url.scheme,
@@ -261,28 +260,17 @@ class Driver(ReverseDriver):
request.scope["query_string"],
dict(request.headers),
request.method,
- data,
+ await request.body(),
)
- x_self_id, response = await BotClass.check_permission(self, http_request)
- if not x_self_id:
- raise HTTPException(
- response and response.status or 401,
- response and response.body and response.body.decode("utf-8"),
- )
+ response = await handle_func(http_request)
+ return Response(response.body, response.status, response.headers)
- if x_self_id in self._clients:
- logger.warning(
- "There's already a reverse websocket connection,"
- "so the event may be handled twice."
- )
-
- bot = BotClass(x_self_id, http_request)
-
- asyncio.create_task(bot.handle_message(data))
- return Response(response and response.body, response and response.status or 200)
-
- async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
+ async def _handle_ws(
+ self,
+ websocket: FastAPIWebSocket,
+ handle_func: Callable[[WebSocket], Awaitable[Any]],
+ ):
ws = WebSocket(
websocket.scope.get("http_version", "1.1"),
websocket.url.scheme,
@@ -292,55 +280,7 @@ class Driver(ReverseDriver):
websocket,
)
- if adapter not in self._adapters:
- logger.warning(
- f"Unknown adapter {adapter}. Please register the adapter before use."
- )
- await ws.close(code=status.WS_1008_POLICY_VIOLATION)
- return
-
- # Create Bot Object
- BotClass = self._adapters[adapter]
- self_id, _ = await BotClass.check_permission(self, ws)
-
- if not self_id:
- await ws.close(code=status.WS_1008_POLICY_VIOLATION)
- return
-
- if self_id in self._clients:
- logger.opt(colors=True).warning(
- "There's already a websocket connection, "
- f"{escape_tag(adapter.upper())} Bot {escape_tag(self_id)} ignored."
- )
- await ws.close(code=status.WS_1008_POLICY_VIOLATION)
- return
-
- bot = BotClass(self_id, ws)
-
- await ws.accept()
- logger.opt(colors=True).info(
- f"WebSocket Connection from {escape_tag(adapter.upper())} "
- f"Bot {escape_tag(self_id)} Accepted!"
- )
-
- self._bot_connect(bot)
-
- try:
- while not ws.closed:
- try:
- data = await ws.receive()
- except WebSocketDisconnect:
- logger.error("WebSocket disconnected by peer.")
- break
- except Exception as e:
- logger.opt(exception=e).error(
- "Error when receiving data from websocket."
- )
- break
-
- asyncio.create_task(bot.handle_message(data.encode()))
- finally:
- self._bot_disconnect(bot)
+ await handle_func(ws)
class FullDriver(ForwardDriver, Driver):
@@ -354,17 +294,6 @@ class FullDriver(ForwardDriver, Driver):
DRIVER=nonebot.drivers.fastapi:FullDriver
"""
- def __init__(self, env: Env, config: NoneBotConfig):
- super(FullDriver, self).__init__(env, config)
-
- self.http_pollings: List[HTTPPOLLING_SETUP] = []
- self.websockets: List[WEBSOCKET_SETUP] = []
- self.shutdown: asyncio.Event = asyncio.Event()
- self.connections: List[asyncio.Task] = []
-
- self.on_startup(self._run_forward)
- self.on_shutdown(self._shutdown_forward)
-
@property
@overrides(ForwardDriver)
def type(self) -> str:
@@ -372,217 +301,25 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full"
@overrides(ForwardDriver)
- def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
- """
- :说明:
-
- 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
-
- :参数:
-
- * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
- """
- self.http_pollings.append(setup)
+ async def request(self, setup: "HTTPRequest") -> Any:
+ async with httpx.AsyncClient(
+ http2=setup.http_version == "2", follow_redirects=True
+ ) as client:
+ response = await client.request(
+ setup.method,
+ setup.url,
+ content=setup.body,
+ headers=setup.headers,
+ timeout=30.0,
+ )
+ return HTTPResponse(
+ response.status_code, response.content, response.headers
+ )
@overrides(ForwardDriver)
- def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
- """
- :说明:
-
- 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
-
- :参数:
-
- * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
- """
- self.websockets.append(setup)
-
- def _run_forward(self):
- for setup in self.http_pollings:
- self.connections.append(asyncio.create_task(self._http_loop(setup)))
- for setup in self.websockets:
- self.connections.append(asyncio.create_task(self._ws_loop(setup)))
-
- def _shutdown_forward(self):
- self.shutdown.set()
- for task in self.connections:
- if not task.done():
- task.cancel()
-
- async def _prepare_setup(
- self, setup: Union[S, Callable[[], Awaitable[S]]]
- ) -> Optional[S]:
- try:
- if callable(setup):
- return await setup()
- else:
- return setup
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error while parsing setup "
- f"{escape_tag(repr(setup))}."
- )
- return
-
- def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
- url = httpx.URL(setup.url)
- if not url.netloc:
- logger.opt(colors=True).error(
- f"Error parsing url {escape_tag(str(url))}"
- )
- return
- return HTTPRequest(
- setup.http_version,
- url.scheme,
- url.path,
- url.query,
- setup.headers,
- setup.method,
- setup.body,
- )
-
- async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
-
- http2: bool = False
- bot: Optional[Bot] = None
- request: Optional[HTTPRequest] = None
- client: Optional[httpx.AsyncClient] = None
-
- # FIXME: seperate const values from setup (self_id, adapter)
- # logger.opt(colors=True).info(
- # f"Start http polling for {escape_tag(_setup.adapter.upper())} "
- # f"Bot {escape_tag(_setup.self_id)}"
- # )
-
- try:
- while not self.shutdown.is_set():
-
- setup = await self._prepare_setup(_setup)
- if not setup:
- await asyncio.sleep(3)
- continue
- request = self._build_http_request(setup)
- if not request:
- await asyncio.sleep(setup.poll_interval)
- continue
-
- if not client:
- client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
- elif http2 != (setup.http_version == "2"):
- await client.aclose()
- client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
- http2 = setup.http_version == "2"
-
- if not bot:
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
- self._bot_connect(bot)
- else:
- bot.request = request
-
- logger.debug(
- f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
- )
- try:
- response = await client.request(
- request.method,
- setup.url,
- content=request.body,
- headers=request.headers,
- timeout=30.0,
- )
- response.raise_for_status()
- data = response.read()
- asyncio.create_task(bot.handle_message(data))
- except httpx.HTTPError as e:
- logger.opt(colors=True, exception=e).error(
- f"Error occurred while requesting {escape_tag(setup.url)}. "
- "Try to reconnect..."
- )
-
- await asyncio.sleep(setup.poll_interval)
-
- except asyncio.CancelledError:
- pass
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Unexpected exception occurred "
- "while http polling"
- )
- finally:
- if bot:
- self._bot_disconnect(bot)
- if client:
- await client.aclose()
-
- async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
- bot: Optional[Bot] = None
-
- try:
- while True:
-
- setup = await self._prepare_setup(_setup)
- if not setup:
- await asyncio.sleep(3)
- continue
-
- url = httpx.URL(setup.url)
- if not url.netloc:
- logger.opt(colors=True).error(
- f"Error parsing url {escape_tag(str(url))}"
- )
- return
-
- logger.debug(
- f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
- )
- try:
- connection = Connect(setup.url, extra_headers=setup.headers)
- async with connection as ws:
- logger.opt(colors=True).info(
- f"WebSocket Connection to {escape_tag(setup.adapter.upper())} "
- f"Bot {escape_tag(setup.self_id)} succeeded!"
- )
- request = WebSocket(
- "1.1", url.scheme, url.path, url.query, setup.headers, ws
- )
-
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
- self._bot_connect(bot)
- while not self.shutdown.is_set():
- # use try except instead of "request.closed" because of queued message
- try:
- msg = await request.receive_bytes()
- asyncio.create_task(bot.handle_message(msg))
- except ConnectionClosed:
- logger.opt(colors=True).error(
- "WebSocket connection closed. "
- "Try to reconnect..."
- )
- break
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- f"Error while connecting to {url}. "
- "Try to reconnect..."
- )
- finally:
- if bot:
- self._bot_disconnect(bot)
- bot = None
-
- if not setup.reconnect:
- logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
- break
- await asyncio.sleep(setup.reconnect_interval)
-
- except asyncio.CancelledError:
- pass
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Unexpected exception occurred "
- "while websocket loop"
- )
+ async def websocket(self, setup: "HTTPConnection") -> Any:
+ ws = await Connect(setup.url, extra_headers=setup.headers)
+ return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
@dataclass
diff --git a/nonebot/message.py b/nonebot/message.py
index b36b3049..02863aa5 100644
--- a/nonebot/message.py
+++ b/nonebot/message.py
@@ -111,9 +111,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func
-# FIXME: run handler with try/except skipped exception
-
-
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
try:
return await coro
diff --git a/nonebot/permission.py b/nonebot/permission.py
index f1f863db..3f33dd5d 100644
--- a/nonebot/permission.py
+++ b/nonebot/permission.py
@@ -14,13 +14,12 @@ from contextlib import AsyncExitStack
from typing import (
Any,
Dict,
- List,
- Type,
Tuple,
Union,
Callable,
NoReturn,
Optional,
+ Coroutine,
)
from nonebot import params
@@ -30,6 +29,13 @@ from nonebot.exception import SkippedException
from nonebot.typing import T_PermissionChecker
+async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
+ try:
+ return await coro
+ except SkippedException:
+ return False
+
+
class Permission:
"""
:说明:
@@ -100,20 +106,18 @@ class Permission:
return True
results = await asyncio.gather(
*(
- checker(
- bot=bot,
- event=event,
- _stack=stack,
- _dependency_cache=dependency_cache,
+ _run_coro_with_catch(
+ checker(
+ bot=bot,
+ event=event,
+ _stack=stack,
+ _dependency_cache=dependency_cache,
+ )
)
for checker in self.checkers
),
- return_exceptions=True,
- )
- return next(
- filter(lambda x: bool(x) and not isinstance(x, SkippedException), results),
- False,
)
+ return any(results)
def __and__(self, other) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")