mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ Feature: 新增 dotenv 嵌套配置项支持 (#1324)
Co-authored-by: hemengyang <hmy0119@hotmail.com>
This commit is contained in:
parent
67b96528af
commit
db534b8824
@ -13,13 +13,13 @@ repos:
|
||||
stages: [commit]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.8.0
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
stages: [commit]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.0.0-alpha.0
|
||||
rev: v3.0.0-alpha.1
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [javascript, jsx, ts, tsx, markdown, yaml]
|
||||
|
@ -37,10 +37,12 @@ FrontMatter:
|
||||
description: nonebot 模块
|
||||
"""
|
||||
|
||||
import os
|
||||
import importlib
|
||||
from typing import Any, Dict, Type, Optional
|
||||
|
||||
import loguru
|
||||
from pydantic.env_settings import DotenvType
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.adapters import Bot
|
||||
@ -217,7 +219,7 @@ def _log_patcher(record: "loguru.Record"):
|
||||
)
|
||||
|
||||
|
||||
def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None:
|
||||
def init(*, _env_file: Optional[DotenvType] = None, **kwargs: Any) -> None:
|
||||
"""初始化 NoneBot 以及 全局 {ref}`nonebot.drivers.Driver` 对象。
|
||||
|
||||
NoneBot 将会从 .env 文件中读取环境信息,并使用相应的 env 文件配置。
|
||||
@ -237,10 +239,12 @@ def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None:
|
||||
if not _driver:
|
||||
logger.success("NoneBot is initializing...")
|
||||
env = Env()
|
||||
_env_file = _env_file or f".env.{env.environment}"
|
||||
config = Config(
|
||||
**kwargs,
|
||||
_common_config=env.dict(),
|
||||
_env_file=_env_file or f".env.{env.environment}",
|
||||
_env_file=(".env", _env_file)
|
||||
if isinstance(_env_file, (str, os.PathLike))
|
||||
else _env_file,
|
||||
)
|
||||
|
||||
logger.configure(
|
||||
|
@ -14,14 +14,14 @@ from datetime import timedelta
|
||||
from ipaddress import IPv4Address
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Tuple, Union, Mapping, Optional
|
||||
|
||||
from pydantic import BaseSettings, IPvAnyAddress
|
||||
from pydantic.utils import deep_update
|
||||
from pydantic import Extra, BaseSettings, IPvAnyAddress
|
||||
from pydantic.env_settings import (
|
||||
DotenvType,
|
||||
SettingsError,
|
||||
EnvSettingsSource,
|
||||
InitSettingsSource,
|
||||
SettingsSourceCallable,
|
||||
read_env_file,
|
||||
env_file_sentinel,
|
||||
)
|
||||
|
||||
from nonebot.log import logger
|
||||
@ -32,32 +32,14 @@ class CustomEnvSettings(EnvSettingsSource):
|
||||
"""
|
||||
Build environment variables suitable for passing to the Model.
|
||||
"""
|
||||
d: Dict[str, Optional[str]] = {}
|
||||
d: Dict[str, Any] = {}
|
||||
|
||||
if settings.__config__.case_sensitive:
|
||||
env_vars: Mapping[str, Optional[str]] = os.environ # pragma: no cover
|
||||
else:
|
||||
env_vars = {k.lower(): v for k, v in os.environ.items()}
|
||||
|
||||
env_file_vars: Dict[str, Optional[str]] = {}
|
||||
env_file = (
|
||||
self.env_file
|
||||
if self.env_file != env_file_sentinel
|
||||
else settings.__config__.env_file
|
||||
)
|
||||
env_file_encoding = (
|
||||
self.env_file_encoding
|
||||
if self.env_file_encoding is not None
|
||||
else settings.__config__.env_file_encoding
|
||||
)
|
||||
if env_file is not None:
|
||||
env_path = Path(env_file)
|
||||
if env_path.is_file():
|
||||
env_file_vars = read_env_file(
|
||||
env_path,
|
||||
encoding=env_file_encoding, # type: ignore
|
||||
case_sensitive=settings.__config__.case_sensitive,
|
||||
)
|
||||
env_file_vars = self._read_env_files(settings.__config__.case_sensitive)
|
||||
env_vars = {**env_file_vars, **env_vars}
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
@ -69,24 +51,38 @@ class CustomEnvSettings(EnvSettingsSource):
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
is_complex, allow_parse_failure = self.field_is_complex(field)
|
||||
if is_complex:
|
||||
if env_val is None:
|
||||
continue
|
||||
|
||||
if field.is_complex():
|
||||
if env_val_built := self.explode_env_vars(field, env_vars):
|
||||
d[field.alias] = env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
env_val = settings.__config__.json_loads(env_val)
|
||||
except ValueError as e: # pragma: no cover
|
||||
env_val = settings.__config__.parse_env_var(field.name, env_val)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise SettingsError(
|
||||
f'error parsing JSON for "{env_name}"' # type: ignore
|
||||
f'error parsing env var "{env_name}"' # type: ignore
|
||||
) from e
|
||||
|
||||
if isinstance(env_val, dict):
|
||||
d[field.alias] = deep_update(
|
||||
env_val, self.explode_env_vars(field, env_vars)
|
||||
)
|
||||
else:
|
||||
d[field.alias] = env_val
|
||||
elif env_val is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
d[field.alias] = env_val
|
||||
|
||||
if env_file_vars:
|
||||
for env_name in env_file_vars.keys():
|
||||
# remain user custom config
|
||||
for env_name in env_file_vars:
|
||||
env_val = env_vars[env_name]
|
||||
if env_val and (val_striped := env_val.strip()):
|
||||
# there's a value, decode that as JSON
|
||||
try:
|
||||
env_val = settings.__config__.json_loads(val_striped)
|
||||
env_val = settings.__config__.parse_env_var(env_name, val_striped)
|
||||
except ValueError as e:
|
||||
logger.trace(
|
||||
"Error while parsing JSON for "
|
||||
@ -94,6 +90,17 @@ class CustomEnvSettings(EnvSettingsSource):
|
||||
"Assumed as string."
|
||||
)
|
||||
|
||||
# explode value when it's a nested dict
|
||||
env_name, *nested_keys = env_name.split(self.env_nested_delimiter)
|
||||
if nested_keys and (env_name not in d or isinstance(d[env_name], dict)):
|
||||
result = {}
|
||||
*keys, last_key = nested_keys
|
||||
_tmp = result
|
||||
for key in keys:
|
||||
_tmp = _tmp.setdefault(key, {})
|
||||
_tmp[last_key] = env_val
|
||||
d[env_name] = deep_update(d.get(env_name, {}), result)
|
||||
elif not nested_keys:
|
||||
d[env_name] = env_val
|
||||
|
||||
return d
|
||||
@ -106,6 +113,9 @@ class BaseConfig(BaseSettings):
|
||||
return self.__dict__.get(name)
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
@ -117,7 +127,10 @@ class BaseConfig(BaseSettings):
|
||||
return (
|
||||
init_settings,
|
||||
CustomEnvSettings(
|
||||
env_settings.env_file, env_settings.env_file_encoding
|
||||
env_settings.env_file,
|
||||
env_settings.env_file_encoding,
|
||||
env_settings.env_nested_delimiter,
|
||||
env_settings.env_prefix_len,
|
||||
),
|
||||
InitSettingsSource(common_config),
|
||||
file_secret_settings,
|
||||
@ -137,7 +150,6 @@ class Env(BaseConfig):
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
@ -150,8 +162,7 @@ class Config(BaseConfig):
|
||||
配置方法参考: [配置](https://v2.nonebot.dev/docs/tutorial/configuration)
|
||||
"""
|
||||
|
||||
_env_file: str = ".env"
|
||||
_common_config: Dict[str, Any] = {}
|
||||
_env_file: DotenvType = ".env", ".env.prod"
|
||||
|
||||
# nonebot configs
|
||||
driver: str = "~fastapi"
|
||||
@ -231,8 +242,7 @@ class Config(BaseConfig):
|
||||
# or from env file using json loads
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
env_file = ".env.prod"
|
||||
env_file = ".env", ".env.prod"
|
||||
|
||||
|
||||
__autodoc__ = {
|
||||
|
13
poetry.lock
generated
13
poetry.lock
generated
@ -115,11 +115,11 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "22.8.0"
|
||||
version = "22.10.0"
|
||||
description = "The uncompromising code formatter."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6.2"
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=8.0.0"
|
||||
@ -920,7 +920,7 @@ python-versions = ">=3.6,<4.0"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.3.0"
|
||||
version = "4.4.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.7+"
|
||||
category = "main"
|
||||
optional = false
|
||||
@ -1076,7 +1076,7 @@ websockets = ["websockets"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "a4a8da0510758a7da5a2c941518505d44099bad4d77ccef30812e5e73e8a9f7f"
|
||||
content-hash = "9a5abe22ecaaa43b8e124bb2c01001da35f6243f3a827282ec4959912eeb5745"
|
||||
|
||||
[metadata.files]
|
||||
aiodns = [
|
||||
@ -1571,10 +1571,7 @@ tomli = [
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
tomlkit = []
|
||||
typing-extensions = [
|
||||
{file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"},
|
||||
{file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"},
|
||||
]
|
||||
typing-extensions = []
|
||||
urllib3 = []
|
||||
uvicorn = []
|
||||
uvloop = []
|
||||
|
@ -31,7 +31,7 @@ tomlkit = ">=0.10.0,<1.0.0"
|
||||
typing-extensions = ">=3.10.0,<5.0.0"
|
||||
Quart = { version = "^0.17.0", optional = true }
|
||||
websockets = { version="^10.0", optional = true }
|
||||
pydantic = { version = "^1.9.0", extras = ["dotenv"] }
|
||||
pydantic = { version = "^1.10.0", extras = ["dotenv"] }
|
||||
uvicorn = { version = "^0.18.0", extras = ["standard"] }
|
||||
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
||||
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }
|
||||
|
@ -1,2 +1,3 @@
|
||||
ENVIRONMENT=test
|
||||
COMMON_CONFIG=common
|
||||
COMMON_OVERRIDE=old
|
||||
|
@ -1,5 +1,13 @@
|
||||
LOG_LEVEL=TRACE
|
||||
NICKNAME=["test"]
|
||||
SUPERUSERS=["test", "fake:faketest"]
|
||||
COMMON_OVERRIDE=new
|
||||
CONFIG_FROM_ENV=
|
||||
CONFIG_OVERRIDE=old
|
||||
NESTED_DICT={"a": 1}
|
||||
NESTED_DICT__B=2
|
||||
NESTED_DICT__C__D=3
|
||||
NESTED_MISSING_DICT__A=1
|
||||
NESTED_MISSING_DICT__B__C=2
|
||||
NOT_NESTED=some string
|
||||
NOT_NESTED__A=1
|
||||
|
@ -29,6 +29,10 @@ async def test_init(nonebug_init):
|
||||
assert config.config_override == "new"
|
||||
assert config.config_from_init == "init"
|
||||
assert config.common_config == "common"
|
||||
assert config.common_override == "new"
|
||||
assert config.nested_dict == {"a": 1, "b": 2, "c": {"d": 3}}
|
||||
assert config.nested_missing_dict == {"a": 1, "b": {"c": 2}}
|
||||
assert config.not_nested == "some string"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user