⚗️ new driver combine expr support

This commit is contained in:
yanyongyu 2021-12-23 17:20:26 +08:00
parent b9f1890d80
commit 8fb394e4c3
11 changed files with 83 additions and 68 deletions

View File

@ -37,7 +37,7 @@ from nonebot.adapters import Bot
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.log import logger, default_filter from nonebot.log import logger, default_filter
from nonebot.drivers import Driver, ReverseDriver from nonebot.drivers import Driver, ReverseDriver, combine_driver
try: try:
_dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2") _dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
@ -195,6 +195,35 @@ def get_bots() -> Dict[str, Bot]:
return driver.bots return driver.bots
def _resolve_dot_notation(
obj_str: str, default_attr: str, default_prefix: Optional[str] = None
) -> Any:
modulename, _, cls = obj_str.partition(":")
if default_prefix is not None and modulename.startswith("~"):
modulename = default_prefix + modulename[1:]
module = importlib.import_module(modulename)
if not cls:
return getattr(module, default_attr)
instance = module
for attr_str in cls.split("."):
instance = getattr(instance, attr_str)
return instance
def _resolve_combine_expr(obj_str: str) -> Type[Driver]:
drivers = obj_str.split("+")
DriverClass = _resolve_dot_notation(
drivers[0], "Driver", default_prefix="nonebot.drivers."
)
if len(drivers) == 1:
return DriverClass
mixins = [
_resolve_dot_notation(mixin, "Mixin", default_prefix="nonebot.drivers.")
for mixin in drivers[1:]
]
return combine_driver(DriverClass, *mixins)
def init(*, _env_file: Optional[str] = None, **kwargs): def init(*, _env_file: Optional[str] = None, **kwargs):
""" """
:说明: :说明:
@ -243,12 +272,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}" f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
) )
modulename, _, cls = config.driver.partition(":") DriverClass: Type[Driver] = _resolve_combine_expr(config.driver)
module = importlib.import_module(modulename)
instance = module
for attr_str in (cls or "Driver").split("."):
instance = getattr(instance, attr_str)
DriverClass: Type[Driver] = instance # type: ignore
_driver = DriverClass(env, config) _driver = DriverClass(env, config)

View File

@ -252,7 +252,13 @@ class ReverseDriver(Driver):
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]: def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore # check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
map(lambda m: issubclass(m, ForwardMixin), mixins)
), "`mixins` must be subclass of ForwardMixin"
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
@property @property
def type(self) -> str: def type(self) -> str:
return ( return (

View File

@ -4,9 +4,9 @@ import threading
from typing import Set, Callable, Awaitable from typing import Set, Callable, Awaitable
from nonebot.log import logger from nonebot.log import logger
from nonebot.drivers import Driver
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import ForwardDriver
STARTUP_FUNC = Callable[[], Awaitable[None]] STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
@ -16,11 +16,7 @@ HANDLED_SIGNALS = (
) )
class BlockDriver(ForwardDriver): class BlockDriver(Driver):
"""
AIOHTTP 驱动框架
"""
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(env, config) super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set() self.startup_funcs: Set[STARTUP_FUNC] = set()
@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver):
self.force_exit: bool = False self.force_exit: bool = False
@property @property
@overrides(ForwardDriver) @overrides(Driver)
def type(self) -> str: def type(self) -> str:
"""驱动名称: ``block_driver``""" """驱动名称: ``block_driver``"""
return "block_driver" return "block_driver"
@property @property
@overrides(ForwardDriver) @overrides(Driver)
def logger(self): def logger(self):
"""block driver 使用的 logger""" """block driver 使用的 logger"""
return logger return logger
@overrides(ForwardDriver) @overrides(Driver)
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC: def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
""" """
:说明: :说明:
@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver):
self.startup_funcs.add(func) self.startup_funcs.add(func)
return func return func
@overrides(ForwardDriver) @overrides(Driver)
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC: def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
""" """
:说明: :说明:
@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver):
self.shutdown_funcs.add(func) self.shutdown_funcs.add(func)
return func return func
@overrides(ForwardDriver) @overrides(Driver)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
"""启动 block driver""" """启动 block driver"""
super().run(*args, **kwargs) super().run(*args, **kwargs)

View File

@ -19,7 +19,7 @@ except ImportError:
) from None ) from None
class AiohttpMixin(ForwardMixin): class Mixin(ForwardMixin):
@property @property
@overrides(ForwardMixin) @overrides(ForwardMixin)
def type(self) -> str: def type(self) -> str:
@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket):
await self.websocket.send_bytes(data) await self.websocket.send_bytes(data)
Driver = combine_driver(BlockDriver, AiohttpMixin) Driver = combine_driver(BlockDriver, Mixin)

View File

@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState
from nonebot.config import Env from nonebot.config import Env
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
try:
from nonebot.drivers.aiohttp import AiohttpMixin
except ImportError:
AiohttpMixin = None
class Config(BaseSettings): class Config(BaseSettings):
@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None: async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data}) await self.websocket.send({"type": "websocket.send", "bytes": data})
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
if AiohttpMixin:
AiohttpDriver = combine_driver(Driver, AiohttpMixin)

