🏗️ change code structure

This commit is contained in:
yanyongyu 2021-11-16 18:30:16 +08:00
parent d1c6eeb6c2
commit 4cbdd726e5
19 changed files with 276 additions and 226 deletions

View File

@ -0,0 +1,12 @@
\-\-\-
contentSidebar: true
sidebarDepth: 0
\-\-\-
NoneBot.handler 模块
====================
.. automodule:: nonebot.dependencies
:members:
:private-members:
:show-inheritance:

View File

@ -278,6 +278,7 @@ def run(host: Optional[str] = None,
get_driver().run(host, port, *args, **kwargs) get_driver().run(host, port, *args, **kwargs)
import nonebot.params as params
from nonebot.plugin import export as export from nonebot.plugin import export as export
from nonebot.plugin import require as require from nonebot.plugin import require as require
from nonebot.plugin import on_regex as on_regex from nonebot.plugin import on_regex as on_regex

View File

@ -55,16 +55,6 @@ class Bot(abc.ABC):
def __getattr__(self, name: str) -> _ApiCall: def __getattr__(self, name: str) -> _ApiCall:
return partial(self.call_api, name) return partial(self.call_api, name)
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not isinstance(v, cls):
raise TypeError(f"{v} is not an instance of {cls}")
return v
@property @property
@abc.abstractmethod @abc.abstractmethod
def type(self) -> str: def type(self) -> str:

View File

@ -1,3 +1,10 @@
"""
依赖注入处理模块
===============
该模块实现了依赖注入的定义与处理
"""
import inspect import inspect
from itertools import chain from itertools import chain
from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
@ -7,33 +14,38 @@ from pydantic import BaseConfig
from pydantic.fields import Required, ModelField from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info from pydantic.schema import get_annotation_from_field_info
from .models import Dependent
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import T_State from .models import Param as Param
from .utils import get_typed_signature from .utils import get_typed_signature
from nonebot.adapters import Bot, Event from .models import Dependent as Dependent
from .models import Depends as DependsClass from .models import DependsWrapper as DependsWrapper
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable) is_async_gen_callable, is_coroutine_callable)
def get_param_sub_dependent(*, param: inspect.Parameter) -> Dependent: class CustomConfig(BaseConfig):
depends: DependsClass = param.default arbitrary_types_allowed = True
def get_param_sub_dependent(
*,
param: inspect.Parameter,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
depends: DependsWrapper = param.default
if depends.dependency: if depends.dependency:
dependency = depends.dependency dependency = depends.dependency
else: else:
dependency = param.annotation dependency = param.annotation
return get_sub_dependant( return get_sub_dependant(depends=depends,
depends=depends,
dependency=dependency, dependency=dependency,
name=param.name, name=param.name,
) allow_types=allow_types)
def get_parameterless_sub_dependant( def get_parameterless_sub_dependant(
*, *,
depends: DependsClass, depends: DependsWrapper,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
assert callable( assert callable(
depends.dependency depends.dependency
), "A parameter-less dependency must have a callable dependency" ), "A parameter-less dependency must have a callable dependency"
@ -44,10 +56,10 @@ def get_parameterless_sub_dependant(
def get_sub_dependant( def get_sub_dependant(
*, *,
depends: DependsClass, depends: DependsWrapper,
dependency: Callable[..., Any], dependency: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
sub_dependant = get_dependent(func=dependency, sub_dependant = get_dependent(func=dependency,
name=name, name=name,
use_cache=depends.use_cache, use_cache=depends.use_cache,
@ -55,32 +67,32 @@ def get_sub_dependant(
return sub_dependant return sub_dependant
def get_dependent( def get_dependent(*,
*,
func: Callable[..., Any], func: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
use_cache: bool = True, use_cache: bool = True,
allow_types: Optional[List["ParamTypes"]] = None) -> Dependent: allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
signature = get_typed_signature(func) signature = get_typed_signature(func)
params = signature.parameters params = signature.parameters
allow_types = allow_types or [ dependent = Dependent(func=func,
ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE name=name,
] allow_types=allow_types,
dependent = Dependent(func=func, name=name, use_cache=use_cache) use_cache=use_cache)
for param_name, param in params.items(): for param_name, param in params.items():
if isinstance(param.default, DependsClass): if isinstance(param.default, DependsWrapper):
sub_dependent = get_param_sub_dependent(param=param) sub_dependent = get_param_sub_dependent(param=param,
allow_types=allow_types)
dependent.dependencies.append(sub_dependent) dependent.dependencies.append(sub_dependent)
continue continue
for allow_type in allow_types: for allow_type in dependent.allow_types:
field_info_class: Type[Param] = allow_type.value if allow_type._check(param_name, param):
if field_info_class._check(param_name, param): field_info = allow_type(param.default)
field_info = field_info_class(param.default)
break break
else: else:
raise ValueError( raise ValueError(
f"Unknown parameter {param_name} with type {param.annotation}") f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}"
)
annotation: Any = Any annotation: Any = Any
if param.annotation != param.empty: if param.annotation != param.empty:
@ -91,7 +103,7 @@ def get_dependent(
ModelField(name=param_name, ModelField(name=param_name,
type_=annotation, type_=annotation,
class_validators=None, class_validators=None,
model_config=BaseConfig, model_config=CustomConfig,
default=Required, default=Required,
required=True, required=True,
field_info=field_info)) field_info=field_info))
@ -102,15 +114,11 @@ def get_dependent(
async def solve_dependencies( async def solve_dependencies(
*, *,
dependent: Dependent, dependent: Dependent,
bot: Bot,
event: Event,
state: T_State,
matcher: Optional["Matcher"] = None,
exception: Optional[Exception] = None,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None, sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None, dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]: ) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {} dependency_cache = dependency_cache or {}
@ -135,18 +143,15 @@ async def solve_dependencies(
use_sub_dependant = get_dependent( use_sub_dependant = get_dependent(
func=func, func=func,
name=sub_dependent.name, name=sub_dependent.name,
allow_types=sub_dependent.allow_types,
) )
# solve sub dependency with current cache # solve sub dependency with current cache
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
dependent=use_sub_dependant, dependent=use_sub_dependant,
bot=bot,
event=event,
state=state,
matcher=matcher,
dependency_overrides_provider=dependency_overrides_provider, dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache, dependency_cache=dependency_cache,
) **params)
sub_values, sub_dependency_cache, ignored = solved_result sub_values, sub_dependency_cache, ignored = solved_result
if ignored: if ignored:
return values, dependency_cache, True return values, dependency_cache, True
@ -182,18 +187,13 @@ async def solve_dependencies(
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, assert isinstance(field_info,
Param), "Params must be subclasses of Param" Param), "Params must be subclasses of Param"
value = field_info._solve(bot=bot, value = field_info._solve(**params)
event=event,
state=state,
matcher=matcher,
exception=exception)
_, errs_ = field.validate(value, _, errs_ = field.validate(value,
values, values,
loc=(ParamTypes(type(field_info)).name, loc=(str(field_info), field.alias))
field.alias))
if errs_: if errs_:
logger.debug( logger.debug(
f"Matcher {matcher} {ParamTypes(type(field_info)).name} " f"{field_info} "
f"type {type(value)} not match depends {dependent.func} " f"type {type(value)} not match depends {dependent.func} "
f"annotation {field._type_display()}, ignored") f"annotation {field._type_display()}, ignored")
return values, dependency_cache, True return values, dependency_cache, True
@ -206,11 +206,14 @@ async def solve_dependencies(
def Depends(dependency: Optional[Callable[..., Any]] = None, def Depends(dependency: Optional[Callable[..., Any]] = None,
*, *,
use_cache: bool = True) -> Any: use_cache: bool = True) -> Any:
return DependsClass(dependency=dependency, use_cache=use_cache) """
:说明:
参数依赖注入装饰器
from .params import Param :参数:
from .handler import Handler as Handler
from .matcher import Matcher as Matcher * ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数默认为参数的类型注释
from .matcher import matchers as matchers * ``use_cache: bool = True``: 是否使用缓存默认为 ``True``
from .params import ParamTypes as ParamTypes """
return DependsWrapper(dependency=dependency, use_cache=use_cache)

View File

@ -1,11 +1,31 @@
from typing import Any, List, Callable, Optional import abc
import inspect
from typing import Any, List, Type, Callable, Optional
from pydantic.fields import ModelField from pydantic.fields import FieldInfo, ModelField
from nonebot.utils import get_name from nonebot.utils import get_name
class Depends: class Param(FieldInfo, abc.ABC):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def __str__(self) -> str:
return repr(self)
@classmethod
@abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
raise NotImplementedError
@abc.abstractmethod
def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
class DependsWrapper:
def __init__(self, def __init__(self,
dependency: Optional[Callable[..., Any]] = None, dependency: Optional[Callable[..., Any]] = None,
@ -27,11 +47,13 @@ class Dependent:
func: Optional[Callable[..., Any]] = None, func: Optional[Callable[..., Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
params: Optional[List[ModelField]] = None, params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None,
dependencies: Optional[List["Dependent"]] = None, dependencies: Optional[List["Dependent"]] = None,
use_cache: bool = True) -> None: use_cache: bool = True) -> None:
self.func = func self.func = func
self.name = name self.name = name
self.params = params or [] self.params = params or []
self.allow_types = allow_types or []
self.dependencies = dependencies or [] self.dependencies = dependencies or []
self.use_cache = use_cache self.use_cache = use_cache
self.cache_key = self.func self.cache_key = self.func

View File

@ -1,6 +1,5 @@
import inspect import inspect
from typing import Any, Dict, Type, Tuple, Union, Callable from typing import Any, Dict, Callable
from typing_extensions import GenericAlias, get_args, get_origin # type: ignore
from loguru import logger from loguru import logger
from pydantic.typing import ForwardRef, evaluate_forwardref from pydantic.typing import ForwardRef, evaluate_forwardref
@ -34,26 +33,3 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str,
) )
return inspect.Parameter.empty return inspect.Parameter.empty
return annotation return annotation
def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any],
...]]) -> bool:
try:
return issubclass(cls, class_or_tuple)
except TypeError:
if get_origin(cls) is Union:
for type_ in get_args(cls):
if not generic_check_issubclass(type_, class_or_tuple):
return False
return True
elif isinstance(cls, GenericAlias):
origin = get_origin(cls)
return bool(origin and issubclass(origin, class_or_tuple))
raise
def generic_get_types(cls: Any) -> Tuple[Type[Any], ...]:
if get_origin(cls) is Union:
return get_args(cls)
return (cls,)

View File

@ -7,83 +7,94 @@
import asyncio import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional from typing import TYPE_CHECKING, Any, Dict, List, Type, Callable, Optional
from .models import Depends, Dependent from nonebot.typing import T_Handler
from nonebot.utils import get_name, run_sync from nonebot.utils import get_name, run_sync
from nonebot.typing import T_State, T_Handler from nonebot.dependencies import (Param, Dependent, DependsWrapper,
from . import get_dependent, solve_dependencies, get_parameterless_sub_dependant get_dependent, solve_dependencies,
get_parameterless_sub_dependant)
if TYPE_CHECKING: if TYPE_CHECKING:
from .matcher import Matcher from nonebot.matcher import Matcher
from .params import ParamTypes
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
class Handler: class Handler:
"""事件处理函数类""" """事件处理器类。支持依赖注入。"""
def __init__(self, def __init__(self,
func: T_Handler, func: T_Handler,
*, *,
name: Optional[str] = None, name: Optional[str] = None,
dependencies: Optional[List[Depends]] = None, dependencies: Optional[List[DependsWrapper]] = None,
allow_types: Optional[List["ParamTypes"]] = None, allow_types: Optional[List[Type[Param]]] = None,
dependency_overrides_provider: Optional[Any] = None): dependency_overrides_provider: Optional[Any] = None):
"""装饰事件处理函数以便根据动态参数运行""" """
self.func: T_Handler = func :说明:
装饰一个函数为事件处理器
:参数:
* ``func: T_Handler``: 事件处理函数
* ``name: Optional[str]``: 事件处理器名称默认为函数名
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型
* ``dependency_overrides_provider: Optional[Any]``: 依赖注入覆盖提供者
"""
self.func = func
""" """
:类型: ``T_Handler`` :类型: ``T_Handler``
:说明: 事件处理函数 :说明: 事件处理函数
""" """
self.name = get_name(func) if name is None else name self.name = get_name(func) if name is None else name
self.allow_types = allow_types """
:类型: ``str``
:说明: 事件处理函数名
"""
self.allow_types = allow_types or []
"""
:类型: ``List[Type[Param]]``
:说明: 事件处理器允许的参数类型
"""
self.dependencies = dependencies or [] self.dependencies = dependencies or []
"""
:类型: ``List[DependsWrapper]``
:说明: 事件处理器的额外依赖
"""
self.sub_dependents: Dict[Callable[..., Any], Dependent] = {} self.sub_dependents: Dict[Callable[..., Any], Dependent] = {}
if dependencies: if dependencies:
for depends in dependencies: for depends in dependencies:
if not depends.dependency: self.cache_dependent(depends)
raise ValueError(f"{depends} has no dependency")
if depends.dependency in self.sub_dependents:
raise ValueError(f"{depends} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant(
depends=depends, allow_types=self.allow_types)
self.sub_dependents[depends.dependency] = sub_dependant
self.dependency_overrides_provider = dependency_overrides_provider self.dependency_overrides_provider = dependency_overrides_provider
self.dependent = get_dependent(func=func, allow_types=self.allow_types) self.dependent = get_dependent(func=func, allow_types=self.allow_types)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<Handler {self.func}({', '.join(map(str, self.dependent.params))})>" f"<Handler {self.name}({', '.join(map(str, self.dependent.params))})>"
) )
def __str__(self) -> str: def __str__(self) -> str:
return repr(self) return repr(self)
async def __call__( async def __call__(self,
self,
matcher: "Matcher",
bot: "Bot",
event: "Event",
state: T_State,
*, *,
stack: Optional[AsyncExitStack] = None, _stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any], _dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> Any: Any]] = None,
**params) -> Any:
values, _, ignored = await solve_dependencies( values, _, ignored = await solve_dependencies(
dependent=self.dependent, dependent=self.dependent,
bot=bot, stack=_stack,
event=event,
state=state,
matcher=matcher,
stack=stack,
sub_dependents=[ sub_dependents=[
self.sub_dependents[dependency.dependency] # type: ignore self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies for dependency in self.dependencies
], ],
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
dependency_cache=dependency_cache) dependency_cache=_dependency_cache,
**params)
if ignored: if ignored:
return return
@ -93,24 +104,24 @@ class Handler:
else: else:
await run_sync(self.func)(**values) await run_sync(self.func)(**values)
def cache_dependent(self, dependency: Depends): def cache_dependent(self, dependency: DependsWrapper):
if not dependency.dependency: if not dependency.dependency:
raise ValueError(f"{dependency} has no dependency") raise ValueError(f"{dependency} has no dependency")
if (dependency.dependency,) in self.sub_dependents: if dependency.dependency in self.sub_dependents:
raise ValueError(f"{dependency} is already in dependencies") raise ValueError(f"{dependency} is already in dependencies")
sub_dependant = get_parameterless_sub_dependant( sub_dependant = get_parameterless_sub_dependant(
depends=dependency, allow_types=self.allow_types) depends=dependency, allow_types=self.allow_types)
self.sub_dependents[dependency.dependency] = sub_dependant self.sub_dependents[dependency.dependency] = sub_dependant
def prepend_dependency(self, dependency: Depends): def prepend_dependency(self, dependency: DependsWrapper):
self.cache_dependent(dependency) self.cache_dependent(dependency)
self.dependencies.insert(0, dependency) self.dependencies.insert(0, dependency)
def append_dependency(self, dependency: Depends): def append_dependency(self, dependency: DependsWrapper):
self.cache_dependent(dependency) self.cache_dependent(dependency)
self.dependencies.append(dependency) self.dependencies.append(dependency)
def remove_dependency(self, dependency: Depends): def remove_dependency(self, dependency: DependsWrapper):
if not dependency.dependency: if not dependency.dependency:
raise ValueError(f"{dependency} has no dependency") raise ValueError(f"{dependency} has no dependency")
if dependency.dependency in self.sub_dependents: if dependency.dependency in self.sub_dependents:

View File

@ -12,12 +12,11 @@ from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
NoReturn, Optional) NoReturn, Optional)
from .models import Depends
from .handler import Handler
from nonebot.rule import Rule from nonebot.rule import Rule
from .params import ParamTypes
from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment, from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate) MessageTemplate)
@ -155,7 +154,8 @@ class Matcher(metaclass=MatcherMeta):
""" """
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [
ParamTypes.BOT, ParamTypes.EVENT, ParamTypes.STATE, ParamTypes.MATCHER params.BotParam, params.EventParam, params.StateParam,
params.MatcherParam
] ]
def __init__(self): def __init__(self):
@ -350,9 +350,10 @@ class Matcher(metaclass=MatcherMeta):
return func return func
@classmethod @classmethod
def append_handler(cls, def append_handler(
cls,
handler: T_Handler, handler: T_Handler,
dependencies: Optional[List[Depends]] = None) -> Handler: dependencies: Optional[List[DependsWrapper]] = None) -> Handler:
handler_ = Handler(handler, handler_ = Handler(handler,
dependencies=dependencies, dependencies=dependencies,
dependency_overrides_provider=get_driver(), dependency_overrides_provider=get_driver(),
@ -398,7 +399,7 @@ class Matcher(metaclass=MatcherMeta):
def _decorator(func: T_Handler) -> T_Handler: def _decorator(func: T_Handler) -> T_Handler:
depend = Depends(_receive) depend = DependsWrapper(_receive)
if cls.handlers and cls.handlers[-1].func is func: if cls.handlers and cls.handlers[-1].func is func:
func_handler = cls.handlers[-1] func_handler = cls.handlers[-1]
@ -461,8 +462,8 @@ class Matcher(metaclass=MatcherMeta):
def _decorator(func: T_Handler) -> T_Handler: def _decorator(func: T_Handler) -> T_Handler:
get_depend = Depends(_key_getter) get_depend = DependsWrapper(_key_getter)
parser_depend = Depends(_key_parser) parser_depend = DependsWrapper(_key_parser)
if cls.handlers and cls.handlers[-1].func is func: if cls.handlers and cls.handlers[-1].func is func:
func_handler = cls.handlers[-1] func_handler = cls.handlers[-1]
@ -600,7 +601,10 @@ class Matcher(metaclass=MatcherMeta):
while self.handlers: while self.handlers:
handler = self.handlers.pop(0) handler = self.handlers.pop(0)
logger.debug(f"Running handler {handler}") logger.debug(f"Running handler {handler}")
await handler(self, bot, event, self.state) await handler(matcher=self,
bot=bot,
event=event,
state=self.state)
except RejectedException: except RejectedException:
self.handlers.insert(0, handler) # type: ignore self.handlers.insert(0, handler) # type: ignore

View File

@ -12,8 +12,10 @@ from typing import TYPE_CHECKING, Set, Type
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import TrieRule from nonebot.rule import TrieRule
from nonebot.handler import Handler
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.processor import Matcher, matchers from nonebot import params, get_driver
from nonebot.matcher import Matcher, matchers
from nonebot.exception import NoLogException, StopPropagation, IgnoredException from nonebot.exception import NoLogException, StopPropagation, IgnoredException
from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor, from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor,
T_EventPreProcessor, T_EventPostProcessor) T_EventPreProcessor, T_EventPostProcessor)
@ -21,10 +23,19 @@ from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor,
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
_event_preprocessors: Set[T_EventPreProcessor] = set() _event_preprocessors: Set[Handler] = set()
_event_postprocessors: Set[T_EventPostProcessor] = set() _event_postprocessors: Set[Handler] = set()
_run_preprocessors: Set[T_RunPreProcessor] = set() _run_preprocessors: Set[Handler] = set()
_run_postprocessors: Set[T_RunPostProcessor] = set() _run_postprocessors: Set[Handler] = set()
EVENT_PCS_PARAMS = [params.BotParam, params.EventParam, params.StateParam]
RUN_PREPCS_PARAMS = [
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam
]
RUN_POSTPCS_PARAMS = [
params.MatcherParam, params.ExceptionParam, params.BotParam,
params.EventParam, params.StateParam
]
def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor: def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor:
@ -41,7 +52,10 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor:
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
* ``state: T_State``: 当前 State * ``state: T_State``: 当前 State
""" """
_event_preprocessors.add(func) _event_preprocessors.add(
Handler(func,
allow_types=EVENT_PCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -59,7 +73,10 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor:
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
* ``state: T_State``: 当前事件运行前 State * ``state: T_State``: 当前事件运行前 State
""" """
_event_postprocessors.add(func) _event_postprocessors.add(
Handler(func,
allow_types=EVENT_PCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -78,7 +95,10 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor:
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
* ``state: T_State``: 当前 State * ``state: T_State``: 当前 State
""" """
_run_preprocessors.add(func) _run_preprocessors.add(
Handler(func,
allow_types=RUN_PREPCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -98,7 +118,10 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
* ``state: T_State``: 当前 State * ``state: T_State``: 当前 State
""" """
_run_postprocessors.add(func) _run_postprocessors.add(
Handler(func,
allow_types=RUN_POSTPCS_PARAMS,
dependency_overrides_provider=get_driver()))
return func return func
@ -136,7 +159,8 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
matcher = Matcher() matcher = Matcher()
coros = list( coros = list(
map(lambda x: x(matcher, bot, event, state), _run_preprocessors)) map(lambda x: x(matcher=matcher, bot=bot, event=event, state=state),
_run_preprocessors))
if coros: if coros:
try: try:
await asyncio.gather(*coros) await asyncio.gather(*coros)
@ -162,8 +186,12 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
exception = e exception = e
coros = list( coros = list(
map(lambda x: x(matcher, exception, bot, event, state), map(
_run_postprocessors)) lambda x: x(matcher=matcher,
exception=exception,
bot=bot,
event=event,
state=state), _run_postprocessors))
if coros: if coros:
try: try:
await asyncio.gather(*coros) await asyncio.gather(*coros)
@ -208,7 +236,9 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
# TODO # TODO
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
coros = list(map(lambda x: x(bot, event, state), _event_preprocessors)) coros = list(
map(lambda x: x(bot=bot, event=event, state=state),
_event_preprocessors))
if coros: if coros:
try: try:
if show_log: if show_log:
@ -255,7 +285,9 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
"<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>" "<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
) )
coros = list(map(lambda x: x(bot, event, state), _event_postprocessors)) coros = list(
map(lambda x: x(bot=bot, event=event, state=state),
_event_postprocessors))
if coros: if coros:
try: try:
if show_log: if show_log:

View File

@ -1,38 +1,19 @@
import abc
import inspect import inspect
from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic.fields import FieldInfo
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.dependencies import Param
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from .utils import generic_check_issubclass from nonebot.utils import generic_check_issubclass
class Param(FieldInfo, abc.ABC):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def __str__(self) -> str:
return repr(self)
@classmethod
@abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
raise NotImplementedError
@abc.abstractmethod
def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
class BotParam(Param): class BotParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Bot) return generic_check_issubclass(
param.annotation, Bot) or (param.annotation == param.empty and
name == "bot")
def _solve(self, bot: Bot, **kwargs: Any) -> Any: def _solve(self, bot: Bot, **kwargs: Any) -> Any:
return bot return bot
@ -42,7 +23,9 @@ class EventParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Event) return generic_check_issubclass(
param.annotation, Event) or (param.annotation == param.empty and
name == "event")
def _solve(self, event: Event, **kwargs: Any) -> Any: def _solve(self, event: Event, **kwargs: Any) -> Any:
return event return event
@ -52,7 +35,9 @@ class StateParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Dict) return generic_check_issubclass(
param.annotation, Dict) or (param.annotation == param.empty and
name == "state")
def _solve(self, state: T_State, **kwargs: Any) -> Any: def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state return state
@ -62,7 +47,9 @@ class MatcherParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Matcher) return generic_check_issubclass(
param.annotation, Matcher) or (param.annotation == param.empty and
name == "matcher")
def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any: def _solve(self, matcher: Optional["Matcher"] = None, **kwargs: Any) -> Any:
return matcher return matcher
@ -72,7 +59,9 @@ class ExceptionParam(Param):
@classmethod @classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:
return generic_check_issubclass(param.annotation, Exception) return generic_check_issubclass(
param.annotation, Exception) or (param.annotation == param.empty and
name == "exception")
def _solve(self, def _solve(self,
exception: Optional[Exception] = None, exception: Optional[Exception] = None,
@ -80,12 +69,4 @@ class ExceptionParam(Param):
return exception return exception
class ParamTypes(Enum): from nonebot.matcher import Matcher
BOT = BotParam
EVENT = EventParam
STATE = StateParam
MATCHER = MatcherParam
EXCEPTION = ExceptionParam
from .matcher import Matcher

View File

@ -4,10 +4,11 @@ import inspect
from types import ModuleType from types import ModuleType
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
from nonebot.handler import Handler
from nonebot.matcher import Matcher
from .manager import _current_plugin from .manager import _current_plugin
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.processor import Handler, Matcher
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword,
endswith, startswith, shell_command) endswith, startswith, shell_command)

View File

@ -1,9 +1,10 @@
import re import re
from typing import Set, List, Type, Tuple, Union, Optional from typing import Set, List, Type, Tuple, Union, Optional
from nonebot.handler import Handler
from nonebot.matcher import Matcher
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.rule import Rule, ArgumentParser from nonebot.rule import Rule, ArgumentParser
from nonebot.processor import Handler, Matcher
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory

View File

@ -3,7 +3,7 @@ from dataclasses import field, dataclass
from typing import Set, Dict, Type, Optional from typing import Set, Dict, Type, Optional
from .export import Export from .export import Export
from nonebot.processor import Matcher from nonebot.matcher import Matcher
plugins: Dict[str, "Plugin"] = {} plugins: Dict[str, "Plugin"] = {}
""" """

View File

@ -3,14 +3,14 @@ from functools import reduce
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.plugin import on_command from nonebot.plugin import on_command
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.adapters.cqhttp import (Bot, Message, MessageEvent, MessageSegment, from nonebot.adapters.cqhttp import (Message, MessageEvent, MessageSegment,
unescape) unescape)
say = on_command("say", to_me(), permission=SUPERUSER) say = on_command("say", to_me(), permission=SUPERUSER)
@say.handle() @say.handle()
async def say_unescape(bot: Bot, event: MessageEvent): async def say_unescape(event: MessageEvent):
def _unescape(message: Message, segment: MessageSegment): def _unescape(message: Message, segment: MessageSegment):
if segment.is_text(): if segment.is_text():
@ -18,12 +18,12 @@ async def say_unescape(bot: Bot, event: MessageEvent):
return message.append(segment) return message.append(segment)
message = reduce(_unescape, event.get_message(), Message()) # type: ignore message = reduce(_unescape, event.get_message(), Message()) # type: ignore
await bot.send(message=message, event=event) await say.send(message=message)
echo = on_command("echo", to_me()) echo = on_command("echo", to_me())
@echo.handle() @echo.handle()
async def echo_escape(bot: Bot, event: MessageEvent): async def echo_escape(event: MessageEvent):
await bot.send(message=event.get_message(), event=event) await say.send(message=event.get_message())

