⚗️ change rule to use handler

This commit is contained in:
yanyongyu 2021-11-19 18:18:53 +08:00
parent ee619a33a9
commit 471d306e13
8 changed files with 182 additions and 148 deletions

View File

@ -18,6 +18,7 @@ from nonebot.log import logger
from .models import Param as Param
from .utils import get_typed_signature
from .models import Dependent as Dependent
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
@ -113,20 +114,19 @@ def get_dependent(*,
async def solve_dependencies(
*,
dependent: Dependent,
stack: Optional[AsyncExitStack] = None,
sub_dependents: Optional[List[Dependent]] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any
) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any], bool]:
_dependent: Dependent,
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_overrides_provider: Optional[Any] = None,
_dependency_cache: Optional[Dict[Callable[..., Any], Any]] = None,
**params: Any) -> Tuple[Dict[str, Any], Dict[Callable[..., Any], Any]]:
values: Dict[str, Any] = {}
dependency_cache = dependency_cache or {}
dependency_cache = _dependency_cache or {}
# solve sub dependencies
sub_dependent: Dependent
for sub_dependent in chain(sub_dependents or tuple(),
dependent.dependencies):
for sub_dependent in chain(_sub_dependents or tuple(),
_dependent.dependencies):
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Callable[..., Any],
sub_dependent.cache_key)
@ -134,10 +134,10 @@ async def solve_dependencies(
# dependency overrides
use_sub_dependant = sub_dependent
if (dependency_overrides_provider and
hasattr(dependency_overrides_provider, "dependency_overrides")):
if (_dependency_overrides_provider and hasattr(
_dependency_overrides_provider, "dependency_overrides")):
original_call = sub_dependent.func
func = getattr(dependency_overrides_provider,
func = getattr(_dependency_overrides_provider,
"dependency_overrides",
{}).get(original_call, original_call)
use_sub_dependant = get_dependent(
@ -148,13 +148,11 @@ async def solve_dependencies(
# solve sub dependency with current cache
solved_result = await solve_dependencies(
dependent=use_sub_dependant,
dependency_overrides_provider=dependency_overrides_provider,
_dependent=use_sub_dependant,
_dependency_overrides_provider=_dependency_overrides_provider,
dependency_cache=dependency_cache,
**params)
sub_values, sub_dependency_cache, ignored = solved_result
if ignored:
return values, dependency_cache, True
sub_values, sub_dependency_cache = solved_result
# update cache?
dependency_cache.update(sub_dependency_cache)
@ -163,13 +161,13 @@ async def solve_dependencies(
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance(
stack, AsyncExitStack
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
solved = await stack.enter_async_context(cm)
solved = await _stack.enter_async_context(cm)
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:
@ -183,7 +181,7 @@ async def solve_dependencies(
dependency_cache[sub_dependent.cache_key] = solved
# usual dependency
for field in dependent.params:
for field in _dependent.params:
field_info = field.field_info
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
@ -194,13 +192,13 @@ async def solve_dependencies(
if errs_:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {dependent.func} "
f"type {type(value)} not match depends {_dependent.func} "
f"annotation {field._type_display()}, ignored")
return values, dependency_cache, True
raise SkippedException
else:
values[field.name] = value
return values, dependency_cache, False
return values, dependency_cache
def Depends(dependency: Optional[Callable[..., Any]] = None,

View File

@ -248,6 +248,8 @@ class Driver(ForwardDriver):
await asyncio.sleep(3)
continue
setup_ = cast(HTTPPollingSetup, setup_)
if not bot:
request = await _build_request(setup_)
if not request:
@ -264,7 +266,6 @@ class Driver(ForwardDriver):
bot.request = request
request = cast(HTTPRequest, request)
setup_ = cast(HTTPPollingSetup, setup_)
headers = request.headers
timeout = aiohttp.ClientTimeout(30)

View File

@ -409,6 +409,8 @@ class Driver(ReverseDriver, ForwardDriver):
await asyncio.sleep(3)
continue
setup_ = cast(HTTPPollingSetup, setup_)
if not bot:
request = await _build_request(setup_)
if not request:
@ -423,7 +425,6 @@ class Driver(ReverseDriver, ForwardDriver):
continue
bot.request = request
setup_ = cast(HTTPPollingSetup, setup_)
request = cast(HTTPRequest, request)
headers = request.headers

View File

@ -6,6 +6,8 @@
这些异常并非所有需要用户处理 NoneBot 内部运行时被捕获并进行对应操作
"""
from typing import Optional
class NoneBotException(Exception):
"""
@ -13,9 +15,33 @@ class NoneBotException(Exception):
所有 NoneBot 发生的异常基类
"""
pass
# Rule Exception
class ParserExit(NoneBotException):
"""
:说明:
``shell command`` 处理消息失败时返回的异常
:参数:
* ``status``
* ``message``
"""
def __init__(self, status: int = 0, message: Optional[str] = None):
self.status = status
self.message = message
def __repr__(self):
return f"<ParserExit status={self.status} message={self.message}>"
def __str__(self):
return self.__repr__()
# Processor Exception
class IgnoredException(NoneBotException):
"""
:说明:
@ -37,71 +63,6 @@ class IgnoredException(NoneBotException):
return self.__repr__()
class ParserExit(NoneBotException):
"""
:说明:
``shell command`` 处理消息失败时返回的异常
:参数:
* ``status``
* ``message``
"""
def __init__(self, status=0, message=None):
self.status = status
self.message = message
def __repr__(self):
return f"<ParserExit status={self.status} message={self.message}>"
def __str__(self):
return self.__repr__()
class PausedException(NoneBotException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``
可用于用户输入新信息
:用法:
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出
"""
pass
class RejectedException(NoneBotException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``
可用于用户重新输入
:用法:
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出
"""
pass
class FinishedException(NoneBotException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行
可用于结束用户会话
:用法:
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出
"""
pass
class StopPropagation(NoneBotException):
"""
:说明:
@ -112,9 +73,69 @@ class StopPropagation(NoneBotException):
``Matcher.block == True`` 时抛出
"""
pass
# Matcher Exceptions
class MatcherException(NoneBotException):
"""
:说明:
所有 Matcher 发生的异常基类
"""
class SkippedException(MatcherException):
"""
:说明:
指示 NoneBot 立即结束当前 ``Handler`` 的处理继续处理下一个 ``Handler``
:用法:
可以在 ``Handler`` 中通过 ``Matcher.skip()`` 抛出
"""
class PausedException(MatcherException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后继续下一个 ``Handler``
可用于用户输入新信息
:用法:
可以在 ``Handler`` 中通过 ``Matcher.pause()`` 抛出
"""
class RejectedException(MatcherException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 并等待下一条消息后重新运行当前 ``Handler``
可用于用户重新输入
:用法:
可以在 ``Handler`` 中通过 ``Matcher.reject()`` 抛出
"""
class FinishedException(MatcherException):
"""
:说明:
指示 NoneBot 结束当前 ``Handler`` 且后续 ``Handler`` 不再被运行
可用于结束用户会话
:用法:
可以在 ``Handler`` 中通过 ``Matcher.finish()`` 抛出
"""
# Adapter Exceptions
class AdapterException(NoneBotException):
"""
:说明:
@ -130,7 +151,7 @@ class AdapterException(NoneBotException):
self.adapter_name = adapter_name
class NoLogException(Exception):
class NoLogException(AdapterException):
"""
:说明:

View File

@ -7,24 +7,19 @@
import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Dict, List, Type, Callable, Optional
from typing import Any, Dict, List, Type, Callable, Optional
from nonebot.typing import T_Handler
from nonebot.utils import get_name, run_sync
from nonebot.dependencies import (Param, Dependent, DependsWrapper,
get_dependent, solve_dependencies,
get_parameterless_sub_dependant)
if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
class Handler:
"""事件处理器类。支持依赖注入。"""
def __init__(self,
func: T_Handler,
func: Callable[..., Any],
*,
name: Optional[str] = None,
dependencies: Optional[List[DependsWrapper]] = None,
@ -37,7 +32,7 @@ class Handler:
:参数:
* ``func: T_Handler``: 事件处理函数
* ``func: Callable[..., Any]``: 事件处理函数
* ``name: Optional[str]``: 事件处理器名称默认为函数名
* ``dependencies: Optional[List[DependsWrapper]]``: 额外的非参数依赖注入
* ``allow_types: Optional[List[Type[Param]]]``: 允许的参数类型
@ -45,7 +40,7 @@ class Handler:
"""
self.func = func
"""
:类型: ``T_Handler``
:类型: ``Callable[..., Any]``
:说明: 事件处理函数
"""
self.name = get_name(func) if name is None else name
@ -85,24 +80,21 @@ class Handler:
_dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None,
**params) -> Any:
values, _, ignored = await solve_dependencies(
dependent=self.dependent,
stack=_stack,
sub_dependents=[
values, cache = await solve_dependencies(
_dependent=self.dependent,
_stack=_stack,
_sub_dependents=[
self.sub_dependents[dependency.dependency] # type: ignore
for dependency in self.dependencies
],
dependency_overrides_provider=self.dependency_overrides_provider,
dependency_cache=_dependency_cache,
_dependency_overrides_provider=self.dependency_overrides_provider,
_dependency_cache=_dependency_cache,
**params)
if ignored:
return
if asyncio.iscoroutinefunction(self.func):
await self.func(**values)
return await self.func(**values)
else:
await run_sync(self.func)(**values)
return await run_sync(self.func)(**values)
def cache_dependent(self, dependency: DependsWrapper):
if not dependency.dependency:

View File

@ -20,10 +20,11 @@ from nonebot.dependencies import DependsWrapper
from nonebot.permission import USER, Permission
from nonebot.adapters import (Bot, Event, Message, MessageSegment,
MessageTemplate)
from nonebot.exception import (PausedException, StopPropagation,
FinishedException, RejectedException)
from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater,
T_StateFactory, T_PermissionUpdater)
from nonebot.exception import (PausedException, StopPropagation,
SkippedException, FinishedException,
RejectedException)
if TYPE_CHECKING:
from nonebot.plugin import Plugin
@ -601,10 +602,13 @@ class Matcher(metaclass=MatcherMeta):
while self.handlers:
handler = self.handlers.pop(0)
logger.debug(f"Running handler {handler}")
try:
await handler(matcher=self,
bot=bot,
event=event,
state=self.state)
except SkippedException:
pass
except RejectedException:
self.handlers.insert(0, handler) # type: ignore

View File

@ -14,16 +14,18 @@ import shlex
import asyncio
from itertools import product
from argparse import Namespace
from contextlib import AsyncExitStack
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence,
Awaitable)
from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Sequence, Awaitable)
from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger
from nonebot.utils import run_sync
from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment
@ -62,16 +64,22 @@ class Rule:
"""
__slots__ = ("checkers",)
def __init__(
self, *checkers: Callable[[Bot, Event, T_State],
Awaitable[bool]]) -> None:
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam
]
def __init__(self, *checkers: T_RuleChecker) -> None:
"""
:参数:
* ``*checkers: Callable[[Bot, Event, T_State], Awaitable[bool]]``: **异步** RuleChecker
* ``*checkers: T_RuleChecker``: RuleChecker
"""
self.checkers = set(checkers)
self.checkers = set(
Handler(checker,
allow_types=self.HANDLER_PARAM_TYPES,
dependency_overrides_provider=get_driver())
for checker in checkers)
"""
:说明:
@ -79,10 +87,17 @@ class Rule:
:类型:
* ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]``
* ``Set[Handler]``
"""
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
async def __call__(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[Dict[Callable[..., Any],
Any]] = None) -> bool:
"""
:说明:
@ -99,19 +114,21 @@ class Rule:
- ``bool``
"""
results = await asyncio.gather(
*map(lambda c: c(bot, event, state), self.checkers))
checker(bot=bot,
event=event,
state=state,
_stack=stack,
_dependency_cache=dependency_cache)
for checker in self.checkers)
return all(results)
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
checkers = self.checkers.copy()
if other is None:
return self
elif isinstance(other, Rule):
checkers |= other.checkers
elif asyncio.iscoroutinefunction(other):
checkers.add(other) # type: ignore
checkers = [*self.checkers, *other.checkers]
else:
checkers.add(run_sync(other))
checkers = [*self.checkers, other]
return Rule(*checkers)
def __or__(self, other) -> NoReturn:
@ -226,7 +243,7 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
async def _keyword(bot: Bot, event: Event, state: T_State) -> bool:
async def _keyword(event: Event) -> bool:
if event.get_type() != "message":
return False
text = event.get_plaintext()
@ -274,7 +291,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: Bot, event: Event, state: T_State) -> bool:
async def _command(state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command)
@ -294,7 +311,7 @@ class ArgumentParser(ArgParser):
old_message += message
setattr(self, "message", old_message)
def exit(self, status=0, message=None):
def exit(self, status: int = 0, message: Optional[str] = None):
raise ParserExit(status=status,
message=message or getattr(self, "message", None))
@ -360,7 +377,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool:
async def _shell_command(event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message())
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
@ -400,7 +417,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags)
async def _regex(bot: Bot, event: Event, state: T_State) -> bool:
async def _regex(event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = pattern.search(str(event.get_message()))
@ -415,6 +432,10 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
return Rule(_regex)
async def _to_me(event: Event) -> bool:
return event.is_tome()
def to_me() -> Rule:
"""
:说明:
@ -426,7 +447,4 @@ def to_me() -> Rule:
*
"""
async def _to_me(bot: Bot, event: Event, state: T_State) -> bool:
return event.is_tome()
return Rule(_to_me)

View File

@ -123,10 +123,9 @@ T_RunPostProcessor = Callable[..., Awaitable[None]]
事件响应器运行前预处理函数 RunPostProcessor 类型第二个参数为运行时产生的错误如果存在
"""
T_RuleChecker = Callable[["Bot", "Event", T_State], Union[bool,
Awaitable[bool]]]
T_RuleChecker = Callable[..., Union[bool, Awaitable[bool]]]
"""
:类型: ``Callable[[Bot, Event, T_State], Union[bool, Awaitable[bool]]]``
:类型: ``Callable[..., Union[bool, Awaitable[bool]]]``
:说明: