change rule

This commit is contained in:
yanyongyu 2020-08-14 17:41:24 +08:00
parent 04f4d5028e
commit 1dcc43161a
13 changed files with 252 additions and 192 deletions

View File

@ -7,19 +7,18 @@ from ipaddress import IPv4Address
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.drivers import BaseDriver
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.typing import Union, Optional, NoReturn
from nonebot.typing import Type, Union, Driver, Optional, NoReturn
try:
import nonebot_test
except ImportError:
nonebot_test = None
_driver: Optional[BaseDriver] = None
_driver: Optional[Driver] = None
def get_driver() -> Union[NoReturn, BaseDriver]:
def get_driver() -> Union[NoReturn, Driver]:
if _driver is None:
raise ValueError("NoneBot has not been initialized.")
return _driver
@ -43,14 +42,16 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
logger.setLevel(logging.DEBUG if config.debug else logging.INFO)
logger.debug(f"Loaded config: {config.dict()}")
Driver = getattr(importlib.import_module(config.driver), "Driver")
_driver = Driver(env, config)
DriverClass: Type[Driver] = getattr(importlib.import_module(config.driver),
"Driver")
_driver = DriverClass(env, config)
# register build-in adapters
_driver.register_adapter("cqhttp", CQBot)
# load nonebot test frontend if debug
if config.debug and nonebot_test:
logger.debug("Loading nonebot test frontend...")
nonebot_test.init()

View File

@ -6,7 +6,7 @@ from functools import reduce, partial
from dataclasses import dataclass, field
from nonebot.config import Config
from nonebot.typing import Driver, WebSocket
from nonebot.typing import Driver, Message, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
@ -83,6 +83,26 @@ class BaseEvent(abc.ABC):
def sub_type(self, value) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def message(self) -> Optional[Message]:
raise NotImplementedError
@message.setter
@abc.abstractmethod
def message(self, value) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def raw_message(self) -> Optional[str]:
raise NotImplementedError
@raw_message.setter
@abc.abstractmethod
def raw_message(self, value) -> None:
raise NotImplementedError
@dataclass
class BaseMessageSegment(abc.ABC):

View File

@ -193,6 +193,26 @@ class Event(BaseEvent):
def sub_type(self, value) -> None:
self._raw_event["sub_type"] = value
@property
@overrides(BaseEvent)
def message(self) -> Optional["Message"]:
return self._raw_event.get("message")
@message.setter
@overrides(BaseEvent)
def message(self, value) -> None:
self._raw_event["message"] = value
@property
@overrides(BaseEvent)
def raw_message(self) -> Optional[str]:
return self._raw_event.get("raw_message")
@raw_message.setter
@overrides(BaseEvent)
def raw_message(self, value) -> None:
self._raw_event["raw_message"] = value
class MessageSegment(BaseMessageSegment):

View File

@ -4,6 +4,7 @@
import abc
from ipaddress import IPv4Address
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.typing import Bot, Dict, Type, Optional, Callable
@ -20,6 +21,7 @@ class BaseDriver(abc.ABC):
@classmethod
def register_adapter(cls, name: str, adapter: Type[Bot]):
cls._adapters[name] = adapter
logger.debug(f'Succeeded to load adapter "{name}"')
@property
@abc.abstractmethod

View File

@ -152,15 +152,16 @@ class Driver(BaseDriver):
await websocket.accept()
self._clients[x_self_id] = bot
while not websocket.closed:
data = await websocket.receive()
try:
while not websocket.closed:
data = await websocket.receive()
if not data:
continue
if not data:
continue
await bot.handle_message(data)
del self._clients[x_self_id]
await bot.handle_message(data)
finally:
del self._clients[x_self_id]
class WebSocket(BaseWebSocket):

View File

