🔥 remove dependency override provider

This commit is contained in:
yanyongyu 2021-11-21 16:12:36 +08:00
parent a864b36e9f
commit 760ac693c0
15 changed files with 35 additions and 83 deletions

View File

@ -67,8 +67,7 @@ class CustomEnvSettings(EnvSettingsSource):
env_val = settings.__config__.json_loads(env_val) env_val = settings.__config__.json_loads(env_val)
except ValueError as e: except ValueError as e:
raise SettingsError( raise SettingsError(
f'error parsing JSON for "{env_name}"' # type: ignore f'error parsing JSON for "{env_name}"') from e
) from e
d[field.alias] = env_val d[field.alias] = env_val
if env_file_vars: if env_file_vars:

View File

@ -121,7 +121,6 @@ async def solve_dependencies(
_dependent: Dependent, _dependent: Dependent,
_stack: Optional[AsyncExitStack] = None, _stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None, _sub_dependents: Optional[List[Dependent]] = None,
_dependency_overrides_provider: Optional[Any] = None,
_dependency_cache: Optional[T_DependencyCache] = None, _dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]: **params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
@ -136,24 +135,9 @@ async def solve_dependencies(
sub_dependent.cache_key) sub_dependent.cache_key)
func = sub_dependent.func func = sub_dependent.func
# dependency overrides
use_sub_dependant = sub_dependent
if (_dependency_overrides_provider and hasattr(
_dependency_overrides_provider, "dependency_overrides")):
original_call = sub_dependent.func
func = getattr(_dependency_overrides_provider,
"dependency_overrides",
{}).get(original_call, original_call)
use_sub_dependant = get_dependent(
func=func,
name=sub_dependent.name,
allow_types=sub_dependent.allow_types,
)
# solve sub dependency with current cache # solve sub dependency with current cache
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
_dependent=use_sub_dependant, _dependent=sub_dependent,
_dependency_overrides_provider=_dependency_overrides_provider,
_dependency_cache=dependency_cache, _dependency_cache=dependency_cache,
**params) **params)
sub_values, sub_dependency_cache = solved_result sub_values, sub_dependency_cache = solved_result

View File

@ -1,6 +1,6 @@
import abc import abc
import inspect import inspect
from typing import Any, List, Type, Callable, Optional from typing import Any, List, Type, Optional
from pydantic.fields import FieldInfo, ModelField from pydantic.fields import FieldInfo, ModelField

View File

@ -40,11 +40,6 @@ class Driver(abc.ABC):
:类型: ``Set[T_BotDisconnectionHook]`` :类型: ``Set[T_BotDisconnectionHook]``
:说明: Bot 连接断开时执行的函数 :说明: Bot 连接断开时执行的函数
""" """
dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
"""
:类型: ``Dict[Callable[..., Any], Callable[..., Any]]``
:说明: Depends 函数的替换表
"""
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
""" """

View File

@ -23,8 +23,7 @@ class Handler:
*, *,
name: Optional[str] = None, name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None, dependencies: Optional[List[DependsWrapper]] = None,
allow_types: Optional[List[Type[Param]]] = None, allow_types: Optional[List[Type[Param]]] = None):
dependency_overrides_provider: Optional[Any] = None):
""" """
:说明: :说明:
@ -36,7 +35,6 @@ class Handler:
* ``name: Optional[str]``: 事件处理器名称默认为函数名 * ``name: Optional[str]``: 事件处理器名称默认为函数名
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入 * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型 * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型
* ``dependency_overrides_provider: Optional[Any]``: 依赖注入覆盖提供者
""" """
self.func = func self.func = func
""" """
@ -63,7 +61,6 @@ class Handler:
if dependencies: if dependencies:
for depends in dependencies: for depends in dependencies:
self.cache_dependent(depends) self.cache_dependent(depends)
self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func, allow_types=self.allow_types) self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str: def __repr__(self) -> str:
@ -87,7 +84,6 @@ class Handler:
self.sub_dependents[dependency.dependency] # type: ignore self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies for dependency in self.dependencies
], ],
_dependency_overrides_provider=self.dependency_overrides_provider,
_dependency_cache=_dependency_cache, _dependency_cache=_dependency_cache,
**params) **params)

View File

