diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index f9a6ab63..dd27ce9f 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -20,6 +20,7 @@ from .utils import get_typed_signature from .models import Dependent as Dependent from nonebot.exception import SkippedException from .models import DependsWrapper as DependsWrapper +from nonebot.typing import T_Handler, T_DependencyCache from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, is_async_gen_callable, is_coroutine_callable) @@ -58,7 +59,7 @@ def get_parameterless_sub_dependant( def get_sub_dependant( *, depends: DependsWrapper, - dependency: Callable[..., Any], + dependency: T_Handler, name: Optional[str] = None, allow_types: Optional[List[Type[Param]]] = None) -> Dependent: sub_dependant = get_dependent(func=dependency, @@ -69,7 +70,7 @@ def get_sub_dependant( def get_dependent(*, - func: Callable[..., Any], + func: T_Handler, name: Optional[str] = None, use_cache: bool = True, allow_types: Optional[List[Type[Param]]] = None) -> Dependent: @@ -118,8 +119,8 @@ async def solve_dependencies( _stack: Optional[AsyncExitStack] = None, _sub_dependents: Optional[List[Dependent]] = None, _dependency_overrides_provider: Optional[Any] = None, - _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, - **params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]: + _dependency_cache: Optional[T_DependencyCache] = None, + **params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]: values: Dict[str, Any] = {} dependency_cache = _dependency_cache or {} @@ -201,7 +202,7 @@ async def solve_dependencies( return values, dependency_cache -def Depends(dependency: Optional[Callable[..., Any]] = None, +def Depends(dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> Any: """ diff --git a/nonebot/dependencies/models.py b/nonebot/dependencies/models.py index ca764f9b..4431d11c 100644 --- a/nonebot/dependencies/models.py +++ b/nonebot/dependencies/models.py @@ -5,9 +5,10 @@ from typing import Any, List, Type, Callable, Optional from pydantic.fields import FieldInfo, ModelField from nonebot.utils import get_name +from nonebot.typing import T_Handler -class Param(FieldInfo, abc.ABC): +class Param(abc.ABC, FieldInfo): def __repr__(self) -> str: return f"{self.__class__.__name__}" @@ -28,7 +29,7 @@ class Param(FieldInfo, abc.ABC): class DependsWrapper: def __init__(self, - dependency: Optional[Callable[..., Any]] = None, + dependency: Optional[T_Handler] = None, *, use_cache: bool = True) -> None: self.dependency = dependency @@ -44,7 +45,7 @@ class Dependent: def __init__(self, *, - func: Optional[Callable[..., Any]] = None, + func: Optional[T_Handler] = None, name: Optional[str] = None, params: Optional[List[ModelField]] = None, allow_types: Optional[List[Type[Param]]] = None, diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index ed08b228..ade7fd97 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -1,13 +1,15 @@ import inspect -from typing import Any, Dict, Callable +from typing import Any, Dict from loguru import logger from pydantic.typing import ForwardRef, evaluate_forwardref +from nonebot.typing import T_Handler -def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: - signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) + +def get_typed_signature(func: T_Handler) -> inspect.Signature: + signature = inspect.signature(func) + globalns = getattr(func, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 0967d5c3..f68c78de 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -9,6 +9,7 @@ from types import ModuleType from datetime import datetime from contextvars import ContextVar from collections import defaultdict +from contextlib import AsyncExitStack from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable, NoReturn, Optional) @@ -20,11 +21,12 @@ from nonebot.dependencies import DependsWrapper from nonebot.permission import USER, Permission from nonebot.adapters import (Bot, Event, Message, MessageSegment, MessageTemplate) -from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, - T_StateFactory, T_PermissionUpdater) from nonebot.exception import (PausedException, StopPropagation, SkippedException, FinishedException, RejectedException) +from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, + T_StateFactory, T_DependencyCache, + T_PermissionUpdater) if TYPE_CHECKING: from nonebot.plugin import Plugin @@ -267,7 +269,13 @@ class Matcher(metaclass=MatcherMeta): return NewMatcher @classmethod - async def check_perm(cls, bot: Bot, event: Event) -> bool: + async def check_perm( + cls, + bot: Bot, + event: Event, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[Dict[Callable[..., Any], + Any]] = None) -> bool: """ :说明: @@ -284,10 +292,17 @@ class Matcher(metaclass=MatcherMeta): """ event_type = event.get_type() return (event_type == (cls.type or event_type) and - await cls.permission(bot, event)) + await cls.permission(bot, event, stack, dependency_cache)) @classmethod - async def check_rule(cls, bot: Bot, event: Event, state: T_State) -> bool: + async def check_rule( + cls, + bot: Bot, + event: Event, + state: T_State, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[Dict[Callable[..., Any], + Any]] = None) -> bool: """ :说明: @@ -305,7 +320,7 @@ class Matcher(metaclass=MatcherMeta): """ event_type = event.get_type() return (event_type == (cls.type or event_type) and - await cls.rule(bot, event, state)) + await cls.rule(bot, event, state, stack, dependency_cache)) @classmethod def args_parser(cls, func: T_ArgsParser) -> T_ArgsParser: @@ -589,7 +604,12 @@ class Matcher(metaclass=MatcherMeta): self.block = True # 运行handlers - async def run(self, bot: Bot, event: Event, state: T_State): + async def run(self, + bot: Bot, + event: Event, + state: T_State, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[T_DependencyCache] = None): b_t = current_bot.set(bot) e_t = current_event.set(event) s_t = current_state.set(self.state) @@ -606,7 +626,9 @@ class Matcher(metaclass=MatcherMeta): await handler(matcher=self, bot=bot, event=event, - state=self.state) + state=self.state, + _stack=stack, + _dependency_cache=dependency_cache) except SkippedException: pass @@ -624,11 +646,8 @@ class Matcher(metaclass=MatcherMeta): updater = self.__class__._default_permission_updater if updater: - permission = await updater( - bot, - event, - self.state, # type: ignore - self.permission) + permission = await updater(bot, event, self.state, + self.permission) else: permission = USER(event.get_session_id(), perm=self.permission) @@ -661,11 +680,8 @@ class Matcher(metaclass=MatcherMeta): updater = self.__class__._default_permission_updater if updater: - permission = await updater( - bot, - event, - self.state, # type: ignore - self.permission) + permission = await updater(bot, event, self.state, + self.permission) else: permission = USER(event.get_session_id(), perm=self.permission) diff --git a/nonebot/message.py b/nonebot/message.py index fec00687..66379e77 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -8,7 +8,7 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供 import asyncio from datetime import datetime from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Set, Type +from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Optional from nonebot.log import logger from nonebot.rule import TrieRule @@ -17,8 +17,9 @@ from nonebot.utils import escape_tag from nonebot import params, get_driver from nonebot.matcher import Matcher, matchers from nonebot.exception import NoLogException, StopPropagation, IgnoredException -from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor, - T_EventPreProcessor, T_EventPostProcessor) +from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor, + T_RunPostProcessor, T_EventPreProcessor, + T_EventPostProcessor) if TYPE_CHECKING: from nonebot.adapters import Bot, Event @@ -43,14 +44,6 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor: :说明: 事件预处理。装饰一个函数,使它在每次接收到事件并分发给各响应器之前执行。 - - :参数: - - 事件预处理函数接收三个参数。 - - * ``bot: Bot``: Bot 对象 - * ``event: Event``: Event 对象 - * ``state: T_State``: 当前 State """ _event_preprocessors.add( Handler(func, @@ -64,14 +57,6 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor: :说明: 事件后处理。装饰一个函数,使它在每次接收到事件并分发给各响应器之后执行。 - - :参数: - - 事件后处理函数接收三个参数。 - - * ``bot: Bot``: Bot 对象 - * ``event: Event``: Event 对象 - * ``state: T_State``: 当前事件运行前 State """ _event_postprocessors.add( Handler(func, @@ -85,15 +70,6 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor: :说明: 运行预处理。装饰一个函数,使它在每次事件响应器运行前执行。 - - :参数: - - 运行预处理函数接收四个参数。 - - * ``matcher: Matcher``: 当前要运行的事件响应器 - * ``bot: Bot``: Bot 对象 - * ``event: Event``: Event 对象 - * ``state: T_State``: 当前 State """ _run_preprocessors.add( Handler(func, @@ -107,16 +83,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor: :说明: 运行后处理。装饰一个函数,使它在每次事件响应器运行后执行。 - - :参数: - - 运行后处理函数接收五个参数。 - - * ``matcher: Matcher``: 运行完毕的事件响应器 - * ``exception: Optional[Exception]``: 事件响应器运行错误(如果存在) - * ``bot: Bot``: Bot 对象 - * ``event: Event``: Event 对象 - * ``state: T_State``: 当前 State """ _run_postprocessors.add( Handler(func, @@ -125,8 +91,14 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor: return func -async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot", - event: "Event", state: T_State) -> None: +async def _check_matcher( + priority: int, + Matcher: Type[Matcher], + bot: "Bot", + event: "Event", + state: T_State, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[T_DependencyCache] = None) -> None: if Matcher.expire_time and datetime.now() > Matcher.expire_time: try: matchers[priority].remove(Matcher) @@ -136,7 +108,9 @@ async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot", try: if not await Matcher.check_perm( - bot, event) or not await Matcher.check_rule(bot, event, state): + bot, event, stack, + dependency_cache) or not await Matcher.check_rule( + bot, event, state, stack, dependency_cache): return except Exception as e: logger.opt(colors=True, exception=e).error( @@ -149,17 +123,28 @@ async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot", except Exception: pass - await _run_matcher(Matcher, bot, event, state) + await _run_matcher(Matcher, bot, event, state, stack, dependency_cache) -async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event", - state: T_State) -> None: +async def _run_matcher( + Matcher: Type[Matcher], + bot: "Bot", + event: "Event", + state: T_State, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[T_DependencyCache] = None) -> None: logger.info(f"Event will be handled by {Matcher}") matcher = Matcher() coros = list( - map(lambda x: x(matcher=matcher, bot=bot, event=event, state=state), + map( + lambda x: x(matcher=matcher, + bot=bot, + event=event, + state=state, + _stack=stack, + _dependency_cache=dependency_cache), _run_preprocessors)) if coros: try: @@ -191,7 +176,10 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event", exception=exception, bot=bot, event=event, - state=state), _run_postprocessors)) + state=state, + _stack=stack, + _dependency_cache=dependency_cache), + _run_postprocessors)) if coros: try: await asyncio.gather(*coros) @@ -232,12 +220,17 @@ async def handle_event(bot: "Bot", event: "Event") -> None: if show_log: logger.opt(colors=True).success(log_msg) - state = {} + state: Dict[Any, Any] = {} + dependency_cache: T_DependencyCache = {} - # TODO async with AsyncExitStack() as stack: coros = list( - map(lambda x: x(bot=bot, event=event, state=state), + map( + lambda x: x(bot=bot, + event=event, + state=state, + _stack=stack, + _dependency_cache=dependency_cache), _event_preprocessors)) if coros: try: @@ -286,7 +279,12 @@ async def handle_event(bot: "Bot", event: "Event") -> None: ) coros = list( - map(lambda x: x(bot=bot, event=event, state=state), + map( + lambda x: x(bot=bot, + event=event, + state=state, + _stack=stack, + _dependency_cache=dependency_cache), _event_postprocessors)) if coros: try: diff --git a/nonebot/permission.py b/nonebot/permission.py index 7564bb13..b83be58a 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -10,9 +10,12 @@ r""" """ import asyncio -from typing import Union, Callable, NoReturn, Optional, Awaitable +from contextlib import AsyncExitStack +from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional -from nonebot.utils import run_sync +from nonebot import params +from nonebot.handler import Handler +from nonebot.dependencies import Param from nonebot.adapters import Bot, Event from nonebot.typing import T_PermissionChecker @@ -34,14 +37,23 @@ class Permission: """ __slots__ = ("checkers",) - def __init__(self, *checkers: Callable[[Bot, Event], - Awaitable[bool]]) -> None: + HANDLER_PARAM_TYPES: List[Type[Param]] = [ + params.BotParam, params.EventParam + ] + + def __init__(self, + *checkers: T_PermissionChecker, + dependency_overrides_provider: Optional[Any] = None) -> None: """ :参数: - * ``*checkers: Callable[[Bot, Event], Awaitable[bool]]``: **异步** PermissionChecker + * ``*checkers: T_PermissionChecker``: PermissionChecker """ - self.checkers = set(checkers) + self.checkers = set( + Handler(checker, + allow_types=self.HANDLER_PARAM_TYPES, + dependency_overrides_provider=dependency_overrides_provider) + for checker in checkers) """ :说明: @@ -49,10 +61,16 @@ class Permission: :类型: - * ``Set[Callable[[Bot, Event], Awaitable[bool]]]`` + * ``Set[Handler]`` """ - async def __call__(self, bot: Bot, event: Event) -> bool: + async def __call__( + self, + bot: Bot, + event: Event, + stack: Optional[AsyncExitStack] = None, + dependency_cache: Optional[Dict[Callable[..., Any], + Any]] = None) -> bool: """ :说明: @@ -62,6 +80,8 @@ class Permission: * ``bot: Bot``: Bot 对象 * ``event: Event``: Event 对象 + * ``stack: Optional[AsyncExitStack]``: 异步上下文栈 + * ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存 :返回: @@ -70,7 +90,11 @@ class Permission: if not self.checkers: return True results = await asyncio.gather( - *map(lambda c: c(bot, event), self.checkers)) + checker(bot=bot, + event=event, + _stack=stack, + _dependency_cache=dependency_cache) + for checker in self.checkers) return any(results) def __and__(self, other) -> NoReturn: @@ -79,16 +103,12 @@ class Permission: def __or__( self, other: Optional[Union["Permission", T_PermissionChecker]]) -> "Permission": - checkers = self.checkers.copy() if other is None: return self elif isinstance(other, Permission): - checkers |= other.checkers - elif asyncio.iscoroutinefunction(other): - checkers.add(other) # type: ignore + return Permission(*self.checkers, *other.checkers) else: - checkers.add(run_sync(other)) - return Permission(*checkers) + return Permission(*self.checkers, other) async def _message(bot: Bot, event: Event) -> bool: diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 1b2cf16e..5b0f2979 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -434,8 +434,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]], message = event.get_message() segment = message.pop(0) new_message = message.__class__( - str(segment) - [len(state["_prefix"]["raw_command"]):].strip()) # type: ignore + str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]):].strip()) for new_segment in reversed(new_message): message.insert(0, new_segment) diff --git a/nonebot/rule.py b/nonebot/rule.py index 3a1c2c33..c2ac4401 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -17,15 +17,15 @@ from argparse import Namespace from contextlib import AsyncExitStack from typing_extensions import TypedDict from argparse import ArgumentParser as ArgParser -from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional, - Sequence, Awaitable) +from typing import (Any, Dict, List, Type, Tuple, Union, Callable, NoReturn, + Optional, Sequence) from pygtrie import CharTrie from nonebot.log import logger -from nonebot.utils import run_sync from nonebot.handler import Handler from nonebot import params, get_driver +from nonebot.dependencies import Param from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker from nonebot.adapters import Bot, Event, MessageSegment @@ -64,11 +64,13 @@ class Rule: """ __slots__ = ("checkers",) - HANDLER_PARAM_TYPES = [ + HANDLER_PARAM_TYPES: List[Type[Param]] = [ params.BotParam, params.EventParam, params.StateParam ] - def __init__(self, *checkers: T_RuleChecker) -> None: + def __init__(self, + *checkers: T_RuleChecker, + dependency_overrides_provider: Optional[Any] = None) -> None: """ :参数: @@ -78,7 +80,7 @@ class Rule: self.checkers = set( Handler(checker, allow_types=self.HANDLER_PARAM_TYPES, - dependency_overrides_provider=get_driver()) + dependency_overrides_provider=dependency_overrides_provider) for checker in checkers) """ :说明: @@ -108,11 +110,15 @@ class Rule: * ``bot: Bot``: Bot 对象 * ``event: Event``: Event 对象 * ``state: T_State``: 当前 State + * ``stack: Optional[AsyncExitStack]``: 异步上下文栈 + * ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存 :返回: - ``bool`` """ + if not self.checkers: + return True results = await asyncio.gather( checker(bot=bot, event=event, @@ -126,10 +132,9 @@ class Rule: if other is None: return self elif isinstance(other, Rule): - checkers = [*self.checkers, *other.checkers] + return Rule(*self.checkers, *other.checkers) else: - checkers = [*self.checkers, other] - return Rule(*checkers) + return Rule(*self.checkers, other) def __or__(self, other) -> NoReturn: raise RuntimeError("Or operation between rules is not allowed.") diff --git a/nonebot/typing.py b/nonebot/typing.py index 337ff426..ac5a2335 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -22,7 +22,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable, NoReturn, Optional, Awaitable) if TYPE_CHECKING: - from nonebot.matcher import Matcher from nonebot.adapters import Bot, Event from nonebot.permission import Permission @@ -90,33 +89,60 @@ T_CalledAPIHook = Callable[ ``bot.call_api`` 后执行的函数,参数分别为 bot, exception, api, data, result """ -T_EventPreProcessor = Callable[..., Awaitable[None]] +T_EventPreProcessor = Callable[..., Union[None, Awaitable[None]]] """ -:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` +:类型: ``Callable[..., Union[None, Awaitable[None]]]`` + +:依赖参数: + + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 :说明: 事件预处理函数 EventPreProcessor 类型 """ -T_EventPostProcessor = Callable[..., Awaitable[None]] +T_EventPostProcessor = Callable[..., Union[None, Awaitable[None]]] """ -:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` +:类型: ``Callable[..., Union[None, Awaitable[None]]]`` + +:依赖参数: + + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 :说明: 事件预处理函数 EventPostProcessor 类型 """ -T_RunPreProcessor = Callable[..., Awaitable[None]] +T_RunPreProcessor = Callable[..., Union[None, Awaitable[None]]] """ -:类型: ``Callable[[Matcher, Bot, Event, T_State], Awaitable[None]]`` +:类型: ``Callable[..., Union[None, Awaitable[None]]]`` + +:依赖参数: + + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 + * ``MatcherParam``: Matcher 对象 :说明: 事件响应器运行前预处理函数 RunPreProcessor 类型 """ -T_RunPostProcessor = Callable[..., Awaitable[None]] +T_RunPostProcessor = Callable[..., Union[None, Awaitable[None]]] """ -:类型: ``Callable[[Matcher, Optional[Exception], Bot, Event, T_State], Awaitable[None]]`` +:类型: ``Callable[..., Union[None, Awaitable[None]]]`` + +:依赖参数: + + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 + * ``MatcherParam``: Matcher 对象 + * ``ExceptionParam``: 异常对象(可能为 None) :说明: @@ -127,28 +153,45 @@ T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]] """ :类型: ``Callable[..., Union[bool, Awaitable[bool]]]`` +:依赖参数: + + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 + * ``StateParam``: State 对象 + :说明: RuleChecker 即判断是否响应事件的处理函数。 """ -T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]] +T_PermissionChecker = Callable[..., Union[bool, Awaitable[bool]]] """ -:类型: ``Callable[[Bot, Event], Union[bool, Awaitable[bool]]]`` +:类型: ``Callable[..., Union[bool, Awaitable[bool]]]`` + +:依赖参数: + + * ``BotParam``: Bot 对象 + * ``EventParam``: Event 对象 :说明: RuleChecker 即判断是否响应消息的处理函数。 """ -T_Handler = Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]] +T_Handler = Callable[..., Any] """ -:类型: - - * ``Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]`` +:类型: ``Callable[..., Any]`` :说明: - Handler 即事件的处理函数。 + Handler 处理函数。 +""" +T_DependencyCache = Dict[T_Handler, Any] +""" +:类型: ``Dict[T_Handler, Any]`` + +:说明: + + 依赖缓存, 用于存储依赖函数的返回值 """ T_ArgsParser = Callable[["Bot", "Event", T_State], Union[Awaitable[None], Awaitable[NoReturn]]]