♻️ rewrite adapter abc class

This commit is contained in:
yanyongyu 2021-12-06 22:19:05 +08:00
parent 180aaadda9
commit d80c02ae46
7 changed files with 172 additions and 437 deletions

View File

@ -22,6 +22,7 @@ except Exception:
from ._bot import Bot as Bot
from ._event import Event as Event
from ._adapter import Adapter as Adapter
from ._message import Message as Message
from ._message import MessageSegment as MessageSegment
from ._template import MessageTemplate as MessageTemplate

View File

@ -0,0 +1,59 @@
import abc
from typing import Any, Dict
from ._bot import Bot
from nonebot.config import Config
from nonebot.drivers import (
Driver,
ForwardDriver,
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
)
class Adapter(abc.ABC):
def __init__(self, driver: Driver, **kwargs: Any):
self.driver = driver
self.bots: Dict[str, Bot] = {}
@classmethod
@abc.abstractmethod
def get_name(cls) -> str:
raise NotImplementedError
@property
def config(self) -> Config:
return self.driver.config
def bot_connect(self, bot: Bot):
self.driver._bot_connect(bot)
self.bots[bot.self_id] = bot
def bot_disconnect(self, bot: Bot):
self.driver._bot_disconnect(bot)
self.bots.pop(bot.self_id, None)
def setup_http_server(self, setup: HTTPServerSetup):
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support http server")
self.driver.setup_http_server(setup)
def setup_websocket_server(self, setup: WebSocketServerSetup):
if not isinstance(self.driver, ReverseDriver):
raise TypeError("Current driver does not support websocket server")
self.driver.setup_websocket_server(setup)
@abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数实现该方法以调用 api
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError

View File

@ -12,6 +12,7 @@ from nonebot.drivers import Driver, HTTPResponse, HTTPConnection
if TYPE_CHECKING:
from ._event import Event
from ._adapter import Adapter
from ._message import Message, MessageSegment
@ -25,10 +26,6 @@ class Bot(abc.ABC):
Bot 基类用于处理上报消息并提供 API 调用接口
"""
driver: Driver
"""Driver 对象"""
config: Config
"""Config 配置对象"""
_calling_api_hook: Set[T_CallingAPIHook] = set()
"""
:类型: ``Set[T_CallingAPIHook]``
@ -40,36 +37,27 @@ class Bot(abc.ABC):
:说明: call_api 后执行的函数
"""
def __init__(self, self_id: str, request: HTTPConnection):
def __init__(self, adapter: "Adapter", self_id: str):
"""
:参数:
* ``self_id: str``: 机器人 ID
* ``request: HTTPConnection``: request 连接对象
"""
self.adapter = adapter
self.self_id: str = self_id
"""机器人 ID"""
self.request: HTTPConnection = request
"""连接信息"""
def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name)
@property
@abc.abstractmethod
def type(self) -> str:
"""Adapter 类型"""
raise NotImplementedError
return self.adapter.get_name()
@classmethod
def register(cls, driver: Driver, config: Config, **kwargs):
"""
:说明:
``register`` 方法会在 ``driver.register_adapter`` 时被调用用于初始化相关配置
"""
cls.driver = driver
cls.config = config
@property
def config(self) -> Config:
return self.adapter.config
@classmethod
@abc.abstractmethod
@ -106,20 +94,6 @@ class Bot(abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any:
"""
:说明:
``adapter`` 实际调用 api 的逻辑实现函数实现该方法以调用 api
:参数:
* ``api: str``: API 名称
* ``**data``: API 数据
"""
raise NotImplementedError
async def call_api(self, api: str, **data: Any) -> Any:
"""
:说明:
@ -162,7 +136,7 @@ class Bot(abc.ABC):
if not skip_calling_api:
try:
result = await self._call_api(api, **data)
result = await self.adapter._call_api(api, **data)
except Exception as e:
exception = e

View File