@ -13,10 +13,10 @@ from contextlib import AsyncExitStack
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
NoReturn, Optional) NoReturn, Optional)
from nonebot import params
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.log import logger from nonebot.log import logger
from nonebot.handler import Handler from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.dependencies import DependsWrapper from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment, from nonebot.adapters import (Bot, Event, Message, MessageSegment,
@ -238,9 +238,7 @@ class Matcher(metaclass=MatcherMeta):
permission or Permission(), permission or Permission(),
"handlers": [ "handlers": [
handler if isinstance(handler, Handler) else Handler( handler if isinstance(handler, Handler) else Handler(
handler, handler, allow_types=cls.HANDLER_PARAM_TYPES)
dependency_overrides_provider=get_driver(),
allow_types=cls.HANDLER_PARAM_TYPES)
for handler in handlers for handler in handlers
] if handlers else [], ] if handlers else [],
"temp": "temp":
@ -372,7 +370,6 @@ class Matcher(metaclass=MatcherMeta):
dependencies: Optional[List[DependsWrapper]] = None) -> Handler: dependencies: Optional[List[DependsWrapper]] = None) -> Handler:
handler_ = Handler(handler, handler_ = Handler(handler,
dependencies=dependencies, dependencies=dependencies,
dependency_overrides_provider=get_driver(),
allow_types=cls.HANDLER_PARAM_TYPES) allow_types=cls.HANDLER_PARAM_TYPES)
cls.handlers.append(handler_) cls.handlers.append(handler_)
return handler_ return handler_

View File

