diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 92d63d75..78f7fc39 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -37,7 +37,7 @@ from nonebot.adapters import Bot from nonebot.utils import escape_tag from nonebot.config import Env, Config from nonebot.log import logger, default_filter -from nonebot.drivers import Driver, ReverseDriver +from nonebot.drivers import Driver, ReverseDriver, combine_driver try: _dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2") @@ -195,6 +195,35 @@ def get_bots() -> Dict[str, Bot]: return driver.bots +def _resolve_dot_notation( + obj_str: str, default_attr: str, default_prefix: Optional[str] = None +) -> Any: + modulename, _, cls = obj_str.partition(":") + if default_prefix is not None and modulename.startswith("~"): + modulename = default_prefix + modulename[1:] + module = importlib.import_module(modulename) + if not cls: + return getattr(module, default_attr) + instance = module + for attr_str in cls.split("."): + instance = getattr(instance, attr_str) + return instance + + +def _resolve_combine_expr(obj_str: str) -> Type[Driver]: + drivers = obj_str.split("+") + DriverClass = _resolve_dot_notation( + drivers[0], "Driver", default_prefix="nonebot.drivers." + ) + if len(drivers) == 1: + return DriverClass + mixins = [ + _resolve_dot_notation(mixin, "Mixin", default_prefix="nonebot.drivers.") + for mixin in drivers[1:] + ] + return combine_driver(DriverClass, *mixins) + + def init(*, _env_file: Optional[str] = None, **kwargs): """ :说明: @@ -243,12 +272,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs): f"Loaded Config: {escape_tag(str(config.dict()))}" ) - modulename, _, cls = config.driver.partition(":") - module = importlib.import_module(modulename) - instance = module - for attr_str in (cls or "Driver").split("."): - instance = getattr(instance, attr_str) - DriverClass: Type[Driver] = instance # type: ignore + DriverClass: Type[Driver] = _resolve_combine_expr(config.driver) _driver = DriverClass(env, config) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 44868897..079534ce 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -252,7 +252,13 @@ class ReverseDriver(Driver): def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]: - class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore + # check first + assert issubclass(driver, Driver), "`driver` must be subclass of Driver" + assert all( + map(lambda m: issubclass(m, ForwardMixin), mixins) + ), "`mixins` must be subclass of ForwardMixin" + + class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore @property def type(self) -> str: return ( diff --git a/nonebot/drivers/_block_driver.py b/nonebot/drivers/_block_driver.py index afd7f8b3..0e8f34f8 100644 --- a/nonebot/drivers/_block_driver.py +++ b/nonebot/drivers/_block_driver.py @@ -4,9 +4,9 @@ import threading from typing import Set, Callable, Awaitable from nonebot.log import logger +from nonebot.drivers import Driver from nonebot.typing import overrides from nonebot.config import Env, Config -from nonebot.drivers import ForwardDriver STARTUP_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]] @@ -16,11 +16,7 @@ HANDLED_SIGNALS = ( ) -class BlockDriver(ForwardDriver): - """ - AIOHTTP 驱动框架 - """ - +class BlockDriver(Driver): def __init__(self, env: Env, config: Config): super().__init__(env, config) self.startup_funcs: Set[STARTUP_FUNC] = set() @@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver): self.force_exit: bool = False @property - @overrides(ForwardDriver) + @overrides(Driver) def type(self) -> str: """驱动名称: ``block_driver``""" return "block_driver" @property - @overrides(ForwardDriver) + @overrides(Driver) def logger(self): """block driver 使用的 logger""" return logger - @overrides(ForwardDriver) + @overrides(Driver) def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC: """ :说明: @@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver): self.startup_funcs.add(func) return func - @overrides(ForwardDriver) + @overrides(Driver) def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC: """ :说明: @@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver): self.shutdown_funcs.add(func) return func - @overrides(ForwardDriver) + @overrides(Driver) def run(self, *args, **kwargs): """启动 block driver""" super().run(*args, **kwargs) diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 031e3b57..f935342c 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -19,7 +19,7 @@ except ImportError: ) from None -class AiohttpMixin(ForwardMixin): +class Mixin(ForwardMixin): @property @overrides(ForwardMixin) def type(self) -> str: @@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket): await self.websocket.send_bytes(data) -Driver = combine_driver(BlockDriver, AiohttpMixin) +Driver = combine_driver(BlockDriver, Mixin) diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 0e36bdda..41bd6c71 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState from nonebot.config import Env from nonebot.typing import overrides from nonebot.utils import escape_tag -from nonebot.drivers.httpx import HttpxMixin from nonebot.config import Config as NoneBotConfig from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket -from nonebot.drivers.websockets import WebSocketsMixin -from nonebot.drivers import ( - ReverseDriver, - HTTPServerSetup, - WebSocketServerSetup, - combine_driver, -) - -try: - from nonebot.drivers.aiohttp import AiohttpMixin -except ImportError: - AiohttpMixin = None +from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup class Config(BaseSettings): @@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: await self.websocket.send({"type": "websocket.send", "bytes": data}) - - -FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin) -if AiohttpMixin: - AiohttpDriver = combine_driver(Driver, AiohttpMixin) diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 1965c439..09234710 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -1,5 +1,3 @@ -import httpx - from nonebot.typing import overrides from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers import ( @@ -11,8 +9,13 @@ from nonebot.drivers import ( combine_driver, ) +try: + import httpx +except ImportError: + raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`") -class HttpxMixin(ForwardMixin): + +class Mixin(ForwardMixin): @property @overrides(ForwardMixin) def type(self) -> str: @@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin): @overrides(ForwardMixin) async def websocket(self, setup: Request) -> WebSocket: - return await super(HttpxMixin, self).websocket(setup) + return await super(Mixin, self).websocket(setup) -Driver = combine_driver(BlockDriver, HttpxMixin) +Driver = combine_driver(BlockDriver, Mixin) diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index a56d4767..89822ee6 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -17,17 +17,10 @@ from nonebot.config import Env from nonebot.log import logger from nonebot.typing import overrides from nonebot.utils import escape_tag -from nonebot.drivers.httpx import HttpxMixin from nonebot.config import Config as NoneBotConfig from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket -from nonebot.drivers.websockets import WebSocketsMixin -from nonebot.drivers import ( - ReverseDriver, - HTTPServerSetup, - WebSocketServerSetup, - combine_driver, -) +from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup try: from quart import request as _request @@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def send_bytes(self, data: bytes): await self.websocket.send(data) - - -FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin) diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 7dbd6399..81903698 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -1,7 +1,5 @@ import logging -from websockets.legacy.client import Connect, WebSocketClientProtocol - from nonebot.typing import overrides from nonebot.log import LoguruHandler from nonebot.drivers import Request, Response @@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import ForwardMixin, combine_driver +try: + from websockets.legacy.client import Connect, WebSocketClientProtocol +except ImportError: + raise ImportError( + "Please install websockets by using `pip install nonebot2[websockets]`" + ) + logger = logging.Logger("websockets.client", "INFO") logger.addHandler(LoguruHandler()) -class WebSocketsMixin(ForwardMixin): +class Mixin(ForwardMixin): @property @overrides(ForwardMixin) def type(self) -> str: @@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin): @overrides(ForwardMixin) async def request(self, setup: Request) -> Response: - return await super(WebSocketsMixin, self).request(setup) + return await super(Mixin, self).request(setup) @overrides(ForwardMixin) async def websocket(self, setup: Request) -> "WebSocket": @@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket): await self.websocket.send(data) -Driver = combine_driver(BlockDriver, WebSocketsMixin) +Driver = combine_driver(BlockDriver, Mixin) diff --git a/nonebot/rule.py b/nonebot/rule.py index 48c60963..8224e31f 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -422,7 +422,7 @@ def shell_command( 命令内容与后续消息间无需空格! \:\:\: """ - if not isinstance(parser, ArgumentParser): + if parser is not None and not isinstance(parser, ArgumentParser): raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser") config = get_driver().config diff --git a/pyproject.toml b/pyproject.toml index b6d94cb6..c687ce18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,13 +28,13 @@ loguru = "^0.5.1" pygtrie = "^2.4.1" tomlkit = "^0.7.0" fastapi = "^0.70.0" -websockets = ">=9.1" typing-extensions = ">=3.10.0,<5.0.0" Quart = { version = "^0.16.0", optional = true } -httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] } +websockets = { version=">=9.1", optional = true } pydantic = { version = "~1.8.0", extras = ["dotenv"] } uvicorn = { version = "^0.15.0", extras = ["standard"] } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } +httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true } [tool.poetry.dev-dependencies] sphinx = "^4.1.1" @@ -46,8 +46,10 @@ sphinx-markdown-builder = { git = "https://github.com/nonebot/sphinx-markdown-bu [tool.poetry.extras] quart = ["quart"] +httpx = ["httpx"] aiohttp = ["aiohttp"] -all = ["quart", "aiohttp"] +websockets = ["websockets"] +all = ["quart", "aiohttp", "httpx", "websockets"] # [[tool.poetry.source]] # name = "aliyun" diff --git a/tests/test_init.py b/tests/test_init.py index 2018fb75..44411fb4 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -15,18 +15,24 @@ os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}' @pytest.mark.asyncio @pytest.mark.parametrize( "nonebug_init", - [{"config_from_init": "init", "driver": "nonebot.drivers.fastapi:FullDriver"}], + [ + { + "config_from_init": "init", + "driver": "nonebot.drivers.fastapi+nonebot.drivers.httpx+nonebot.drivers.websockets", + }, + { + "config_from_init": "init", + "driver": "~fastapi+~httpx+~websockets", + }, + ], indirect=True, ) async def test_init(nonebug_init): from nonebot import get_driver - from nonebot.drivers.fastapi import FullDriver env = get_driver().env assert env == "test" - assert isinstance(get_driver(), FullDriver) - config = get_driver().config assert config.config_from_env == {"test": "test"} assert config.config_from_init == "init"