add contextvars and fix mutable default args

This commit is contained in:
yanyongyu 2020-08-25 15:23:10 +08:00
parent d66259da2b
commit c5ea8bc1c3
5 changed files with 112 additions and 33 deletions

View File

@ -42,6 +42,10 @@ class BaseBot(abc.ABC):
async def call_api(self, api: str, data: dict):
raise NotImplementedError
@abc.abstractmethod
async def send(self, *args, **kwargs):
raise NotImplementedError
# TODO: improve event
class BaseEvent(abc.ABC):
@ -102,6 +106,16 @@ class BaseEvent(abc.ABC):
def user_id(self, value) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def group_id(self) -> Optional[int]:
raise NotImplementedError
@group_id.setter
@abc.abstractmethod
def group_id(self, value) -> None:
raise NotImplementedError
@property
@abc.abstractmethod
def to_me(self) -> Optional[bool]:

View File

@ -223,6 +223,36 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self, event: "Event", message: Union[str, "Message",
"MessageSegment"],
**kwargs) -> Union[Any, NoReturn]:
msg = message if isinstance(message, Message) else Message(message)
at_sender = kwargs.pop("at_sender", False) and bool(event.user_id)
params = {}
if event.user_id:
params["user_id"] = event.user_id
if event.group_id:
params["group_id"] = event.group_id
params.update(kwargs)
if "message_type" not in params:
if "group_id" in params:
params["message_type"] = "group"
elif "user_id" in params:
params["message_type"] = "private"
else:
raise ValueError("Cannot guess message type to reply!")
if at_sender and params["message_type"] != "private":
params["message"] = MessageSegment.at(params["user_id"]) + \
MessageSegment.text(" ") + msg
else:
params["message"] = msg
return await self.send_msg(**params)
class Event(BaseEvent):
@ -277,6 +307,16 @@ class Event(BaseEvent):
def user_id(self, value) -> None:
self._raw_event["user_id"] = value
@property
@overrides(BaseEvent)
def group_id(self) -> Optional[int]:
return self._raw_event.get("group_id")
@group_id.setter
@overrides(BaseEvent)
def group_id(self, value) -> None:
self._raw_event["group_id"] = value
@property
@overrides(BaseEvent)
def to_me(self) -> Optional[bool]:

View File

@ -6,14 +6,17 @@ import inspect
from functools import wraps
from datetime import datetime
from collections import defaultdict
from contextvars import Context, ContextVar, copy_context
from nonebot.rule import Rule
from nonebot.permission import Permission, USER
from nonebot.typing import Bot, Event, Handler, ArgsParser
from nonebot.typing import Type, List, Dict, Callable, Optional, NoReturn
from nonebot.typing import Type, List, Dict, Union, Callable, Optional, NoReturn
from nonebot.typing import Bot, Event, Handler, Message, ArgsParser, MessageSegment
from nonebot.exception import PausedException, RejectedException, FinishedException
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
current_bot: ContextVar = ContextVar("current_bot")
current_event: ContextVar = ContextVar("current_event")
class Matcher:
@ -51,12 +54,12 @@ class Matcher:
type_: str = "",
rule: Rule = Rule(),
permission: Permission = Permission(),
handlers: list = [],
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
*,
default_state: dict = {},
default_state: Optional[dict] = None,
expire_time: Optional[datetime] = None) -> Type["Matcher"]:
"""创建新的 Matcher
@ -69,12 +72,12 @@ class Matcher:
"type": type_,
"rule": rule,
"permission": permission,
"handlers": handlers,
"handlers": handlers or [],
"temp": temp,
"expire_time": expire_time,
"priority": priority,
"block": block,
"_default_state": default_state
"_default_state": default_state or {}
})
matchers[priority].append(NewMatcher)
@ -117,12 +120,12 @@ class Matcher:
def receive(cls) -> Callable[[Handler], Handler]:
"""接收一条新消息并处理"""
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn:
async def _receive(bot: Bot, event: Event, state: dict) -> NoReturn:
raise PausedException
if cls.handlers:
# 已有前置handlers则接受一条新的消息否则视为接收初始消息
cls.handlers.append(_handler)
cls.handlers.append(_receive)
def _decorator(func: Handler) -> Handler:
if not cls.handlers or cls.handlers[-1] is not func:
@ -144,8 +147,7 @@ class Matcher:
if key not in state:
state["_current_key"] = key
if prompt:
await bot.send_private_msg(user_id=event.user_id,
message=prompt)
await bot.send(event=event, message=prompt)
raise PausedException
async def _key_parser(bot: Bot, event: Event, state: dict):
@ -176,19 +178,42 @@ class Matcher:
return _decorator
@classmethod
def finish(cls) -> NoReturn:
async def finish(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise FinishedException
@classmethod
def pause(cls) -> NoReturn:
async def pause(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise PausedException
@classmethod
def reject(cls) -> NoReturn:
async def reject(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise RejectedException
# 运行handlers
async def run(self, bot: Bot, event: Event, state: dict):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
try:
# Refresh preprocess state
self.state.update(state)
@ -214,7 +239,6 @@ class Matcher:
block=True,
default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except PausedException:
Matcher.new(
self.type,
@ -226,6 +250,8 @@ class Matcher:
block=True,
default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except FinishedException:
return
pass
finally:
current_bot.reset(b_t)
current_event.reset(e_t)

View File

@ -31,11 +31,11 @@ class Plugin(object):
def on(rule: Union[Rule, RuleChecker] = Rule(),
permission: Permission = Permission(),
*,
handlers=[],
temp=False,
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("",
Rule() & rule,
permission,
@ -50,11 +50,11 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("meta_event",
Rule() & rule,
Permission(),
@ -70,11 +70,11 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
def on_message(rule: Union[Rule, RuleChecker] = Rule(),
permission: Permission = Permission(),
*,
handlers=[],
temp=False,
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = True,
state={}) -> Type[Matcher]:
state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("message",
Rule() & rule,
permission,
@ -89,11 +89,11 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("notice",
Rule() & rule,
Permission(),
@ -108,11 +108,11 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
def on_request(rule: Union[Rule, RuleChecker] = Rule(),
*,
handlers=[],
temp=False,
handlers: Optional[list] = None,
temp: bool = False,
priority: int = 1,
block: bool = False,
state={}) -> Type[Matcher]:
state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("request",
Rule() & rule,
Permission(),

View File

@ -24,6 +24,5 @@ async def test_handler(bot: Bot, event: Event, state: dict):
async def test_handler(bot: Bot, event: Event, state: dict):
print("[!] Command 帮助:", state["help"])
if state["help"] not in ["test1", "test2"]:
await bot.send_private_msg(message=f"{state['help']} 不支持,请重新输入!")
test_command.reject()
await test_command.reject(f"{state['help']} 不支持,请重新输入!")
await bot.send_private_msg(message=f"{state['help']} 帮助:\n...")