⚗️ new driver combine expr support

This commit is contained in:
yanyongyu 2021-12-23 17:20:26 +08:00
parent b9f1890d80
commit 8fb394e4c3
11 changed files with 83 additions and 68 deletions

View File

@ -37,7 +37,7 @@ from nonebot.adapters import Bot
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from nonebot.log import logger, default_filter
from nonebot.drivers import Driver, ReverseDriver
from nonebot.drivers import Driver, ReverseDriver, combine_driver
try:
_dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
@ -195,6 +195,35 @@ def get_bots() -> Dict[str, Bot]:
return driver.bots
def _resolve_dot_notation(
obj_str: str, default_attr: str, default_prefix: Optional[str] = None
) -> Any:
modulename, _, cls = obj_str.partition(":")
if default_prefix is not None and modulename.startswith("~"):
modulename = default_prefix + modulename[1:]
module = importlib.import_module(modulename)
if not cls:
return getattr(module, default_attr)
instance = module
for attr_str in cls.split("."):
instance = getattr(instance, attr_str)
return instance
def _resolve_combine_expr(obj_str: str) -> Type[Driver]:
drivers = obj_str.split("+")
DriverClass = _resolve_dot_notation(
drivers[0], "Driver", default_prefix="nonebot.drivers."
)
if len(drivers) == 1:
return DriverClass
mixins = [
_resolve_dot_notation(mixin, "Mixin", default_prefix="nonebot.drivers.")
for mixin in drivers[1:]
]
return combine_driver(DriverClass, *mixins)
def init(*, _env_file: Optional[str] = None, **kwargs):
"""
:说明:
@ -243,12 +272,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
)
modulename, _, cls = config.driver.partition(":")
module = importlib.import_module(modulename)
instance = module
for attr_str in (cls or "Driver").split("."):
instance = getattr(instance, attr_str)
DriverClass: Type[Driver] = instance # type: ignore
DriverClass: Type[Driver] = _resolve_combine_expr(config.driver)
_driver = DriverClass(env, config)

View File

@ -252,7 +252,13 @@ class ReverseDriver(Driver):
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore
# check first
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
assert all(
map(lambda m: issubclass(m, ForwardMixin), mixins)
), "`mixins` must be subclass of ForwardMixin"
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
@property
def type(self) -> str:
return (

View File

@ -4,9 +4,9 @@ import threading
from typing import Set, Callable, Awaitable
from nonebot.log import logger
from nonebot.drivers import Driver
from nonebot.typing import overrides
from nonebot.config import Env, Config
from nonebot.drivers import ForwardDriver
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
@ -16,11 +16,7 @@ HANDLED_SIGNALS = (
)
class BlockDriver(ForwardDriver):
"""
AIOHTTP 驱动框架
"""
class BlockDriver(Driver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver):
self.force_exit: bool = False
@property
@overrides(ForwardDriver)
@overrides(Driver)
def type(self) -> str:
"""驱动名称: ``block_driver``"""
return "block_driver"
@property
@overrides(ForwardDriver)
@overrides(Driver)
def logger(self):
"""block driver 使用的 logger"""
return logger
@overrides(ForwardDriver)
@overrides(Driver)
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
"""
:说明:
@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver):
self.startup_funcs.add(func)
return func
@overrides(ForwardDriver)
@overrides(Driver)
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
"""
:说明:
@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver):
self.shutdown_funcs.add(func)
return func
@overrides(ForwardDriver)
@overrides(Driver)
def run(self, *args, **kwargs):
"""启动 block driver"""
super().run(*args, **kwargs)

View File

@ -19,7 +19,7 @@ except ImportError:
) from None
class AiohttpMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket):
await self.websocket.send_bytes(data)
Driver = combine_driver(BlockDriver, AiohttpMixin)
Driver = combine_driver(BlockDriver, Mixin)

View File

@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
try:
from nonebot.drivers.aiohttp import AiohttpMixin
except ImportError:
AiohttpMixin = None
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
class Config(BaseSettings):
@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
if AiohttpMixin:
AiohttpDriver = combine_driver(Driver, AiohttpMixin)

View File

@ -1,5 +1,3 @@
import httpx
from nonebot.typing import overrides
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import (
@ -11,8 +9,13 @@ from nonebot.drivers import (
combine_driver,
)
try:
import httpx
except ImportError:
raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`")
class HttpxMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin):
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> WebSocket:
return await super(HttpxMixin, self).websocket(setup)
return await super(Mixin, self).websocket(setup)
Driver = combine_driver(BlockDriver, HttpxMixin)
Driver = combine_driver(BlockDriver, Mixin)

View File

@ -17,17 +17,10 @@ from nonebot.config import Env
from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers.httpx import HttpxMixin
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.websockets import WebSocketsMixin
from nonebot.drivers import (
ReverseDriver,
HTTPServerSetup,
WebSocketServerSetup,
combine_driver,
)
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
try:
from quart import request as _request
@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes):
await self.websocket.send(data)
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)

View File

@ -1,7 +1,5 @@
import logging
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, combine_driver
try:
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError:
raise ImportError(
"Please install websockets by using `pip install nonebot2[websockets]`"
)
logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
class WebSocketsMixin(ForwardMixin):
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
def type(self) -> str:
@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin):
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
return await super(WebSocketsMixin, self).request(setup)
return await super(Mixin, self).request(setup)
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket):
await self.websocket.send(data)
Driver = combine_driver(BlockDriver, WebSocketsMixin)
Driver = combine_driver(BlockDriver, Mixin)

View File

@ -422,7 +422,7 @@ def shell_command(
命令内容与后续消息间无需空格
\:\:\:
"""
if not isinstance(parser, ArgumentParser):
if parser is not None and not isinstance(parser, ArgumentParser):
raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
config = get_driver().config

View File

@ -28,13 +28,13 @@ loguru = "^0.5.1"
pygtrie = "^2.4.1"
tomlkit = "^0.7.0"
fastapi = "^0.70.0"
websockets = ">=9.1"
typing-extensions = ">=3.10.0,<5.0.0"
Quart = { version = "^0.16.0", optional = true }
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
websockets = { version=">=9.1", optional = true }
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
uvicorn = { version = "^0.15.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }
[tool.poetry.dev-dependencies]
sphinx = "^4.1.1"
@ -46,8 +46,10 @@ sphinx-markdown-builder = { git = "https://github.com/nonebot/sphinx-markdown-bu
[tool.poetry.extras]
quart = ["quart"]
httpx = ["httpx"]
aiohttp = ["aiohttp"]
all = ["quart", "aiohttp"]
websockets = ["websockets"]
all = ["quart", "aiohttp", "httpx", "websockets"]
# [[tool.poetry.source]]
# name = "aliyun"

View File

@ -15,18 +15,24 @@ os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
@pytest.mark.asyncio
@pytest.mark.parametrize(
"nonebug_init",
[{"config_from_init": "init", "driver": "nonebot.drivers.fastapi:FullDriver"}],
[
{
"config_from_init": "init",
"driver": "nonebot.drivers.fastapi+nonebot.drivers.httpx+nonebot.drivers.websockets",
},
{
"config_from_init": "init",
"driver": "~fastapi+~httpx+~websockets",
},
],
indirect=True,
)
async def test_init(nonebug_init):
from nonebot import get_driver
from nonebot.drivers.fastapi import FullDriver
env = get_driver().env
assert env == "test"
assert isinstance(get_driver(), FullDriver)
config = get_driver().config
assert config.config_from_env == {"test": "test"}
assert config.config_from_init == "init"