diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0ca03c5d..852d6a85 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,13 +13,13 @@ repos: stages: [commit] - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 22.10.0 hooks: - id: black stages: [commit] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.0 + rev: v3.0.0-alpha.1 hooks: - id: prettier types_or: [javascript, jsx, ts, tsx, markdown, yaml] diff --git a/nonebot/__init__.py b/nonebot/__init__.py index a0d3bbfb..52b6ac07 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -37,10 +37,12 @@ FrontMatter: description: nonebot 模块 """ +import os import importlib from typing import Any, Dict, Type, Optional import loguru +from pydantic.env_settings import DotenvType from nonebot.log import logger from nonebot.adapters import Bot @@ -217,7 +219,7 @@ def _log_patcher(record: "loguru.Record"): ) -def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None: +def init(*, _env_file: Optional[DotenvType] = None, **kwargs: Any) -> None: """初始化 NoneBot 以及 全局 {ref}`nonebot.drivers.Driver` 对象。 NoneBot 将会从 .env 文件中读取环境信息,并使用相应的 env 文件配置。 @@ -237,10 +239,12 @@ def init(*, _env_file: Optional[str] = None, **kwargs: Any) -> None: if not _driver: logger.success("NoneBot is initializing...") env = Env() + _env_file = _env_file or f".env.{env.environment}" config = Config( **kwargs, - _common_config=env.dict(), - _env_file=_env_file or f".env.{env.environment}", + _env_file=(".env", _env_file) + if isinstance(_env_file, (str, os.PathLike)) + else _env_file, ) logger.configure( diff --git a/nonebot/config.py b/nonebot/config.py index f0bb374f..95afb401 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -14,14 +14,14 @@ from datetime import timedelta from ipaddress import IPv4Address from typing import TYPE_CHECKING, Any, Set, Dict, Tuple, Union, Mapping, Optional -from pydantic import BaseSettings, IPvAnyAddress +from pydantic.utils import deep_update +from pydantic import Extra, BaseSettings, IPvAnyAddress from pydantic.env_settings import ( + DotenvType, SettingsError, EnvSettingsSource, InitSettingsSource, SettingsSourceCallable, - read_env_file, - env_file_sentinel, ) from nonebot.log import logger @@ -32,33 +32,15 @@ class CustomEnvSettings(EnvSettingsSource): """ Build environment variables suitable for passing to the Model. """ - d: Dict[str, Optional[str]] = {} + d: Dict[str, Any] = {} if settings.__config__.case_sensitive: env_vars: Mapping[str, Optional[str]] = os.environ # pragma: no cover else: env_vars = {k.lower(): v for k, v in os.environ.items()} - env_file_vars: Dict[str, Optional[str]] = {} - env_file = ( - self.env_file - if self.env_file != env_file_sentinel - else settings.__config__.env_file - ) - env_file_encoding = ( - self.env_file_encoding - if self.env_file_encoding is not None - else settings.__config__.env_file_encoding - ) - if env_file is not None: - env_path = Path(env_file) - if env_path.is_file(): - env_file_vars = read_env_file( - env_path, - encoding=env_file_encoding, # type: ignore - case_sensitive=settings.__config__.case_sensitive, - ) - env_vars = {**env_file_vars, **env_vars} + env_file_vars = self._read_env_files(settings.__config__.case_sensitive) + env_vars = {**env_file_vars, **env_vars} for field in settings.__fields__.values(): env_val: Optional[str] = None @@ -69,31 +51,56 @@ class CustomEnvSettings(EnvSettingsSource): if env_val is not None: break - if env_val is None: - continue - - if field.is_complex(): - try: - env_val = settings.__config__.json_loads(env_val) - except ValueError as e: # pragma: no cover - raise SettingsError( - f'error parsing JSON for "{env_name}"' # type: ignore - ) from e - d[field.alias] = env_val - - if env_file_vars: - for env_name in env_file_vars.keys(): - env_val = env_vars[env_name] - if env_val and (val_striped := env_val.strip()): + is_complex, allow_parse_failure = self.field_is_complex(field) + if is_complex: + if env_val is None: + if env_val_built := self.explode_env_vars(field, env_vars): + d[field.alias] = env_val_built + else: + # field is complex and there's a value, decode that as JSON, then add explode_env_vars try: - env_val = settings.__config__.json_loads(val_striped) + env_val = settings.__config__.parse_env_var(field.name, env_val) except ValueError as e: - logger.trace( - "Error while parsing JSON for " - f"{env_name!r}={val_striped!r}. " - "Assumed as string." - ) + if not allow_parse_failure: + raise SettingsError( + f'error parsing env var "{env_name}"' # type: ignore + ) from e + if isinstance(env_val, dict): + d[field.alias] = deep_update( + env_val, self.explode_env_vars(field, env_vars) + ) + else: + d[field.alias] = env_val + elif env_val is not None: + # simplest case, field is not complex, we only need to add the value if it was found + d[field.alias] = env_val + + # remain user custom config + for env_name in env_file_vars: + env_val = env_vars[env_name] + if env_val and (val_striped := env_val.strip()): + # there's a value, decode that as JSON + try: + env_val = settings.__config__.parse_env_var(env_name, val_striped) + except ValueError as e: + logger.trace( + "Error while parsing JSON for " + f"{env_name!r}={val_striped!r}. " + "Assumed as string." + ) + + # explode value when it's a nested dict + env_name, *nested_keys = env_name.split(self.env_nested_delimiter) + if nested_keys and (env_name not in d or isinstance(d[env_name], dict)): + result = {} + *keys, last_key = nested_keys + _tmp = result + for key in keys: + _tmp = _tmp.setdefault(key, {}) + _tmp[last_key] = env_val + d[env_name] = deep_update(d.get(env_name, {}), result) + elif not nested_keys: d[env_name] = env_val return d @@ -106,6 +113,9 @@ class BaseConfig(BaseSettings): return self.__dict__.get(name) class Config: + extra = Extra.allow + env_nested_delimiter = "__" + @classmethod def customise_sources( cls, @@ -117,7 +127,10 @@ class BaseConfig(BaseSettings): return ( init_settings, CustomEnvSettings( - env_settings.env_file, env_settings.env_file_encoding + env_settings.env_file, + env_settings.env_file_encoding, + env_settings.env_nested_delimiter, + env_settings.env_prefix_len, ), InitSettingsSource(common_config), file_secret_settings, @@ -137,7 +150,6 @@ class Env(BaseConfig): """ class Config: - extra = "allow" env_file = ".env" @@ -150,8 +162,7 @@ class Config(BaseConfig): 配置方法参考: [配置](https://v2.nonebot.dev/docs/tutorial/configuration) """ - _env_file: str = ".env" - _common_config: Dict[str, Any] = {} + _env_file: DotenvType = ".env", ".env.prod" # nonebot configs driver: str = "~fastapi" @@ -231,8 +242,7 @@ class Config(BaseConfig): # or from env file using json loads class Config: - extra = "allow" - env_file = ".env.prod" + env_file = ".env", ".env.prod" __autodoc__ = { diff --git a/poetry.lock b/poetry.lock index 6d96b29a..7bcb1ea2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -115,11 +115,11 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (> [[package]] name = "black" -version = "22.8.0" +version = "22.10.0" description = "The uncompromising code formatter." category = "dev" optional = false -python-versions = ">=3.6.2" +python-versions = ">=3.7" [package.dependencies] click = ">=8.0.0" @@ -920,7 +920,7 @@ python-versions = ">=3.6,<4.0" [[package]] name = "typing-extensions" -version = "4.3.0" +version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false @@ -1076,7 +1076,7 @@ websockets = ["websockets"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "a4a8da0510758a7da5a2c941518505d44099bad4d77ccef30812e5e73e8a9f7f" +content-hash = "9a5abe22ecaaa43b8e124bb2c01001da35f6243f3a827282ec4959912eeb5745" [metadata.files] aiodns = [ @@ -1571,10 +1571,7 @@ tomli = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] tomlkit = [] -typing-extensions = [ - {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"}, - {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"}, -] +typing-extensions = [] urllib3 = [] uvicorn = [] uvloop = [] diff --git a/pyproject.toml b/pyproject.toml index c931efea..f9f6aad2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ tomlkit = ">=0.10.0,<1.0.0" typing-extensions = ">=3.10.0,<5.0.0" Quart = { version = "^0.17.0", optional = true } websockets = { version="^10.0", optional = true } -pydantic = { version = "^1.9.0", extras = ["dotenv"] } +pydantic = { version = "^1.10.0", extras = ["dotenv"] } uvicorn = { version = "^0.18.0", extras = ["standard"] } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"], optional = true } diff --git a/tests/.env b/tests/.env index 46048df6..87410a28 100644 --- a/tests/.env +++ b/tests/.env @@ -1,2 +1,3 @@ ENVIRONMENT=test COMMON_CONFIG=common +COMMON_OVERRIDE=old diff --git a/tests/.env.test b/tests/.env.test index ba30adf9..e4fba171 100644 --- a/tests/.env.test +++ b/tests/.env.test @@ -1,5 +1,13 @@ LOG_LEVEL=TRACE NICKNAME=["test"] SUPERUSERS=["test", "fake:faketest"] +COMMON_OVERRIDE=new CONFIG_FROM_ENV= CONFIG_OVERRIDE=old +NESTED_DICT={"a": 1} +NESTED_DICT__B=2 +NESTED_DICT__C__D=3 +NESTED_MISSING_DICT__A=1 +NESTED_MISSING_DICT__B__C=2 +NOT_NESTED=some string +NOT_NESTED__A=1 diff --git a/tests/test_init.py b/tests/test_init.py index e6c7fae1..3d49d5ce 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -29,6 +29,10 @@ async def test_init(nonebug_init): assert config.config_override == "new" assert config.config_from_init == "init" assert config.common_config == "common" + assert config.common_override == "new" + assert config.nested_dict == {"a": 1, "b": 2, "c": {"d": 3}} + assert config.nested_missing_dict == {"a": 1, "b": {"c": 2}} + assert config.not_nested == "some string" @pytest.mark.asyncio