Feature: 新增 dotenv 嵌套配置项支持 (#1324)

Co-authored-by: hemengyang <hmy0119@hotmail.com>
This commit is contained in:
Ju4tCode 2022-10-14 09:58:44 +08:00 committed by GitHub
parent 67b96528af
commit db534b8824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 90 additions and 66 deletions

View File

@ -13,13 +13,13 @@ repos:
stages: [commit] stages: [commit]
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.8.0 rev: 22.10.0
hooks: hooks:
- id: black - id: black
stages: [commit] stages: [commit]
- repo: https://github.com/pre-commit/mirrors-prettier - repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.0 rev: v3.0.0-alpha.1
hooks: hooks:
- id: prettier - id: prettier
types_or: [javascript, jsx, ts, tsx, markdown, yaml] types_or: [javascript, jsx, ts, tsx, markdown, yaml]

View File

@ -37,10 +37,12 @@ FrontMatter:
description: nonebot 模块 description: nonebot 模块
""" """
import os
import importlib import importlib
from typing import Any, Dict, Type, Optional from typing import Any, Dict, Type, Optional
import loguru import loguru
from pydantic.env_settings import DotenvType
from nonebot.log import logger from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -217,7 +219,7 @@ def _log_patcher(record: "loguru.Record"):
) )
def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None: def init(*, _env_file: Optional[DotenvType] = None, **kwargs: Any) -> None:
"""初始化 NoneBot 以及 全局 {ref}`nonebot.drivers.Driver` 对象。 """初始化 NoneBot 以及 全局 {ref}`nonebot.drivers.Driver` 对象。
NoneBot 将会从 .env 文件中读取环境信息并使用相应的 env 文件配置 NoneBot 将会从 .env 文件中读取环境信息并使用相应的 env 文件配置
@ -237,10 +239,12 @@ def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None:
if not _driver: if not _driver:
logger.success("NoneBot is initializing...") logger.success("NoneBot is initializing...")
env = Env() env = Env()
_env_file = _env_file or f".env.{env.environment}"
config = Config( config = Config(
**kwargs, **kwargs,
_common_config=env.dict(), _env_file=(".env", _env_file)
_env_file=_env_file or f".env.{env.environment}", if isinstance(_env_file, (str, os.PathLike))
else _env_file,
) )
logger.configure( logger.configure(

View File

@ -14,14 +14,14 @@ from datetime import timedelta
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import TYPE_CHECKING, Any, Set, Dict, Tuple, Union, Mapping, Optional from typing import TYPE_CHECKING, Any, Set, Dict, Tuple, Union, Mapping, Optional
from pydantic import BaseSettings, IPvAnyAddress from pydantic.utils import deep_update
from pydantic import Extra, BaseSettings, IPvAnyAddress
from pydantic.env_settings import ( from pydantic.env_settings import (
DotenvType,
SettingsError, SettingsError,
EnvSettingsSource, EnvSettingsSource,
InitSettingsSource, InitSettingsSource,
SettingsSourceCallable, SettingsSourceCallable,
read_env_file,
env_file_sentinel,
) )
from nonebot.log import logger from nonebot.log import logger
@ -32,33 +32,15 @@ class CustomEnvSettings(EnvSettingsSource):
""" """
Build environment variables suitable for passing to the Model. Build environment variables suitable for passing to the Model.
""" """
d: Dict[str, Optional[str]] = {} d: Dict[str, Any] = {}
if settings.__config__.case_sensitive: if settings.__config__.case_sensitive:
env_vars: Mapping[str, Optional[str]] = os.environ # pragma: no cover env_vars: Mapping[str, Optional[str]] = os.environ # pragma: no cover
else: else:
env_vars = {k.lower(): v for k, v in os.environ.items()} env_vars = {k.lower(): v for k, v in os.environ.items()}
env_file_vars: Dict[str, Optional[str]] = {} env_file_vars = self._read_env_files(settings.__config__.case_sensitive)
env_file = ( env_vars = {**env_file_vars, **env_vars}
self.env_file
if self.env_file != env_file_sentinel
else settings.__config__.env_file
)
env_file_encoding = (
self.env_file_encoding
if self.env_file_encoding is not None
else settings.__config__.env_file_encoding
)
if env_file is not None:
env_path = Path(env_file)
if env_path.is_file():
env_file_vars = read_env_file(
env_path,
encoding=env_file_encoding, # type: ignore
case_sensitive=settings.__config__.case_sensitive,
)
env_vars = {**env_file_vars, **env_vars}
for field in settings.__fields__.values(): for field in settings.__fields__.values():
env_val: Optional[str] = None env_val: Optional[str] = None
@ -69,31 +51,56 @@ class CustomEnvSettings(EnvSettingsSource):
if env_val is not None: if env_val is not None:
break break
if env_val is None: is_complex, allow_parse_failure = self.field_is_complex(field)
continue if is_complex:
if env_val is None:
if field.is_complex(): if env_val_built := self.explode_env_vars(field, env_vars):
try: d[field.alias] = env_val_built
env_val = settings.__config__.json_loads(env_val) else:
except ValueError as e: # pragma: no cover # field is complex and there's a value, decode that as JSON, then add explode_env_vars
raise SettingsError(
f'error parsing JSON for "{env_name}"' # type: ignore
) from e
d[field.alias] = env_val
if env_file_vars:
for env_name in env_file_vars.keys():
env_val = env_vars[env_name]
if env_val and (val_striped := env_val.strip()):
try: try:
env_val = settings.__config__.json_loads(val_striped) env_val = settings.__config__.parse_env_var(field.name, env_val)
except ValueError as e: except ValueError as e:
logger.trace( if not allow_parse_failure:
"Error while parsing JSON for " raise SettingsError(
f"{env_name!r}={val_striped!r}. " f'error parsing env var "{env_name}"' # type: ignore
"Assumed as string." ) from e
)
if isinstance(env_val, dict):
d[field.alias] = deep_update(
env_val, self.explode_env_vars(field, env_vars)
)
else:
d[field.alias] = env_val
elif env_val is not None:
# simplest case, field is not complex, we only need to add the value if it was found
d[field.alias] = env_val
# remain user custom config
for env_name in env_file_vars:
env_val = env_vars[env_name]
if env_val and (val_striped := env_val.strip()):
# there's a value, decode that as JSON
try:
env_val = settings.__config__.parse_env_var(env_name, val_striped)
except ValueError as e:
logger.trace(
"Error while parsing JSON for "
f"{env_name!r}={val_striped!r}. "
"Assumed as string."
)
# explode value when it's a nested dict
env_name, *nested_keys = env_name.split(self.env_nested_delimiter)
if nested_keys and (env_name not in d or isinstance(d[env_name], dict)):
result = {}
*keys, last_key = nested_keys
_tmp = result
for key in keys:
_tmp = _tmp.setdefault(key, {})
_tmp[last_key] = env_val
d[env_name] = deep_update(d.get(env_name, {}), result)
elif not nested_keys:
d[env_name] = env_val d[env_name] = env_val
return d return d
@ -106,6 +113,9 @@ class BaseConfig(BaseSettings):
return self.__dict__.get(name) return self.__dict__.get(name)
class Config: class Config:
extra = Extra.allow
env_nested_delimiter = "__"
@classmethod @classmethod
def customise_sources( def customise_sources(
cls, cls,
@ -117,7 +127,10 @@ class BaseConfig(BaseSettings):
return ( return (
init_settings, init_settings,
CustomEnvSettings( CustomEnvSettings(
env_settings.env_file, env_settings.env_file_encoding env_settings.env_file,
env_settings.env_file_encoding,
env_settings.env_nested_delimiter,
env_settings.env_prefix_len,
), ),
InitSettingsSource(common_config), InitSettingsSource(common_config),
file_secret_settings, file_secret_settings,
@ -137,7 +150,6 @@ class Env(BaseConfig):
""" """
class Config: class Config:
extra = "allow"
env_file = ".env" env_file = ".env"
@ -150,8 +162,7 @@ class Config(BaseConfig):
配置方法参考: [配置](https://v2.nonebot.dev/docs/tutorial/configuration) 配置方法参考: [配置](https://v2.nonebot.dev/docs/tutorial/configuration)
""" """
_env_file: str = ".env" _env_file: DotenvType = ".env", ".env.prod"
_common_config: Dict[str, Any] = {}
# nonebot configs # nonebot configs
driver: str = "~fastapi" driver: str = "~fastapi"
@ -231,8 +242,7 @@ class Config(BaseConfig):
# or from env file using json loads # or from env file using json loads
class Config: class Config:
extra = "allow" env_file = ".env", ".env.prod"
env_file = ".env.prod"
__autodoc__ = { __autodoc__ = {

13
poetry.lock generated
View File

@ -115,11 +115,11 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>
[[package]] [[package]]
name = "black" name = "black"
version = "22.8.0" version = "22.10.0"
description = "The uncompromising code formatter." description = "The uncompromising code formatter."
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.6.2" python-versions = ">=3.7"
[package.dependencies] [package.dependencies]
click = ">=8.0.0" click = ">=8.0.0"
@ -920,7 +920,7 @@ python-versions = ">=3.6,<4.0"
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.3.0" version = "4.4.0"
description = "Backported and Experimental Type Hints for Python 3.7+" description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main" category = "main"
optional = false optional = false
@ -1076,7 +1076,7 @@ websockets = ["websockets"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "a4a8da0510758a7da5a2c941518505d44099bad4d77ccef30812e5e73e8a9f7f" content-hash = "9a5abe22ecaaa43b8e124bb2c01001da35f6243f3a827282ec4959912eeb5745"
[metadata.files] [metadata.files]
aiodns = [ aiodns = [
@ -1571,10 +1571,7 @@ tomli = [
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
] ]
tomlkit = [] tomlkit = []
typing-extensions = [ typing-extensions = []
{file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"},
{file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"},
]
urllib3 = [] urllib3 = []
uvicorn = [] uvicorn = []
uvloop = [] uvloop = []

View File

@ -31,7 +31,7 @@ tomlkit = ">=0.10.0,<1.0.0"
typing-extensions = ">=3.10.0,<5.0.0" typing-extensions = ">=3.10.0,<5.0.0"
Quart = { version = "^0.17.0", optional = true } Quart = { version = "^0.17.0", optional = true }
websockets = { version="^10.0", optional = true } websockets = { version="^10.0", optional = true }
pydantic = { version = "^1.9.0", extras = ["dotenv"] } pydantic = { version = "^1.10.0", extras = ["dotenv"] }
uvicorn = { version = "^0.18.0", extras = ["standard"] } uvicorn = { version = "^0.18.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true } httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true }

View File

@ -1,2 +1,3 @@
ENVIRONMENT=test ENVIRONMENT=test
COMMON_CONFIG=common COMMON_CONFIG=common
COMMON_OVERRIDE=old

View File

@ -1,5 +1,13 @@
LOG_LEVEL=TRACE LOG_LEVEL=TRACE
NICKNAME=["test"] NICKNAME=["test"]
SUPERUSERS=["test", "fake:faketest"] SUPERUSERS=["test", "fake:faketest"]
COMMON_OVERRIDE=new
CONFIG_FROM_ENV= CONFIG_FROM_ENV=
CONFIG_OVERRIDE=old CONFIG_OVERRIDE=old
NESTED_DICT={"a": 1}
NESTED_DICT__B=2
NESTED_DICT__C__D=3
NESTED_MISSING_DICT__A=1
NESTED_MISSING_DICT__B__C=2
NOT_NESTED=some string
NOT_NESTED__A=1

View File

@ -29,6 +29,10 @@ async def test_init(nonebug_init):
assert config.config_override == "new" assert config.config_override == "new"
assert config.config_from_init == "init" assert config.config_from_init == "init"
assert config.common_config == "common" assert config.common_config == "common"
assert config.common_override == "new"
assert config.nested_dict == {"a": 1, "b": 2, "c": {"d": 3}}
assert config.nested_missing_dict == {"a": 1, "b": {"c": 2}}
assert config.not_nested == "some string"
@pytest.mark.asyncio @pytest.mark.asyncio