diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 54a409ce..1611e714 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -7,19 +7,18 @@ from ipaddress import IPv4Address from nonebot.log import logger from nonebot.config import Env, Config -from nonebot.drivers import BaseDriver from nonebot.adapters.cqhttp import Bot as CQBot -from nonebot.typing import Union, Optional, NoReturn +from nonebot.typing import Type, Union, Driver, Optional, NoReturn try: import nonebot_test except ImportError: nonebot_test = None -_driver: Optional[BaseDriver] = None +_driver: Optional[Driver] = None -def get_driver() -> Union[NoReturn, BaseDriver]: +def get_driver() -> Union[NoReturn, Driver]: if _driver is None: raise ValueError("NoneBot has not been initialized.") return _driver @@ -43,14 +42,16 @@ def init(*, _env_file: Optional[str] = None, **kwargs): logger.setLevel(logging.DEBUG if config.debug else logging.INFO) logger.debug(f"Loaded config: {config.dict()}") - Driver = getattr(importlib.import_module(config.driver), "Driver") - _driver = Driver(env, config) + DriverClass: Type[Driver] = getattr(importlib.import_module(config.driver), + "Driver") + _driver = DriverClass(env, config) # register build-in adapters _driver.register_adapter("cqhttp", CQBot) # load nonebot test frontend if debug if config.debug and nonebot_test: + logger.debug("Loading nonebot test frontend...") nonebot_test.init() diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 04183a24..11b639b2 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -6,7 +6,7 @@ from functools import reduce, partial from dataclasses import dataclass, field from nonebot.config import Config -from nonebot.typing import Driver, WebSocket +from nonebot.typing import Driver, Message, WebSocket from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable @@ -83,6 +83,26 @@ class BaseEvent(abc.ABC): def sub_type(self, value) -> None: raise NotImplementedError + @property + @abc.abstractmethod + def message(self) -> Optional[Message]: + raise NotImplementedError + + @message.setter + @abc.abstractmethod + def message(self, value) -> None: + raise NotImplementedError + + @property + @abc.abstractmethod + def raw_message(self) -> Optional[str]: + raise NotImplementedError + + @raw_message.setter + @abc.abstractmethod + def raw_message(self, value) -> None: + raise NotImplementedError + @dataclass class BaseMessageSegment(abc.ABC): diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index c1beaf7f..0b228b3d 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -193,6 +193,26 @@ class Event(BaseEvent): def sub_type(self, value) -> None: self._raw_event["sub_type"] = value + @property + @overrides(BaseEvent) + def message(self) -> Optional["Message"]: + return self._raw_event.get("message") + + @message.setter + @overrides(BaseEvent) + def message(self, value) -> None: + self._raw_event["message"] = value + + @property + @overrides(BaseEvent) + def raw_message(self) -> Optional[str]: + return self._raw_event.get("raw_message") + + @raw_message.setter + @overrides(BaseEvent) + def raw_message(self, value) -> None: + self._raw_event["raw_message"] = value + class MessageSegment(BaseMessageSegment): diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index baef3455..4d0f398b 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -4,6 +4,7 @@ import abc from ipaddress import IPv4Address +from nonebot.log import logger from nonebot.config import Env, Config from nonebot.typing import Bot, Dict, Type, Optional, Callable @@ -20,6 +21,7 @@ class BaseDriver(abc.ABC): @classmethod def register_adapter(cls, name: str, adapter: Type[Bot]): cls._adapters[name] = adapter + logger.debug(f'Succeeded to load adapter "{name}"') @property @abc.abstractmethod diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index fd39985d..618d61cb 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -152,15 +152,16 @@ class Driver(BaseDriver): await websocket.accept() self._clients[x_self_id] = bot - while not websocket.closed: - data = await websocket.receive() + try: + while not websocket.closed: + data = await websocket.receive() - if not data: - continue + if not data: + continue - await bot.handle_message(data) - - del self._clients[x_self_id] + await bot.handle_message(data) + finally: + del self._clients[x_self_id] class WebSocket(BaseWebSocket): diff --git a/nonebot/event.py b/nonebot/event.py deleted file mode 100644 index 6f4db964..00000000 --- a/nonebot/event.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -from nonebot.typing import Any, Dict, Optional - - -class Event(dict): - """ - 封装从 CQHTTP 收到的事件数据对象(字典),提供属性以获取其中的字段。 - - 除 `type` 和 `detail_type` 属性对于任何事件都有效外,其它属性存在与否(不存在则返回 - `None`)依事件不同而不同。 - """ - - @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["Event"]: - """ - 从 CQHTTP 事件数据构造 `Event` 对象。 - """ - try: - e = Event(payload) - _ = e.type, e.detail_type - return e - except KeyError: - return None - - @property - def type(self) -> str: - """ - 事件类型,有 ``message``、``notice``、``request``、``meta_event`` 等。 - """ - return self["post_type"] - - @property - def detail_type(self) -> str: - """ - 事件具体类型,依 `type` 的不同而不同,以 ``message`` 类型为例,有 - ``private``、``group``、``discuss`` 等。 - """ - return self[f"{self.type}_type"] - - @property - def sub_type(self) -> Optional[str]: - """ - 事件子类型,依 `detail_type` 不同而不同,以 ``message.private`` 为例,有 - ``friend``、``group``、``discuss``、``other`` 等。 - """ - return self.get("sub_type") - - @property - def name(self): - """ - 事件名,对于有 `sub_type` 的事件,为 ``{type}.{detail_type}.{sub_type}``,否则为 - ``{type}.{detail_type}``。 - """ - n = self.type + "." + self.detail_type - if self.sub_type: - n += "." + self.sub_type - return n - - @property - def self_id(self) -> int: - """机器人自身 ID。""" - return self["self_id"] - - @property - def user_id(self) -> Optional[int]: - """用户 ID。""" - return self.get("user_id") - - @property - def operator_id(self) -> Optional[int]: - """操作者 ID。""" - return self.get("operator_id") - - @property - def group_id(self) -> Optional[int]: - """群 ID。""" - return self.get("group_id") - - @property - def discuss_id(self) -> Optional[int]: - """讨论组 ID。""" - return self.get("discuss_id") - - @property - def message_id(self) -> Optional[int]: - """消息 ID。""" - return self.get("message_id") - - @property - def message(self) -> Optional[Any]: - """消息。""" - return self.get("message") - - @property - def raw_message(self) -> Optional[str]: - """未经 CQHTTP 处理的原始消息。""" - return self.get("raw_message") - - @property - def sender(self) -> Optional[Dict[str, Any]]: - """消息发送者信息。""" - return self.get("sender") - - @property - def anonymous(self) -> Optional[Dict[str, Any]]: - """匿名信息。""" - return self.get("anonymous") - - @property - def file(self) -> Optional[Dict[str, Any]]: - """文件信息。""" - return self.get("file") - - @property - def comment(self) -> Optional[str]: - """请求验证消息。""" - return self.get("comment") - - @property - def flag(self) -> Optional[str]: - """请求标识。""" - return self.get("flag") - - def __repr__(self) -> str: - return f"" diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 5016aef2..a33c3527 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -7,8 +7,8 @@ import importlib from nonebot.log import logger from nonebot.matcher import Matcher -from nonebot.typing import Set, Dict, Type, Optional, ModuleType from nonebot.rule import Rule, metaevent, message, notice, request +from nonebot.typing import Set, Dict, Type, Union, Optional, ModuleType, RuleChecker plugins: Dict[str, "Plugin"] = {} @@ -25,7 +25,7 @@ class Plugin(object): self.matchers = matchers -def on_metaevent(rule: Rule, +def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, @@ -40,7 +40,7 @@ def on_metaevent(rule: Rule, return matcher -def on_message(rule: Rule, +def on_message(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, @@ -55,7 +55,7 @@ def on_message(rule: Rule, return matcher -def on_notice(rule: Rule, +def on_notice(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, @@ -70,7 +70,7 @@ def on_notice(rule: Rule, return matcher -def on_request(rule: Rule, +def on_request(rule: Union[Rule, RuleChecker] = Rule(), *, handlers=[], temp=False, diff --git a/nonebot/rule.py b/nonebot/rule.py index ab81489f..9caa9d30 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -2,81 +2,207 @@ # -*- coding: utf-8 -*- import re +import abc +import asyncio +from typing import cast -from nonebot.event import Event -from nonebot.typing import Union, Callable, Optional +from nonebot.utils import run_sync +from nonebot.typing import Bot, Event, Union, Optional, Awaitable +from nonebot.typing import RuleChecker, SyncRuleChecker, AsyncRuleChecker -class Rule: +class BaseRule(abc.ABC): - def __init__( - self, - checker: Optional[Callable[["BaseBot", Event], # type: ignore - bool]] = None): - self.checker = checker or (lambda bot, event: True) + def __init__(self, checker: RuleChecker): + self.checker: RuleChecker = checker - def __call__(self, bot, event: Event) -> bool: + @abc.abstractmethod + def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]: + raise NotImplementedError + + @abc.abstractmethod + def __and__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule": + raise NotImplementedError + + @abc.abstractmethod + def __or__(self, other: Union["BaseRule", RuleChecker]) -> "BaseRule": + raise NotImplementedError + + @abc.abstractmethod + def __neg__(self) -> "BaseRule": + raise NotImplementedError + + +class AsyncRule(BaseRule): + + def __init__(self, checker: Optional[AsyncRuleChecker] = None): + + async def always_true(bot: Bot, event: Event) -> bool: + return True + + self.checker: AsyncRuleChecker = checker or always_true + + def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]: return self.checker(bot, event) - def __and__(self, other: "Rule") -> "Rule": - return Rule(lambda bot, event: self.checker(bot, event) and other. - checker(bot, event)) + def __and__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule": + func = other + if isinstance(other, BaseRule): + func = other.checker - def __or__(self, other: "Rule") -> "Rule": - return Rule(lambda bot, event: self.checker(bot, event) or other. - checker(bot, event)) + if not asyncio.iscoroutinefunction(func): + func = run_sync(func) - def __neg__(self) -> "Rule": - return Rule(lambda bot, event: not self.checker(bot, event)) + async def tmp(bot: Bot, event: Event) -> bool: + a, b = await asyncio.gather(self.checker(bot, event), + func(bot, event)) + return a and b + + return AsyncRule(tmp) + + def __or__(self, other: Union[BaseRule, RuleChecker]) -> "AsyncRule": + func = other + if isinstance(other, BaseRule): + func = other.checker + + if not asyncio.iscoroutinefunction(func): + func = run_sync(func) + + async def tmp(bot: Bot, event: Event) -> bool: + a, b = await asyncio.gather(self.checker(bot, event), + func(bot, event)) + return a or b + + return AsyncRule(tmp) + + def __neg__(self) -> "AsyncRule": + + async def neg(bot: Bot, event: Event) -> bool: + result = await self.checker(bot, event) + return not result + + return AsyncRule(neg) -def message() -> Rule: +class SyncRule(BaseRule): + + def __init__(self, checker: Optional[SyncRuleChecker] = None): + + def always_true(bot: Bot, event: Event) -> bool: + return True + + self.checker: SyncRuleChecker = checker or always_true + + def __call__(self, bot: Bot, event: Event) -> Awaitable[bool]: + return run_sync(self.checker)(bot, event) + + def __and__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule: + func = other + if isinstance(other, BaseRule): + func = other.checker + + if not asyncio.iscoroutinefunction(func): + # func: SyncRuleChecker + syncfunc = cast(SyncRuleChecker, func) + + def tmp(bot: Bot, event: Event) -> bool: + return self.checker(bot, event) and syncfunc(bot, event) + + return SyncRule(tmp) + else: + # func: AsyncRuleChecker + asyncfunc = cast(AsyncRuleChecker, func) + + async def tmp(bot: Bot, event: Event) -> bool: + a, b = await asyncio.gather( + run_sync(self.checker)(bot, event), asyncfunc(bot, event)) + return a and b + + return AsyncRule(tmp) + + def __or__(self, other: Union[BaseRule, RuleChecker]) -> BaseRule: + func = other + if isinstance(other, BaseRule): + func = other.checker + + if not asyncio.iscoroutinefunction(func): + # func: SyncRuleChecker + syncfunc = cast(SyncRuleChecker, func) + + def tmp(bot: Bot, event: Event) -> bool: + return self.checker(bot, event) or syncfunc(bot, event) + + return SyncRule(tmp) + else: + # func: AsyncRuleChecker + asyncfunc = cast(AsyncRuleChecker, func) + + async def tmp(bot: Bot, event: Event) -> bool: + a, b = await asyncio.gather( + run_sync(self.checker)(bot, event), asyncfunc(bot, event)) + return a or b + + return AsyncRule(tmp) + + def __neg__(self) -> "SyncRule": + + def neg(bot: Bot, event: Event) -> bool: + return not self.checker(bot, event) + + return SyncRule(neg) + + +def Rule(func: Optional[RuleChecker] = None) -> BaseRule: + if func and asyncio.iscoroutinefunction(func): + asyncfunc = cast(AsyncRuleChecker, func) + return AsyncRule(asyncfunc) + else: + syncfunc = cast(Optional[SyncRuleChecker], func) + return SyncRule(syncfunc) + + +def message() -> BaseRule: return Rule(lambda bot, event: event.type == "message") -def notice() -> Rule: +def notice() -> BaseRule: return Rule(lambda bot, event: event.type == "notice") -def request() -> Rule: +def request() -> BaseRule: return Rule(lambda bot, event: event.type == "request") -def metaevent() -> Rule: +def metaevent() -> BaseRule: return Rule(lambda bot, event: event.type == "meta_event") -def user(*qq: int) -> Rule: +def user(*qq: int) -> BaseRule: return Rule(lambda bot, event: event.user_id in qq) -def private() -> Rule: +def private() -> BaseRule: return Rule(lambda bot, event: event.detail_type == "private") -def group(*group: int) -> Rule: +def group(*group: int) -> BaseRule: return Rule(lambda bot, event: event.detail_type == "group" and event. group_id in group) -def discuss(*discuss: int) -> Rule: - return Rule(lambda bot, event: event.detail_type == "discuss" and event. - discuss_id in discuss) - - -def startswith(msg, start: int = None, end: int = None) -> Rule: +def startswith(msg, start: int = None, end: int = None) -> BaseRule: return Rule(lambda bot, event: event.message.startswith(msg, start, end)) -def endswith(msg, start: int = None, end: int = None) -> Rule: +def endswith(msg, start: int = None, end: int = None) -> BaseRule: return Rule( lambda bot, event: event.message.endswith(msg, start=None, end=None)) -def has(msg: str) -> Rule: +def has(msg: str) -> BaseRule: return Rule(lambda bot, event: msg in event.message) -def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> Rule: +def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> BaseRule: pattern = re.compile(regex, flags) return Rule(lambda bot, event: bool(pattern.search(str(event.message)))) diff --git a/nonebot/typing.py b/nonebot/typing.py index 816b62e2..21fe0389 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -8,6 +8,7 @@ from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable # import some modules needed when checking types if TYPE_CHECKING: + from nonebot.rule import BaseRule from nonebot.matcher import Matcher as MatcherClass from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment @@ -37,3 +38,8 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None], Matcher = TypeVar("Matcher", bound="MatcherClass") Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None], Awaitable[NoReturn]]] +Rule = TypeVar("Rule", bound="BaseRule") +_RuleChecker_Return = TypeVar("_RuleChecker_Return", bool, Awaitable[bool]) +RuleChecker = Callable[[Bot, Event], _RuleChecker_Return] +SyncRuleChecker = RuleChecker[Bot, Event, bool] +AsyncRuleChecker = RuleChecker[Bot, Event, Awaitable[bool]] diff --git a/nonebot/utils.py b/nonebot/utils.py index c1c0ddf2..631395f5 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -2,9 +2,23 @@ # -*- coding: utf-8 -*- import json +import asyncio import dataclasses +from functools import wraps, partial -from nonebot.typing import overrides +from nonebot.typing import Any, Callable, Awaitable, overrides + + +def run_sync(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: + + @wraps(func) + async def _wrapper(*args: Any, **kwargs: Any) -> Any: + loop = asyncio.get_running_loop() + pfunc = partial(func, *args, **kwargs) + result = await loop.run_in_executor(None, pfunc) + return result + + return _wrapper class DataclassEncoder(json.JSONEncoder): diff --git a/poetry.lock b/poetry.lock index 6208e3c7..032b4261 100644 --- a/poetry.lock +++ b/poetry.lock @@ -197,7 +197,7 @@ description = "Chromium HSTS Preload list as a Python package and updated daily" name = "hstspreload" optional = false python-versions = ">=3.6" -version = "2020.8.11" +version = "2020.8.12" [package.source] reference = "aliyun" @@ -838,7 +838,7 @@ scheduler = ["apscheduler"] test = [] [metadata] -content-hash = "ceb51a95975f80d81b1901bb634cc58a583d31914a99495b12df4679d27fe531" +content-hash = "b89641a9b24184b999991e1534842905ece528b73824eb79d6d378d686526da2" python-versions = "^3.7" [metadata.files] @@ -891,8 +891,8 @@ hpack = [ {file = "hpack-3.0.0.tar.gz", hash = "sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2"}, ] hstspreload = [ - {file = "hstspreload-2020.8.11-py3-none-any.whl", hash = "sha256:e9971e67ed1fe61da1ea4c145f6ebab4591e2cc934def81bbf9c37d20d0abab9"}, - {file = "hstspreload-2020.8.11.tar.gz", hash = "sha256:88b102ce3cdc1b27bb117d407f886ed16e35522564f1a31d64373ccde33b19af"}, + {file = "hstspreload-2020.8.12-py3-none-any.whl", hash = "sha256:64f4441066d5544873faccf2e0b5757c6670217d34dc31d362ca2977f44604ff"}, + {file = "hstspreload-2020.8.12.tar.gz", hash = "sha256:3f5c324b1eb9d924e32ffeb5fe265b879806b6e346b765f57566410344f4b41e"}, ] html2text = [ {file = "html2text-2020.1.16-py3-none-any.whl", hash = "sha256:c7c629882da0cf377d66f073329ccf34a12ed2adf0169b9285ae4e63ef54c82b"}, diff --git a/tests/test_plugins/test_matcher.py b/tests/test_plugins/test_matcher.py index fb36f821..2b59d4e9 100644 --- a/tests/test_plugins/test_matcher.py +++ b/tests/test_plugins/test_matcher.py @@ -6,9 +6,7 @@ from nonebot.typing import Event from nonebot.plugin import on_message from nonebot.adapters.cqhttp import Bot, Message -print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]"))) - -test_matcher = on_message(Rule(), state={"default": 1}) +test_matcher = on_message(state={"default": 1}) @test_matcher.handle() diff --git a/tests/test_plugins/test_metaevent.py b/tests/test_plugins/test_metaevent.py index 3fb7360e..88f6a82f 100644 --- a/tests/test_plugins/test_metaevent.py +++ b/tests/test_plugins/test_metaevent.py @@ -1,18 +1,17 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from nonebot.rule import Rule -from nonebot.event import Event from nonebot.plugin import on_metaevent +from nonebot.typing import Bot, Event -def heartbeat(bot, event: Event) -> bool: +def heartbeat(bot: Bot, event: Event) -> bool: return event.detail_type == "heartbeat" -test_matcher = on_metaevent(Rule(heartbeat)) +test_matcher = on_metaevent(heartbeat) @test_matcher.handle() -async def handle_heartbeat(bot, event: Event, state: dict): +async def handle_heartbeat(bot: Bot, event: Event, state: dict): print("[i] Heartbeat")