View File

@ -1,7 +1,7 @@
from typing import Dict, Optional from typing import Dict, Optional
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.processor import Matcher from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.message import (IgnoredException, run_preprocessor, from nonebot.message import (IgnoredException, run_preprocessor,
run_postprocessor) run_postprocessor)
@ -10,7 +10,7 @@ _running_matcher: Dict[str, int] = {}
@run_preprocessor @run_preprocessor
async def preprocess(matcher: Matcher, bot: Bot, event: Event, state: T_State): async def preprocess(event: Event):
try: try:
session_id = event.get_session_id() session_id = event.get_session_id()
except Exception: except Exception:
@ -24,8 +24,7 @@ async def preprocess(matcher: Matcher, bot: Bot, event: Event, state: T_State):
@run_postprocessor @run_postprocessor
async def postprocess(matcher: Matcher, exception: Optional[Exception], async def postprocess(event: Event):
bot: Bot, event: Event, state: T_State):
try: try:
session_id = event.get_session_id() session_id = event.get_session_id()
except Exception: except Exception:

View File

@ -22,7 +22,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable,
NoReturn, Optional, Awaitable) NoReturn, Optional, Awaitable)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.processor import Matcher from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
@ -90,7 +90,7 @@ T_CalledAPIHook = Callable[
``bot.call_api`` 后执行的函数参数分别为 bot, exception, api, data, result ``bot.call_api`` 后执行的函数参数分别为 bot, exception, api, data, result
""" """
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] T_EventPreProcessor = Callable[..., Awaitable[None]]
""" """
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` :类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
@ -98,7 +98,7 @@ T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
事件预处理函数 EventPreProcessor 类型 事件预处理函数 EventPreProcessor 类型
""" """
T_EventPostProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] T_EventPostProcessor = Callable[..., Awaitable[None]]
""" """
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` :类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
@ -106,8 +106,7 @@ T_EventPostProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
事件预处理函数 EventPostProcessor 类型 事件预处理函数 EventPostProcessor 类型
""" """
T_RunPreProcessor = Callable[["Matcher", "Bot", "Event", T_State], T_RunPreProcessor = Callable[..., Awaitable[None]]
Awaitable[None]]
""" """
:类型: ``Callable[[Matcher, Bot, Event, T_State], Awaitable[None]]`` :类型: ``Callable[[Matcher, Bot, Event, T_State], Awaitable[None]]``
@ -115,8 +114,7 @@ T_RunPreProcessor = Callable[["Matcher", "Bot", "Event", T_State],
事件响应器运行前预处理函数 RunPreProcessor 类型 事件响应器运行前预处理函数 RunPreProcessor 类型
""" """
T_RunPostProcessor = Callable[ T_RunPostProcessor = Callable[..., Awaitable[None]]
["Matcher", Optional[Exception], "Bot", "Event", T_State], Awaitable[None]]
""" """
:类型: ``Callable[[Matcher, Optional[Exception], Bot, Event, T_State], Awaitable[None]]`` :类型: ``Callable[[Matcher, Optional[Exception], Bot, Event, T_State], Awaitable[None]]``

View File

@ -4,10 +4,11 @@ import asyncio
import inspect import inspect
import dataclasses import dataclasses
from functools import wraps, partial from functools import wraps, partial
from typing_extensions import ParamSpec
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import (Any, TypeVar, Callable, Optional, Awaitable, AsyncGenerator, from typing_extensions import GenericAlias # type: ignore
ContextManager) from typing_extensions import ParamSpec, get_args, get_origin
from typing import (Any, Type, Tuple, Union, TypeVar, Callable, Optional,
Awaitable, AsyncGenerator, ContextManager)
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import overrides from nonebot.typing import overrides
@ -34,6 +35,24 @@ def escape_tag(s: str) -> str:
return re.sub(r"</?((?:[fb]g\s)?[^<>\s]*)>", r"\\\g<0>", s) return re.sub(r"</?((?:[fb]g\s)?[^<>\s]*)>", r"\\\g<0>", s)
def generic_check_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any],
...]]) -> bool:
try:
return issubclass(cls, class_or_tuple)
except TypeError:
if get_origin(cls) is Union:
for type_ in get_args(cls):
if type_ is not type(None) and not generic_check_issubclass(
type_, class_or_tuple):
return False
return True
elif isinstance(cls, GenericAlias):
origin = get_origin(cls)
return bool(origin and issubclass(origin, class_or_tuple))
raise
def is_coroutine_callable(func: Callable[..., Any]) -> bool: def is_coroutine_callable(func: Callable[..., Any]) -> bool:
if inspect.isroutine(func): if inspect.isroutine(func):
return inspect.iscoroutinefunction(func) return inspect.iscoroutinefunction(func)

View File

@ -1,6 +1,6 @@
from nonebot import on_command from nonebot import on_command
from nonebot.log import logger from nonebot.log import logger
from nonebot.processor import Depends from nonebot.dependencies import Depends
test = on_command("123") test = on_command("123")

View File

@ -1,15 +1,15 @@
from nonebot.adapters import Event
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.processor import Matcher from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
from nonebot.message import run_preprocessor, event_preprocessor from nonebot.message import run_preprocessor, event_preprocessor
@event_preprocessor @event_preprocessor
async def handle(bot: Bot, event: Event, state: T_State): async def handle(event: Event, state: T_State):
state["preprocessed"] = True state["preprocessed"] = True
print(type(event), event) print(type(event), event)
@run_preprocessor @run_preprocessor
async def run(matcher: Matcher, bot: Bot, event: Event, state: T_State): async def run(matcher: Matcher):
print(matcher) print(matcher)