⚗️ change rule to use handler

This commit is contained in:
yanyongyu 2021-11-19 18:18:53 +08:00
parent ee619a33a9
commit 471d306e13
8 changed files with 182 additions and 148 deletions

View File

@ -18,6 +18,7 @@ from nonebot.log import logger
from .models import Param as Param from .models import Param as Param
from .utils import get_typed_signature from .utils import get_typed_signature
from .models import Dependent as Dependent from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper from .models import DependsWrapper as DependsWrapper
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager, from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable) is_async_gen_callable, is_coroutine_callable)
@ -112,21 +113,20 @@ def get_dependent(*,
async def solve_dependencies( async def solve_dependencies(
*, *,
dependent: Dependent, _dependent: Dependent,
stack: Optional[AsyncExitStack] = None, _stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None, _sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None, _dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None, _dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any **params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]:
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
values: Dict[str, Any] = {} values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {} dependency_cache = _dependency_cache or {}
# solve sub dependencies # solve sub dependencies
sub_dependent: Dependent sub_dependent: Dependent
for sub_dependent in chain(sub_dependents or tuple(), for sub_dependent in chain(_sub_dependents or tuple(),
dependent.dependencies): _dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func) sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Callable[..., Any], sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key) sub_dependent.cache_key)
@ -134,10 +134,10 @@ async def solve_dependencies(
# dependency overrides # dependency overrides
use_sub_dependant = sub_dependent use_sub_dependant = sub_dependent
if (dependency_overrides_provider and if (_dependency_overrides_provider and hasattr(
hasattr(dependency_overrides_provider, "dependency_overrides")): _dependency_overrides_provider, "dependency_overrides")):
original_call = sub_dependent.func original_call = sub_dependent.func
func = getattr(dependency_overrides_provider, func = getattr(_dependency_overrides_provider,
"dependency_overrides", "dependency_overrides",
{}).get(original_call, original_call) {}).get(original_call, original_call)
use_sub_dependant = get_dependent( use_sub_dependant = get_dependent(
@ -148,13 +148,11 @@ async def solve_dependencies(
# solve sub dependency with current cache # solve sub dependency with current cache
solved_result = await solve_dependencies( solved_result = await solve_dependencies(
dependent=use_sub_dependant, _dependent=use_sub_dependant,
dependency_overrides_provider=dependency_overrides_provider, _dependency_overrides_provider=_dependency_overrides_provider,
dependency_cache=dependency_cache, dependency_cache=dependency_cache,
**params) **params)
sub_values, sub_dependency_cache, ignored = solved_result sub_values, sub_dependency_cache = solved_result
if ignored:
return values, dependency_cache, True
# update cache? # update cache?
dependency_cache.update(sub_dependency_cache) dependency_cache.update(sub_dependency_cache)
@ -163,13 +161,13 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependent.cache_key] solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func): elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance( assert isinstance(
stack, AsyncExitStack _stack, AsyncExitStack
), "Generator dependency should be called in context" ), "Generator dependency should be called in context"
if is_gen_callable(func): if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values)) cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else: else:
cm = asynccontextmanager(func)(**sub_values) cm = asynccontextmanager(func)(**sub_values)
solved = await stack.enter_async_context(cm) solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func): elif is_coroutine_callable(func):
solved = await func(**sub_values) solved = await func(**sub_values)
else: else:
@ -183,7 +181,7 @@ async def solve_dependencies(
dependency_cache[sub_dependent.cache_key] = solved dependency_cache[sub_dependent.cache_key] = solved
# usual dependency # usual dependency
for field in dependent.params: for field in _dependent.params:
field_info = field.field_info field_info = field.field_info
assert isinstance(field_info, assert isinstance(field_info,
Param), "Params must be subclasses of Param" Param), "Params must be subclasses of Param"
@ -194,13 +192,13 @@ async def solve_dependencies(
if errs_: if errs_:
logger.debug( logger.debug(
f"{field_info} " f"{field_info} "
f"type {type(value)} not match depends {dependent.func} " f"type {type(value)} not match depends {_dependent.func} "
f"annotation {field._type_display()}, ignored") f"annotation {field._type_display()}, ignored")
return values, dependency_cache, True raise SkippedException
else: else:
values[field.name] = value values[field.name] = value
return values, dependency_cache, False return values, dependency_cache
def Depends(dependency: Optional[Callable[..., Any]] = None, def Depends(dependency: Optional[Callable[..., Any]] = None,

View File

@ -248,6 +248,8 @@ class Driver(ForwardDriver):
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
setup_ = cast(HTTPPollingSetup, setup_)
if not bot: if not bot:
request = await _build_request(setup_) request = await _build_request(setup_)
if not request: if not request:
@ -264,7 +266,6 @@ class Driver(ForwardDriver):
bot.request = request bot.request = request
request = cast(HTTPRequest, request) request = cast(HTTPRequest, request)
setup_ = cast(HTTPPollingSetup, setup_)
headers = request.headers headers = request.headers
timeout = aiohttp.ClientTimeout(30) timeout = aiohttp.ClientTimeout(30)

View File

@ -409,6 +409,8 @@ class Driver(ReverseDriver, ForwardDriver):
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
setup_ = cast(HTTPPollingSetup, setup_)
if not bot: if not bot:
request = await _build_request(setup_) request = await _build_request(setup_)
if not request: if not request:
@ -423,7 +425,6 @@ class Driver(ReverseDriver, ForwardDriver):
continue continue
bot.request = request bot.request = request
setup_ = cast(HTTPPollingSetup, setup_)
request = cast(HTTPRequest, request) request = cast(HTTPRequest, request)
headers = request.headers headers = request.headers

View File

@ -6,6 +6,8 @@
这些异常并非所有需要用户处理 NoneBot 内部运行时被捕获并进行对应操作 这些异常并非所有需要用户处理 NoneBot 内部运行时被捕获并进行对应操作
""" """
from typing import Optional
class NoneBotException(Exception): class NoneBotException(Exception):
""" """
@ -13,9 +15,33 @@ class NoneBotException(Exception):
所有 NoneBot 发生的异常基类 所有 NoneBot 发生的异常基类
""" """
pass
# Rule Exception
class ParserExit(NoneBotException):
"""
:说明:
``shell command`` 处理消息失败时返回的异常
:参数:
* ``status``
* ``message``
"""
def __init__(self, status: int = 0, message: Optional[str] = None):
self.status = status
self.message = message
def __repr__(self):
return f"<ParserExit status={self.status} message={self.message}>"
def __str__(self):
return self.__repr__()
# Processor Exception
class IgnoredException(NoneBotException): class IgnoredException(NoneBotException):
""" """
:说明: :说明:
@ -37,71 +63,6 @@ class IgnoredException(NoneBotException):
return self.__repr__() return self.__repr__()
class ParserExit(NoneBotException):
"""
:说明:
``shell command`` 处理消息失败时返回的异常
:参数:
* ``status``
* ``message``
"""
def __init__(self, status=0, message=None):
self.status = status
self.message = message
def __repr__(self):
return f"<ParserExit status={self.status} message={self.message}>"
def __str__(self):
return self.__repr__()
class PausedException(NoneBotException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``
可用于用户输入新信息
:用法:
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出
"""
pass
class RejectedException(NoneBotException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``
可用于用户重新输入
:用法:
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出
"""
pass
class FinishedException(NoneBotException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行
可用于结束用户会话
:用法:
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出
"""
pass
class StopPropagation(NoneBotException): class StopPropagation(NoneBotException):
""" """
:说明: :说明:
@ -112,9 +73,69 @@ class StopPropagation(NoneBotException):
``Matcher.block == True`` 时抛出 ``Matcher.block == True`` 时抛出
""" """
pass
# Matcher Exceptions
class MatcherException(NoneBotException):
"""
:说明:
所有 Matcher 发生的异常基类
"""
class SkippedException(MatcherException):
"""
:说明:
指示 NoneBot 立即结束当前 ``Handler`` 的处理继续处理下一个 ``Handler``
:用法:
可以在 ``Handler`` 中通过 ``Matcher.skip()`` 抛出
"""
class PausedException(MatcherException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``
可用于用户输入新信息
:用法:
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出
"""
class RejectedException(MatcherException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``
可用于用户重新输入
:用法:
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出
"""
class FinishedException(MatcherException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行
可用于结束用户会话
:用法:
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出
"""
# Adapter Exceptions
class AdapterException(NoneBotException): class AdapterException(NoneBotException):
""" """
:说明: :说明:
@ -130,7 +151,7 @@ class AdapterException(NoneBotException):
self.adapter_name = adapter_name self.adapter_name = adapter_name
class NoLogException(Exception): class NoLogException(AdapterException):
""" """
:说明: :说明:

View File

@ -7,24 +7,19 @@
import asyncio import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Type, Callable, Optional from typing import Any, Dict, List, Type, Callable, Optional
from nonebot.typing import T_Handler
from nonebot.utils import get_name, run_sync from nonebot.utils import get_name, run_sync
from nonebot.dependencies import (Param, Dependent, DependsWrapper, from nonebot.dependencies import (Param, Dependent, DependsWrapper,
get_dependent, solve_dependencies, get_dependent, solve_dependencies,
get_parameterless_sub_dependant) get_parameterless_sub_dependant)
if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
class Handler: class Handler:
"""事件处理器类。支持依赖注入。""" """事件处理器类。支持依赖注入。"""
def __init__(self, def __init__(self,
func: T_Handler, func: Callable[..., Any],
*, *,
name: Optional[str] = None, name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None, dependencies: Optional[List[DependsWrapper]] = None,
@ -37,7 +32,7 @@ class Handler:
:参数: :参数:
* ``func: T_Handler``: 事件处理函数 * ``func: Callable[..., Any]``: 事件处理函数
* ``name: Optional[str]``: 事件处理器名称默认为函数名 * ``name: Optional[str]``: 事件处理器名称默认为函数名
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入 * ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型 * ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型
@ -45,7 +40,7 @@ class Handler:
""" """
self.func = func self.func = func
""" """
:类型: ``T_Handler`` :类型: ``Callable[..., Any]``
:说明: 事件处理函数 :说明: 事件处理函数
""" """
self.name = get_name(func) if name is None else name self.name = get_name(func) if name is None else name
@ -85,24 +80,21 @@ class Handler:
_dependency_cache: Optional[Dict[Callable[..., Any], _dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None, Any]] = None,
**params) -> Any: **params) -> Any:
values, _, ignored = await solve_dependencies( values, cache = await solve_dependencies(
dependent=self.dependent, _dependent=self.dependent,
stack=_stack, _stack=_stack,
sub_dependents=[ _sub_dependents=[
self.sub_dependents[dependency.dependency] # type: ignore self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies for dependency in self.dependencies
], ],
dependency_overrides_provider=self.dependency_overrides_provider, _dependency_overrides_provider=self.dependency_overrides_provider,
dependency_cache=_dependency_cache, _dependency_cache=_dependency_cache,
**params) **params)
if ignored:
return
if asyncio.iscoroutinefunction(self.func): if asyncio.iscoroutinefunction(self.func):
await self.func(**values) return await self.func(**values)
else: else:
await run_sync(self.func)(**values) return await run_sync(self.func)(**values)
def cache_dependent(self, dependency: DependsWrapper): def cache_dependent(self, dependency: DependsWrapper):
if not dependency.dependency: if not dependency.dependency:

View File

@ -20,10 +20,11 @@ from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment, from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate) MessageTemplate)
from nonebot.exception import (PausedException, StopPropagation,
FinishedException, RejectedException)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
T_StateFactory, T_PermissionUpdater) T_StateFactory, T_PermissionUpdater)
from nonebot.exception import (PausedException, StopPropagation,
SkippedException, FinishedException,
RejectedException)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.plugin import Plugin from nonebot.plugin import Plugin
@ -601,10 +602,13 @@ class Matcher(metaclass=MatcherMeta):
while self.handlers: while self.handlers:
handler = self.handlers.pop(0) handler = self.handlers.pop(0)
logger.debug(f"Running handler {handler}") logger.debug(f"Running handler {handler}")
await handler(matcher=self, try:
bot=bot, await handler(matcher=self,
event=event, bot=bot,
state=self.state) event=event,
state=self.state)
except SkippedException:
pass
except RejectedException: except RejectedException:
self.handlers.insert(0, handler) # type: ignore self.handlers.insert(0, handler) # type: ignore

View File

@ -14,16 +14,18 @@ import shlex
import asyncio import asyncio
from itertools import product from itertools import product
from argparse import Namespace from argparse import Namespace
from contextlib import AsyncExitStack
from typing_extensions import TypedDict from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser from argparse import ArgumentParser as ArgParser
from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence, from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Awaitable) Sequence, Awaitable)
from pygtrie import CharTrie from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import run_sync from nonebot.utils import run_sync
from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.exception import ParserExit from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment from nonebot.adapters import Bot, Event, MessageSegment
@ -62,16 +64,22 @@ class Rule:
""" """
__slots__ = ("checkers",) __slots__ = ("checkers",)
def __init__( HANDLER_PARAM_TYPES = [
self, *checkers: Callable[[Bot, Event, T_State], params.BotParam, params.EventParam, params.StateParam
Awaitable[bool]]) -> None: ]
def __init__(self, *checkers: T_RuleChecker) -> None:
""" """
:参数: :参数:
* ``*checkers: Callable[[Bot, Event, T_State], Awaitable[bool]]``: **异步** RuleChecker * ``*checkers: T_RuleChecker``: RuleChecker
""" """
self.checkers = set(checkers) self.checkers = set(
Handler(checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=get_driver())
for checker in checkers)
""" """
:说明: :说明:
@ -79,10 +87,17 @@ class Rule:
:类型: :类型:
* ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]`` * ``Set[Handler]``
""" """
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: async def __call__(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
""" """
:说明: :说明:
@ -99,19 +114,21 @@ class Rule:
- ``bool`` - ``bool``
""" """
results = await asyncio.gather( results = await asyncio.gather(
*map(lambda c: c(bot, event, state), self.checkers)) checker(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers)
return all(results) return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule": def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
checkers = self.checkers.copy()
if other is None: if other is None:
return self return self
elif isinstance(other, Rule): elif isinstance(other, Rule):
checkers |= other.checkers checkers = [*self.checkers, *other.checkers]
elif asyncio.iscoroutinefunction(other):
checkers.add(other) # type: ignore
else: else:
checkers.add(run_sync(other)) checkers = [*self.checkers, other]
return Rule(*checkers) return Rule(*checkers)
def __or__(self, other) -> NoReturn: def __or__(self, other) -> NoReturn:
@ -226,7 +243,7 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词 * ``*keywords: str``: 关键词
""" """
async def _keyword(bot: Bot, event: Event, state: T_State) -> bool: async def _keyword(event: Event) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
return False return False
text = event.get_plaintext() text = event.get_plaintext()
@ -274,7 +291,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep): for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: Bot, event: Event, state: T_State) -> bool: async def _command(state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in commands return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command) return Rule(_command)
@ -294,7 +311,7 @@ class ArgumentParser(ArgParser):
old_message += message old_message += message
setattr(self, "message", old_message) setattr(self, "message", old_message)
def exit(self, status=0, message=None): def exit(self, status: int = 0, message: Optional[str] = None):
raise ParserExit(status=status, raise ParserExit(status=status,
message=message or getattr(self, "message", None)) message=message or getattr(self, "message", None))
@ -360,7 +377,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
for start, sep in product(command_start, command_sep): for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool: async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands: if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message()) message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY] strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
@ -400,7 +417,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags) pattern = re.compile(regex, flags)
async def _regex(bot: Bot, event: Event, state: T_State) -> bool: async def _regex(event: Event, state: T_State) -> bool:
if event.get_type() != "message": if event.get_type() != "message":
return False return False
matched = pattern.search(str(event.get_message())) matched = pattern.search(str(event.get_message()))
@ -415,6 +432,10 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
return Rule(_regex) return Rule(_regex)
async def _to_me(event: Event) -> bool:
return event.is_tome()
def to_me() -> Rule: def to_me() -> Rule:
""" """
:说明: :说明:
@ -426,7 +447,4 @@ def to_me() -> Rule:
* *
""" """
async def _to_me(bot: Bot, event: Event, state: T_State) -> bool:
return event.is_tome()
return Rule(_to_me) return Rule(_to_me)

View File

@ -123,10 +123,9 @@ T_RunPostProcessor = Callable[..., Awaitable[None]]
事件响应器运行前预处理函数 RunPostProcessor 类型第二个参数为运行时产生的错误如果存在 事件响应器运行前预处理函数 RunPostProcessor 类型第二个参数为运行时产生的错误如果存在
""" """
T_RuleChecker = Callable[["Bot", "Event", T_State], Union[bool, T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]]
Awaitable[bool]]]
""" """
:类型: ``Callable[[Bot, Event, T_State], Union[bool, Awaitable[bool]]]`` :类型: ``Callable[..., Union[bool, Awaitable[bool]]]``
:说明: :说明: