Feature: 新增 dotenv 嵌套配置项支持 (#1324)

Co-authored-by: hemengyang <hmy0119@hotmail.com>
This commit is contained in:
Ju4tCode 2022-10-14 09:58:44 +08:00 committed by GitHub
parent 67b96528af
commit db534b8824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 90 additions and 66 deletions

View File

@ -13,13 +13,13 @@ repos:
stages: [commit]
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 22.10.0
hooks:
- id: black
stages: [commit]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.0
rev: v3.0.0-alpha.1
hooks:
- id: prettier
types_or: [javascript, jsx, ts, tsx, markdown, yaml]

View File

@ -37,10 +37,12 @@ FrontMatter:
description: nonebot 模块
"""
import os
import importlib
from typing import Any, Dict, Type, Optional
import loguru
from pydantic.env_settings import DotenvType
from nonebot.log import logger
from nonebot.adapters import Bot
@ -217,7 +219,7 @@ def _log_patcher(record: "loguru.Record"):
)
def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None:
def init(*, _env_file: Optional[DotenvType] = None, **kwargs: Any) -> None:
"""初始化 NoneBot 以及 全局 {ref}`nonebot.drivers.Driver` 对象。
NoneBot 将会从 .env 文件中读取环境信息并使用相应的 env 文件配置
@ -237,10 +239,12 @@ def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None:
if not _driver:
logger.success("NoneBot is initializing...")
env = Env()
_env_file = _env_file or f".env.{env.environment}"
config = Config(
**kwargs,
_common_config=env.dict(),
_env_file=_env_file or f".env.{env.environment}",
_env_file=(".env", _env_file)
if isinstance(_env_file, (str, os.PathLike))
else _env_file,
)
logger.configure(

View File

@ -14,14 +14,14 @@ from datetime import timedelta
from ipaddress import IPv4Address
from typing import TYPE_CHECKING, Any, Set, Dict, Tuple, Union, Mapping, Optional
from pydantic import BaseSettings, IPvAnyAddress
from pydantic.utils import deep_update
from pydantic import Extra, BaseSettings, IPvAnyAddress
from pydantic.env_settings import (
DotenvType,
SettingsError,
EnvSettingsSource,
InitSettingsSource,
SettingsSourceCallable,
read_env_file,
env_file_sentinel,
)
from nonebot.log import logger
@ -32,33 +32,15 @@ class CustomEnvSettings(EnvSettingsSource):
"""
Build environment variables suitable for passing to the Model.
"""
d: Dict[str, Optional[str]] = {}
d: Dict[str, Any] = {}
if settings.__config__.case_sensitive:
env_vars: Mapping[str, Optional[str]] = os.environ # pragma: no cover
else:
env_vars = {k.lower(): v for k, v in os.environ.items()}
env_file_vars: Dict[str, Optional[str]] = {}
env_file = (
self.env_file
if self.env_file != env_file_sentinel
else settings.__config__.env_file
)
env_file_encoding = (
self.env_file_encoding
if self.env_file_encoding is not None
else settings.__config__.env_file_encoding
)
if env_file is not None:
env_path = Path(env_file)
if env_path.is_file():
env_file_vars = read_env_file(
env_path,
encoding=env_file_encoding, # type: ignore
case_sensitive=settings.__config__.case_sensitive,
)
env_vars = {**env_file_vars, **env_vars}
env_file_vars = self._read_env_files(settings.__config__.case_sensitive)
env_vars = {**env_file_vars, **env_vars}
for field in settings.__fields__.values():
env_val: Optional[str] = None
@ -69,31 +51,56 @@ class CustomEnvSettings(EnvSettingsSource):
if env_val is not None:
break
if env_val is None:
continue
if field.is_complex():
try:
env_val = settings.__config__.json_loads(env_val)
except ValueError as e: # pragma: no cover
raise SettingsError(
f'error parsing JSON for "{env_name}"' # type: ignore
) from e
d[field.alias] = env_val
if env_file_vars:
for env_name in env_file_vars.keys():
env_val = env_vars[env_name]
if env_val and (val_striped := env_val.strip()):
is_complex, allow_parse_failure = self.field_is_complex(field)
if is_complex:
if env_val is None:
if env_val_built := self.explode_env_vars(field, env_vars):
d[field.alias] = env_val_built
else:
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
try:
env_val = settings.__config__.json_loads(val_striped)
env_val = settings.__config__.parse_env_var(field.name, env_val)
except ValueError as e:
logger.trace(
"Error while parsing JSON for "
f"{env_name!r}={val_striped!r}. "
"Assumed as string."
)
if not allow_parse_failure:
raise SettingsError(
f'error parsing env var "{env_name}"' # type: ignore
) from e
if isinstance(env_val, dict):
d[field.alias] = deep_update(
env_val, self.explode_env_vars(field, env_vars)
)
else:
d[field.alias] = env_val
elif env_val is not None:
# simplest case, field is not complex, we only need to add the value if it was found
d[field.alias] = env_val
# remain user custom config
for env_name in env_file_vars:
env_val = env_vars[env_name]
if env_val and (val_striped := env_val.strip()):
# there's a value, decode that as JSON
try:
env_val = settings.__config__.parse_env_var(env_name, val_striped)
except ValueError as e:
logger.trace(
"Error while parsing JSON for "
f"{env_name!r}={val_striped!r}. "
"Assumed as string."
)
# explode value when it's a nested dict
env_name, *nested_keys = env_name.split(self.env_nested_delimiter)
if nested_keys and (env_name not in d or isinstance(d[env_name], dict)):
result = {}
*keys, last_key = nested_keys
_tmp = result
for key in keys:
_tmp = _tmp.setdefault(key, {})
_tmp[last_key] = env_val
d[env_name] = deep_update(d.get(env_name, {}), result)
elif not nested_keys:
d[env_name] = env_val
return d
@ -106,6 +113,9 @@ class BaseConfig(BaseSettings):
return self.__dict__.get(name)
class Config:
extra = Extra.allow
env_nested_delimiter = "__"
@classmethod
def customise_sources(
cls,
@ -117,7 +127,10 @@ class BaseConfig(BaseSettings):
return (
init_settings,
CustomEnvSettings(
env_settings.env_file, env_settings.env_file_encoding
env_settings.env_file,
env_settings.env_file_encoding,
env_settings.env_nested_delimiter,
env_settings.env_prefix_len,
),
InitSettingsSource(common_config),
file_secret_settings,
@ -137,7 +150,6 @@ class Env(BaseConfig):
"""
class Config:
extra = "allow"
env_file = ".env"
@ -150,8 +162,7 @@ class Config(BaseConfig):
配置方法参考: [配置](https://v2.nonebot.dev/docs/tutorial/configuration)
"""
_env_file: str = ".env"
_common_config: Dict[str, Any] = {}
_env_file: DotenvType = ".env", ".env.prod"
# nonebot configs
driver: str = "~fastapi"
@ -231,8 +242,7 @@ class Config(BaseConfig):
# or from env file using json loads
class Config:
extra = "allow"
env_file = ".env.prod"
env_file = ".env", ".env.prod"
__autodoc__ = {

13
poetry.lock generated
View File

@ -115,11 +115,11 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>
[[package]]
name = "black"
version = "22.8.0"
version = "22.10.0"
description = "The uncompromising code formatter."
category = "dev"
optional = false
python-versions = ">=3.6.2"
python-versions = ">=3.7"
[package.dependencies]
click = ">=8.0.0"
@ -920,7 +920,7 @@ python-versions = ">=3.6,<4.0"
[[package]]
name = "typing-extensions"
version = "4.3.0"
version = "4.4.0"
description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main"
optional = false
@ -1076,7 +1076,7 @@ websockets = ["websockets"]
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "a4a8da0510758a7da5a2c941518505d44099bad4d77ccef30812e5e73e8a9f7f"
content-hash = "9a5abe22ecaaa43b8e124bb2c01001da35f6243f3a827282ec4959912eeb5745"
[metadata.files]
aiodns = [
@ -1571,10 +1571,7 @@ tomli = [
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
tomlkit = []
typing-extensions = [
{file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"},
{file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"},
]
typing-extensions = []
urllib3 = []
uvicorn = []
uvloop = []

View File

@ -31,7 +31,7 @@ tomlkit = ">=0.10.0,<1.0.0"
typing-extensions = ">=3.10.0,<5.0.0"
Quart = { version = "^0.17.0", optional = true }
websockets = { version="^10.0", optional = true }
pydantic = { version = "^1.9.0", extras = ["dotenv"] }
pydantic = { version = "^1.10.0", extras = ["dotenv"] }
uvicorn = { version = "^0.18.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }

View File

@ -1,2 +1,3 @@
ENVIRONMENT=test
COMMON_CONFIG=common
COMMON_OVERRIDE=old

View File

@ -1,5 +1,13 @@
LOG_LEVEL=TRACE
NICKNAME=["test"]
SUPERUSERS=["test", "fake:faketest"]
COMMON_OVERRIDE=new
CONFIG_FROM_ENV=
CONFIG_OVERRIDE=old
NESTED_DICT={"a": 1}
NESTED_DICT__B=2
NESTED_DICT__C__D=3
NESTED_MISSING_DICT__A=1
NESTED_MISSING_DICT__B__C=2
NOT_NESTED=some string
NOT_NESTED__A=1

View File

@ -29,6 +29,10 @@ async def test_init(nonebug_init):
assert config.config_override == "new"
assert config.config_from_init == "init"
assert config.common_config == "common"
assert config.common_override == "new"
assert config.nested_dict == {"a": 1, "b": 2, "c": {"d": 3}}
assert config.nested_missing_dict == {"a": 1, "b": {"c": 2}}
assert config.not_nested == "some string"
@pytest.mark.asyncio