mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
add permission and command
This commit is contained in:
parent
865fd6af4c
commit
6435e29e8b
@ -4,10 +4,6 @@
|
||||
import logging
|
||||
import importlib
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||
from nonebot.typing import Type, Union, Driver, Optional, NoReturn
|
||||
|
||||
_driver: Optional[Driver] = None
|
||||
@ -34,6 +30,10 @@ def get_bots():
|
||||
return driver.bots
|
||||
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||
|
||||
try:
|
||||
import nonebot_test
|
||||
except ImportError:
|
||||
|
@ -83,6 +83,16 @@ class BaseEvent(abc.ABC):
|
||||
def sub_type(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def user_id(self) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@user_id.setter
|
||||
@abc.abstractmethod
|
||||
def user_id(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def message(self) -> Optional[Message]:
|
||||
@ -103,6 +113,21 @@ class BaseEvent(abc.ABC):
|
||||
def raw_message(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def plain_text(self) -> Optional[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def sender(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@sender.setter
|
||||
@abc.abstractmethod
|
||||
def sender(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMessageSegment(abc.ABC):
|
||||
|
@ -142,7 +142,10 @@ class Bot(BaseBot):
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
response = await client.post(api_root + api, json=data)
|
||||
response = await client.post(
|
||||
api_root + api,
|
||||
json=data,
|
||||
timeout=self.config.api_timeout)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
result = response.json()
|
||||
@ -193,6 +196,16 @@ class Event(BaseEvent):
|
||||
def sub_type(self, value) -> None:
|
||||
self._raw_event["sub_type"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def user_id(self) -> Optional[int]:
|
||||
return self._raw_event.get("user_id")
|
||||
|
||||
@user_id.setter
|
||||
@overrides(BaseEvent)
|
||||
def user_id(self, value) -> None:
|
||||
self._raw_event["user_id"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def message(self) -> Optional["Message"]:
|
||||
@ -213,6 +226,21 @@ class Event(BaseEvent):
|
||||
def raw_message(self, value) -> None:
|
||||
self._raw_event["raw_message"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def plain_text(self) -> Optional[str]:
|
||||
return self.message and self.message.extract_plain_text()
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def sender(self) -> Optional[dict]:
|
||||
return self._raw_event.get("sender")
|
||||
|
||||
@sender.setter
|
||||
@overrides(BaseEvent)
|
||||
def sender(self, value) -> None:
|
||||
self._raw_event["sender"] = value
|
||||
|
||||
|
||||
class MessageSegment(BaseMessageSegment):
|
||||
|
||||
|
@ -103,12 +103,14 @@ class Config(BaseConfig):
|
||||
|
||||
# bot connection configs
|
||||
api_root: Dict[str, str] = {}
|
||||
api_timeout: float = 60.
|
||||
api_timeout: Optional[float] = 60.
|
||||
access_token: Optional[str] = None
|
||||
|
||||
# bot runtime configs
|
||||
superusers: Set[int] = set()
|
||||
nickname: Union[str, Set[str]] = ""
|
||||
command_start: Set[str] = {"/"}
|
||||
command_sep: Set[str] = {"."}
|
||||
session_expire_timeout: timedelta = timedelta(minutes=2)
|
||||
|
||||
# custom configs
|
||||
|
@ -6,8 +6,9 @@ from functools import wraps
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
from nonebot.rule import SyncRule, user
|
||||
from nonebot.typing import Bot, Rule, Event, Handler
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.permission import Permission, EVERYBODY, USER
|
||||
from nonebot.typing import Bot, Event, Handler
|
||||
from nonebot.typing import Type, List, Dict, Optional, NoReturn
|
||||
from nonebot.exception import PausedException, RejectedException, FinishedException
|
||||
|
||||
@ -18,7 +19,8 @@ class Matcher:
|
||||
"""`Matcher`类
|
||||
"""
|
||||
|
||||
rule: Rule = SyncRule()
|
||||
rule: Rule = Rule()
|
||||
permission: Permission = Permission()
|
||||
handlers: List[Handler] = []
|
||||
temp: bool = False
|
||||
expire_time: Optional[datetime] = None
|
||||
@ -38,7 +40,8 @@ class Matcher:
|
||||
|
||||
@classmethod
|
||||
def new(cls,
|
||||
rule: Rule = SyncRule(),
|
||||
rule: Rule = Rule(),
|
||||
permission: Permission = Permission(),
|
||||
handlers: list = [],
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
@ -54,6 +57,7 @@ class Matcher:
|
||||
NewMatcher = type(
|
||||
"Matcher", (Matcher,), {
|
||||
"rule": rule,
|
||||
"permission": permission,
|
||||
"handlers": handlers,
|
||||
"temp": temp,
|
||||
"expire_time": expire_time,
|
||||
@ -66,7 +70,11 @@ class Matcher:
|
||||
return NewMatcher
|
||||
|
||||
@classmethod
|
||||
async def check_rule(cls, bot: Bot, event: Event) -> bool:
|
||||
async def check_perm(cls, bot: Bot, event: Event) -> bool:
|
||||
return await cls.permission(bot, event)
|
||||
|
||||
@classmethod
|
||||
async def check_rule(cls, bot: Bot, event: Event, state: dict) -> bool:
|
||||
"""检查 Matcher 的 Rule 是否成立
|
||||
|
||||
Args:
|
||||
@ -75,7 +83,7 @@ class Matcher:
|
||||
Returns:
|
||||
bool: 条件成立与否
|
||||
"""
|
||||
return await cls.rule(bot, event)
|
||||
return await cls.rule(bot, event, state)
|
||||
|
||||
# @classmethod
|
||||
# def args_parser(cls, func: Callable[[Event, dict], None]):
|
||||
@ -144,11 +152,14 @@ class Matcher:
|
||||
# raise RejectedException
|
||||
|
||||
# 运行handlers
|
||||
async def run(self, bot: Bot, event: Event):
|
||||
async def run(self, bot: Bot, event: Event, state):
|
||||
try:
|
||||
# if self.parser:
|
||||
# await self.parser(event, state) # type: ignore
|
||||
|
||||
# Refresh preprocess state
|
||||
self.state.update(state)
|
||||
|
||||
for _ in range(len(self.handlers)):
|
||||
handler = self.handlers.pop(0)
|
||||
annotation = typing.get_type_hints(handler)
|
||||
@ -158,23 +169,25 @@ class Matcher:
|
||||
await handler(bot, event, self.state)
|
||||
except RejectedException:
|
||||
self.handlers.insert(0, handler) # type: ignore
|
||||
matcher = Matcher.new(user(event.user_id) & self.rule,
|
||||
matcher = Matcher.new(
|
||||
self.rule,
|
||||
USER(event.user_id, perm=self.permission), # type:ignore
|
||||
self.handlers,
|
||||
temp=True,
|
||||
priority=0,
|
||||
default_state=self.state,
|
||||
expire_time=datetime.now() +
|
||||
bot.config.session_expire_timeout)
|
||||
expire_time=datetime.now() + bot.config.session_expire_timeout)
|
||||
matchers[0].append(matcher)
|
||||
return
|
||||
except PausedException:
|
||||
matcher = Matcher.new(user(event.user_id) & self.rule,
|
||||
matcher = Matcher.new(
|
||||
self.rule,
|
||||
USER(event.user_id, perm=self.permission), # type:ignore
|
||||
self.handlers,
|
||||
temp=True,
|
||||
priority=0,
|
||||
default_state=self.state,
|
||||
expire_time=datetime.now() +
|
||||
bot.config.session_expire_timeout)
|
||||
expire_time=datetime.now() + bot.config.session_expire_timeout)
|
||||
matchers[0].append(matcher)
|
||||
return
|
||||
except FinishedException:
|
||||
|
@ -5,6 +5,7 @@ import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import TrieRule
|
||||
from nonebot.matcher import matchers
|
||||
from nonebot.exception import IgnoredException
|
||||
from nonebot.typing import Bot, Set, Event, PreProcessor
|
||||
@ -19,8 +20,9 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
|
||||
|
||||
async def handle_event(bot: Bot, event: Event):
|
||||
coros = []
|
||||
state = {}
|
||||
for preprocessor in _event_preprocessors:
|
||||
coros.append(preprocessor(bot, event))
|
||||
coros.append(preprocessor(bot, event, state))
|
||||
if coros:
|
||||
try:
|
||||
await asyncio.gather(*coros)
|
||||
@ -28,6 +30,9 @@ async def handle_event(bot: Bot, event: Event):
|
||||
logger.info(f"Event {event} is ignored")
|
||||
return
|
||||
|
||||
# Trie Match
|
||||
_, _ = TrieRule.get_value(bot, event, state)
|
||||
|
||||
for priority in sorted(matchers.keys()):
|
||||
index = 0
|
||||
while index <= len(matchers[priority]):
|
||||
@ -40,7 +45,9 @@ async def handle_event(bot: Bot, event: Event):
|
||||
|
||||
# Check rule
|
||||
try:
|
||||
if not await Matcher.check_rule(bot, event):
|
||||
if not await Matcher.check_perm(
|
||||
bot, event) or not await Matcher.check_rule(
|
||||
bot, event, state):
|
||||
index += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
@ -55,7 +62,7 @@ async def handle_event(bot: Bot, event: Event):
|
||||
del matchers[priority][index]
|
||||
|
||||
try:
|
||||
await matcher.run(bot, event)
|
||||
await matcher.run(bot, event, state)
|
||||
except Exception as e:
|
||||
logger.error(f"Running matcher {matcher} failed.")
|
||||
logger.exception(e)
|
||||
|
124
nonebot/permission.py
Normal file
124
nonebot/permission.py
Normal file
@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
|
||||
from nonebot.utils import run_sync
|
||||
from nonebot.typing import Bot, Event, Union, NoReturn, PermissionChecker
|
||||
|
||||
|
||||
class Permission:
|
||||
__slots__ = ("checkers",)
|
||||
|
||||
def __init__(self, *checkers: PermissionChecker) -> None:
|
||||
self.checkers = list(checkers)
|
||||
|
||||
async def __call__(self, bot: Bot, event: Event) -> bool:
|
||||
if not self.checkers:
|
||||
return True
|
||||
results = await asyncio.gather(
|
||||
*map(lambda c: c(bot, event), self.checkers))
|
||||
return any(results)
|
||||
|
||||
def __and__(self, other) -> NoReturn:
|
||||
raise RuntimeError("And operation between Permissions is not allowed.")
|
||||
|
||||
def __or__(self, other: Union["Permission",
|
||||
PermissionChecker]) -> "Permission":
|
||||
checkers = [*self.checkers]
|
||||
if isinstance(other, Permission):
|
||||
checkers.extend(other.checkers)
|
||||
elif asyncio.iscoroutinefunction(other):
|
||||
checkers.append(other)
|
||||
else:
|
||||
checkers.append(run_sync(other))
|
||||
return Permission(*checkers)
|
||||
|
||||
|
||||
async def _message(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "message"
|
||||
|
||||
|
||||
async def _notice(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "notice"
|
||||
|
||||
|
||||
async def _request(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "request"
|
||||
|
||||
|
||||
async def _metaevent(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "meta_event"
|
||||
|
||||
|
||||
MESSAGE = Permission(_message)
|
||||
NOTICE = Permission(_notice)
|
||||
REQUEST = Permission(_request)
|
||||
METAEVENT = Permission(_metaevent)
|
||||
|
||||
|
||||
def USER(*user: int, perm: Permission = Permission()):
|
||||
|
||||
async def _user(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "message" and event.user_id in user and await perm(
|
||||
bot, event)
|
||||
|
||||
return Permission(_user)
|
||||
|
||||
|
||||
async def _private(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "message" and event.detail_type == "private"
|
||||
|
||||
|
||||
async def _private_friend(bot: Bot, event: Event) -> bool:
|
||||
return (event.type == "message" and event.detail_type == "private" and
|
||||
event.sub_type == "friend")
|
||||
|
||||
|
||||
async def _private_group(bot: Bot, event: Event) -> bool:
|
||||
return (event.type == "message" and event.detail_type == "private" and
|
||||
event.sub_type == "group")
|
||||
|
||||
|
||||
async def _private_other(bot: Bot, event: Event) -> bool:
|
||||
return (event.type == "message" and event.detail_type == "private" and
|
||||
event.sub_type == "other")
|
||||
|
||||
|
||||
PRIVATE = Permission(_private)
|
||||
PRIVATE_FRIEND = Permission(_private_friend)
|
||||
PRIVATE_GROUP = Permission(_private_group)
|
||||
PRIVATE_OTHER = Permission(_private_other)
|
||||
|
||||
|
||||
async def _group(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "message" and event.detail_type == "group"
|
||||
|
||||
|
||||
async def _group_member(bot: Bot, event: Event) -> bool:
|
||||
return (event.type == "message" and event.detail_type == "group" and
|
||||
event.sender.get("role") == "member")
|
||||
|
||||
|
||||
async def _group_admin(bot: Bot, event: Event) -> bool:
|
||||
return (event.type == "message" and event.detail_type == "group" and
|
||||
event.sender.get("role") == "admin")
|
||||
|
||||
|
||||
async def _group_owner(bot: Bot, event: Event) -> bool:
|
||||
return (event.type == "message" and event.detail_type == "group" and
|
||||
event.sender.get("role") == "owner")
|
||||
|
||||
|
||||
GROUP = Permission(_group)
|
||||
GROUP_MEMBER = Permission(_group_member)
|
||||
GROUP_ADMIN = Permission(_group_admin)
|
||||
GROUP_OWNER = Permission(_group_owner)
|
||||
|
||||
|
||||
async def _superuser(bot: Bot, event: Event) -> bool:
|
||||
return event.type == "message" and event.user_id in bot.config.superusers
|
||||
|
||||
|
||||
SUPERUSER = Permission(_superuser)
|
||||
EVERYBODY = MESSAGE
|
@ -1,14 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
import pkgutil
|
||||
import importlib
|
||||
from importlib.util import module_from_spec
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.rule import SyncRule, metaevent, message, notice, request
|
||||
from nonebot.typing import Set, Dict, Type, Rule, Union, Optional, ModuleType, RuleChecker
|
||||
from nonebot.rule import Rule, startswith, endswith, command, regex
|
||||
from nonebot.permission import Permission, METAEVENT, MESSAGE, NOTICE, REQUEST
|
||||
from nonebot.typing import Set, Dict, Type, Tuple, Union, Optional, ModuleType, RuleChecker
|
||||
|
||||
plugins: Dict[str, "Plugin"] = {}
|
||||
|
||||
@ -25,13 +27,14 @@ class Plugin(object):
|
||||
self.matchers = matchers
|
||||
|
||||
|
||||
def on_metaevent(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new(metaevent() & rule,
|
||||
matcher = Matcher.new(Rule() & rule,
|
||||
METAEVENT,
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
handlers=handlers,
|
||||
@ -40,13 +43,15 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
return matcher
|
||||
|
||||
|
||||
def on_message(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
def on_message(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
permission: Permission = MESSAGE,
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new(message() & rule,
|
||||
matcher = Matcher.new(Rule() & rule,
|
||||
permission,
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
handlers=handlers,
|
||||
@ -55,13 +60,14 @@ def on_message(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
return matcher
|
||||
|
||||
|
||||
def on_notice(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new(notice() & rule,
|
||||
matcher = Matcher.new(Rule() & rule,
|
||||
NOTICE,
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
handlers=handlers,
|
||||
@ -70,13 +76,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
return matcher
|
||||
|
||||
|
||||
def on_request(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
def on_request(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
priority: int = 1,
|
||||
state={}) -> Type[Matcher]:
|
||||
matcher = Matcher.new(request() & rule,
|
||||
matcher = Matcher.new(Rule() & rule,
|
||||
REQUEST,
|
||||
temp=temp,
|
||||
priority=priority,
|
||||
handlers=handlers,
|
||||
@ -85,22 +92,40 @@ def on_request(rule: Union[Rule, RuleChecker] = SyncRule(),
|
||||
return matcher
|
||||
|
||||
|
||||
# def on_startswith(msg,
|
||||
# start: int = None,
|
||||
# end: int = None,
|
||||
# rule: Optional[Rule] = None,
|
||||
# **kwargs) -> Type[Matcher]:
|
||||
# return on_message(startswith(msg, start, end) &
|
||||
# rule, **kwargs) if rule else on_message(
|
||||
# startswith(msg, start, end), **kwargs)
|
||||
def on_startswith(msg: str,
|
||||
rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
permission: Permission = MESSAGE,
|
||||
**kwargs) -> Type[Matcher]:
|
||||
return on_message(startswith(msg) &
|
||||
rule, permission, **kwargs) if rule else on_message(
|
||||
startswith(msg), permission, **kwargs)
|
||||
|
||||
# def on_regex(pattern,
|
||||
# flags: Union[int, re.RegexFlag] = 0,
|
||||
# rule: Optional[Rule] = None,
|
||||
# **kwargs) -> Type[Matcher]:
|
||||
# return on_message(regex(pattern, flags) &
|
||||
# rule, **kwargs) if rule else on_message(
|
||||
# regex(pattern, flags), **kwargs)
|
||||
|
||||
def on_endswith(msg: str,
|
||||
rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
permission: Permission = MESSAGE,
|
||||
**kwargs) -> Type[Matcher]:
|
||||
return on_message(endswith(msg) &
|
||||
rule, permission, **kwargs) if rule else on_message(
|
||||
startswith(msg), permission, **kwargs)
|
||||
|
||||
|
||||
def on_command(cmd: Tuple[str],
|
||||
rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
permission: Permission = MESSAGE,
|
||||
**kwargs) -> Type[Matcher]:
|
||||
return on_message(command(cmd) &
|
||||
rule, permission, **kwargs) if rule else on_message(
|
||||
command(cmd), permission, **kwargs)
|
||||
|
||||
|
||||
def on_regex(pattern: str,
|
||||
flags: Union[int, re.RegexFlag] = 0,
|
||||
rule: Optional[Rule] = None,
|
||||
**kwargs) -> Type[Matcher]:
|
||||
return on_message(regex(pattern, flags) &
|
||||
rule, **kwargs) if rule else on_message(
|
||||
regex(pattern, flags), **kwargs)
|
||||
|
||||
|
||||
def load_plugin(module_path: str) -> Optional[Plugin]:
|
||||
|
281
nonebot/rule.py
281
nonebot/rule.py
@ -2,207 +2,126 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
import abc
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from itertools import product
|
||||
|
||||
from pygtrie import CharTrie
|
||||
|
||||
from nonebot import get_driver
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import run_sync
|
||||
from nonebot.typing import Bot, Event, Union, Optional, Awaitable
|
||||
from nonebot.typing import RuleChecker, SyncRuleChecker, AsyncRuleChecker
|
||||
from nonebot.typing import Bot, Any, Dict, Event, Union, Tuple, NoReturn, RuleChecker
|
||||
|
||||
|
||||
class BaseRule(abc.ABC):
|
||||
class Rule:
|
||||
__slots__ = ("checkers",)
|
||||
|
||||
def __init__(self, checker: RuleChecker):
|
||||
self.checker: RuleChecker = checker
|
||||
def __init__(self, *checkers: RuleChecker) -> None:
|
||||
self.checkers = list(checkers)
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]:
|
||||
raise NotImplementedError
|
||||
async def __call__(self, bot: Bot, event: Event, state: dict) -> bool:
|
||||
results = await asyncio.gather(
|
||||
*map(lambda c: c(bot, event, state), self.checkers))
|
||||
return all(results)
|
||||
|
||||
@abc.abstractmethod
|
||||
def __and__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule":
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __or__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule":
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __neg__(self) -> "BaseRule":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncRule(BaseRule):
|
||||
|
||||
def __init__(self, checker: Optional[AsyncRuleChecker] = None):
|
||||
|
||||
async def always_true(bot: Bot, event: Event) -> bool:
|
||||
return True
|
||||
|
||||
self.checker: AsyncRuleChecker = checker or always_true
|
||||
|
||||
def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]:
|
||||
return self.checker(bot, event)
|
||||
|
||||
def __and__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule":
|
||||
func = other
|
||||
if isinstance(other, BaseRule):
|
||||
func = other.checker
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
func = run_sync(func)
|
||||
|
||||
async def tmp(bot: Bot, event: Event) -> bool:
|
||||
a, b = await asyncio.gather(self.checker(bot, event),
|
||||
func(bot, event))
|
||||
return a and b
|
||||
|
||||
return AsyncRule(tmp)
|
||||
|
||||
def __or__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule":
|
||||
func = other
|
||||
if isinstance(other, BaseRule):
|
||||
func = other.checker
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
func = run_sync(func)
|
||||
|
||||
async def tmp(bot: Bot, event: Event) -> bool:
|
||||
a, b = await asyncio.gather(self.checker(bot, event),
|
||||
func(bot, event))
|
||||
return a or b
|
||||
|
||||
return AsyncRule(tmp)
|
||||
|
||||
def __neg__(self) -> "AsyncRule":
|
||||
|
||||
async def neg(bot: Bot, event: Event) -> bool:
|
||||
result = await self.checker(bot, event)
|
||||
return not result
|
||||
|
||||
return AsyncRule(neg)
|
||||
|
||||
|
||||
class SyncRule(BaseRule):
|
||||
|
||||
def __init__(self, checker: Optional[SyncRuleChecker] = None):
|
||||
|
||||
def always_true(bot: Bot, event: Event) -> bool:
|
||||
return True
|
||||
|
||||
self.checker: SyncRuleChecker = checker or always_true
|
||||
|
||||
def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]:
|
||||
return run_sync(self.checker)(bot, event)
|
||||
|
||||
def __and__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule:
|
||||
func = other
|
||||
if isinstance(other, BaseRule):
|
||||
func = other.checker
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
# func: SyncRuleChecker
|
||||
syncfunc = cast(SyncRuleChecker, func)
|
||||
|
||||
def tmp(bot: Bot, event: Event) -> bool:
|
||||
return self.checker(bot, event) and syncfunc(bot, event)
|
||||
|
||||
return SyncRule(tmp)
|
||||
def __and__(self, other: Union["Rule", RuleChecker]) -> "Rule":
|
||||
checkers = [*self.checkers]
|
||||
if isinstance(other, Rule):
|
||||
checkers.extend(other.checkers)
|
||||
elif asyncio.iscoroutinefunction(other):
|
||||
checkers.append(other)
|
||||
else:
|
||||
# func: AsyncRuleChecker
|
||||
asyncfunc = cast(AsyncRuleChecker, func)
|
||||
checkers.append(run_sync(other))
|
||||
return Rule(*checkers)
|
||||
|
||||
async def tmp(bot: Bot, event: Event) -> bool:
|
||||
a, b = await asyncio.gather(
|
||||
run_sync(self.checker)(bot, event), asyncfunc(bot, event))
|
||||
return a and b
|
||||
|
||||
return AsyncRule(tmp)
|
||||
|
||||
def __or__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule:
|
||||
func = other
|
||||
if isinstance(other, BaseRule):
|
||||
func = other.checker
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
# func: SyncRuleChecker
|
||||
syncfunc = cast(SyncRuleChecker, func)
|
||||
|
||||
def tmp(bot: Bot, event: Event) -> bool:
|
||||
return self.checker(bot, event) or syncfunc(bot, event)
|
||||
|
||||
return SyncRule(tmp)
|
||||
else:
|
||||
# func: AsyncRuleChecker
|
||||
asyncfunc = cast(AsyncRuleChecker, func)
|
||||
|
||||
async def tmp(bot: Bot, event: Event) -> bool:
|
||||
a, b = await asyncio.gather(
|
||||
run_sync(self.checker)(bot, event), asyncfunc(bot, event))
|
||||
return a or b
|
||||
|
||||
return AsyncRule(tmp)
|
||||
|
||||
def __neg__(self) -> "SyncRule":
|
||||
|
||||
def neg(bot: Bot, event: Event) -> bool:
|
||||
return not self.checker(bot, event)
|
||||
|
||||
return SyncRule(neg)
|
||||
def __or__(self, other) -> NoReturn:
|
||||
raise RuntimeError("Or operation between rules is not allowed.")
|
||||
|
||||
|
||||
def Rule(func: Optional[RuleChecker] = None) -> BaseRule:
|
||||
if func and asyncio.iscoroutinefunction(func):
|
||||
asyncfunc = cast(AsyncRuleChecker, func)
|
||||
return AsyncRule(asyncfunc)
|
||||
else:
|
||||
syncfunc = cast(Optional[SyncRuleChecker], func)
|
||||
return SyncRule(syncfunc)
|
||||
class TrieRule:
|
||||
prefix: CharTrie = CharTrie()
|
||||
suffix: CharTrie = CharTrie()
|
||||
|
||||
@classmethod
|
||||
def add_prefix(cls, prefix: str, value: Any):
|
||||
if prefix in cls.prefix:
|
||||
logger.warning(f'Duplicated prefix rule "{prefix}"')
|
||||
return
|
||||
cls.prefix[prefix] = value
|
||||
|
||||
@classmethod
|
||||
def add_suffix(cls, suffix: str, value: Any):
|
||||
if suffix[::-1] in cls.suffix:
|
||||
logger.warning(f'Duplicated suffix rule "{suffix}"')
|
||||
return
|
||||
cls.suffix[suffix[::-1]] = value
|
||||
|
||||
@classmethod
|
||||
def get_value(cls, bot: Bot, event: Event,
|
||||
state: dict) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
prefix = None
|
||||
suffix = None
|
||||
message = event.message[0]
|
||||
if message.type == "text":
|
||||
prefix = cls.prefix.longest_prefix(message.data["text"].lstrip())
|
||||
message_r = event.message[-1]
|
||||
if message_r.type == "text":
|
||||
suffix = cls.suffix.longest_prefix(
|
||||
message_r.data["text"].rstrip()[::-1])
|
||||
|
||||
state["_prefix"] = {prefix.key: prefix.value} if prefix else {}
|
||||
state["_suffix"] = {suffix.key: suffix.value} if suffix else {}
|
||||
|
||||
return ({
|
||||
prefix.key: prefix.value
|
||||
} if prefix else {}, {
|
||||
suffix.key: suffix.value
|
||||
} if suffix else {})
|
||||
|
||||
|
||||
def message() -> BaseRule:
|
||||
return Rule(lambda bot, event: event.type == "message")
|
||||
def startswith(msg: str) -> Rule:
|
||||
TrieRule.add_prefix(msg, (msg,))
|
||||
|
||||
async def _startswith(bot: Bot, event: Event, state: dict) -> bool:
|
||||
return msg in state["_prefix"]
|
||||
|
||||
return Rule(_startswith)
|
||||
|
||||
|
||||
def notice() -> BaseRule:
|
||||
return Rule(lambda bot, event: event.type == "notice")
|
||||
def endswith(msg: str) -> Rule:
|
||||
TrieRule.add_suffix(msg, (msg,))
|
||||
|
||||
async def _endswith(bot: Bot, event: Event, state: dict) -> bool:
|
||||
return msg in state["_suffix"]
|
||||
|
||||
return Rule(_endswith)
|
||||
|
||||
|
||||
def request() -> BaseRule:
|
||||
return Rule(lambda bot, event: event.type == "request")
|
||||
def keyword(msg: str) -> Rule:
|
||||
|
||||
async def _keyword(bot: Bot, event: Event, state: dict) -> bool:
|
||||
return bool(event.plain_text and msg in event.plain_text)
|
||||
|
||||
return Rule(_keyword)
|
||||
|
||||
|
||||
def metaevent() -> BaseRule:
|
||||
return Rule(lambda bot, event: event.type == "meta_event")
|
||||
def command(command: Tuple[str]) -> Rule:
|
||||
config = get_driver().config
|
||||
command_start = config.command_start
|
||||
command_sep = config.command_sep
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
async def _command(bot: Bot, event: Event, state: dict) -> bool:
|
||||
return command in state["_prefix"].values()
|
||||
|
||||
return Rule(_command)
|
||||
|
||||
|
||||
def user(*qq: int) -> BaseRule:
|
||||
return Rule(lambda bot, event: event.user_id in qq)
|
||||
|
||||
|
||||
def private() -> BaseRule:
|
||||
return Rule(lambda bot, event: event.detail_type == "private")
|
||||
|
||||
|
||||
def group(*group: int) -> BaseRule:
|
||||
return Rule(lambda bot, event: event.detail_type == "group" and event.
|
||||
group_id in group)
|
||||
|
||||
|
||||
def startswith(msg, start: int = None, end: int = None) -> BaseRule:
|
||||
return Rule(lambda bot, event: event.message.startswith(msg, start, end))
|
||||
|
||||
|
||||
def endswith(msg, start: int = None, end: int = None) -> BaseRule:
|
||||
return Rule(
|
||||
lambda bot, event: event.message.endswith(msg, start=None, end=None))
|
||||
|
||||
|
||||
def has(msg: str) -> BaseRule:
|
||||
return Rule(lambda bot, event: msg in event.message)
|
||||
|
||||
|
||||
def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> BaseRule:
|
||||
def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||
pattern = re.compile(regex, flags)
|
||||
return Rule(lambda bot, event: bool(pattern.search(str(event.message))))
|
||||
|
||||
async def _regex(bot: Bot, event: Event, state: dict) -> bool:
|
||||
return bool(pattern.search(str(event.message)))
|
||||
|
||||
return Rule(_regex)
|
||||
|
@ -8,9 +8,10 @@ from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable
|
||||
|
||||
# import some modules needed when checking types
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.rule import BaseRule
|
||||
from nonebot.rule import Rule as RuleClass
|
||||
from nonebot.matcher import Matcher as MatcherClass
|
||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||
from nonebot.permission import Permission as PermissionClass
|
||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||
|
||||
|
||||
@ -32,14 +33,13 @@ Event = TypeVar("Event", bound="BaseEvent")
|
||||
Message = TypeVar("Message", bound="BaseMessage")
|
||||
MessageSegment = TypeVar("MessageSegment", bound="BaseMessageSegment")
|
||||
|
||||
PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
|
||||
PreProcessor = Callable[[Bot, Event, dict], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]]
|
||||
|
||||
Matcher = TypeVar("Matcher", bound="MatcherClass")
|
||||
Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]]
|
||||
Rule = TypeVar("Rule", bound="BaseRule")
|
||||
_RuleChecker_Return = TypeVar("_RuleChecker_Return", bool, Awaitable[bool])
|
||||
RuleChecker = Callable[[Bot, Event], _RuleChecker_Return]
|
||||
SyncRuleChecker = RuleChecker[Bot, Event, bool]
|
||||
AsyncRuleChecker = RuleChecker[Bot, Event, Awaitable[bool]]
|
||||
Rule = TypeVar("Rule", bound="RuleClass")
|
||||
RuleChecker = Callable[[Bot, Event, dict], Awaitable[bool]]
|
||||
Permission = TypeVar("Permission", bound="PermissionClass")
|
||||
PermissionChecker = Callable[[Bot, Event], Awaitable[bool]]
|
||||
|
27
poetry.lock
generated
27
poetry.lock
generated
@ -414,6 +414,19 @@ reference = "aliyun"
|
||||
type = "legacy"
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
|
||||
[[package]]
|
||||
category = "main"
|
||||
description = "Trie data structure implementation."
|
||||
name = "pygtrie"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "2.3.3"
|
||||
|
||||
[package.source]
|
||||
reference = "aliyun"
|
||||
type = "legacy"
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Python parsing module"
|
||||
@ -540,7 +553,7 @@ description = "Python documentation generator"
|
||||
name = "sphinx"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
version = "3.2.0"
|
||||
version = "3.2.1"
|
||||
|
||||
[package.dependencies]
|
||||
Jinja2 = ">=2.3"
|
||||
@ -587,9 +600,10 @@ unify = "*"
|
||||
yapf = "*"
|
||||
|
||||
[package.source]
|
||||
reference = "5254c22fad13be69d8301e184818c4578d0e4115"
|
||||
reference = "88a68ed340013067a1c673bdf7541680c581fa60"
|
||||
type = "git"
|
||||
url = "https://github.com/nonebot/sphinx-markdown-builder.git"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books"
|
||||
@ -838,7 +852,7 @@ scheduler = ["apscheduler"]
|
||||
test = []
|
||||
|
||||
[metadata]
|
||||
content-hash = "b89641a9b24184b999991e1534842905ece528b73824eb79d6d378d686526da2"
|
||||
content-hash = "4d16d7ad0930bc9851802bc149f843c4e990a987e89414d765579ea8dccc8d6e"
|
||||
python-versions = "^3.7"
|
||||
|
||||
[metadata.files]
|
||||
@ -1002,6 +1016,9 @@ pygments = [
|
||||
{file = "Pygments-2.6.1-py3-none-any.whl", hash = "sha256:ff7a40b4860b727ab48fad6360eb351cc1b33cbf9b15a0f689ca5353e9463324"},
|
||||
{file = "Pygments-2.6.1.tar.gz", hash = "sha256:647344a061c249a3b74e230c739f434d7ea4d8b1d5f3721bc0f3558049b38f44"},
|
||||
]
|
||||
pygtrie = [
|
||||
{file = "pygtrie-2.3.3.tar.gz", hash = "sha256:2204dbd95584f67821da5b3771c4305ac5585552b3230b210f1f05322608db2c"},
|
||||
]
|
||||
pyparsing = [
|
||||
{file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"},
|
||||
{file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"},
|
||||
@ -1035,8 +1052,8 @@ snowballstemmer = [
|
||||
{file = "snowballstemmer-2.0.0.tar.gz", hash = "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"},
|
||||
]
|
||||
sphinx = [
|
||||
{file = "Sphinx-3.2.0-py3-none-any.whl", hash = "sha256:f7db5b76c42c8b5ef31853c2de7178ef378b985d7793829ec071e120dac1d0ca"},
|
||||
{file = "Sphinx-3.2.0.tar.gz", hash = "sha256:cf2d5bc3c6c930ab0a1fbef3ad8a82994b1bf4ae923f8098a05c7e5516f07177"},
|
||||
{file = "Sphinx-3.2.1-py3-none-any.whl", hash = "sha256:ce6fd7ff5b215af39e2fcd44d4a321f6694b4530b6f2b2109b64d120773faea0"},
|
||||
{file = "Sphinx-3.2.1.tar.gz", hash = "sha256:321d6d9b16fa381a5306e5a0b76cd48ffbc588e6340059a729c6fdd66087e0e8"},
|
||||
]
|
||||
sphinx-markdown-builder = []
|
||||
sphinxcontrib-applehelp = [
|
||||
|
@ -21,6 +21,7 @@ classifiers = [
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
httpx = "^0.13.3"
|
||||
pygtrie = "^2.3.3"
|
||||
fastapi = "^0.58.1"
|
||||
uvicorn = "^0.11.5"
|
||||
pydantic = { extras = ["dotenv"], version = "^1.6.1" }
|
||||
|
@ -4,19 +4,19 @@
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.typing import Event
|
||||
from nonebot.plugin import on_message
|
||||
from nonebot.adapters.cqhttp import Bot, Message
|
||||
from nonebot.adapters.cqhttp import Bot
|
||||
|
||||
test_matcher = on_message(state={"default": 1})
|
||||
test_message = on_message(state={"default": 1})
|
||||
|
||||
|
||||
@test_matcher.handle()
|
||||
@test_message.handle()
|
||||
async def test_handler(bot: Bot, event: Event, state: dict):
|
||||
print("Test Matcher Received:", event)
|
||||
print("Current State:", state)
|
||||
state["event"] = event
|
||||
|
||||
|
||||
@test_matcher.receive()
|
||||
@test_message.receive()
|
||||
async def test_receive(bot: Bot, event: Event, state: dict):
|
||||
print("Test Matcher Received next time:", event)
|
||||
print("Current State:", state)
|
@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from . import matchers
|
||||
from . import test_command
|
||||
|
14
tests/test_plugins/test_package/test_command.py
Normal file
14
tests/test_plugins/test_package/test_command.py
Normal file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.typing import Event
|
||||
from nonebot.plugin import on_command
|
||||
from nonebot.adapters.cqhttp import Bot
|
||||
|
||||
test_command = on_command(("帮助",))
|
||||
|
||||
|
||||
@test_command.handle()
|
||||
async def test_handler(bot: Bot, event: Event, state: dict):
|
||||
print(state["_prefix"])
|
Loading…
Reference in New Issue
Block a user