View File

@ -1,5 +1,3 @@
import httpx
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import ( from nonebot.drivers import (
@ -11,8 +9,13 @@ from nonebot.drivers import (
combine_driver, combine_driver,
) )
try:
import httpx
except ImportError:
raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`")
class HttpxMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property @property
@overrides(ForwardMixin) @overrides(ForwardMixin)
def type(self) -> str: def type(self) -> str:
@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin):
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def websocket(self, setup: Request) -> WebSocket: async def websocket(self, setup: Request) -> WebSocket:
return await super(HttpxMixin, self).websocket(setup) return await super(Mixin, self).websocket(setup)
Driver = combine_driver(BlockDriver, HttpxMixin) Driver = combine_driver(BlockDriver, Mixin)

View File

@ -17,17 +17,10 @@ from nonebot.config import Env
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
try: try:
from quart import request as _request from quart import request as _request
@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send_bytes(self, data: bytes): async def send_bytes(self, data: bytes):
await self.websocket.send(data) await self.websocket.send(data)
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)

View File

@ -1,7 +1,5 @@
import logging import logging
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.log import LoguruHandler from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response from nonebot.drivers import Request, Response
@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, combine_driver from nonebot.drivers import ForwardMixin, combine_driver
try:
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError:
raise ImportError(
"Please install websockets by using `pip install nonebot2[websockets]`"
)
logger = logging.Logger("websockets.client", "INFO") logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler()) logger.addHandler(LoguruHandler())
class WebSocketsMixin(ForwardMixin): class Mixin(ForwardMixin):
@property @property
@overrides(ForwardMixin) @overrides(ForwardMixin)
def type(self) -> str: def type(self) -> str:
@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin):
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def request(self, setup: Request) -> Response: async def request(self, setup: Request) -> Response:
return await super(WebSocketsMixin, self).request(setup) return await super(Mixin, self).request(setup)
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket": async def websocket(self, setup: Request) -> "WebSocket":
@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket):
await self.websocket.send(data) await self.websocket.send(data)
Driver = combine_driver(BlockDriver, WebSocketsMixin) Driver = combine_driver(BlockDriver, Mixin)

View File

@ -422,7 +422,7 @@ def shell_command(
命令内容与后续消息间无需空格 命令内容与后续消息间无需空格
\:\:\: \:\:\:
""" """
if not isinstance(parser, ArgumentParser): if parser is not None and not isinstance(parser, ArgumentParser):
raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser") raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
config = get_driver().config config = get_driver().config

View File

@ -28,13 +28,13 @@ loguru = "^0.5.1"
pygtrie = "^2.4.1" pygtrie = "^2.4.1"
tomlkit = "^0.7.0" tomlkit = "^0.7.0"
fastapi = "^0.70.0" fastapi = "^0.70.0"
websockets = ">=9.1"
typing-extensions = ">=3.10.0,<5.0.0" typing-extensions = ">=3.10.0,<5.0.0"
Quart = { version = "^0.16.0", optional = true } Quart = { version = "^0.16.0", optional = true }
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] } websockets = { version=">=9.1", optional = true }
pydantic = { version = "~1.8.0", extras = ["dotenv"] } pydantic = { version = "~1.8.0", extras = ["dotenv"] }
uvicorn = { version = "^0.15.0", extras = ["standard"] } uvicorn = { version = "^0.15.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
sphinx = "^4.1.1" sphinx = "^4.1.1"
@ -46,8 +46,10 @@ sphinx-markdown-builder = { git = "https://github.com/nonebot/sphinx-markdown-bu
[tool.poetry.extras] [tool.poetry.extras]
quart = ["quart"] quart = ["quart"]
httpx = ["httpx"]
aiohttp = ["aiohttp"] aiohttp = ["aiohttp"]
all = ["quart", "aiohttp"] websockets = ["websockets"]
all = ["quart", "aiohttp", "httpx", "websockets"]
# [[tool.poetry.source]] # [[tool.poetry.source]]
# name = "aliyun" # name = "aliyun"

View File

@ -15,18 +15,24 @@ os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"nonebug_init", "nonebug_init",
[{"config_from_init": "init", "driver": "nonebot.drivers.fastapi:FullDriver"}], [
{
"config_from_init": "init",
"driver": "nonebot.drivers.fastapi+nonebot.drivers.httpx+nonebot.drivers.websockets",
},
{
"config_from_init": "init",
"driver": "~fastapi+~httpx+~websockets",
},
],
indirect=True, indirect=True,
) )
async def test_init(nonebug_init): async def test_init(nonebug_init):
from nonebot import get_driver from nonebot import get_driver
from nonebot.drivers.fastapi import FullDriver
env = get_driver().env env = get_driver().env
assert env == "test" assert env == "test"
assert isinstance(get_driver(), FullDriver)
config = get_driver().config config = get_driver().config
assert config.config_from_env == {"test": "test"} assert config.config_from_env == {"test": "test"}
assert config.config_from_init == "init" assert config.config_from_init == "init"