@ -26,7 +26,7 @@ from nonebot.config import Env, Config
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
from nonebot.adapters import Bot
from nonebot.adapters import Bot, Adapter
class Driver(abc.ABC):
@ -34,9 +34,9 @@ class Driver(abc.ABC):
Driver 基类
"""
_adapters: Dict[str, Type["Bot"]] = {}
_adapters: Dict[str, "Adapter"] = {}
"""
:类型: ``Dict[str, Type[Bot]]``
:类型: ``Dict[str, Adapter]``
:说明: 已注册的适配器列表
"""
_bot_connection_hook: Set[T_BotConnectionHook] = set()
@ -85,7 +85,7 @@ class Driver(abc.ABC):
"""
return self._clients
def register_adapter(self, name: str, adapter: Type["Bot"], **kwargs):
def register_adapter(self, adapter: Type["Adapter"], **kwargs):
"""
:说明:
@ -97,13 +97,13 @@ class Driver(abc.ABC):
* ``adapter: Type[Bot]``: 适配器 Class
* ``**kwargs``: 其他传递给适配器的参数
"""
name = adapter.get_name()
if name in self._adapters:
logger.opt(colors=True).debug(
f'Adapter "<y>{escape_tag(name)}</y>" already exists'
)
return
self._adapters[name] = adapter
adapter.register(self, self.config, **kwargs)
self._adapters[name] = adapter(self, **kwargs)
logger.opt(colors=True).debug(
f'Succeeded to load adapter "<y>{escape_tag(name)}</y>"'
)
@ -213,34 +213,11 @@ class ForwardDriver(Driver):
"""
@abc.abstractmethod
def setup_http_polling(
self,
setup: Union["HTTPPollingSetup", Callable[[], Awaitable["HTTPPollingSetup"]]],
) -> None:
"""
:说明:
注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
async def request(self, setup: "HTTPRequest") -> Any:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket(
self, setup: Union["WebSocketSetup", Callable[[], Awaitable["WebSocketSetup"]]]
) -> None:
"""
:说明:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
async def websocket(self, setup: "HTTPConnection") -> Any:
raise NotImplementedError
@ -261,7 +238,16 @@ class ReverseDriver(Driver):
"""驱动 ASGI 对象"""
raise NotImplementedError
@abc.abstractmethod
def setup_http_server(self, setup: "HTTPServerSetup") -> None:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket_server(self, setup: "WebSocketServerSetup") -> None:
raise NotImplementedError
# TODO: repack dataclass
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
@ -401,36 +387,13 @@ class WebSocket(HTTPConnection, abc.ABC):
@dataclass
class HTTPPollingSetup:
adapter: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
class HTTPServerSetup:
path: str
method: str
"""HTTP method"""
body: bytes
"""HTTP body"""
headers: Dict[str, str]
"""HTTP headers"""
http_version: str
"""HTTP version"""
poll_interval: float
"""HTTP 轮询间隔"""
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
@dataclass
class WebSocketSetup:
adapter: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
headers: Dict[str, str] = field(default_factory=dict)
"""HTTP headers"""
reconnect: bool = True
"""WebSocket 是否重连"""
reconnect_interval: float = 3.0
"""WebSocket 重连间隔"""
class WebSocketServerSetup:
path: str
handle_func: Callable[[WebSocket], Awaitable[Any]]

View File

@ -12,38 +12,35 @@ FastAPI 驱动适配
import asyncio
import logging
from functools import partial
from dataclasses import dataclass
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
from typing import Any, List, Union, Callable, Optional, Awaitable
import httpx
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from websockets.exceptions import ConnectionClosed
from fastapi import FastAPI, Request, HTTPException, status
from starlette.websockets import WebSocketState
from fastapi import Depends, FastAPI, Request, status
from starlette.websockets import WebSocket as FastAPIWebSocket
from starlette.websockets import WebSocketState, WebSocketDisconnect
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env
from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers import WebSocket
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (
HTTPRequest,
HTTPResponse,
ForwardDriver,
ReverseDriver,
WebSocketSetup,
HTTPPollingSetup,
HTTPConnection,
HTTPServerSetup,
WebSocketServerSetup,
)
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
class Config(BaseSettings):
"""
@ -136,16 +133,7 @@ class Config(BaseSettings):
class Driver(ReverseDriver):
"""
FastAPI 驱动框架包含反向 Server 功能
:上报地址:
* ``/{adapter name}/``: HTTP POST 上报
* ``/{adapter name}/http/``: HTTP POST 上报
* ``/{adapter name}/ws``: WebSocket 上报
* ``/{adapter name}/ws/``: WebSocket 上报
"""
"""FastAPI 驱动框架。包含反向 Server 功能。"""
def __init__(self, env: Env, config: NoneBotConfig):
super(Driver, self).__init__(env, config)
@ -159,11 +147,6 @@ class Driver(ReverseDriver):
redoc_url=self.fastapi_config.fastapi_redoc_url,
)
self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http)
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
@property
@overrides(ReverseDriver)
def type(self) -> str:
@ -188,6 +171,30 @@ class Driver(ReverseDriver):
"""fastapi 使用的 logger"""
return logging.getLogger("fastapi")
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
def _get_handle_func():
return setup.handle_func
self._server_app.add_api_route(
setup.path,
partial(self._handle_http, handle_func=Depends(_get_handle_func)),
methods=[setup.method],
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
def _get_handle_func():
return setup.handle_func
self._server_app.add_api_websocket_route(
setup.path,
partial(
self._handle_ws,
handle_func=Depends(_get_handle_func),
),
)
@overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
@ -241,19 +248,11 @@ class Driver(ReverseDriver):
**kwargs,
)
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
)
# 创建 Bot 对象
BotClass = self._adapters[adapter]
async def _handle_http(
self,
request: Request,
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
):
http_request = HTTPRequest(
request.scope["http_version"],
request.url.scheme,
@ -261,28 +260,17 @@ class Driver(ReverseDriver):
request.scope["query_string"],
dict(request.headers),
request.method,
data,
)
x_self_id, response = await BotClass.check_permission(self, http_request)
if not x_self_id:
raise HTTPException(
response and response.status or 401,
response and response.body and response.body.decode("utf-8"),
await request.body(),
)
if x_self_id in self._clients:
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
response = await handle_func(http_request)
return Response(response.body, response.status, response.headers)
bot = BotClass(x_self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body, response and response.status or 200)
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
async def _handle_ws(
self,
websocket: FastAPIWebSocket,
handle_func: Callable[[WebSocket], Awaitable[Any]],
):
ws = WebSocket(
websocket.scope.get("http_version", "1.1"),
websocket.url.scheme,
@ -292,55 +280,7 @@ class Driver(ReverseDriver):
websocket,
)
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
# Create Bot Object
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, ws)
if not self_id:
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
if self_id in self._clients:
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
try:
while not ws.closed:
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
await handle_func(ws)
class FullDriver(ForwardDriver, Driver):
@ -354,17 +294,6 @@ class FullDriver(ForwardDriver, Driver):
DRIVER=nonebot.drivers.fastapi:FullDriver
"""
def __init__(self, env: Env, config: NoneBotConfig):
super(FullDriver, self).__init__(env, config)
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
@property
@overrides(ForwardDriver)
def type(self) -> str:
@ -372,217 +301,25 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full"
@overrides(ForwardDriver)
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
"""
:说明:
注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
@overrides(ForwardDriver)
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
async def _prepare_setup(
self, setup: Union[S, Callable[[], Awaitable[S]]]
) -> Optional[S]:
try:
if callable(setup):
return await setup()
else:
return setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
return
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
setup.headers,
setup.method,
setup.body,
)
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
http2: bool = False
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
client: Optional[httpx.AsyncClient] = None
# FIXME: seperate const values from setup (self_id, adapter)
# logger.opt(colors=True).info(
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
# f"Bot {escape_tag(_setup.self_id)}</y>"
# )
try:
while not self.shutdown.is_set():
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
request = self._build_http_request(setup)
if not request:
await asyncio.sleep(setup.poll_interval)
continue
if not client:
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
elif http2 != (setup.http_version == "2"):
await client.aclose()
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
http2 = setup.http_version == "2"
if not bot:
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
else:
bot.request = request
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
)
try:
async def request(self, setup: "HTTPRequest") -> Any:
async with httpx.AsyncClient(
http2=setup.http_version == "2", follow_redirects=True
) as client:
response = await client.request(
request.method,
setup.method,
setup.url,
content=request.body,
headers=request.headers,
content=setup.body,
headers=setup.headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
return HTTPResponse(
response.status_code, response.content, response.headers
)
await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
if client:
await client.aclose()
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
while True:
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, setup.headers, ws
)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
try:
msg = await request.receive_bytes()
asyncio.create_task(bot.handle_message(msg))
except ConnectionClosed:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed. "
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
bot = None
if not setup.reconnect:
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
break
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>"
)
@overrides(ForwardDriver)
async def websocket(self, setup: "HTTPConnection") -> Any:
ws = await Connect(setup.url, extra_headers=setup.headers)
return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
@dataclass

View File

@ -111,9 +111,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func
# FIXME: run handler with try/except skipped exception
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]) -> Any:
try:
return await coro

View File

@ -14,13 +14,12 @@ from contextlib import AsyncExitStack
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Callable,
NoReturn,
Optional,
Coroutine,
)
from nonebot import params
@ -30,6 +29,13 @@ from nonebot.exception import SkippedException
from nonebot.typing import T_PermissionChecker
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
try:
return await coro
except SkippedException:
return False
class Permission:
"""
:说明:
@ -100,20 +106,18 @@ class Permission:
return True
results = await asyncio.gather(
*(
_run_coro_with_catch(
checker(
bot=bot,
event=event,
_stack=stack,
_dependency_cache=dependency_cache,
)
)
for checker in self.checkers
),
return_exceptions=True,
)
return next(
filter(lambda x: bool(x) and not isinstance(x, SkippedException), results),
False,
)
return any(results)
def __and__(self, other) -> NoReturn:
raise RuntimeError("And operation between Permissions is not allowed.")