From c5ea8bc1c31c2f2eb6460b6718c7d7a7368a469c Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Tue, 25 Aug 2020 15:23:10 +0800 Subject: [PATCH] add contextvars and fix mutable default args --- nonebot/adapters/__init__.py | 14 +++++ nonebot/adapters/cqhttp.py | 40 +++++++++++++ nonebot/matcher.py | 58 ++++++++++++++----- nonebot/plugin.py | 30 +++++----- .../test_plugins/test_package/test_command.py | 3 +- 5 files changed, 112 insertions(+), 33 deletions(-) diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index b12652d9..2903b0f9 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -42,6 +42,10 @@ class BaseBot(abc.ABC): async def call_api(self, api: str, data: dict): raise NotImplementedError + @abc.abstractmethod + async def send(self, *args, **kwargs): + raise NotImplementedError + # TODO: improve event class BaseEvent(abc.ABC): @@ -102,6 +106,16 @@ class BaseEvent(abc.ABC): def user_id(self, value) -> None: raise NotImplementedError + @property + @abc.abstractmethod + def group_id(self) -> Optional[int]: + raise NotImplementedError + + @group_id.setter + @abc.abstractmethod + def group_id(self, value) -> None: + raise NotImplementedError + @property @abc.abstractmethod def to_me(self) -> Optional[bool]: diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index a62d8546..b6332d8e 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -223,6 +223,36 @@ class Bot(BaseBot): except httpx.HTTPError: raise NetworkError("HTTP request failed") + @overrides(BaseBot) + async def send(self, event: "Event", message: Union[str, "Message", + "MessageSegment"], + **kwargs) -> Union[Any, NoReturn]: + msg = message if isinstance(message, Message) else Message(message) + + at_sender = kwargs.pop("at_sender", False) and bool(event.user_id) + + params = {} + if event.user_id: + params["user_id"] = event.user_id + if event.group_id: + params["group_id"] = event.group_id + params.update(kwargs) + + if "message_type" not in params: + if "group_id" in params: + params["message_type"] = "group" + elif "user_id" in params: + params["message_type"] = "private" + else: + raise ValueError("Cannot guess message type to reply!") + + if at_sender and params["message_type"] != "private": + params["message"] = MessageSegment.at(params["user_id"]) + \ + MessageSegment.text(" ") + msg + else: + params["message"] = msg + return await self.send_msg(**params) + class Event(BaseEvent): @@ -277,6 +307,16 @@ class Event(BaseEvent): def user_id(self, value) -> None: self._raw_event["user_id"] = value + @property + @overrides(BaseEvent) + def group_id(self) -> Optional[int]: + return self._raw_event.get("group_id") + + @group_id.setter + @overrides(BaseEvent) + def group_id(self, value) -> None: + self._raw_event["group_id"] = value + @property @overrides(BaseEvent) def to_me(self) -> Optional[bool]: diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 92fffbc1..6ef7d8f1 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -6,14 +6,17 @@ import inspect from functools import wraps from datetime import datetime from collections import defaultdict +from contextvars import Context, ContextVar, copy_context from nonebot.rule import Rule from nonebot.permission import Permission, USER -from nonebot.typing import Bot, Event, Handler, ArgsParser -from nonebot.typing import Type, List, Dict, Callable, Optional, NoReturn +from nonebot.typing import Type, List, Dict, Union, Callable, Optional, NoReturn +from nonebot.typing import Bot, Event, Handler, Message, ArgsParser, MessageSegment from nonebot.exception import PausedException, RejectedException, FinishedException matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) +current_bot: ContextVar = ContextVar("current_bot") +current_event: ContextVar = ContextVar("current_event") class Matcher: @@ -51,12 +54,12 @@ class Matcher: type_: str = "", rule: Rule = Rule(), permission: Permission = Permission(), - handlers: list = [], + handlers: Optional[list] = None, temp: bool = False, priority: int = 1, block: bool = False, *, - default_state: dict = {}, + default_state: Optional[dict] = None, expire_time: Optional[datetime] = None) -> Type["Matcher"]: """创建新的 Matcher @@ -69,12 +72,12 @@ class Matcher: "type": type_, "rule": rule, "permission": permission, - "handlers": handlers, + "handlers": handlers or [], "temp": temp, "expire_time": expire_time, "priority": priority, "block": block, - "_default_state": default_state + "_default_state": default_state or {} }) matchers[priority].append(NewMatcher) @@ -117,12 +120,12 @@ class Matcher: def receive(cls) -> Callable[[Handler], Handler]: """接收一条新消息并处理""" - async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn: + async def _receive(bot: Bot, event: Event, state: dict) -> NoReturn: raise PausedException if cls.handlers: # 已有前置handlers则接受一条新的消息,否则视为接收初始消息 - cls.handlers.append(_handler) + cls.handlers.append(_receive) def _decorator(func: Handler) -> Handler: if not cls.handlers or cls.handlers[-1] is not func: @@ -144,8 +147,7 @@ class Matcher: if key not in state: state["_current_key"] = key if prompt: - await bot.send_private_msg(user_id=event.user_id, - message=prompt) + await bot.send(event=event, message=prompt) raise PausedException async def _key_parser(bot: Bot, event: Event, state: dict): @@ -176,19 +178,42 @@ class Matcher: return _decorator @classmethod - def finish(cls) -> NoReturn: + async def finish( + cls, + prompt: Optional[Union[str, Message, + MessageSegment]] = None) -> NoReturn: + bot: Bot = current_bot.get() + event: Event = current_event.get() + if prompt: + await bot.send(event=event, message=prompt) raise FinishedException @classmethod - def pause(cls) -> NoReturn: + async def pause( + cls, + prompt: Optional[Union[str, Message, + MessageSegment]] = None) -> NoReturn: + bot: Bot = current_bot.get() + event: Event = current_event.get() + if prompt: + await bot.send(event=event, message=prompt) raise PausedException @classmethod - def reject(cls) -> NoReturn: + async def reject( + cls, + prompt: Optional[Union[str, Message, + MessageSegment]] = None) -> NoReturn: + bot: Bot = current_bot.get() + event: Event = current_event.get() + if prompt: + await bot.send(event=event, message=prompt) raise RejectedException # 运行handlers async def run(self, bot: Bot, event: Event, state: dict): + b_t = current_bot.set(bot) + e_t = current_event.set(event) try: # Refresh preprocess state self.state.update(state) @@ -214,7 +239,6 @@ class Matcher: block=True, default_state=self.state, expire_time=datetime.now() + bot.config.session_expire_timeout) - return except PausedException: Matcher.new( self.type, @@ -226,6 +250,8 @@ class Matcher: block=True, default_state=self.state, expire_time=datetime.now() + bot.config.session_expire_timeout) - return except FinishedException: - return + pass + finally: + current_bot.reset(b_t) + current_event.reset(e_t) diff --git a/nonebot/plugin.py b/nonebot/plugin.py index c6584615..0ea04046 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -31,11 +31,11 @@ class Plugin(object): def on(rule: Union[Rule, RuleChecker] = Rule(), permission: Permission = Permission(), *, - handlers=[], - temp=False, + handlers: Optional[list] = None, + temp: bool = False, priority: int = 1, block: bool = False, - state={}) -> Type[Matcher]: + state: Optional[dict] = None) -> Type[Matcher]: matcher = Matcher.new("", Rule() & rule, permission, @@ -50,11 +50,11 @@ def on(rule: Union[Rule, RuleChecker] = Rule(), def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), *, - handlers=[], - temp=False, + handlers: Optional[list] = None, + temp: bool = False, priority: int = 1, block: bool = False, - state={}) -> Type[Matcher]: + state: Optional[dict] = None) -> Type[Matcher]: matcher = Matcher.new("meta_event", Rule() & rule, Permission(), @@ -70,11 +70,11 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), def on_message(rule: Union[Rule, RuleChecker] = Rule(), permission: Permission = Permission(), *, - handlers=[], - temp=False, + handlers: Optional[list] = None, + temp: bool = False, priority: int = 1, block: bool = True, - state={}) -> Type[Matcher]: + state: Optional[dict] = None) -> Type[Matcher]: matcher = Matcher.new("message", Rule() & rule, permission, @@ -89,11 +89,11 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(), def on_notice(rule: Union[Rule, RuleChecker] = Rule(), *, - handlers=[], - temp=False, + handlers: Optional[list] = None, + temp: bool = False, priority: int = 1, block: bool = False, - state={}) -> Type[Matcher]: + state: Optional[dict] = None) -> Type[Matcher]: matcher = Matcher.new("notice", Rule() & rule, Permission(), @@ -108,11 +108,11 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(), def on_request(rule: Union[Rule, RuleChecker] = Rule(), *, - handlers=[], - temp=False, + handlers: Optional[list] = None, + temp: bool = False, priority: int = 1, block: bool = False, - state={}) -> Type[Matcher]: + state: Optional[dict] = None) -> Type[Matcher]: matcher = Matcher.new("request", Rule() & rule, Permission(), diff --git a/tests/test_plugins/test_package/test_command.py b/tests/test_plugins/test_package/test_command.py index 57551450..c2b5f334 100644 --- a/tests/test_plugins/test_package/test_command.py +++ b/tests/test_plugins/test_package/test_command.py @@ -24,6 +24,5 @@ async def test_handler(bot: Bot, event: Event, state: dict): async def test_handler(bot: Bot, event: Event, state: dict): print("[!] Command 帮助:", state["help"]) if state["help"] not in ["test1", "test2"]: - await bot.send_private_msg(message=f"{state['help']} 不支持,请重新输入!") - test_command.reject() + await test_command.reject(f"{state['help']} 不支持,请重新输入!") await bot.send_private_msg(message=f"{state['help']} 帮助:\n...")