mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
⚗️ new driver combine expr support
This commit is contained in:
parent
b9f1890d80
commit
8fb394e4c3
@ -37,7 +37,7 @@ from nonebot.adapters import Bot
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.log import logger, default_filter
|
||||
from nonebot.drivers import Driver, ReverseDriver
|
||||
from nonebot.drivers import Driver, ReverseDriver, combine_driver
|
||||
|
||||
try:
|
||||
_dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
|
||||
@ -195,6 +195,35 @@ def get_bots() -> Dict[str, Bot]:
|
||||
return driver.bots
|
||||
|
||||
|
||||
def _resolve_dot_notation(
|
||||
obj_str: str, default_attr: str, default_prefix: Optional[str] = None
|
||||
) -> Any:
|
||||
modulename, _, cls = obj_str.partition(":")
|
||||
if default_prefix is not None and modulename.startswith("~"):
|
||||
modulename = default_prefix + modulename[1:]
|
||||
module = importlib.import_module(modulename)
|
||||
if not cls:
|
||||
return getattr(module, default_attr)
|
||||
instance = module
|
||||
for attr_str in cls.split("."):
|
||||
instance = getattr(instance, attr_str)
|
||||
return instance
|
||||
|
||||
|
||||
def _resolve_combine_expr(obj_str: str) -> Type[Driver]:
|
||||
drivers = obj_str.split("+")
|
||||
DriverClass = _resolve_dot_notation(
|
||||
drivers[0], "Driver", default_prefix="nonebot.drivers."
|
||||
)
|
||||
if len(drivers) == 1:
|
||||
return DriverClass
|
||||
mixins = [
|
||||
_resolve_dot_notation(mixin, "Mixin", default_prefix="nonebot.drivers.")
|
||||
for mixin in drivers[1:]
|
||||
]
|
||||
return combine_driver(DriverClass, *mixins)
|
||||
|
||||
|
||||
def init(*, _env_file: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
:说明:
|
||||
@ -243,12 +272,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
|
||||
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
|
||||
)
|
||||
|
||||
modulename, _, cls = config.driver.partition(":")
|
||||
module = importlib.import_module(modulename)
|
||||
instance = module
|
||||
for attr_str in (cls or "Driver").split("."):
|
||||
instance = getattr(instance, attr_str)
|
||||
DriverClass: Type[Driver] = instance # type: ignore
|
||||
DriverClass: Type[Driver] = _resolve_combine_expr(config.driver)
|
||||
_driver = DriverClass(env, config)
|
||||
|
||||
|
||||
|
@ -252,7 +252,13 @@ class ReverseDriver(Driver):
|
||||
|
||||
|
||||
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||
class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore
|
||||
# check first
|
||||
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
|
||||
assert all(
|
||||
map(lambda m: issubclass(m, ForwardMixin), mixins)
|
||||
), "`mixins` must be subclass of ForwardMixin"
|
||||
|
||||
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return (
|
||||
|
@ -4,9 +4,9 @@ import threading
|
||||
from typing import Set, Callable, Awaitable
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.drivers import Driver
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import ForwardDriver
|
||||
|
||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||
@ -16,11 +16,7 @@ HANDLED_SIGNALS = (
|
||||
)
|
||||
|
||||
|
||||
class BlockDriver(ForwardDriver):
|
||||
"""
|
||||
AIOHTTP 驱动框架
|
||||
"""
|
||||
|
||||
class BlockDriver(Driver):
|
||||
def __init__(self, env: Env, config: Config):
|
||||
super().__init__(env, config)
|
||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||
@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver):
|
||||
self.force_exit: bool = False
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def type(self) -> str:
|
||||
"""驱动名称: ``block_driver``"""
|
||||
return "block_driver"
|
||||
|
||||
@property
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def logger(self):
|
||||
"""block driver 使用的 logger"""
|
||||
return logger
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver):
|
||||
self.startup_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||
"""
|
||||
:说明:
|
||||
@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver):
|
||||
self.shutdown_funcs.add(func)
|
||||
return func
|
||||
|
||||
@overrides(ForwardDriver)
|
||||
@overrides(Driver)
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 block driver"""
|
||||
super().run(*args, **kwargs)
|
||||
|
@ -19,7 +19,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class AiohttpMixin(ForwardMixin):
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.send_bytes(data)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, AiohttpMixin)
|
||||
Driver = combine_driver(BlockDriver, Mixin)
|
||||
|
@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers.httpx import HttpxMixin
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.websockets import WebSocketsMixin
|
||||
from nonebot.drivers import (
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
try:
|
||||
from nonebot.drivers.aiohttp import AiohttpMixin
|
||||
except ImportError:
|
||||
AiohttpMixin = None
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
|
||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||
if AiohttpMixin:
|
||||
AiohttpDriver = combine_driver(Driver, AiohttpMixin)
|
||||
|
@ -1,5 +1,3 @@
|
||||
import httpx
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import (
|
||||
@ -11,8 +9,13 @@ from nonebot.drivers import (
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`")
|
||||
|
||||
class HttpxMixin(ForwardMixin):
|
||||
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin):
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> WebSocket:
|
||||
return await super(HttpxMixin, self).websocket(setup)
|
||||
return await super(Mixin, self).websocket(setup)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, HttpxMixin)
|
||||
Driver = combine_driver(BlockDriver, Mixin)
|
||||
|
@ -17,17 +17,10 @@ from nonebot.config import Env
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
from nonebot.drivers.httpx import HttpxMixin
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers.websockets import WebSocketsMixin
|
||||
from nonebot.drivers import (
|
||||
ReverseDriver,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
combine_driver,
|
||||
)
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
try:
|
||||
from quart import request as _request
|
||||
@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
async def send_bytes(self, data: bytes):
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
||||
|
@ -1,7 +1,5 @@
|
||||
import logging
|
||||
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.log import LoguruHandler
|
||||
from nonebot.drivers import Request, Response
|
||||
@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ForwardMixin, combine_driver
|
||||
|
||||
try:
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install websockets by using `pip install nonebot2[websockets]`"
|
||||
)
|
||||
|
||||
logger = logging.Logger("websockets.client", "INFO")
|
||||
logger.addHandler(LoguruHandler())
|
||||
|
||||
|
||||
class WebSocketsMixin(ForwardMixin):
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
def type(self) -> str:
|
||||
@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin):
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def request(self, setup: Request) -> Response:
|
||||
return await super(WebSocketsMixin, self).request(setup)
|
||||
return await super(Mixin, self).request(setup)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
async def websocket(self, setup: Request) -> "WebSocket":
|
||||
@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
||||
Driver = combine_driver(BlockDriver, WebSocketsMixin)
|
||||
Driver = combine_driver(BlockDriver, Mixin)
|
||||
|
@ -422,7 +422,7 @@ def shell_command(
|
||||
命令内容与后续消息间无需空格!
|
||||
\:\:\:
|
||||
"""
|
||||
if not isinstance(parser, ArgumentParser):
|
||||
if parser is not None and not isinstance(parser, ArgumentParser):
|
||||
raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
|
||||
|
||||
config = get_driver().config
|
||||
|
@ -28,13 +28,13 @@ loguru = "^0.5.1"
|
||||
pygtrie = "^2.4.1"
|
||||
tomlkit = "^0.7.0"
|
||||
fastapi = "^0.70.0"
|
||||
websockets = ">=9.1"
|
||||
typing-extensions = ">=3.10.0,<5.0.0"
|
||||
Quart = { version = "^0.16.0", optional = true }
|
||||
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
|
||||
websockets = { version=">=9.1", optional = true }
|
||||
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
||||
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
||||
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
||||
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
sphinx = "^4.1.1"
|
||||
@ -46,8 +46,10 @@ sphinx-markdown-builder = { git = "https://github.com/nonebot/sphinx-markdown-bu
|
||||
|
||||
[tool.poetry.extras]
|
||||
quart = ["quart"]
|
||||
httpx = ["httpx"]
|
||||
aiohttp = ["aiohttp"]
|
||||
all = ["quart", "aiohttp"]
|
||||
websockets = ["websockets"]
|
||||
all = ["quart", "aiohttp", "httpx", "websockets"]
|
||||
|
||||
# [[tool.poetry.source]]
|
||||
# name = "aliyun"
|
||||
|
@ -15,18 +15,24 @@ os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"nonebug_init",
|
||||
[{"config_from_init": "init", "driver": "nonebot.drivers.fastapi:FullDriver"}],
|
||||
[
|
||||
{
|
||||
"config_from_init": "init",
|
||||
"driver": "nonebot.drivers.fastapi+nonebot.drivers.httpx+nonebot.drivers.websockets",
|
||||
},
|
||||
{
|
||||
"config_from_init": "init",
|
||||
"driver": "~fastapi+~httpx+~websockets",
|
||||
},
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_init(nonebug_init):
|
||||
from nonebot import get_driver
|
||||
from nonebot.drivers.fastapi import FullDriver
|
||||
|
||||
env = get_driver().env
|
||||
assert env == "test"
|
||||
|
||||
assert isinstance(get_driver(), FullDriver)
|
||||
|
||||
config = get_driver().config
|
||||
assert config.config_from_env == {"test": "test"}
|
||||
assert config.config_from_init == "init"
|
||||
|
Loading…
Reference in New Issue
Block a user