⚗️ change permission to use handler

This commit is contained in:
yanyongyu 2021-11-21 12:36:44 +08:00
parent a5948fb5e3
commit b4d12d905d
9 changed files with 207 additions and 122 deletions

View File

@ -20,6 +20,7 @@ from .utils import get_typed_signature
from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
@ -58,7 +59,7 @@ def get_parameterless_sub_dependant(
def get_sub_dependant(
*,
depends: DependsWrapper,
dependency: Callable[..., Any],
dependency: T_Handler,
name: Optional[str] = None,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
sub_dependant = get_dependent(func=dependency,
@ -69,7 +70,7 @@ def get_sub_dependant(
def get_dependent(*,
func: Callable[..., Any],
func: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
@ -118,8 +119,8 @@ async def solve_dependencies(
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_overrides_provider: Optional[Any] = None,
_dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]:
_dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
values: Dict[str, Any] = {}
dependency_cache = _dependency_cache or {}
@ -201,7 +202,7 @@ async def solve_dependencies(
return values, dependency_cache
def Depends(dependency: Optional[Callable[..., Any]] = None,
def Depends(dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True) -> Any:
"""

View File

@ -5,9 +5,10 @@ from typing import Any, List, Type, Callable, Optional
from pydantic.fields import FieldInfo, ModelField
from nonebot.utils import get_name
from nonebot.typing import T_Handler
class Param(FieldInfo, abc.ABC):
class Param(abc.ABC, FieldInfo):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
@ -28,7 +29,7 @@ class Param(FieldInfo, abc.ABC):
class DependsWrapper:
def __init__(self,
dependency: Optional[Callable[..., Any]] = None,
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True) -> None:
self.dependency = dependency
@ -44,7 +45,7 @@ class Dependent:
def __init__(self,
*,
func: Optional[Callable[..., Any]] = None,
func: Optional[T_Handler] = None,
name: Optional[str] = None,
params: Optional[List[ModelField]] = None,
allow_types: Optional[List[Type[Param]]] = None,

View File

@ -1,13 +1,15 @@
import inspect
from typing import Any, Dict, Callable
from typing import Any, Dict
from loguru import logger
from pydantic.typing import ForwardRef, evaluate_forwardref
from nonebot.typing import T_Handler
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
def get_typed_signature(func: T_Handler) -> inspect.Signature:
signature = inspect.signature(func)
globalns = getattr(func, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,

View File

@ -9,6 +9,7 @@ from types import ModuleType
from datetime import datetime
from contextvars import ContextVar
from collections import defaultdict
from contextlib import AsyncExitStack
from typing import (TYPE_CHECKING, Any, Dict, List, Type, Union, Callable,
NoReturn, Optional)
@ -20,11 +21,12 @@ from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
T_StateFactory, T_PermissionUpdater)
from nonebot.exception import (PausedException, StopPropagation,
SkippedException, FinishedException,
RejectedException)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
T_StateFactory, T_DependencyCache,
T_PermissionUpdater)
if TYPE_CHECKING:
from nonebot.plugin import Plugin
@ -267,7 +269,13 @@ class Matcher(metaclass=MatcherMeta):
return NewMatcher
@classmethod
async def check_perm(cls, bot: Bot, event: Event) -> bool:
async def check_perm(
cls,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
"""
:说明:
@ -284,10 +292,17 @@ class Matcher(metaclass=MatcherMeta):
"""
event_type = event.get_type()
return (event_type == (cls.type or event_type) and
await cls.permission(bot, event))
await cls.permission(bot, event, stack, dependency_cache))
@classmethod
async def check_rule(cls, bot: Bot, event: Event, state: T_State) -> bool:
async def check_rule(
cls,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
"""
:说明:
@ -305,7 +320,7 @@ class Matcher(metaclass=MatcherMeta):
"""
event_type = event.get_type()
return (event_type == (cls.type or event_type) and
await cls.rule(bot, event, state))
await cls.rule(bot, event, state, stack, dependency_cache))
@classmethod
def args_parser(cls, func: T_ArgsParser) -> T_ArgsParser:
@ -589,7 +604,12 @@ class Matcher(metaclass=MatcherMeta):
self.block = True
# 运行handlers
async def run(self, bot: Bot, event: Event, state: T_State):
async def run(self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
s_t = current_state.set(self.state)
@ -606,7 +626,9 @@ class Matcher(metaclass=MatcherMeta):
await handler(matcher=self,
bot=bot,
event=event,
state=self.state)
state=self.state,
_stack=stack,
_dependency_cache=dependency_cache)
except SkippedException:
pass
@ -624,11 +646,8 @@ class Matcher(metaclass=MatcherMeta):
updater = self.__class__._default_permission_updater
if updater:
permission = await updater(
bot,
event,
self.state, # type: ignore
self.permission)
permission = await updater(bot, event, self.state,
self.permission)
else:
permission = USER(event.get_session_id(), perm=self.permission)
@ -661,11 +680,8 @@ class Matcher(metaclass=MatcherMeta):
updater = self.__class__._default_permission_updater
if updater:
permission = await updater(
bot,
event,
self.state, # type: ignore
self.permission)
permission = await updater(bot, event, self.state,
self.permission)
else:
permission = USER(event.get_session_id(), perm=self.permission)

View File

@ -8,7 +8,7 @@ NoneBot 内部处理并按优先级分发事件给所有事件响应器,提供
import asyncio
from datetime import datetime
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Set, Type
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Optional
from nonebot.log import logger
from nonebot.rule import TrieRule
@ -17,8 +17,9 @@ from nonebot.utils import escape_tag
from nonebot import params, get_driver
from nonebot.matcher import Matcher, matchers
from nonebot.exception import NoLogException, StopPropagation, IgnoredException
from nonebot.typing import (T_State, T_RunPreProcessor, T_RunPostProcessor,
T_EventPreProcessor, T_EventPostProcessor)
from nonebot.typing import (T_State, T_DependencyCache, T_RunPreProcessor,
T_RunPostProcessor, T_EventPreProcessor,
T_EventPostProcessor)
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
@ -43,14 +44,6 @@ def event_preprocessor(func: T_EventPreProcessor) -> T_EventPreProcessor:
:说明:
事件预处理装饰一个函数使它在每次接收到事件并分发给各响应器之前执行
:参数:
事件预处理函数接收三个参数
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``state: T_State``: 当前 State
"""
_event_preprocessors.add(
Handler(func,
@ -64,14 +57,6 @@ def event_postprocessor(func: T_EventPostProcessor) -> T_EventPostProcessor:
:说明:
事件后处理装饰一个函数使它在每次接收到事件并分发给各响应器之后执行
:参数:
事件后处理函数接收三个参数
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``state: T_State``: 当前事件运行前 State
"""
_event_postprocessors.add(
Handler(func,
@ -85,15 +70,6 @@ def run_preprocessor(func: T_RunPreProcessor) -> T_RunPreProcessor:
:说明:
运行预处理装饰一个函数使它在每次事件响应器运行前执行
:参数:
运行预处理函数接收四个参数
* ``matcher: Matcher``: 当前要运行的事件响应器
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``state: T_State``: 当前 State
"""
_run_preprocessors.add(
Handler(func,
@ -107,16 +83,6 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
:说明:
运行后处理装饰一个函数使它在每次事件响应器运行后执行
:参数:
运行后处理函数接收五个参数
* ``matcher: Matcher``: 运行完毕的事件响应器
* ``exception: Optional[Exception]``: 事件响应器运行错误如果存在
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``state: T_State``: 当前 State
"""
_run_postprocessors.add(
Handler(func,
@ -125,8 +91,14 @@ def run_postprocessor(func: T_RunPostProcessor) -> T_RunPostProcessor:
return func
async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot",
event: "Event", state: T_State) -> None:
async def _check_matcher(
priority: int,
Matcher: Type[Matcher],
bot: "Bot",
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None) -> None:
if Matcher.expire_time and datetime.now() > Matcher.expire_time:
try:
matchers[priority].remove(Matcher)
@ -136,7 +108,9 @@ async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot",
try:
if not await Matcher.check_perm(
bot, event) or not await Matcher.check_rule(bot, event, state):
bot, event, stack,
dependency_cache) or not await Matcher.check_rule(
bot, event, state, stack, dependency_cache):
return
except Exception as e:
logger.opt(colors=True, exception=e).error(
@ -149,17 +123,28 @@ async def _check_matcher(priority: int, Matcher: Type[Matcher], bot: "Bot",
except Exception:
pass
await _run_matcher(Matcher, bot, event, state)
await _run_matcher(Matcher, bot, event, state, stack, dependency_cache)
async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
state: T_State) -> None:
async def _run_matcher(
Matcher: Type[Matcher],
bot: "Bot",
event: "Event",
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None) -> None:
logger.info(f"Event will be handled by {Matcher}")
matcher = Matcher()
coros = list(
map(lambda x: x(matcher=matcher, bot=bot, event=event, state=state),
map(
lambda x: x(matcher=matcher,
bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_run_preprocessors))
if coros:
try:
@ -191,7 +176,10 @@ async def _run_matcher(Matcher: Type[Matcher], bot: "Bot", event: "Event",
exception=exception,
bot=bot,
event=event,
state=state), _run_postprocessors))
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_run_postprocessors))
if coros:
try:
await asyncio.gather(*coros)
@ -232,12 +220,17 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
if show_log:
logger.opt(colors=True).success(log_msg)
state = {}
state: Dict[Any, Any] = {}
dependency_cache: T_DependencyCache = {}
# TODO
async with AsyncExitStack() as stack:
coros = list(
map(lambda x: x(bot=bot, event=event, state=state),
map(
lambda x: x(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_event_preprocessors))
if coros:
try:
@ -286,7 +279,12 @@ async def handle_event(bot: "Bot", event: "Event") -> None:
)
coros = list(
map(lambda x: x(bot=bot, event=event, state=state),
map(
lambda x: x(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache),
_event_postprocessors))
if coros:
try:

View File

@ -10,9 +10,12 @@ r"""
"""
import asyncio
from typing import Union, Callable, NoReturn, Optional, Awaitable
from contextlib import AsyncExitStack
from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
from nonebot.utils import run_sync
from nonebot import params
from nonebot.handler import Handler
from nonebot.dependencies import Param
from nonebot.adapters import Bot, Event
from nonebot.typing import T_PermissionChecker
@ -34,14 +37,23 @@ class Permission:
"""
__slots__ = ("checkers",)
def __init__(self, *checkers: Callable[[Bot, Event],
Awaitable[bool]]) -> None:
HANDLER_PARAM_TYPES: List[Type[Param]] = [
params.BotParam, params.EventParam
]
def __init__(self,
*checkers: T_PermissionChecker,
dependency_overrides_provider: Optional[Any] = None) -> None:
"""
:参数:
* ``*checkers: Callable[[Bot, Event], Awaitable[bool]]``: **异步** PermissionChecker
* ``*checkers: T_PermissionChecker``: PermissionChecker
"""
self.checkers = set(checkers)
self.checkers = set(
Handler(checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers)
"""
:说明:
@ -49,10 +61,16 @@ class Permission:
:类型:
* ``Set[Callable[[Bot, Event], Awaitable[bool]]]``
* ``Set[Handler]``
"""
async def __call__(self, bot: Bot, event: Event) -> bool:
async def __call__(
self,
bot: Bot,
event: Event,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
"""
:说明:
@ -62,6 +80,8 @@ class Permission:
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``stack: Optional[AsyncExitStack]``: 异步上下文栈
* ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存
:返回:
@ -70,7 +90,11 @@ class Permission:
if not self.checkers:
return True
results = await asyncio.gather(
*map(lambda c: c(bot, event), self.checkers))
checker(bot=bot,
event=event,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers)
return any(results)
def __and__(self, other) -> NoReturn:
@ -79,16 +103,12 @@ class Permission:
def __or__(
self, other: Optional[Union["Permission",
T_PermissionChecker]]) -> "Permission":
checkers = self.checkers.copy()
if other is None:
return self
elif isinstance(other, Permission):
checkers |= other.checkers
elif asyncio.iscoroutinefunction(other):
checkers.add(other) # type: ignore
return Permission(*self.checkers, *other.checkers)
else:
checkers.add(run_sync(other))
return Permission(*checkers)
return Permission(*self.checkers, other)
async def _message(bot: Bot, event: Event) -> bool:

View File

@ -434,8 +434,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(
str(segment)
[len(state["_prefix"]["raw_command"]):].strip()) # type: ignore
str(segment)[len(state[PREFIX_KEY][RAW_CMD_KEY]):].strip())
for new_segment in reversed(new_message):
message.insert(0, new_segment)

View File

@ -17,15 +17,15 @@ from argparse import Namespace
from contextlib import AsyncExitStack
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Sequence, Awaitable)
from typing import (Any, Dict, List, Type, Tuple, Union, Callable, NoReturn,
Optional, Sequence)
from pygtrie import CharTrie
from nonebot.log import logger
from nonebot.utils import run_sync
from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.dependencies import Param
from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment
@ -64,11 +64,13 @@ class Rule:
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
HANDLER_PARAM_TYPES: List[Type[Param]] = [
params.BotParam, params.EventParam, params.StateParam
]
def __init__(self, *checkers: T_RuleChecker) -> None:
def __init__(self,
*checkers: T_RuleChecker,
dependency_overrides_provider: Optional[Any] = None) -> None:
"""
:参数:
@ -78,7 +80,7 @@ class Rule:
self.checkers = set(
Handler(checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=get_driver())
dependency_overrides_provider=dependency_overrides_provider)
for checker in checkers)
"""
:说明:
@ -108,11 +110,15 @@ class Rule:
* ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象
* ``state: T_State``: 当前 State
* ``stack: Optional[AsyncExitStack]``: 异步上下文栈
* ``dependency_cache: Optional[Dict[Callable[..., Any], Any]]``: 依赖缓存
:返回:
- ``bool``
"""
if not self.checkers:
return True
results = await asyncio.gather(
checker(bot=bot,
event=event,
@ -126,10 +132,9 @@ class Rule:
if other is None:
return self
elif isinstance(other, Rule):
checkers = [*self.checkers, *other.checkers]
return Rule(*self.checkers, *other.checkers)
else:
checkers = [*self.checkers, other]
return Rule(*checkers)
return Rule(*self.checkers, other)
def __or__(self, other) -> NoReturn:
raise RuntimeError("Or operation between rules is not allowed.")

View File

@ -22,7 +22,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Union, TypeVar, Callable,
NoReturn, Optional, Awaitable)
if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission
@ -90,33 +89,60 @@ T_CalledAPIHook = Callable[
``bot.call_api`` 后执行的函数参数分别为 bot, exception, api, data, result
"""
T_EventPreProcessor = Callable[..., Awaitable[None]]
T_EventPreProcessor = Callable[..., Union[None, Awaitable[None]]]
"""
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
:依赖参数:
* ``BotParam``: Bot 对象
* ``EventParam``: Event 对象
* ``StateParam``: State 对象
:说明:
事件预处理函数 EventPreProcessor 类型
"""
T_EventPostProcessor = Callable[..., Awaitable[None]]
T_EventPostProcessor = Callable[..., Union[None, Awaitable[None]]]
"""
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
:依赖参数:
* ``BotParam``: Bot 对象
* ``EventParam``: Event 对象
* ``StateParam``: State 对象
:说明:
事件预处理函数 EventPostProcessor 类型
"""
T_RunPreProcessor = Callable[..., Awaitable[None]]
T_RunPreProcessor = Callable[..., Union[None, Awaitable[None]]]
"""
:类型: ``Callable[[Matcher, Bot, Event, T_State], Awaitable[None]]``
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
:依赖参数:
* ``BotParam``: Bot 对象
* ``EventParam``: Event 对象
* ``StateParam``: State 对象
* ``MatcherParam``: Matcher 对象
:说明:
事件响应器运行前预处理函数 RunPreProcessor 类型
"""
T_RunPostProcessor = Callable[..., Awaitable[None]]
T_RunPostProcessor = Callable[..., Union[None, Awaitable[None]]]
"""
:类型: ``Callable[[Matcher, Optional[Exception], Bot, Event, T_State], Awaitable[None]]``
:类型: ``Callable[..., Union[None, Awaitable[None]]]``
:依赖参数:
* ``BotParam``: Bot 对象
* ``EventParam``: Event 对象
* ``StateParam``: State 对象
* ``MatcherParam``: Matcher 对象
* ``ExceptionParam``: 异常对象可能为 None
:说明:
@ -127,28 +153,45 @@ T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]]
"""
:类型: ``Callable[..., Union[bool, Awaitable[bool]]]``
:依赖参数:
* ``BotParam``: Bot 对象
* ``EventParam``: Event 对象
* ``StateParam``: State 对象
:说明:
RuleChecker 即判断是否响应事件的处理函数
"""
T_PermissionChecker = Callable[["Bot", "Event"], Union[bool, Awaitable[bool]]]
T_PermissionChecker = Callable[..., Union[bool, Awaitable[bool]]]
"""
:类型: ``Callable[[Bot, Event], Union[bool, Awaitable[bool]]]``
:类型: ``Callable[..., Union[bool, Awaitable[bool]]]``
:依赖参数:
* ``BotParam``: Bot 对象
* ``EventParam``: Event 对象
:说明:
RuleChecker 即判断是否响应消息的处理函数
"""
T_Handler = Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]
T_Handler = Callable[..., Any]
"""
:类型:
* ``Callable[..., Union[Awaitable[None], Awaitable[NoReturn]]]``
:类型: ``Callable[..., Any]``
:说明:
Handler 即事件的处理函数
Handler 处理函数
"""
T_DependencyCache = Dict[T_Handler, Any]
"""
:类型: ``Dict[T_Handler, Any]``
:说明:
依赖缓存, 用于存储依赖函数的返回值
"""
T_ArgsParser = Callable[["Bot", "Event", T_State], Union[Awaitable[None],
Awaitable[NoReturn]]]