add contextvars and fix mutable default args

This commit is contained in:
yanyongyu 2020-08-25 15:23:10 +08:00
parent d66259da2b
commit c5ea8bc1c3
5 changed files with 112 additions and 33 deletions

View File

@ -42,6 +42,10 @@ class BaseBot(abc.ABC):
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, data: dict):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def send(self, *args, **kwargs):
raise NotImplementedError
# TODO: improve event # TODO: improve event
class BaseEvent(abc.ABC): class BaseEvent(abc.ABC):
@ -102,6 +106,16 @@ class BaseEvent(abc.ABC):
def user_id(self, value) -> None: def user_id(self, value) -> None:
raise NotImplementedError raise NotImplementedError
@property
@abc.abstractmethod
def group_id(self) -> Optional[int]:
raise NotImplementedError
@group_id.setter
@abc.abstractmethod
def group_id(self, value) -> None:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def to_me(self) -> Optional[bool]: def to_me(self) -> Optional[bool]:

View File

@ -223,6 +223,36 @@ class Bot(BaseBot):
except httpx.HTTPError: except httpx.HTTPError:
raise NetworkError("HTTP request failed") raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self, event: "Event", message: Union[str, "Message",
"MessageSegment"],
**kwargs) -> Union[Any, NoReturn]:
msg = message if isinstance(message, Message) else Message(message)
at_sender = kwargs.pop("at_sender", False) and bool(event.user_id)
params = {}
if event.user_id:
params["user_id"] = event.user_id
if event.group_id:
params["group_id"] = event.group_id
params.update(kwargs)
if "message_type" not in params:
if "group_id" in params:
params["message_type"] = "group"
elif "user_id" in params:
params["message_type"] = "private"
else:
raise ValueError("Cannot guess message type to reply!")
if at_sender and params["message_type"] != "private":
params["message"] = MessageSegment.at(params["user_id"]) + \
MessageSegment.text(" ") + msg
else:
params["message"] = msg
return await self.send_msg(**params)
class Event(BaseEvent): class Event(BaseEvent):
@ -277,6 +307,16 @@ class Event(BaseEvent):
def user_id(self, value) -> None: def user_id(self, value) -> None:
self._raw_event["user_id"] = value self._raw_event["user_id"] = value
@property
@overrides(BaseEvent)
def group_id(self) -> Optional[int]:
return self._raw_event.get("group_id")
@group_id.setter
@overrides(BaseEvent)
def group_id(self, value) -> None:
self._raw_event["group_id"] = value
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def to_me(self) -> Optional[bool]: def to_me(self) -> Optional[bool]:

View File

