From 6435e29e8ba052eab77c317eb3f3898897ab1d91 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 17 Aug 2020 16:09:41 +0800 Subject: [PATCH] add permission and command --- nonebot/__init__.py | 8 +- nonebot/adapters/__init__.py | 25 ++ nonebot/adapters/cqhttp.py | 30 +- nonebot/config.py | 4 +- nonebot/matcher.py | 55 ++-- nonebot/message.py | 13 +- nonebot/permission.py | 124 ++++++++ nonebot/plugin.py | 75 +++-- nonebot/rule.py | 281 +++++++----------- nonebot/typing.py | 16 +- poetry.lock | 27 +- pyproject.toml | 1 + .../{test_matcher.py => test_message.py} | 8 +- tests/test_plugins/test_package/__init__.py | 2 +- tests/test_plugins/test_package/matchers.py | 0 .../test_plugins/test_package/test_command.py | 14 + 16 files changed, 429 insertions(+), 254 deletions(-) create mode 100644 nonebot/permission.py rename tests/test_plugins/{test_matcher.py => test_message.py} (76%) delete mode 100644 tests/test_plugins/test_package/matchers.py create mode 100644 tests/test_plugins/test_package/test_command.py diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 0049bf76..66da4436 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -4,10 +4,6 @@ import logging import importlib from ipaddress import IPv4Address - -from nonebot.log import logger -from nonebot.config import Env, Config -from nonebot.adapters.cqhttp import Bot as CQBot from nonebot.typing import Type, Union, Driver, Optional, NoReturn _driver: Optional[Driver] = None @@ -34,6 +30,10 @@ def get_bots(): return driver.bots +from nonebot.log import logger +from nonebot.config import Env, Config +from nonebot.adapters.cqhttp import Bot as CQBot + try: import nonebot_test except ImportError: diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 11b639b2..75c59541 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -83,6 +83,16 @@ class BaseEvent(abc.ABC): def sub_type(self, value) -> None: raise NotImplementedError + @property + @abc.abstractmethod + def user_id(self) -> Optional[int]: + raise NotImplementedError + + @user_id.setter + @abc.abstractmethod + def user_id(self, value) -> None: + raise NotImplementedError + @property @abc.abstractmethod def message(self) -> Optional[Message]: @@ -103,6 +113,21 @@ class BaseEvent(abc.ABC): def raw_message(self, value) -> None: raise NotImplementedError + @property + @abc.abstractmethod + def plain_text(self) -> Optional[str]: + raise NotImplementedError + + @property + @abc.abstractmethod + def sender(self) -> Optional[dict]: + raise NotImplementedError + + @sender.setter + @abc.abstractmethod + def sender(self, value) -> None: + raise NotImplementedError + @dataclass class BaseMessageSegment(abc.ABC): diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 0b228b3d..07d2305f 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -142,7 +142,10 @@ class Bot(BaseBot): try: async with httpx.AsyncClient(headers=headers) as client: - response = await client.post(api_root + api, json=data) + response = await client.post( + api_root + api, + json=data, + timeout=self.config.api_timeout) if 200 <= response.status_code < 300: result = response.json() @@ -193,6 +196,16 @@ class Event(BaseEvent): def sub_type(self, value) -> None: self._raw_event["sub_type"] = value + @property + @overrides(BaseEvent) + def user_id(self) -> Optional[int]: + return self._raw_event.get("user_id") + + @user_id.setter + @overrides(BaseEvent) + def user_id(self, value) -> None: + self._raw_event["user_id"] = value + @property @overrides(BaseEvent) def message(self) -> Optional["Message"]: @@ -213,6 +226,21 @@ class Event(BaseEvent): def raw_message(self, value) -> None: self._raw_event["raw_message"] = value + @property + @overrides(BaseEvent) + def plain_text(self) -> Optional[str]: + return self.message and self.message.extract_plain_text() + + @property + @overrides(BaseEvent) + def sender(self) -> Optional[dict]: + return self._raw_event.get("sender") + + @sender.setter + @overrides(BaseEvent) + def sender(self, value) -> None: + self._raw_event["sender"] = value + class MessageSegment(BaseMessageSegment): diff --git a/nonebot/config.py b/nonebot/config.py index 0223d4b0..f5618862 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -103,12 +103,14 @@ class Config(BaseConfig): # bot connection configs api_root: Dict[str, str] = {} - api_timeout: float = 60. + api_timeout: Optional[float] = 60. access_token: Optional[str] = None # bot runtime configs superusers: Set[int] = set() nickname: Union[str, Set[str]] = "" + command_start: Set[str] = {"/"} + command_sep: Set[str] = {"."} session_expire_timeout: timedelta = timedelta(minutes=2) # custom configs diff --git a/nonebot/matcher.py b/nonebot/matcher.py index b5405686..bdf9b474 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -6,8 +6,9 @@ from functools import wraps from datetime import datetime from collections import defaultdict -from nonebot.rule import SyncRule, user -from nonebot.typing import Bot, Rule, Event, Handler +from nonebot.rule import Rule +from nonebot.permission import Permission, EVERYBODY, USER +from nonebot.typing import Bot, Event, Handler from nonebot.typing import Type, List, Dict, Optional, NoReturn from nonebot.exception import PausedException, RejectedException, FinishedException @@ -18,7 +19,8 @@ class Matcher: """`Matcher`类 """ - rule: Rule = SyncRule() + rule: Rule = Rule() + permission: Permission = Permission() handlers: List[Handler] = [] temp: bool = False expire_time: Optional[datetime] = None @@ -38,7 +40,8 @@ class Matcher: @classmethod def new(cls, - rule: Rule = SyncRule(), + rule: Rule = Rule(), + permission: Permission = Permission(), handlers: list = [], temp: bool = False, priority: int = 1, @@ -54,6 +57,7 @@ class Matcher: NewMatcher = type( "Matcher", (Matcher,), { "rule": rule, + "permission": permission, "handlers": handlers, "temp": temp, "expire_time": expire_time, @@ -66,7 +70,11 @@ class Matcher: return NewMatcher @classmethod - async def check_rule(cls, bot: Bot, event: Event) -> bool: + async def check_perm(cls, bot: Bot, event: Event) -> bool: + return await cls.permission(bot, event) + + @classmethod + async def check_rule(cls, bot: Bot, event: Event, state: dict) -> bool: """检查 Matcher 的 Rule 是否成立 Args: @@ -75,7 +83,7 @@ class Matcher: Returns: bool: 条件成立与否 """ - return await cls.rule(bot, event) + return await cls.rule(bot, event, state) # @classmethod # def args_parser(cls, func: Callable[[Event, dict], None]): @@ -144,11 +152,14 @@ class Matcher: # raise RejectedException # 运行handlers - async def run(self, bot: Bot, event: Event): + async def run(self, bot: Bot, event: Event, state): try: # if self.parser: # await self.parser(event, state) # type: ignore + # Refresh preprocess state + self.state.update(state) + for _ in range(len(self.handlers)): handler = self.handlers.pop(0) annotation = typing.get_type_hints(handler) @@ -158,23 +169,25 @@ class Matcher: await handler(bot, event, self.state) except RejectedException: self.handlers.insert(0, handler) # type: ignore - matcher = Matcher.new(user(event.user_id) & self.rule, - self.handlers, - temp=True, - priority=0, - default_state=self.state, - expire_time=datetime.now() + - bot.config.session_expire_timeout) + matcher = Matcher.new( + self.rule, + USER(event.user_id, perm=self.permission), # type:ignore + self.handlers, + temp=True, + priority=0, + default_state=self.state, + expire_time=datetime.now() + bot.config.session_expire_timeout) matchers[0].append(matcher) return except PausedException: - matcher = Matcher.new(user(event.user_id) & self.rule, - self.handlers, - temp=True, - priority=0, - default_state=self.state, - expire_time=datetime.now() + - bot.config.session_expire_timeout) + matcher = Matcher.new( + self.rule, + USER(event.user_id, perm=self.permission), # type:ignore + self.handlers, + temp=True, + priority=0, + default_state=self.state, + expire_time=datetime.now() + bot.config.session_expire_timeout) matchers[0].append(matcher) return except FinishedException: diff --git a/nonebot/message.py b/nonebot/message.py index 32c77517..5a6306bf 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -5,6 +5,7 @@ import asyncio from datetime import datetime from nonebot.log import logger +from nonebot.rule import TrieRule from nonebot.matcher import matchers from nonebot.exception import IgnoredException from nonebot.typing import Bot, Set, Event, PreProcessor @@ -19,8 +20,9 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor: async def handle_event(bot: Bot, event: Event): coros = [] + state = {} for preprocessor in _event_preprocessors: - coros.append(preprocessor(bot, event)) + coros.append(preprocessor(bot, event, state)) if coros: try: await asyncio.gather(*coros) @@ -28,6 +30,9 @@ async def handle_event(bot: Bot, event: Event): logger.info(f"Event {event} is ignored") return + # Trie Match + _, _ = TrieRule.get_value(bot, event, state) + for priority in sorted(matchers.keys()): index = 0 while index <= len(matchers[priority]): @@ -40,7 +45,9 @@ async def handle_event(bot: Bot, event: Event): # Check rule try: - if not await Matcher.check_rule(bot, event): + if not await Matcher.check_perm( + bot, event) or not await Matcher.check_rule( + bot, event, state): index += 1 continue except Exception as e: @@ -55,7 +62,7 @@ async def handle_event(bot: Bot, event: Event): del matchers[priority][index] try: - await matcher.run(bot, event) + await matcher.run(bot, event, state) except Exception as e: logger.error(f"Running matcher {matcher} failed.") logger.exception(e) diff --git a/nonebot/permission.py b/nonebot/permission.py new file mode 100644 index 00000000..fefda749 --- /dev/null +++ b/nonebot/permission.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import asyncio + +from nonebot.utils import run_sync +from nonebot.typing import Bot, Event, Union, NoReturn, PermissionChecker + + +class Permission: + __slots__ = ("checkers",) + + def __init__(self, *checkers: PermissionChecker) -> None: + self.checkers = list(checkers) + + async def __call__(self, bot: Bot, event: Event) -> bool: + if not self.checkers: + return True + results = await asyncio.gather( + *map(lambda c: c(bot, event), self.checkers)) + return any(results) + + def __and__(self, other) -> NoReturn: + raise RuntimeError("And operation between Permissions is not allowed.") + + def __or__(self, other: Union["Permission", + PermissionChecker]) -> "Permission": + checkers = [*self.checkers] + if isinstance(other, Permission): + checkers.extend(other.checkers) + elif asyncio.iscoroutinefunction(other): + checkers.append(other) + else: + checkers.append(run_sync(other)) + return Permission(*checkers) + + +async def _message(bot: Bot, event: Event) -> bool: + return event.type == "message" + + +async def _notice(bot: Bot, event: Event) -> bool: + return event.type == "notice" + + +async def _request(bot: Bot, event: Event) -> bool: + return event.type == "request" + + +async def _metaevent(bot: Bot, event: Event) -> bool: + return event.type == "meta_event" + + +MESSAGE = Permission(_message) +NOTICE = Permission(_notice) +REQUEST = Permission(_request) +METAEVENT = Permission(_metaevent) + + +def USER(*user: int, perm: Permission = Permission()): + + async def _user(bot: Bot, event: Event) -> bool: + return event.type == "message" and event.user_id in user and await perm( + bot, event) + + return Permission(_user) + + +async def _private(bot: Bot, event: Event) -> bool: + return event.type == "message" and event.detail_type == "private" + + +async def _private_friend(bot: Bot, event: Event) -> bool: + return (event.type == "message" and event.detail_type == "private" and + event.sub_type == "friend") + + +async def _private_group(bot: Bot, event: Event) -> bool: + return (event.type == "message" and event.detail_type == "private" and + event.sub_type == "group") + + +async def _private_other(bot: Bot, event: Event) -> bool: + return (event.type == "message" and event.detail_type == "private" and + event.sub_type == "other") + + +PRIVATE = Permission(_private) +PRIVATE_FRIEND = Permission(_private_friend) +PRIVATE_GROUP = Permission(_private_group) +PRIVATE_OTHER = Permission(_private_other) + + +async def _group(bot: Bot, event: Event) -> bool: + return event.type == "message" and event.detail_type == "group" + + +async def _group_member(bot: Bot, event: Event) -> bool: + return (event.type == "message" and event.detail_type == "group" and + event.sender.get("role") == "member") + + +async def _group_admin(bot: Bot, event: Event) -> bool: + return (event.type == "message" and event.detail_type == "group" and + event.sender.get("role") == "admin") + + +async def _group_owner(bot: Bot, event: Event) -> bool: + return (event.type == "message" and event.detail_type == "group" and + event.sender.get("role") == "owner") + + +GROUP = Permission(_group) +GROUP_MEMBER = Permission(_group_member) +GROUP_ADMIN = Permission(_group_admin) +GROUP_OWNER = Permission(_group_owner) + + +async def _superuser(bot: Bot, event: Event) -> bool: + return event.type == "message" and event.user_id in bot.config.superusers + + +SUPERUSER = Permission(_superuser) +EVERYBODY = MESSAGE diff --git a/nonebot/plugin.py b/nonebot/plugin.py index bd9832d8..73c326db 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import re import pkgutil import importlib from importlib.util import module_from_spec from nonebot.log import logger from nonebot.matcher import Matcher -from nonebot.rule import SyncRule, metaevent, message, notice, request -from nonebot.typing import Set, Dict, Type, Rule, Union, Optional, ModuleType, RuleChecker +from nonebot.rule import Rule, startswith, endswith, command, regex +from nonebot.permission import Permission, METAEVENT, MESSAGE, NOTICE, REQUEST +from nonebot.typing import Set, Dict, Type, Tuple, Union, Optional, ModuleType, RuleChecker plugins: Dict[str, "Plugin"] = {} @@ -25,13 +27,14 @@ class Plugin(object): self.matchers = matchers -def on_metaevent(rule: Union[Rule, RuleChecker] = SyncRule(), +def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, priority: int = 1, state={}) -> Type[Matcher]: - matcher = Matcher.new(metaevent() & rule, + matcher = Matcher.new(Rule() & rule, + METAEVENT, temp=temp, priority=priority, handlers=handlers, @@ -40,13 +43,15 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = SyncRule(), return matcher -def on_message(rule: Union[Rule, RuleChecker] = SyncRule(), +def on_message(rule: Union[Rule, RuleChecker] = Rule(), + permission: Permission = MESSAGE, *, handlers=[], temp=False, priority: int = 1, state={}) -> Type[Matcher]: - matcher = Matcher.new(message() & rule, + matcher = Matcher.new(Rule() & rule, + permission, temp=temp, priority=priority, handlers=handlers, @@ -55,13 +60,14 @@ def on_message(rule: Union[Rule, RuleChecker] = SyncRule(), return matcher -def on_notice(rule: Union[Rule, RuleChecker] = SyncRule(), +def on_notice(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, priority: int = 1, state={}) -> Type[Matcher]: - matcher = Matcher.new(notice() & rule, + matcher = Matcher.new(Rule() & rule, + NOTICE, temp=temp, priority=priority, handlers=handlers, @@ -70,13 +76,14 @@ def on_notice(rule: Union[Rule, RuleChecker] = SyncRule(), return matcher -def on_request(rule: Union[Rule, RuleChecker] = SyncRule(), +def on_request(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, priority: int = 1, state={}) -> Type[Matcher]: - matcher = Matcher.new(request() & rule, + matcher = Matcher.new(Rule() & rule, + REQUEST, temp=temp, priority=priority, handlers=handlers, @@ -85,22 +92,40 @@ def on_request(rule: Union[Rule, RuleChecker] = SyncRule(), return matcher -# def on_startswith(msg, -# start: int = None, -# end: int = None, -# rule: Optional[Rule] = None, -# **kwargs) -> Type[Matcher]: -# return on_message(startswith(msg, start, end) & -# rule, **kwargs) if rule else on_message( -# startswith(msg, start, end), **kwargs) +def on_startswith(msg: str, + rule: Optional[Union[Rule, RuleChecker]] = None, + permission: Permission = MESSAGE, + **kwargs) -> Type[Matcher]: + return on_message(startswith(msg) & + rule, permission, **kwargs) if rule else on_message( + startswith(msg), permission, **kwargs) -# def on_regex(pattern, -# flags: Union[int, re.RegexFlag] = 0, -# rule: Optional[Rule] = None, -# **kwargs) -> Type[Matcher]: -# return on_message(regex(pattern, flags) & -# rule, **kwargs) if rule else on_message( -# regex(pattern, flags), **kwargs) + +def on_endswith(msg: str, + rule: Optional[Union[Rule, RuleChecker]] = None, + permission: Permission = MESSAGE, + **kwargs) -> Type[Matcher]: + return on_message(endswith(msg) & + rule, permission, **kwargs) if rule else on_message( + startswith(msg), permission, **kwargs) + + +def on_command(cmd: Tuple[str], + rule: Optional[Union[Rule, RuleChecker]] = None, + permission: Permission = MESSAGE, + **kwargs) -> Type[Matcher]: + return on_message(command(cmd) & + rule, permission, **kwargs) if rule else on_message( + command(cmd), permission, **kwargs) + + +def on_regex(pattern: str, + flags: Union[int, re.RegexFlag] = 0, + rule: Optional[Rule] = None, + **kwargs) -> Type[Matcher]: + return on_message(regex(pattern, flags) & + rule, **kwargs) if rule else on_message( + regex(pattern, flags), **kwargs) def load_plugin(module_path: str) -> Optional[Plugin]: diff --git a/nonebot/rule.py b/nonebot/rule.py index 9caa9d30..e76c2925 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -2,207 +2,126 @@ # -*- coding: utf-8 -*- import re -import abc import asyncio -from typing import cast +from itertools import product +from pygtrie import CharTrie + +from nonebot import get_driver +from nonebot.log import logger from nonebot.utils import run_sync -from nonebot.typing import Bot, Event, Union, Optional, Awaitable -from nonebot.typing import RuleChecker, SyncRuleChecker, AsyncRuleChecker +from nonebot.typing import Bot, Any, Dict, Event, Union, Tuple, NoReturn, RuleChecker -class BaseRule(abc.ABC): +class Rule: + __slots__ = ("checkers",) - def __init__(self, checker: RuleChecker): - self.checker: RuleChecker = checker + def __init__(self, *checkers: RuleChecker) -> None: + self.checkers = list(checkers) - @abc.abstractmethod - def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]: - raise NotImplementedError + async def __call__(self, bot: Bot, event: Event, state: dict) -> bool: + results = await asyncio.gather( + *map(lambda c: c(bot, event, state), self.checkers)) + return all(results) - @abc.abstractmethod - def __and__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule": - raise NotImplementedError - - @abc.abstractmethod - def __or__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule": - raise NotImplementedError - - @abc.abstractmethod - def __neg__(self) -> "BaseRule": - raise NotImplementedError - - -class AsyncRule(BaseRule): - - def __init__(self, checker: Optional[AsyncRuleChecker] = None): - - async def always_true(bot: Bot, event: Event) -> bool: - return True - - self.checker: AsyncRuleChecker = checker or always_true - - def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]: - return self.checker(bot, event) - - def __and__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule": - func = other - if isinstance(other, BaseRule): - func = other.checker - - if not asyncio.iscoroutinefunction(func): - func = run_sync(func) - - async def tmp(bot: Bot, event: Event) -> bool: - a, b = await asyncio.gather(self.checker(bot, event), - func(bot, event)) - return a and b - - return AsyncRule(tmp) - - def __or__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule": - func = other - if isinstance(other, BaseRule): - func = other.checker - - if not asyncio.iscoroutinefunction(func): - func = run_sync(func) - - async def tmp(bot: Bot, event: Event) -> bool: - a, b = await asyncio.gather(self.checker(bot, event), - func(bot, event)) - return a or b - - return AsyncRule(tmp) - - def __neg__(self) -> "AsyncRule": - - async def neg(bot: Bot, event: Event) -> bool: - result = await self.checker(bot, event) - return not result - - return AsyncRule(neg) - - -class SyncRule(BaseRule): - - def __init__(self, checker: Optional[SyncRuleChecker] = None): - - def always_true(bot: Bot, event: Event) -> bool: - return True - - self.checker: SyncRuleChecker = checker or always_true - - def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]: - return run_sync(self.checker)(bot, event) - - def __and__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule: - func = other - if isinstance(other, BaseRule): - func = other.checker - - if not asyncio.iscoroutinefunction(func): - # func: SyncRuleChecker - syncfunc = cast(SyncRuleChecker, func) - - def tmp(bot: Bot, event: Event) -> bool: - return self.checker(bot, event) and syncfunc(bot, event) - - return SyncRule(tmp) + def __and__(self, other: Union["Rule", RuleChecker]) -> "Rule": + checkers = [*self.checkers] + if isinstance(other, Rule): + checkers.extend(other.checkers) + elif asyncio.iscoroutinefunction(other): + checkers.append(other) else: - # func: AsyncRuleChecker - asyncfunc = cast(AsyncRuleChecker, func) + checkers.append(run_sync(other)) + return Rule(*checkers) - async def tmp(bot: Bot, event: Event) -> bool: - a, b = await asyncio.gather( - run_sync(self.checker)(bot, event), asyncfunc(bot, event)) - return a and b - - return AsyncRule(tmp) - - def __or__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule: - func = other - if isinstance(other, BaseRule): - func = other.checker - - if not asyncio.iscoroutinefunction(func): - # func: SyncRuleChecker - syncfunc = cast(SyncRuleChecker, func) - - def tmp(bot: Bot, event: Event) -> bool: - return self.checker(bot, event) or syncfunc(bot, event) - - return SyncRule(tmp) - else: - # func: AsyncRuleChecker - asyncfunc = cast(AsyncRuleChecker, func) - - async def tmp(bot: Bot, event: Event) -> bool: - a, b = await asyncio.gather( - run_sync(self.checker)(bot, event), asyncfunc(bot, event)) - return a or b - - return AsyncRule(tmp) - - def __neg__(self) -> "SyncRule": - - def neg(bot: Bot, event: Event) -> bool: - return not self.checker(bot, event) - - return SyncRule(neg) + def __or__(self, other) -> NoReturn: + raise RuntimeError("Or operation between rules is not allowed.") -def Rule(func: Optional[RuleChecker] = None) -> BaseRule: - if func and asyncio.iscoroutinefunction(func): - asyncfunc = cast(AsyncRuleChecker, func) - return AsyncRule(asyncfunc) - else: - syncfunc = cast(Optional[SyncRuleChecker], func) - return SyncRule(syncfunc) +class TrieRule: + prefix: CharTrie = CharTrie() + suffix: CharTrie = CharTrie() + + @classmethod + def add_prefix(cls, prefix: str, value: Any): + if prefix in cls.prefix: + logger.warning(f'Duplicated prefix rule "{prefix}"') + return + cls.prefix[prefix] = value + + @classmethod + def add_suffix(cls, suffix: str, value: Any): + if suffix[::-1] in cls.suffix: + logger.warning(f'Duplicated suffix rule "{suffix}"') + return + cls.suffix[suffix[::-1]] = value + + @classmethod + def get_value(cls, bot: Bot, event: Event, + state: dict) -> Tuple[Dict[str, Any], Dict[str, Any]]: + prefix = None + suffix = None + message = event.message[0] + if message.type == "text": + prefix = cls.prefix.longest_prefix(message.data["text"].lstrip()) + message_r = event.message[-1] + if message_r.type == "text": + suffix = cls.suffix.longest_prefix( + message_r.data["text"].rstrip()[::-1]) + + state["_prefix"] = {prefix.key: prefix.value} if prefix else {} + state["_suffix"] = {suffix.key: suffix.value} if suffix else {} + + return ({ + prefix.key: prefix.value + } if prefix else {}, { + suffix.key: suffix.value + } if suffix else {}) -def message() -> BaseRule: - return Rule(lambda bot, event: event.type == "message") +def startswith(msg: str) -> Rule: + TrieRule.add_prefix(msg, (msg,)) + + async def _startswith(bot: Bot, event: Event, state: dict) -> bool: + return msg in state["_prefix"] + + return Rule(_startswith) -def notice() -> BaseRule: - return Rule(lambda bot, event: event.type == "notice") +def endswith(msg: str) -> Rule: + TrieRule.add_suffix(msg, (msg,)) + + async def _endswith(bot: Bot, event: Event, state: dict) -> bool: + return msg in state["_suffix"] + + return Rule(_endswith) -def request() -> BaseRule: - return Rule(lambda bot, event: event.type == "request") +def keyword(msg: str) -> Rule: + + async def _keyword(bot: Bot, event: Event, state: dict) -> bool: + return bool(event.plain_text and msg in event.plain_text) + + return Rule(_keyword) -def metaevent() -> BaseRule: - return Rule(lambda bot, event: event.type == "meta_event") +def command(command: Tuple[str]) -> Rule: + config = get_driver().config + command_start = config.command_start + command_sep = config.command_sep + for start, sep in product(command_start, command_sep): + TrieRule.add_prefix(f"{start}{sep.join(command)}", command) + + async def _command(bot: Bot, event: Event, state: dict) -> bool: + return command in state["_prefix"].values() + + return Rule(_command) -def user(*qq: int) -> BaseRule: - return Rule(lambda bot, event: event.user_id in qq) - - -def private() -> BaseRule: - return Rule(lambda bot, event: event.detail_type == "private") - - -def group(*group: int) -> BaseRule: - return Rule(lambda bot, event: event.detail_type == "group" and event. - group_id in group) - - -def startswith(msg, start: int = None, end: int = None) -> BaseRule: - return Rule(lambda bot, event: event.message.startswith(msg, start, end)) - - -def endswith(msg, start: int = None, end: int = None) -> BaseRule: - return Rule( - lambda bot, event: event.message.endswith(msg, start=None, end=None)) - - -def has(msg: str) -> BaseRule: - return Rule(lambda bot, event: msg in event.message) - - -def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> BaseRule: +def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: pattern = re.compile(regex, flags) - return Rule(lambda bot, event: bool(pattern.search(str(event.message)))) + + async def _regex(bot: Bot, event: Event, state: dict) -> bool: + return bool(pattern.search(str(event.message))) + + return Rule(_regex) diff --git a/nonebot/typing.py b/nonebot/typing.py index 21fe0389..f389e177 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -8,9 +8,10 @@ from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable # import some modules needed when checking types if TYPE_CHECKING: - from nonebot.rule import BaseRule + from nonebot.rule import Rule as RuleClass from nonebot.matcher import Matcher as MatcherClass from nonebot.drivers import BaseDriver, BaseWebSocket + from nonebot.permission import Permission as PermissionClass from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment @@ -32,14 +33,13 @@ Event = TypeVar("Event", bound="BaseEvent") Message = TypeVar("Message", bound="BaseMessage") MessageSegment = TypeVar("MessageSegment", bound="BaseMessageSegment") -PreProcessor = Callable[[Bot, Event], Union[Awaitable[None], - Awaitable[NoReturn]]] +PreProcessor = Callable[[Bot, Event, dict], Union[Awaitable[None], + Awaitable[NoReturn]]] Matcher = TypeVar("Matcher", bound="MatcherClass") Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None], Awaitable[NoReturn]]] -Rule = TypeVar("Rule", bound="BaseRule") -_RuleChecker_Return = TypeVar("_RuleChecker_Return", bool, Awaitable[bool]) -RuleChecker = Callable[[Bot, Event], _RuleChecker_Return] -SyncRuleChecker = RuleChecker[Bot, Event, bool] -AsyncRuleChecker = RuleChecker[Bot, Event, Awaitable[bool]] +Rule = TypeVar("Rule", bound="RuleClass") +RuleChecker = Callable[[Bot, Event, dict], Awaitable[bool]] +Permission = TypeVar("Permission", bound="PermissionClass") +PermissionChecker = Callable[[Bot, Event], Awaitable[bool]] diff --git a/poetry.lock b/poetry.lock index 032b4261..de883bca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -414,6 +414,19 @@ reference = "aliyun" type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" +[[package]] +category = "main" +description = "Trie data structure implementation." +name = "pygtrie" +optional = false +python-versions = "*" +version = "2.3.3" + +[package.source] +reference = "aliyun" +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" + [[package]] category = "dev" description = "Python parsing module" @@ -540,7 +553,7 @@ description = "Python documentation generator" name = "sphinx" optional = false python-versions = ">=3.5" -version = "3.2.0" +version = "3.2.1" [package.dependencies] Jinja2 = ">=2.3" @@ -587,9 +600,10 @@ unify = "*" yapf = "*" [package.source] -reference = "5254c22fad13be69d8301e184818c4578d0e4115" +reference = "88a68ed340013067a1c673bdf7541680c581fa60" type = "git" url = "https://github.com/nonebot/sphinx-markdown-builder.git" + [[package]] category = "dev" description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books" @@ -838,7 +852,7 @@ scheduler = ["apscheduler"] test = [] [metadata] -content-hash = "b89641a9b24184b999991e1534842905ece528b73824eb79d6d378d686526da2" +content-hash = "4d16d7ad0930bc9851802bc149f843c4e990a987e89414d765579ea8dccc8d6e" python-versions = "^3.7" [metadata.files] @@ -1002,6 +1016,9 @@ pygments = [ {file = "Pygments-2.6.1-py3-none-any.whl", hash = "sha256:ff7a40b4860b727ab48fad6360eb351cc1b33cbf9b15a0f689ca5353e9463324"}, {file = "Pygments-2.6.1.tar.gz", hash = "sha256:647344a061c249a3b74e230c739f434d7ea4d8b1d5f3721bc0f3558049b38f44"}, ] +pygtrie = [ + {file = "pygtrie-2.3.3.tar.gz", hash = "sha256:2204dbd95584f67821da5b3771c4305ac5585552b3230b210f1f05322608db2c"}, +] pyparsing = [ {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, @@ -1035,8 +1052,8 @@ snowballstemmer = [ {file = "snowballstemmer-2.0.0.tar.gz", hash = "sha256:df3bac3df4c2c01363f3dd2cfa78cce2840a79b9f1c2d2de9ce8d31683992f52"}, ] sphinx = [ - {file = "Sphinx-3.2.0-py3-none-any.whl", hash = "sha256:f7db5b76c42c8b5ef31853c2de7178ef378b985d7793829ec071e120dac1d0ca"}, - {file = "Sphinx-3.2.0.tar.gz", hash = "sha256:cf2d5bc3c6c930ab0a1fbef3ad8a82994b1bf4ae923f8098a05c7e5516f07177"}, + {file = "Sphinx-3.2.1-py3-none-any.whl", hash = "sha256:ce6fd7ff5b215af39e2fcd44d4a321f6694b4530b6f2b2109b64d120773faea0"}, + {file = "Sphinx-3.2.1.tar.gz", hash = "sha256:321d6d9b16fa381a5306e5a0b76cd48ffbc588e6340059a729c6fdd66087e0e8"}, ] sphinx-markdown-builder = [] sphinxcontrib-applehelp = [ diff --git a/pyproject.toml b/pyproject.toml index 57535f89..cf97cda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" httpx = "^0.13.3" +pygtrie = "^2.3.3" fastapi = "^0.58.1" uvicorn = "^0.11.5" pydantic = { extras = ["dotenv"], version = "^1.6.1" } diff --git a/tests/test_plugins/test_matcher.py b/tests/test_plugins/test_message.py similarity index 76% rename from tests/test_plugins/test_matcher.py rename to tests/test_plugins/test_message.py index 2b59d4e9..0ef9f938 100644 --- a/tests/test_plugins/test_matcher.py +++ b/tests/test_plugins/test_message.py @@ -4,19 +4,19 @@ from nonebot.rule import Rule from nonebot.typing import Event from nonebot.plugin import on_message -from nonebot.adapters.cqhttp import Bot, Message +from nonebot.adapters.cqhttp import Bot -test_matcher = on_message(state={"default": 1}) +test_message = on_message(state={"default": 1}) -@test_matcher.handle() +@test_message.handle() async def test_handler(bot: Bot, event: Event, state: dict): print("Test Matcher Received:", event) print("Current State:", state) state["event"] = event -@test_matcher.receive() +@test_message.receive() async def test_receive(bot: Bot, event: Event, state: dict): print("Test Matcher Received next time:", event) print("Current State:", state) diff --git a/tests/test_plugins/test_package/__init__.py b/tests/test_plugins/test_package/__init__.py index 38e61a36..60c5cee2 100644 --- a/tests/test_plugins/test_package/__init__.py +++ b/tests/test_plugins/test_package/__init__.py @@ -1,4 +1,4 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from . import matchers +from . import test_command diff --git a/tests/test_plugins/test_package/matchers.py b/tests/test_plugins/test_package/matchers.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_plugins/test_package/test_command.py b/tests/test_plugins/test_package/test_command.py new file mode 100644 index 00000000..8e6d4af7 --- /dev/null +++ b/tests/test_plugins/test_package/test_command.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from nonebot.rule import Rule +from nonebot.typing import Event +from nonebot.plugin import on_command +from nonebot.adapters.cqhttp import Bot + +test_command = on_command(("帮助",)) + + +@test_command.handle() +async def test_handler(bot: Bot, event: Event, state: dict): + print(state["_prefix"])