@ -8,13 +8,13 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Optional from typing import TYPE_CHECKING, Any, Set, Dict, Type, Optional
from nonebot import params
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import TrieRule from nonebot.rule import TrieRule
from nonebot.handler import Handler from nonebot.handler import Handler
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot import params, get_driver
from nonebot.matcher import Matcher, matchers from nonebot.matcher import Matcher, matchers
from nonebot.exception import NoLogException, StopPropagation, IgnoredException from nonebot.exception import NoLogException, StopPropagation, IgnoredException
from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor, from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor,
@ -45,10 +45,7 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor:
事件预处理装饰一个函数使它在每次接收到事件并分发给各响应器之前执行 事件预处理装饰一个函数使它在每次接收到事件并分发给各响应器之前执行
""" """
_event_preprocessors.add( _event_preprocessors.add(Handler(func, allow_types=EVENT_PCS_PARAMS))
Handler(func,
allow_types=EVENT_PCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -58,10 +55,7 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor:
事件后处理装饰一个函数使它在每次接收到事件并分发给各响应器之后执行 事件后处理装饰一个函数使它在每次接收到事件并分发给各响应器之后执行
""" """
_event_postprocessors.add( _event_postprocessors.add(Handler(func, allow_types=EVENT_PCS_PARAMS))
Handler(func,
allow_types=EVENT_PCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -71,10 +65,7 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor:
运行预处理装饰一个函数使它在每次事件响应器运行前执行 运行预处理装饰一个函数使它在每次事件响应器运行前执行
""" """
_run_preprocessors.add( _run_preprocessors.add(Handler(func, allow_types=RUN_PREPCS_PARAMS))
Handler(func,
allow_types=RUN_PREPCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -84,10 +75,7 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
运行后处理装饰一个函数使它在每次事件响应器运行后执行 运行后处理装饰一个函数使它在每次事件响应器运行后执行
""" """
_run_postprocessors.add( _run_postprocessors.add(Handler(func, allow_types=RUN_POSTPCS_PARAMS))
Handler(func,
allow_types=RUN_POSTPCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func

View File

@ -2,7 +2,7 @@ r"""
权限 权限
==== ====
每个 ``Matcher`` 拥有一个 ``Permission`` 其中是 **异步** ``PermissionChecker`` 的集合只要有一个 ``PermissionChecker`` 检查结果为 ``True`` 时就会继续运行 每个 ``Matcher`` 拥有一个 ``Permission`` 其中是 ``PermissionChecker`` 的集合只要有一个 ``PermissionChecker`` 检查结果为 ``True`` 时就会继续运行
\:\:\:tip 提示 \:\:\:tip 提示
``PermissionChecker`` 既可以是 async function 也可以是 sync function ``PermissionChecker`` 既可以是 async function 也可以是 sync function
@ -41,9 +41,7 @@ class Permission:
params.BotParam, params.EventParam params.BotParam, params.EventParam
] ]
def __init__(self, def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None:
*checkers: Union[T_PermissionChecker, Handler],
dependency_overrides_provider: Optional[Any] = None) -> None:
""" """
:参数: :参数:
@ -52,9 +50,7 @@ class Permission:
self.checkers = set( self.checkers = set(
checker if isinstance(checker, Handler) else Handler( checker if isinstance(checker, Handler) else Handler(
checker, checker, allow_types=self.HANDLER_PARAM_TYPES)
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers) for checker in checkers)
""" """
:说明: :说明:

View File

@ -26,4 +26,4 @@ echo = on_command("echo", to_me())
@echo.handle() @echo.handle()
async def echo_escape(event: MessageEvent): async def echo_escape(event: MessageEvent):
await say.send(message=event.get_message()) await echo.send(message=event.get_message())

View File

@ -1,8 +1,6 @@
from typing import Dict, Optional from typing import Dict
from nonebot.typing import T_State from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
from nonebot.message import (IgnoredException, run_preprocessor, from nonebot.message import (IgnoredException, run_preprocessor,
run_postprocessor) run_postprocessor)

View File

@ -1 +0,0 @@

View File

@ -2,10 +2,10 @@ r"""
规则 规则
==== ====
每个事件响应器 ``Matcher`` 拥有一个匹配规则 ``Rule`` 其中是 **异步** ``RuleChecker`` 的集合只有当所有 ``RuleChecker`` 检查结果为 ``True`` 时继续运行 每个事件响应器 ``Matcher`` 拥有一个匹配规则 ``Rule`` 其中是 ``RuleChecker`` 的集合只有当所有 ``RuleChecker`` 检查结果为 ``True`` 时继续运行
\:\:\:tip 提示 \:\:\:tip 提示
``RuleChecker`` 既可以是 async function 也可以是 sync function但在最终会被 ``nonebot.utils.run_sync`` 转换为 async function ``RuleChecker`` 既可以是 async function 也可以是 sync function
\:\:\: \:\:\:
""" """
@ -68,9 +68,7 @@ class Rule:
params.BotParam, params.EventParam, params.StateParam params.BotParam, params.EventParam, params.StateParam
] ]
def __init__(self, def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None:
*checkers: Union[T_RuleChecker, Handler],
dependency_overrides_provider: Optional[Any] = None) -> None:
""" """
:参数: :参数:
@ -79,9 +77,7 @@ class Rule:
""" """
self.checkers = set( self.checkers = set(
checker if isinstance(checker, Handler) else Handler( checker if isinstance(checker, Handler) else Handler(
checker, checker, allow_types=self.HANDLER_PARAM_TYPES)
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers) for checker in checkers)
""" """
:说明: :说明:

View File

@ -17,7 +17,7 @@
.. _typing: .. _typing:
https://docs.python.org/3/library/typing.html https://docs.python.org/3/library/typing.html
""" """
from collections.abc import Callable as BaseCallable
from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable, from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable,
NoReturn, Optional, Awaitable) NoReturn, Optional, Awaitable)
@ -25,7 +25,7 @@ if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
T_Wrapped = TypeVar("T_Wrapped", bound=BaseCallable) T_Wrapped = TypeVar("T_Wrapped", bound=Callable)
def overrides(InterfaceClass: object): def overrides(InterfaceClass: object):

View File

@ -5,13 +5,12 @@ from typing import TYPE_CHECKING, List, Type, Optional
from pydantic import BaseModel from pydantic import BaseModel
from pygtrie import StringTrie from pygtrie import StringTrie
from .message import Message
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.exception import NoLogException from .exception import NoLogException
from nonebot.adapters import Event as BaseEvent from nonebot.adapters import Event as BaseEvent
from .message import Message
if TYPE_CHECKING: if TYPE_CHECKING:
from .bot import Bot from .bot import Bot

View File

@ -1,9 +1,10 @@
from typing import Optional from typing import Optional
from nonebot.exception import (AdapterException, ActionFailed as from nonebot.exception import AdapterException
BaseActionFailed, NetworkError as from nonebot.exception import ActionFailed as BaseActionFailed
BaseNetworkError, ApiNotAvailable as from nonebot.exception import NetworkError as BaseNetworkError
BaseApiNotAvailable) from nonebot.exception import NoLogException as BaseNoLogException
from nonebot.exception import ApiNotAvailable as BaseApiNotAvailable
class CQHTTPAdapterException(AdapterException): class CQHTTPAdapterException(AdapterException):
@ -12,6 +13,10 @@ class CQHTTPAdapterException(AdapterException):
super().__init__("cqhttp") super().__init__("cqhttp")
class NoLogException(BaseNoLogException, CQHTTPAdapterException):
pass
class ActionFailed(BaseActionFailed, CQHTTPAdapterException): class ActionFailed(BaseActionFailed, CQHTTPAdapterException):
""" """
:说明: :说明: