nonebot2/nonebot/internal/rule.py

122 lines
3.5 KiB
Python
Raw Permalink Normal View History

from contextlib import AsyncExitStack
from typing import Union, ClassVar, NoReturn, Optional
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
2022-02-06 17:08:11 +08:00
from .adapter import Bot, Event
from .params import Param, BotParam, EventParam, StateParam, DependParam, DefaultParam
class Rule:
"""{ref}`nonebot.matcher.Matcher` 规则类。
当事件传递时 {ref}`nonebot.matcher.Matcher` 运行前进行检查
参数:
*checkers: RuleChecker
用法:
```python
Rule(async_function) & sync_function
# 等价于
Rule(async_function, sync_function)
```
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES: ClassVar[list[type[Param]]] = [
DependParam,
BotParam,
EventParam,
StateParam,
DefaultParam,
]
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
self.checkers: set[Dependent[bool]] = {
(
checker
if isinstance(checker, Dependent)
else Dependent[bool].parse(
call=checker, allow_types=self.HANDLER_PARAM_TYPES
)
)
for checker in checkers
}
"""存储 `RuleChecker`"""
def __repr__(self) -> str:
return f"Rule({', '.join(repr(checker) for checker in self.checkers)})"
async def __call__(
self,
bot: Bot,
event: Event,
state: T_State,
stack: Optional[AsyncExitStack] = None,
dependency_cache: Optional[T_DependencyCache] = None,
) -> bool:
"""检查是否符合所有规则
参数:
bot: Bot 对象
event: Event 对象
state: 当前 State
stack: 异步上下文栈
dependency_cache: 依赖缓存
"""
if not self.checkers:
return True
result = True
def _handle_skipped_exception(
exc_group: BaseExceptionGroup[SkippedException],
) -> None:
nonlocal result
result = False
async def _run_checker(checker: Dependent[bool]) -> None:
nonlocal result
# calculate the result first to avoid data racing
is_passed = await checker(
bot=bot,
event=event,
state=state,
stack=stack,
dependency_cache=dependency_cache,
)
result &= is_passed
with catch({SkippedException: _handle_skipped_exception}):
async with anyio.create_task_group() as tg:
for checker in self.checkers:
tg.start_soon(_run_checker, checker)
return result
def __and__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None:
return self
elif isinstance(other, Rule):
return Rule(*self.checkers, *other.checkers)
else:
return Rule(*self.checkers, other)
def __rand__(self, other: Optional[Union["Rule", T_RuleChecker]]) -> "Rule":
if other is None:
return self
elif isinstance(other, Rule):
return Rule(*other.checkers, *self.checkers)
else:
return Rule(other, *self.checkers)
def __or__(self, other: object) -> NoReturn:
raise RuntimeError("Or operation between rules is not allowed.")