@ -6,14 +6,17 @@ import inspect
from functools import wraps from functools import wraps
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
from contextvars import Context, ContextVar, copy_context
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.permission import Permission, USER from nonebot.permission import Permission, USER
from nonebot.typing import Bot, Event, Handler, ArgsParser from nonebot.typing import Type, List, Dict, Union, Callable, Optional, NoReturn
from nonebot.typing import Type, List, Dict, Callable, Optional, NoReturn from nonebot.typing import Bot, Event, Handler, Message, ArgsParser, MessageSegment
from nonebot.exception import PausedException, RejectedException, FinishedException from nonebot.exception import PausedException, RejectedException, FinishedException
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
current_bot: ContextVar = ContextVar("current_bot")
current_event: ContextVar = ContextVar("current_event")
class Matcher: class Matcher:
@ -51,12 +54,12 @@ class Matcher:
type_: str = "", type_: str = "",
rule: Rule = Rule(), rule: Rule = Rule(),
permission: Permission = Permission(), permission: Permission = Permission(),
handlers: list = [], handlers: Optional[list] = None,
temp: bool = False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
*, *,
default_state: dict = {}, default_state: Optional[dict] = None,
expire_time: Optional[datetime] = None) -> Type["Matcher"]: expire_time: Optional[datetime] = None) -> Type["Matcher"]:
"""创建新的 Matcher """创建新的 Matcher
@ -69,12 +72,12 @@ class Matcher:
"type": type_, "type": type_,
"rule": rule, "rule": rule,
"permission": permission, "permission": permission,
"handlers": handlers, "handlers": handlers or [],
"temp": temp, "temp": temp,
"expire_time": expire_time, "expire_time": expire_time,
"priority": priority, "priority": priority,
"block": block, "block": block,
"_default_state": default_state "_default_state": default_state or {}
}) })
matchers[priority].append(NewMatcher) matchers[priority].append(NewMatcher)
@ -117,12 +120,12 @@ class Matcher:
def receive(cls) -> Callable[[Handler], Handler]: def receive(cls) -> Callable[[Handler], Handler]:
"""接收一条新消息并处理""" """接收一条新消息并处理"""
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn: async def _receive(bot: Bot, event: Event, state: dict) -> NoReturn:
raise PausedException raise PausedException
if cls.handlers: if cls.handlers:
# 已有前置handlers则接受一条新的消息否则视为接收初始消息 # 已有前置handlers则接受一条新的消息否则视为接收初始消息
cls.handlers.append(_handler) cls.handlers.append(_receive)
def _decorator(func: Handler) -> Handler: def _decorator(func: Handler) -> Handler:
if not cls.handlers or cls.handlers[-1] is not func: if not cls.handlers or cls.handlers[-1] is not func:
@ -144,8 +147,7 @@ class Matcher:
if key not in state: if key not in state:
state["_current_key"] = key state["_current_key"] = key
if prompt: if prompt:
await bot.send_private_msg(user_id=event.user_id, await bot.send(event=event, message=prompt)
message=prompt)
raise PausedException raise PausedException
async def _key_parser(bot: Bot, event: Event, state: dict): async def _key_parser(bot: Bot, event: Event, state: dict):
@ -176,19 +178,42 @@ class Matcher:
return _decorator return _decorator
@classmethod @classmethod
def finish(cls) -> NoReturn: async def finish(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise FinishedException raise FinishedException
@classmethod @classmethod
def pause(cls) -> NoReturn: async def pause(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise PausedException raise PausedException
@classmethod @classmethod
def reject(cls) -> NoReturn: async def reject(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise RejectedException raise RejectedException
# 运行handlers # 运行handlers
async def run(self, bot: Bot, event: Event, state: dict): async def run(self, bot: Bot, event: Event, state: dict):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
try: try:
# Refresh preprocess state # Refresh preprocess state
self.state.update(state) self.state.update(state)
@ -214,7 +239,6 @@ class Matcher:
block=True, block=True,
default_state=self.state, default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout) expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except PausedException: except PausedException:
Matcher.new( Matcher.new(
self.type, self.type,
@ -226,6 +250,8 @@ class Matcher:
block=True, block=True,
default_state=self.state, default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout) expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except FinishedException: except FinishedException:
return pass
finally:
current_bot.reset(b_t)
current_event.reset(e_t)

View File

@ -31,11 +31,11 @@ class Plugin(object):
def on(rule: Union[Rule, RuleChecker] = Rule(), def on(rule: Union[Rule, RuleChecker] = Rule(),
permission: Permission = Permission(), permission: Permission = Permission(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("", matcher = Matcher.new("",
Rule() & rule, Rule() & rule,
permission, permission,
@ -50,11 +50,11 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("meta_event", matcher = Matcher.new("meta_event",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
@ -70,11 +70,11 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
def on_message(rule: Union[Rule, RuleChecker] = Rule(), def on_message(rule: Union[Rule, RuleChecker] = Rule(),
permission: Permission = Permission(), permission: Permission = Permission(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = True, block: bool = True,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("message", matcher = Matcher.new("message",
Rule() & rule, Rule() & rule,
permission, permission,
@ -89,11 +89,11 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
def on_notice(rule: Union[Rule, RuleChecker] = Rule(), def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("notice", matcher = Matcher.new("notice",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
@ -108,11 +108,11 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
def on_request(rule: Union[Rule, RuleChecker] = Rule(), def on_request(rule: Union[Rule, RuleChecker] = Rule(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("request", matcher = Matcher.new("request",
Rule() & rule, Rule() & rule,
Permission(), Permission(),

View File

@ -24,6 +24,5 @@ async def test_handler(bot: Bot, event: Event, state: dict):
async def test_handler(bot: Bot, event: Event, state: dict): async def test_handler(bot: Bot, event: Event, state: dict):
print("[!] Command 帮助:", state["help"]) print("[!] Command 帮助:", state["help"])
if state["help"] not in ["test1", "test2"]: if state["help"] not in ["test1", "test2"]:
await bot.send_private_msg(message=f"{state['help']} 不支持,请重新输入!") await test_command.reject(f"{state['help']} 不支持,请重新输入!")
test_command.reject()
await bot.send_private_msg(message=f"{state['help']} 帮助:\n...") await bot.send_private_msg(message=f"{state['help']} 帮助:\n...")