mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-30 17:15:08 +08:00
⚗️ new driver combine expr support
This commit is contained in:
parent
b9f1890d80
commit
8fb394e4c3
@ -37,7 +37,7 @@ from nonebot.adapters import Bot
|
|||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.log import logger, default_filter
|
from nonebot.log import logger, default_filter
|
||||||
from nonebot.drivers import Driver, ReverseDriver
|
from nonebot.drivers import Driver, ReverseDriver, combine_driver
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
|
_dist: pkg_resources.Distribution = pkg_resources.get_distribution("nonebot2")
|
||||||
@ -195,6 +195,35 @@ def get_bots() -> Dict[str, Bot]:
|
|||||||
return driver.bots
|
return driver.bots
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_dot_notation(
|
||||||
|
obj_str: str, default_attr: str, default_prefix: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
|
modulename, _, cls = obj_str.partition(":")
|
||||||
|
if default_prefix is not None and modulename.startswith("~"):
|
||||||
|
modulename = default_prefix + modulename[1:]
|
||||||
|
module = importlib.import_module(modulename)
|
||||||
|
if not cls:
|
||||||
|
return getattr(module, default_attr)
|
||||||
|
instance = module
|
||||||
|
for attr_str in cls.split("."):
|
||||||
|
instance = getattr(instance, attr_str)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_combine_expr(obj_str: str) -> Type[Driver]:
|
||||||
|
drivers = obj_str.split("+")
|
||||||
|
DriverClass = _resolve_dot_notation(
|
||||||
|
drivers[0], "Driver", default_prefix="nonebot.drivers."
|
||||||
|
)
|
||||||
|
if len(drivers) == 1:
|
||||||
|
return DriverClass
|
||||||
|
mixins = [
|
||||||
|
_resolve_dot_notation(mixin, "Mixin", default_prefix="nonebot.drivers.")
|
||||||
|
for mixin in drivers[1:]
|
||||||
|
]
|
||||||
|
return combine_driver(DriverClass, *mixins)
|
||||||
|
|
||||||
|
|
||||||
def init(*, _env_file: Optional[str] = None, **kwargs):
|
def init(*, _env_file: Optional[str] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -243,12 +272,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
|
|||||||
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
|
f"Loaded <y><b>Config</b></y>: {escape_tag(str(config.dict()))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
modulename, _, cls = config.driver.partition(":")
|
DriverClass: Type[Driver] = _resolve_combine_expr(config.driver)
|
||||||
module = importlib.import_module(modulename)
|
|
||||||
instance = module
|
|
||||||
for attr_str in (cls or "Driver").split("."):
|
|
||||||
instance = getattr(instance, attr_str)
|
|
||||||
DriverClass: Type[Driver] = instance # type: ignore
|
|
||||||
_driver = DriverClass(env, config)
|
_driver = DriverClass(env, config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -252,7 +252,13 @@ class ReverseDriver(Driver):
|
|||||||
|
|
||||||
|
|
||||||
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]:
|
||||||
class CombinedDriver(driver, *mixins, ForwardDriver): # type: ignore
|
# check first
|
||||||
|
assert issubclass(driver, Driver), "`driver` must be subclass of Driver"
|
||||||
|
assert all(
|
||||||
|
map(lambda m: issubclass(m, ForwardMixin), mixins)
|
||||||
|
), "`mixins` must be subclass of ForwardMixin"
|
||||||
|
|
||||||
|
class CombinedDriver(*mixins, driver, ForwardDriver): # type: ignore
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
@ -4,9 +4,9 @@ import threading
|
|||||||
from typing import Set, Callable, Awaitable
|
from typing import Set, Callable, Awaitable
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
|
from nonebot.drivers import Driver
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.drivers import ForwardDriver
|
|
||||||
|
|
||||||
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
STARTUP_FUNC = Callable[[], Awaitable[None]]
|
||||||
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
|
||||||
@ -16,11 +16,7 @@ HANDLED_SIGNALS = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockDriver(ForwardDriver):
|
class BlockDriver(Driver):
|
||||||
"""
|
|
||||||
AIOHTTP 驱动框架
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: Env, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
super().__init__(env, config)
|
super().__init__(env, config)
|
||||||
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
self.startup_funcs: Set[STARTUP_FUNC] = set()
|
||||||
@ -29,18 +25,18 @@ class BlockDriver(ForwardDriver):
|
|||||||
self.force_exit: bool = False
|
self.force_exit: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardDriver)
|
@overrides(Driver)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
"""驱动名称: ``block_driver``"""
|
"""驱动名称: ``block_driver``"""
|
||||||
return "block_driver"
|
return "block_driver"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardDriver)
|
@overrides(Driver)
|
||||||
def logger(self):
|
def logger(self):
|
||||||
"""block driver 使用的 logger"""
|
"""block driver 使用的 logger"""
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(Driver)
|
||||||
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -54,7 +50,7 @@ class BlockDriver(ForwardDriver):
|
|||||||
self.startup_funcs.add(func)
|
self.startup_funcs.add(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(Driver)
|
||||||
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -68,7 +64,7 @@ class BlockDriver(ForwardDriver):
|
|||||||
self.shutdown_funcs.add(func)
|
self.shutdown_funcs.add(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(Driver)
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
"""启动 block driver"""
|
"""启动 block driver"""
|
||||||
super().run(*args, **kwargs)
|
super().run(*args, **kwargs)
|
||||||
|
@ -19,7 +19,7 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
class AiohttpMixin(ForwardMixin):
|
class Mixin(ForwardMixin):
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -114,4 +114,4 @@ class WebSocket(BaseWebSocket):
|
|||||||
await self.websocket.send_bytes(data)
|
await self.websocket.send_bytes(data)
|
||||||
|
|
||||||
|
|
||||||
Driver = combine_driver(BlockDriver, AiohttpMixin)
|
Driver = combine_driver(BlockDriver, Mixin)
|
||||||
|
@ -22,22 +22,10 @@ from starlette.websockets import WebSocket, WebSocketState
|
|||||||
from nonebot.config import Env
|
from nonebot.config import Env
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
from nonebot.drivers.httpx import HttpxMixin
|
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import Request as BaseRequest
|
from nonebot.drivers import Request as BaseRequest
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers.websockets import WebSocketsMixin
|
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||||
from nonebot.drivers import (
|
|
||||||
ReverseDriver,
|
|
||||||
HTTPServerSetup,
|
|
||||||
WebSocketServerSetup,
|
|
||||||
combine_driver,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from nonebot.drivers.aiohttp import AiohttpMixin
|
|
||||||
except ImportError:
|
|
||||||
AiohttpMixin = None
|
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseSettings):
|
class Config(BaseSettings):
|
||||||
@ -317,8 +305,3 @@ class FastAPIWebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send_bytes(self, data: bytes) -> None:
|
async def send_bytes(self, data: bytes) -> None:
|
||||||
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||||
|
|
||||||
|
|
||||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
|
||||||
if AiohttpMixin:
|
|
||||||
AiohttpDriver = combine_driver(Driver, AiohttpMixin)
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import httpx
|
|
||||||
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.drivers._block_driver import BlockDriver
|
from nonebot.drivers._block_driver import BlockDriver
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
@ -11,8 +9,13 @@ from nonebot.drivers import (
|
|||||||
combine_driver,
|
combine_driver,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install httpx by using `pip install nonebot2[httpx]`")
|
||||||
|
|
||||||
class HttpxMixin(ForwardMixin):
|
|
||||||
|
class Mixin(ForwardMixin):
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -39,7 +42,7 @@ class HttpxMixin(ForwardMixin):
|
|||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def websocket(self, setup: Request) -> WebSocket:
|
async def websocket(self, setup: Request) -> WebSocket:
|
||||||
return await super(HttpxMixin, self).websocket(setup)
|
return await super(Mixin, self).websocket(setup)
|
||||||
|
|
||||||
|
|
||||||
Driver = combine_driver(BlockDriver, HttpxMixin)
|
Driver = combine_driver(BlockDriver, Mixin)
|
||||||
|
@ -17,17 +17,10 @@ from nonebot.config import Env
|
|||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
from nonebot.drivers.httpx import HttpxMixin
|
|
||||||
from nonebot.config import Config as NoneBotConfig
|
from nonebot.config import Config as NoneBotConfig
|
||||||
from nonebot.drivers import Request as BaseRequest
|
from nonebot.drivers import Request as BaseRequest
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers.websockets import WebSocketsMixin
|
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||||
from nonebot.drivers import (
|
|
||||||
ReverseDriver,
|
|
||||||
HTTPServerSetup,
|
|
||||||
WebSocketServerSetup,
|
|
||||||
combine_driver,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from quart import request as _request
|
from quart import request as _request
|
||||||
@ -295,6 +288,3 @@ class WebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send_bytes(self, data: bytes):
|
async def send_bytes(self, data: bytes):
|
||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
|
|
||||||
FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin)
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
|
||||||
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.log import LoguruHandler
|
from nonebot.log import LoguruHandler
|
||||||
from nonebot.drivers import Request, Response
|
from nonebot.drivers import Request, Response
|
||||||
@ -9,11 +7,18 @@ from nonebot.drivers._block_driver import BlockDriver
|
|||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.drivers import ForwardMixin, combine_driver
|
from nonebot.drivers import ForwardMixin, combine_driver
|
||||||
|
|
||||||
|
try:
|
||||||
|
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install websockets by using `pip install nonebot2[websockets]`"
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.Logger("websockets.client", "INFO")
|
logger = logging.Logger("websockets.client", "INFO")
|
||||||
logger.addHandler(LoguruHandler())
|
logger.addHandler(LoguruHandler())
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsMixin(ForwardMixin):
|
class Mixin(ForwardMixin):
|
||||||
@property
|
@property
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -21,7 +26,7 @@ class WebSocketsMixin(ForwardMixin):
|
|||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
return await super(WebSocketsMixin, self).request(setup)
|
return await super(Mixin, self).request(setup)
|
||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def websocket(self, setup: Request) -> "WebSocket":
|
async def websocket(self, setup: Request) -> "WebSocket":
|
||||||
@ -75,4 +80,4 @@ class WebSocket(BaseWebSocket):
|
|||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
|
|
||||||
Driver = combine_driver(BlockDriver, WebSocketsMixin)
|
Driver = combine_driver(BlockDriver, Mixin)
|
||||||
|
@ -422,7 +422,7 @@ def shell_command(
|
|||||||
命令内容与后续消息间无需空格!
|
命令内容与后续消息间无需空格!
|
||||||
\:\:\:
|
\:\:\:
|
||||||
"""
|
"""
|
||||||
if not isinstance(parser, ArgumentParser):
|
if parser is not None and not isinstance(parser, ArgumentParser):
|
||||||
raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
|
raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")
|
||||||
|
|
||||||
config = get_driver().config
|
config = get_driver().config
|
||||||
|
@ -28,13 +28,13 @@ loguru = "^0.5.1"
|
|||||||
pygtrie = "^2.4.1"
|
pygtrie = "^2.4.1"
|
||||||
tomlkit = "^0.7.0"
|
tomlkit = "^0.7.0"
|
||||||
fastapi = "^0.70.0"
|
fastapi = "^0.70.0"
|
||||||
websockets = ">=9.1"
|
|
||||||
typing-extensions = ">=3.10.0,<5.0.0"
|
typing-extensions = ">=3.10.0,<5.0.0"
|
||||||
Quart = { version = "^0.16.0", optional = true }
|
Quart = { version = "^0.16.0", optional = true }
|
||||||
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
|
websockets = { version=">=9.1", optional = true }
|
||||||
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
||||||
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
||||||
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
||||||
|
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
sphinx = "^4.1.1"
|
sphinx = "^4.1.1"
|
||||||
@ -46,8 +46,10 @@ sphinx-markdown-builder = { git = "https://github.com/nonebot/sphinx-markdown-bu
|
|||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
quart = ["quart"]
|
quart = ["quart"]
|
||||||
|
httpx = ["httpx"]
|
||||||
aiohttp = ["aiohttp"]
|
aiohttp = ["aiohttp"]
|
||||||
all = ["quart", "aiohttp"]
|
websockets = ["websockets"]
|
||||||
|
all = ["quart", "aiohttp", "httpx", "websockets"]
|
||||||
|
|
||||||
# [[tool.poetry.source]]
|
# [[tool.poetry.source]]
|
||||||
# name = "aliyun"
|
# name = "aliyun"
|
||||||
|
@ -15,18 +15,24 @@ os.environ["CONFIG_FROM_ENV"] = '{"test": "test"}'
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"nonebug_init",
|
"nonebug_init",
|
||||||
[{"config_from_init": "init", "driver": "nonebot.drivers.fastapi:FullDriver"}],
|
[
|
||||||
|
{
|
||||||
|
"config_from_init": "init",
|
||||||
|
"driver": "nonebot.drivers.fastapi+nonebot.drivers.httpx+nonebot.drivers.websockets",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"config_from_init": "init",
|
||||||
|
"driver": "~fastapi+~httpx+~websockets",
|
||||||
|
},
|
||||||
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
async def test_init(nonebug_init):
|
async def test_init(nonebug_init):
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from nonebot.drivers.fastapi import FullDriver
|
|
||||||
|
|
||||||
env = get_driver().env
|
env = get_driver().env
|
||||||
assert env == "test"
|
assert env == "test"
|
||||||
|
|
||||||
assert isinstance(get_driver(), FullDriver)
|
|
||||||
|
|
||||||
config = get_driver().config
|
config = get_driver().config
|
||||||
assert config.config_from_env == {"test": "test"}
|
assert config.config_from_env == {"test": "test"}
|
||||||
assert config.config_from_init == "init"
|
assert config.config_from_init == "init"
|
||||||
|
Loading…
Reference in New Issue
Block a user