@ -1,127 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from nonebot.typing import Any, Dict, Optional
class Event(dict):
"""
封装从 CQHTTP 收到的事件数据对象字典提供属性以获取其中的字段
`type` `detail_type` 属性对于任何事件都有效外其它属性存在与否不存在则返回
`None`依事件不同而不同
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["Event"]:
"""
CQHTTP 事件数据构造 `Event` 对象
"""
try:
e = Event(payload)
_ = e.type, e.detail_type
return e
except KeyError:
return None
@property
def type(self) -> str:
"""
事件类型 ``message````notice````request````meta_event``
"""
return self["post_type"]
@property
def detail_type(self) -> str:
"""
事件具体类型 `type` 的不同而不同 ``message`` 类型为例
``private````group````discuss``
"""
return self[f"{self.type}_type"]
@property
def sub_type(self) -> Optional[str]:
"""
事件子类型 `detail_type` 不同而不同 ``message.private`` 为例
``friend````group````discuss````other``
"""
return self.get("sub_type")
@property
def name(self):
"""
事件名对于有 `sub_type` 的事件 ``{type}.{detail_type}.{sub_type}``否则为
``{type}.{detail_type}``
"""
n = self.type + "." + self.detail_type
if self.sub_type:
n += "." + self.sub_type
return n
@property
def self_id(self) -> int:
"""机器人自身 ID。"""
return self["self_id"]
@property
def user_id(self) -> Optional[int]:
"""用户 ID。"""
return self.get("user_id")
@property
def operator_id(self) -> Optional[int]:
"""操作者 ID。"""
return self.get("operator_id")
@property
def group_id(self) -> Optional[int]:
"""群 ID。"""
return self.get("group_id")
@property
def discuss_id(self) -> Optional[int]:
"""讨论组 ID。"""
return self.get("discuss_id")
@property
def message_id(self) -> Optional[int]:
"""消息 ID。"""
return self.get("message_id")
@property
def message(self) -> Optional[Any]:
"""消息。"""
return self.get("message")
@property
def raw_message(self) -> Optional[str]:
"""未经 CQHTTP 处理的原始消息。"""
return self.get("raw_message")
@property
def sender(self) -> Optional[Dict[str, Any]]:
"""消息发送者信息。"""
return self.get("sender")
@property
def anonymous(self) -> Optional[Dict[str, Any]]:
"""匿名信息。"""
return self.get("anonymous")
@property
def file(self) -> Optional[Dict[str, Any]]:
"""文件信息。"""
return self.get("file")
@property
def comment(self) -> Optional[str]:
"""请求验证消息。"""
return self.get("comment")
@property
def flag(self) -> Optional[str]:
"""请求标识。"""
return self.get("flag")
def __repr__(self) -> str:
return f"<Event, {super().__repr__()}>"

View File

@ -7,8 +7,8 @@ import importlib
from nonebot.log import logger
from nonebot.matcher import Matcher
from nonebot.typing import Set, Dict, Type, Optional, ModuleType
from nonebot.rule import Rule, metaevent, message, notice, request
from nonebot.typing import Set, Dict, Type, Union, Optional, ModuleType, RuleChecker
plugins: Dict[str, "Plugin"] = {}
@ -25,7 +25,7 @@ class Plugin(object):
self.matchers = matchers
def on_metaevent(rule: Rule,
def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,
@ -40,7 +40,7 @@ def on_metaevent(rule: Rule,
return matcher
def on_message(rule: Rule,
def on_message(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,
@ -55,7 +55,7 @@ def on_message(rule: Rule,
return matcher
def on_notice(rule: Rule,
def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,
@ -70,7 +70,7 @@ def on_notice(rule: Rule,
return matcher
def on_request(rule: Rule,
def on_request(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,

View File

@ -2,81 +2,207 @@
# -*- coding: utf-8 -*-
import re
import abc
import asyncio
from typing import cast
from nonebot.event import Event
from nonebot.typing import Union, Callable, Optional
from nonebot.utils import run_sync
from nonebot.typing import Bot, Event, Union, Optional, Awaitable
from nonebot.typing import RuleChecker, SyncRuleChecker, AsyncRuleChecker
class Rule:
class BaseRule(abc.ABC):
def __init__(
self,
checker: Optional[Callable[["BaseBot", Event], # type: ignore
bool]] = None):
self.checker = checker or (lambda bot, event: True)
def __init__(self, checker: RuleChecker):
self.checker: RuleChecker = checker
def __call__(self, bot, event: Event) -> bool:
@abc.abstractmethod
def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]:
raise NotImplementedError
@abc.abstractmethod
def __and__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule":
raise NotImplementedError
@abc.abstractmethod
def __or__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule":
raise NotImplementedError
@abc.abstractmethod
def __neg__(self) -> "BaseRule":
raise NotImplementedError
class AsyncRule(BaseRule):
def __init__(self, checker: Optional[AsyncRuleChecker] = None):
async def always_true(bot: Bot, event: Event) -> bool:
return True
self.checker: AsyncRuleChecker = checker or always_true
def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]:
return self.checker(bot, event)
def __and__(self, other: "Rule") -> "Rule":
return Rule(lambda bot, event: self.checker(bot, event) and other.
checker(bot, event))
def __and__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule":
func = other
if isinstance(other, BaseRule):
func = other.checker
def __or__(self, other: "Rule") -> "Rule":
return Rule(lambda bot, event: self.checker(bot, event) or other.
checker(bot, event))
if not asyncio.iscoroutinefunction(func):
func = run_sync(func)
def __neg__(self) -> "Rule":
return Rule(lambda bot, event: not self.checker(bot, event))
async def tmp(bot: Bot, event: Event) -> bool:
a, b = await asyncio.gather(self.checker(bot, event),
func(bot, event))
return a and b
return AsyncRule(tmp)
def __or__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule":
func = other
if isinstance(other, BaseRule):
func = other.checker
if not asyncio.iscoroutinefunction(func):
func = run_sync(func)
async def tmp(bot: Bot, event: Event) -> bool:
a, b = await asyncio.gather(self.checker(bot, event),
func(bot, event))
return a or b
return AsyncRule(tmp)
def __neg__(self) -> "AsyncRule":
async def neg(bot: Bot, event: Event) -> bool:
result = await self.checker(bot, event)
return not result
return AsyncRule(neg)
def message() -> Rule:
class SyncRule(BaseRule):
def __init__(self, checker: Optional[SyncRuleChecker] = None):
def always_true(bot: Bot, event: Event) -> bool:
return True
self.checker: SyncRuleChecker = checker or always_true
def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]:
return run_sync(self.checker)(bot, event)
def __and__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule:
func = other
if isinstance(other, BaseRule):
func = other.checker
if not asyncio.iscoroutinefunction(func):
# func: SyncRuleChecker
syncfunc = cast(SyncRuleChecker, func)
def tmp(bot: Bot, event: Event) -> bool:
return self.checker(bot, event) and syncfunc(bot, event)
return SyncRule(tmp)
else:
# func: AsyncRuleChecker
asyncfunc = cast(AsyncRuleChecker, func)
async def tmp(bot: Bot, event: Event) -> bool:
a, b = await asyncio.gather(
run_sync(self.checker)(bot, event), asyncfunc(bot, event))
return a and b
return AsyncRule(tmp)
def __or__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule:
func = other
if isinstance(other, BaseRule):
func = other.checker
if not asyncio.iscoroutinefunction(func):
# func: SyncRuleChecker
syncfunc = cast(SyncRuleChecker, func)
def tmp(bot: Bot, event: Event) -> bool:
return self.checker(bot, event) or syncfunc(bot, event)
return SyncRule(tmp)
else:
# func: AsyncRuleChecker
asyncfunc = cast(AsyncRuleChecker, func)
async def tmp(bot: Bot, event: Event) -> bool:
a, b = await asyncio.gather(
run_sync(self.checker)(bot, event), asyncfunc(bot, event))
return a or b
return AsyncRule(tmp)
def __neg__(self) -> "SyncRule":
def neg(bot: Bot, event: Event) -> bool:
return not self.checker(bot, event)
return SyncRule(neg)
def Rule(func: Optional[RuleChecker] = None) -> BaseRule:
if func and asyncio.iscoroutinefunction(func):
asyncfunc = cast(AsyncRuleChecker, func)
return AsyncRule(asyncfunc)
else:
syncfunc = cast(Optional[SyncRuleChecker], func)
return SyncRule(syncfunc)
def message() -> BaseRule:
return Rule(lambda bot, event: event.type == "message")
def notice() -> Rule:
def notice() -> BaseRule:
return Rule(lambda bot, event: event.type == "notice")
def request() -> Rule:
def request() -> BaseRule:
return Rule(lambda bot, event: event.type == "request")
def metaevent() -> Rule:
def metaevent() -> BaseRule:
return Rule(lambda bot, event: event.type == "meta_event")
def user(*qq: int) -> Rule:
def user(*qq: int) -> BaseRule:
return Rule(lambda bot, event: event.user_id in qq)
def private() -> Rule:
def private() -> BaseRule:
return Rule(lambda bot, event: event.detail_type == "private")
def group(*group: int) -> Rule:
def group(*group: int) -> BaseRule:
return Rule(lambda bot, event: event.detail_type == "group" and event.
group_id in group)
def discuss(*discuss: int) -> Rule:
return Rule(lambda bot, event: event.detail_type == "discuss" and event.
discuss_id in discuss)
def startswith(msg, start: int = None, end: int = None) -> Rule:
def startswith(msg, start: int = None, end: int = None) -> BaseRule:
return Rule(lambda bot, event: event.message.startswith(msg, start, end))
def endswith(msg, start: int = None, end: int = None) -> Rule:
def endswith(msg, start: int = None, end: int = None) -> BaseRule:
return Rule(
lambda bot, event: event.message.endswith(msg, start=None, end=None))
def has(msg: str) -> Rule:
def has(msg: str) -> BaseRule:
return Rule(lambda bot, event: msg in event.message)
def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> Rule:
def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> BaseRule:
pattern = re.compile(regex, flags)
return Rule(lambda bot, event: bool(pattern.search(str(event.message))))

View File

@ -8,6 +8,7 @@ from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable
# import some modules needed when checking types
if TYPE_CHECKING:
from nonebot.rule import BaseRule
from nonebot.matcher import Matcher as MatcherClass
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
@ -37,3 +38,8 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
Matcher = TypeVar("Matcher", bound="MatcherClass")
Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None],
Awaitable[NoReturn]]]
Rule = TypeVar("Rule", bound="BaseRule")
_RuleChecker_Return = TypeVar("_RuleChecker_Return", bool, Awaitable[bool])
RuleChecker = Callable[[Bot, Event], _RuleChecker_Return]
SyncRuleChecker = RuleChecker[Bot, Event, bool]
AsyncRuleChecker = RuleChecker[Bot, Event, Awaitable[bool]]

View File

@ -2,9 +2,23 @@
# -*- coding: utf-8 -*-
import json
import asyncio
import dataclasses
from functools import wraps, partial
from nonebot.typing import overrides
from nonebot.typing import Any, Callable, Awaitable, overrides
def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
@wraps(func)
async def _wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
pfunc = partial(func, *args, **kwargs)
result = await loop.run_in_executor(None, pfunc)
return result
return _wrapper
class DataclassEncoder(json.JSONEncoder):

8
poetry.lock generated
View File

@ -197,7 +197,7 @@ description = "Chromium HSTS Preload list as a Python package and updated daily"
name = "hstspreload"
optional = false
python-versions = ">=3.6"
version = "2020.8.11"
version = "2020.8.12"
[package.source]
reference = "aliyun"
@ -838,7 +838,7 @@ scheduler = ["apscheduler"]
test = []
[metadata]
content-hash = "ceb51a95975f80d81b1901bb634cc58a583d31914a99495b12df4679d27fe531"
content-hash = "b89641a9b24184b999991e1534842905ece528b73824eb79d6d378d686526da2"
python-versions = "^3.7"
[metadata.files]
@ -891,8 +891,8 @@ hpack = [
{file = "hpack-3.0.0.tar.gz", hash = "sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2"},
]
hstspreload = [
{file = "hstspreload-2020.8.11-py3-none-any.whl", hash = "sha256:e9971e67ed1fe61da1ea4c145f6ebab4591e2cc934def81bbf9c37d20d0abab9"},
{file = "hstspreload-2020.8.11.tar.gz", hash = "sha256:88b102ce3cdc1b27bb117d407f886ed16e35522564f1a31d64373ccde33b19af"},
{file = "hstspreload-2020.8.12-py3-none-any.whl", hash = "sha256:64f4441066d5544873faccf2e0b5757c6670217d34dc31d362ca2977f44604ff"},
{file = "hstspreload-2020.8.12.tar.gz", hash = "sha256:3f5c324b1eb9d924e32ffeb5fe265b879806b6e346b765f57566410344f4b41e"},
]
html2text = [
{file = "html2text-2020.1.16-py3-none-any.whl", hash = "sha256:c7c629882da0cf377d66f073329ccf34a12ed2adf0169b9285ae4e63ef54c82b"},

View File

@ -6,9 +6,7 @@ from nonebot.typing import Event
from nonebot.plugin import on_message
from nonebot.adapters.cqhttp import Bot, Message
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
test_matcher = on_message(Rule(), state={"default": 1})
test_matcher = on_message(state={"default": 1})
@test_matcher.handle()

View File

@ -1,18 +1,17 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from nonebot.rule import Rule
from nonebot.event import Event
from nonebot.plugin import on_metaevent
from nonebot.typing import Bot, Event
def heartbeat(bot, event: Event) -> bool:
def heartbeat(bot: Bot, event: Event) -> bool:
return event.detail_type == "heartbeat"
test_matcher = on_metaevent(Rule(heartbeat))
test_matcher = on_metaevent(heartbeat)
@test_matcher.handle()
async def handle_heartbeat(bot, event: Event, state: dict):
async def handle_heartbeat(bot: Bot, event: Event, state: dict):
print("[i] Heartbeat")