Merge pull request #5 from nonebot/dev

This commit is contained in:
CodeCreator 2020-08-26 14:52:34 +08:00 committed by GitHub
commit 310aeb8447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 329 additions and 155 deletions

View File

@ -147,6 +147,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
""" """
global _driver global _driver
env = Env() env = Env()
logger.debug(f"Current Env: {env.environment}")
config = Config(**kwargs, _env_file=_env_file or f".env.{env.environment}") config = Config(**kwargs, _env_file=_env_file or f".env.{env.environment}")
logger.setLevel(logging.DEBUG if config.debug else logging.INFO) logger.setLevel(logging.DEBUG if config.debug else logging.INFO)

View File

@ -42,6 +42,10 @@ class BaseBot(abc.ABC):
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, data: dict):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def send(self, *args, **kwargs):
raise NotImplementedError
# TODO: improve event # TODO: improve event
class BaseEvent(abc.ABC): class BaseEvent(abc.ABC):
@ -50,18 +54,32 @@ class BaseEvent(abc.ABC):
self._raw_event = raw_event self._raw_event = raw_event
def __repr__(self) -> str: def __repr__(self) -> str:
# TODO: pretty print return f"<Event {self.self_id}: {self.name} {self.time}>"
return f"<Event: {self.type}/{self.detail_type} {self.raw_message}>"
@property @property
def raw_event(self) -> dict: def raw_event(self) -> dict:
return self._raw_event return self._raw_event
@property
@abc.abstractmethod
def id(self) -> int:
raise NotImplementedError
@property
@abc.abstractmethod
def name(self) -> str:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def self_id(self) -> str: def self_id(self) -> str:
raise NotImplementedError raise NotImplementedError
@property
@abc.abstractmethod
def time(self) -> int:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def type(self) -> str: def type(self) -> str:
@ -102,6 +120,16 @@ class BaseEvent(abc.ABC):
def user_id(self, value) -> None: def user_id(self, value) -> None:
raise NotImplementedError raise NotImplementedError
@property
@abc.abstractmethod
def group_id(self) -> Optional[int]:
raise NotImplementedError
@group_id.setter
@abc.abstractmethod
def group_id(self, value) -> None:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def to_me(self) -> Optional[bool]: def to_me(self) -> Optional[bool]:
@ -151,7 +179,7 @@ class BaseEvent(abc.ABC):
@dataclass @dataclass
class BaseMessageSegment(abc.ABC): class BaseMessageSegment(abc.ABC):
type: str type: str
data: Dict[str, str] = field(default_factory=lambda: {}) data: Dict[str, Union[str, list]] = field(default_factory=lambda: {})
@abc.abstractmethod @abc.abstractmethod
def __str__(self): def __str__(self):

View File

@ -1,5 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
CQHTTP (OneBot) v11 协议适配
============================
协议详情请看: `CQHTTP`_ | `OneBot`_
.. _CQHTTP:
http://cqhttp.cc/
.. _OneBot:
https://github.com/howmanybots/onebot
"""
import re import re
import sys import sys
@ -38,8 +49,8 @@ def unescape(s: str) -> str:
.replace("&amp;", "&") .replace("&amp;", "&")
def _b2s(b: bool) -> str: def _b2s(b: Optional[bool]) -> Optional[str]:
return str(b).lower() return b if b is None else str(b).lower()
def _check_at_me(bot: "Bot", event: "Event"): def _check_at_me(bot: "Bot", event: "Event"):
@ -168,9 +179,14 @@ class Bot(BaseBot):
if not message: if not message:
return return
if "post_type" not in message:
ResultStore.add_result(message)
return
event = Event(message) event = Event(message)
# Check whether user is calling me # Check whether user is calling me
# TODO: Check reply
_check_at_me(self, event) _check_at_me(self, event)
_check_nickname(self, event) _check_nickname(self, event)
@ -223,6 +239,36 @@ class Bot(BaseBot):
except httpx.HTTPError: except httpx.HTTPError:
raise NetworkError("HTTP request failed") raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self, event: "Event", message: Union[str, "Message",
"MessageSegment"],
**kwargs) -> Union[Any, NoReturn]:
msg = message if isinstance(message, Message) else Message(message)
at_sender = kwargs.pop("at_sender", False) and bool(event.user_id)
params = {}
if event.user_id:
params["user_id"] = event.user_id
if event.group_id:
params["group_id"] = event.group_id
params.update(kwargs)
if "message_type" not in params:
if "group_id" in params:
params["message_type"] = "group"
elif "user_id" in params:
params["message_type"] = "private"
else:
raise ValueError("Cannot guess message type to reply!")
if at_sender and params["message_type"] != "private":
params["message"] = MessageSegment.at(params["user_id"]) + \
MessageSegment.text(" ") + msg
else:
params["message"] = msg
return await self.send_msg(**params)
class Event(BaseEvent): class Event(BaseEvent):
@ -232,11 +278,29 @@ class Event(BaseEvent):
super().__init__(raw_event) super().__init__(raw_event)
@property
@overrides(BaseEvent)
def id(self) -> Optional[int]:
return self._raw_event.get("message_id") or self._raw_event.get("flag")
@property
@overrides(BaseEvent)
def name(self) -> str:
n = self.type + "." + self.detail_type
if self.sub_type:
n += "." + self.sub_type
return n
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def self_id(self) -> str: def self_id(self) -> str:
return str(self._raw_event["self_id"]) return str(self._raw_event["self_id"])
@property
@overrides(BaseEvent)
def time(self) -> int:
return self._raw_event["time"]
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def type(self) -> str: def type(self) -> str:
@ -277,6 +341,16 @@ class Event(BaseEvent):
def user_id(self, value) -> None: def user_id(self, value) -> None:
self._raw_event["user_id"] = value self._raw_event["user_id"] = value
@property
@overrides(BaseEvent)
def group_id(self) -> Optional[int]:
return self._raw_event.get("group_id")
@group_id.setter
@overrides(BaseEvent)
def group_id(self, value) -> None:
self._raw_event["group_id"] = value
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def to_me(self) -> Optional[bool]: def to_me(self) -> Optional[bool]:
@ -326,14 +400,8 @@ class Event(BaseEvent):
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __init__(self, type: str, data: Dict[str, str]) -> None: def __init__(self, type: str, data: Dict[str, Union[str, list]]) -> None:
if type == "at" and data.get("qq") == "all": if type == "text":
type = "at_all"
data.clear()
elif type == "shake":
type = "poke"
data = {"type": "Poke"}
elif type == "text":
data["text"] = unescape(data["text"]) data["text"] = unescape(data["text"])
super().__init__(type=type, data=data) super().__init__(type=type, data=data)
@ -343,16 +411,11 @@ class MessageSegment(BaseMessageSegment):
data = self.data.copy() data = self.data.copy()
# process special types # process special types
if type_ == "at_all": if type_ == "text":
type_ = "at"
data = {"qq": "all"}
elif type_ == "poke":
type_ = "shake"
data.clear()
elif type_ == "text":
return escape(data.get("text", ""), escape_comma=False) return escape(data.get("text", ""), escape_comma=False)
params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) params = ",".join(
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None])
return f"[CQ:{type_}{',' if params else ''}{params}]" return f"[CQ:{type_}{',' if params else ''}{params}]"
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
@ -360,17 +423,13 @@ class MessageSegment(BaseMessageSegment):
return Message(self) + other return Message(self) + other
@staticmethod @staticmethod
def anonymous(ignore_failure: bool = False) -> "MessageSegment": def anonymous(ignore_failure: Optional[bool] = None) -> "MessageSegment":
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
@staticmethod @staticmethod
def at(user_id: Union[int, str]) -> "MessageSegment": def at(user_id: Union[int, str]) -> "MessageSegment":
return MessageSegment("at", {"qq": str(user_id)}) return MessageSegment("at", {"qq": str(user_id)})
@staticmethod
def at_all() -> "MessageSegment":
return MessageSegment("at_all")
@staticmethod @staticmethod
def contact_group(group_id: int) -> "MessageSegment": def contact_group(group_id: int) -> "MessageSegment":
return MessageSegment("contact", {"type": "group", "id": str(group_id)}) return MessageSegment("contact", {"type": "group", "id": str(group_id)})
@ -379,23 +438,43 @@ class MessageSegment(BaseMessageSegment):
def contact_user(user_id: int) -> "MessageSegment": def contact_user(user_id: int) -> "MessageSegment":
return MessageSegment("contact", {"type": "qq", "id": str(user_id)}) return MessageSegment("contact", {"type": "qq", "id": str(user_id)})
@staticmethod
def dice() -> "MessageSegment":
return MessageSegment("dice", {})
@staticmethod @staticmethod
def face(id_: int) -> "MessageSegment": def face(id_: int) -> "MessageSegment":
return MessageSegment("face", {"id": str(id_)}) return MessageSegment("face", {"id": str(id_)})
@staticmethod @staticmethod
def forward(id_: str) -> "MessageSegment": def forward(id_: str) -> "MessageSegment":
logger.warning("Forward Message only can be received!")
return MessageSegment("forward", {"id": id_}) return MessageSegment("forward", {"id": id_})
@staticmethod @staticmethod
def image(file: str) -> "MessageSegment": def image(file: str,
return MessageSegment("image", {"file": file}) type_: Optional[str] = None,
cache: bool = True,
proxy: bool = True,
timeout: Optional[int] = None) -> "MessageSegment":
return MessageSegment(
"image", {
"file": file,
"type": type_,
"cache": cache,
"proxy": proxy,
"timeout": timeout
})
@staticmethod
def json(data: str) -> "MessageSegment":
return MessageSegment("json", {"data": data})
@staticmethod @staticmethod
def location(latitude: float, def location(latitude: float,
longitude: float, longitude: float,
title: str = "", title: Optional[str] = None,
content: str = "") -> "MessageSegment": content: Optional[str] = None) -> "MessageSegment":
return MessageSegment( return MessageSegment(
"location", { "location", {
"lat": str(latitude), "lat": str(latitude),
@ -405,36 +484,18 @@ class MessageSegment(BaseMessageSegment):
}) })
@staticmethod @staticmethod
def magic_face(type_: str) -> "MessageSegment": def music(type_: str, id_: int) -> "MessageSegment":
if type_ not in ["dice", "rpc"]:
raise ValueError(
f"Coolq doesn't support magic face type {type_}. Supported types: dice, rpc."
)
return MessageSegment("magic_face", {"type": type_})
@staticmethod
def music(type_: str,
id_: int,
style: Optional[int] = None) -> "MessageSegment":
if style is None:
return MessageSegment("music", {"type": type_, "id": id_}) return MessageSegment("music", {"type": type_, "id": id_})
else:
return MessageSegment("music", {
"type": type_,
"id": id_,
"style": style
})
@staticmethod @staticmethod
def music_custom(type_: str, def music_custom(url: str,
url: str,
audio: str, audio: str,
title: str, title: str,
content: str = "", content: Optional[str] = None,
img_url: str = "") -> "MessageSegment": img_url: Optional[str] = None) -> "MessageSegment":
return MessageSegment( return MessageSegment(
"music", { "music", {
"type": type_, "type": "custom",
"url": url, "url": url,
"audio": audio, "audio": audio,
"title": title, "title": title,
@ -447,35 +508,43 @@ class MessageSegment(BaseMessageSegment):
return MessageSegment("node", {"id": str(id_)}) return MessageSegment("node", {"id": str(id_)})
@staticmethod @staticmethod
def node_custom(name: str, uin: int, def node_custom(user_id: int, nickname: str,
content: "Message") -> "MessageSegment": content: Union[str, "Message"]) -> "MessageSegment":
return MessageSegment("node", { return MessageSegment("node", {
"name": name, "user_id": str(user_id),
"uin": str(uin), "nickname": nickname,
"content": str(content) "content": content
}) })
@staticmethod @staticmethod
def poke(type_: str = "Poke") -> "MessageSegment": def poke(type_: str, id_: str) -> "MessageSegment":
if type_ not in ["Poke"]: return MessageSegment("poke", {"type": type_, "id": id_})
raise ValueError(
f"Coolq doesn't support poke type {type_}. Supported types: Poke."
)
return MessageSegment("poke", {"type": type_})
@staticmethod @staticmethod
def record(file: str, magic: bool = False) -> "MessageSegment": def record(file: str,
magic: Optional[bool] = None,
cache: Optional[bool] = None,
proxy: Optional[bool] = None,
timeout: Optional[int] = None) -> "MessageSegment":
return MessageSegment("record", {"file": file, "magic": _b2s(magic)}) return MessageSegment("record", {"file": file, "magic": _b2s(magic)})
@staticmethod @staticmethod
def replay(id_: int) -> "MessageSegment": def reply(id_: int) -> "MessageSegment":
return MessageSegment("replay", {"id": str(id_)}) return MessageSegment("reply", {"id": str(id_)})
@staticmethod
def rps() -> "MessageSegment":
return MessageSegment("rps", {})
@staticmethod
def shake() -> "MessageSegment":
return MessageSegment("shake", {})
@staticmethod @staticmethod
def share(url: str = "", def share(url: str = "",
title: str = "", title: str = "",
content: str = "", content: Optional[str] = None,
img_url: str = "") -> "MessageSegment": img_url: Optional[str] = None) -> "MessageSegment":
return MessageSegment("share", { return MessageSegment("share", {
"url": url, "url": url,
"title": title, "title": title,
@ -487,6 +556,22 @@ class MessageSegment(BaseMessageSegment):
def text(text: str) -> "MessageSegment": def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text}) return MessageSegment("text", {"text": text})
@staticmethod
def video(file: str,
cache: Optional[bool] = None,
proxy: Optional[bool] = None,
timeout: Optional[int] = None) -> "MessageSegment":
return MessageSegment("video", {
"file": file,
"cache": cache,
"proxy": proxy,
"timeout": timeout
})
@staticmethod
def xml(data: str) -> "MessageSegment":
return MessageSegment("xml", {"data": data})
class Message(BaseMessage): class Message(BaseMessage):
@ -501,7 +586,7 @@ class Message(BaseMessage):
yield MessageSegment(seg["type"], seg.get("data") or {}) yield MessageSegment(seg["type"], seg.get("data") or {})
return return
def _iter_message() -> Iterable[Tuple[str, str]]: def _iter_message(msg: str) -> Iterable[Tuple[str, str]]:
text_begin = 0 text_begin = 0
for cqcode in re.finditer( for cqcode in re.finditer(
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)" r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
@ -514,7 +599,7 @@ class Message(BaseMessage):
yield cqcode.group("type"), cqcode.group("params").lstrip(",") yield cqcode.group("type"), cqcode.group("params").lstrip(",")
yield "text", unescape(msg[text_begin:]) yield "text", unescape(msg[text_begin:])
for type_, data in _iter_message(): for type_, data in _iter_message(msg):
if type_ == "text": if type_ == "text":
if data: if data:
# only yield non-empty text segment # only yield non-empty text segment
@ -526,13 +611,4 @@ class Message(BaseMessage):
filter(lambda x: x, ( filter(lambda x: x, (
x.lstrip() for x in data.split(",")))) x.lstrip() for x in data.split(","))))
} }
if type_ == "at" and data["qq"] == "all":
type_ = "at_all"
data.clear()
elif type_ in ["dice", "rpc"]:
type_ = "magic_face"
data["type"] = type_
elif type_ == "shake":
type_ = "poke"
data["type"] = "Poke"
yield MessageSegment(type_, data) yield MessageSegment(type_, data)

View File

@ -1,19 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import hmac
import json import json
import logging import logging
import uvicorn import uvicorn
from fastapi.responses import Response
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Union, Optional, Callable, overrides from nonebot.typing import Optional, Callable, overrides
def get_auth_bearer(access_token: Optional[str] = Header( def get_auth_bearer(access_token: Optional[str] = Header(
@ -116,28 +117,50 @@ class Driver(BaseDriver):
**kwargs) **kwargs)
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_http( async def _handle_http(self,
self,
adapter: str, adapter: str,
data: dict = Body(...), data: dict = Body(...),
x_self_id: str = Header(None), x_self_id: Optional[str] = Header(None),
access_token: Optional[str] = Depends(get_auth_bearer)): x_signature: Optional[str] = Header(None)):
secret = self.config.secret # 检查self_id
if secret is not None and secret != access_token: if not x_self_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, logger.warning("Missing X-Self-ID Header")
detail="Not authenticated", raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
headers={"WWW-Authenticate": "Bearer"}) detail="Missing X-Self-ID Header")
# Create Bot Object # 检查签名
secret = self.config.secret
if secret:
if not x_signature:
logger.warning("Missing Signature Header")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Signature")
sig = hmac.new(secret.encode("utf-8"),
json.dumps(data).encode(), "sha1").hexdigest()
if x_signature != "sha1=" + sig:
logger.warning("Signature Header is invalid")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail="Signature is invalid")
if not isinstance(data, dict):
logger.warning("Data received is invalid")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if x_self_id in self._clients:
logger.warning("There's already a reverse websocket api connection,"
"so the event may be handled twice.")
# 创建 Bot 对象
if adapter in self._adapters: if adapter in self._adapters:
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id) bot = BotClass(self, "http", self.config, x_self_id)
else: else:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found") detail="adapter not found")
await bot.handle_message(data) await bot.handle_message(data)
return {"status": 200, "message": "success"} return Response("", 204)
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_ws_reverse( async def _handle_ws_reverse(
@ -146,19 +169,21 @@ class Driver(BaseDriver):
websocket: FastAPIWebSocket, websocket: FastAPIWebSocket,
x_self_id: str = Header(None), x_self_id: str = Header(None),
access_token: Optional[str] = Depends(get_auth_bearer)): access_token: Optional[str] = Depends(get_auth_bearer)):
ws = WebSocket(websocket)
secret = self.config.secret secret = self.config.secret
if secret is not None and secret != access_token: if secret is not None and secret != access_token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) logger.warning("Authorization Header is invalid"
if access_token else "Missing Authorization Header")
websocket = WebSocket(websocket) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
if not x_self_id: if not x_self_id:
logger.error(f"Error Connection Unkown: self_id {x_self_id}") logger.warning(f"Missing X-Self-ID Header")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
if x_self_id in self._clients: if x_self_id in self._clients:
logger.error(f"Error Connection Conflict: self_id {x_self_id}") logger.warning(f"Connection Conflict: self_id {x_self_id}")
await websocket.close(code=status.WS_1008_POLICY_VIOLATION) await ws.close(code=status.WS_1008_POLICY_VIOLATION)
# Create Bot Object # Create Bot Object
if adapter in self._adapters: if adapter in self._adapters:
@ -167,17 +192,18 @@ class Driver(BaseDriver):
"websocket", "websocket",
self.config, self.config,
x_self_id, x_self_id,
websocket=websocket) websocket=ws)
else: else:
logger.warning("Unknown adapter")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail="adapter not found") detail="adapter not found")
await websocket.accept() await ws.accept()
self._clients[x_self_id] = bot self._clients[x_self_id] = bot
try: try:
while not websocket.closed: while not ws.closed:
data = await websocket.receive() data = await ws.receive()
if not data: if not data:
continue continue
@ -213,8 +239,11 @@ class WebSocket(BaseWebSocket):
data = None data = None
try: try:
data = await self.websocket.receive_json() data = await self.websocket.receive_json()
if not isinstance(data, dict):
data = None
raise ValueError
except ValueError: except ValueError:
logger.debug("Received an invalid json message.") logger.warning("Received an invalid json message.")
except WebSocketDisconnect: except WebSocketDisconnect:
self._closed = True self._closed = True
logger.error("WebSocket disconnected by peer.") logger.error("WebSocket disconnected by peer.")

View File

@ -6,14 +6,17 @@ import inspect
from functools import wraps from functools import wraps
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
from contextvars import Context, ContextVar, copy_context
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.permission import Permission, USER from nonebot.permission import Permission, USER
from nonebot.typing import Bot, Event, Handler, ArgsParser from nonebot.typing import Type, List, Dict, Union, Callable, Optional, NoReturn
from nonebot.typing import Type, List, Dict, Callable, Optional, NoReturn from nonebot.typing import Bot, Event, Handler, Message, ArgsParser, MessageSegment
from nonebot.exception import PausedException, RejectedException, FinishedException from nonebot.exception import PausedException, RejectedException, FinishedException
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
current_bot: ContextVar = ContextVar("current_bot")
current_event: ContextVar = ContextVar("current_event")
class Matcher: class Matcher:
@ -51,12 +54,12 @@ class Matcher:
type_: str = "", type_: str = "",
rule: Rule = Rule(), rule: Rule = Rule(),
permission: Permission = Permission(), permission: Permission = Permission(),
handlers: list = [], handlers: Optional[list] = None,
temp: bool = False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
*, *,
default_state: dict = {}, default_state: Optional[dict] = None,
expire_time: Optional[datetime] = None) -> Type["Matcher"]: expire_time: Optional[datetime] = None) -> Type["Matcher"]:
"""创建新的 Matcher """创建新的 Matcher
@ -69,12 +72,12 @@ class Matcher:
"type": type_, "type": type_,
"rule": rule, "rule": rule,
"permission": permission, "permission": permission,
"handlers": handlers, "handlers": handlers or [],
"temp": temp, "temp": temp,
"expire_time": expire_time, "expire_time": expire_time,
"priority": priority, "priority": priority,
"block": block, "block": block,
"_default_state": default_state "_default_state": default_state or {}
}) })
matchers[priority].append(NewMatcher) matchers[priority].append(NewMatcher)
@ -117,12 +120,12 @@ class Matcher:
def receive(cls) -> Callable[[Handler], Handler]: def receive(cls) -> Callable[[Handler], Handler]:
"""接收一条新消息并处理""" """接收一条新消息并处理"""
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn: async def _receive(bot: Bot, event: Event, state: dict) -> NoReturn:
raise PausedException raise PausedException
if cls.handlers: if cls.handlers:
# 已有前置handlers则接受一条新的消息否则视为接收初始消息 # 已有前置handlers则接受一条新的消息否则视为接收初始消息
cls.handlers.append(_handler) cls.handlers.append(_receive)
def _decorator(func: Handler) -> Handler: def _decorator(func: Handler) -> Handler:
if not cls.handlers or cls.handlers[-1] is not func: if not cls.handlers or cls.handlers[-1] is not func:
@ -144,8 +147,7 @@ class Matcher:
if key not in state: if key not in state:
state["_current_key"] = key state["_current_key"] = key
if prompt: if prompt:
await bot.send_private_msg(user_id=event.user_id, await bot.send(event=event, message=prompt)
message=prompt)
raise PausedException raise PausedException
async def _key_parser(bot: Bot, event: Event, state: dict): async def _key_parser(bot: Bot, event: Event, state: dict):
@ -176,19 +178,42 @@ class Matcher:
return _decorator return _decorator
@classmethod @classmethod
def finish(cls) -> NoReturn: async def finish(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise FinishedException raise FinishedException
@classmethod @classmethod
def pause(cls) -> NoReturn: async def pause(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise PausedException raise PausedException
@classmethod @classmethod
def reject(cls) -> NoReturn: async def reject(
cls,
prompt: Optional[Union[str, Message,
MessageSegment]] = None) -> NoReturn:
bot: Bot = current_bot.get()
event: Event = current_event.get()
if prompt:
await bot.send(event=event, message=prompt)
raise RejectedException raise RejectedException
# 运行handlers # 运行handlers
async def run(self, bot: Bot, event: Event, state: dict): async def run(self, bot: Bot, event: Event, state: dict):
b_t = current_bot.set(bot)
e_t = current_event.set(event)
try: try:
# Refresh preprocess state # Refresh preprocess state
self.state.update(state) self.state.update(state)
@ -214,7 +239,6 @@ class Matcher:
block=True, block=True,
default_state=self.state, default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout) expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except PausedException: except PausedException:
Matcher.new( Matcher.new(
self.type, self.type,
@ -226,6 +250,8 @@ class Matcher:
block=True, block=True,
default_state=self.state, default_state=self.state,
expire_time=datetime.now() + bot.config.session_expire_timeout) expire_time=datetime.now() + bot.config.session_expire_timeout)
return
except FinishedException: except FinishedException:
return pass
finally:
current_bot.reset(b_t)
current_event.reset(e_t)

View File

@ -53,15 +53,31 @@ async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event,
async def handle_event(bot: Bot, event: Event): async def handle_event(bot: Bot, event: Event):
log_msg = f"{bot.type.upper()} Bot {event.self_id} [{event.name}]: "
if event.type == "message":
log_msg += f"Message {event.id} from "
log_msg += str(event.user_id)
if event.detail_type == "group":
log_msg += f"@[群:{event.group_id}]: "
log_msg += repr(str(event.message))
elif event.type == "notice":
log_msg += f"Notice {event.raw_event}"
elif event.type == "request":
log_msg += f"Request {event.raw_event}"
elif event.type == "meta_event":
log_msg += f"MetaEvent {event.raw_event}"
logger.info(log_msg)
coros = [] coros = []
state = {} state = {}
for preprocessor in _event_preprocessors: for preprocessor in _event_preprocessors:
coros.append(preprocessor(bot, event, state)) coros.append(preprocessor(bot, event, state))
if coros: if coros:
try: try:
logger.debug("Running PreProcessors...")
await asyncio.gather(*coros) await asyncio.gather(*coros)
except IgnoredException: except IgnoredException:
logger.info(f"Event {event} is ignored") logger.info(f"Event {event.name} is ignored")
return return
# Trie Match # Trie Match
@ -77,6 +93,7 @@ async def handle_event(bot: Bot, event: Event):
for matcher in matchers[priority] for matcher in matchers[priority]
] ]
logger.debug(f"Checking for all matchers in priority {priority}...")
results = await asyncio.gather(*pending_tasks, return_exceptions=True) results = await asyncio.gather(*pending_tasks, return_exceptions=True)
i = 0 i = 0
@ -85,6 +102,7 @@ async def handle_event(bot: Bot, event: Event):
e_list = result.exceptions e_list = result.exceptions
if StopPropagation in e_list: if StopPropagation in e_list:
break_flag = True break_flag = True
logger.debug("Stop event propafation")
if ExpiredException in e_list: if ExpiredException in e_list:
del matchers[priority][index - i] del matchers[priority][index - i]
i += 1 i += 1

View File

@ -31,11 +31,11 @@ class Plugin(object):
def on(rule: Union[Rule, RuleChecker] = Rule(), def on(rule: Union[Rule, RuleChecker] = Rule(),
permission: Permission = Permission(), permission: Permission = Permission(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("", matcher = Matcher.new("",
Rule() & rule, Rule() & rule,
permission, permission,
@ -50,11 +50,11 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(), def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("meta_event", matcher = Matcher.new("meta_event",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
@ -70,11 +70,11 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
def on_message(rule: Union[Rule, RuleChecker] = Rule(), def on_message(rule: Union[Rule, RuleChecker] = Rule(),
permission: Permission = Permission(), permission: Permission = Permission(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = True, block: bool = True,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("message", matcher = Matcher.new("message",
Rule() & rule, Rule() & rule,
permission, permission,
@ -89,11 +89,11 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
def on_notice(rule: Union[Rule, RuleChecker] = Rule(), def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("notice", matcher = Matcher.new("notice",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
@ -108,11 +108,11 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
def on_request(rule: Union[Rule, RuleChecker] = Rule(), def on_request(rule: Union[Rule, RuleChecker] = Rule(),
*, *,
handlers=[], handlers: Optional[list] = None,
temp=False, temp: bool = False,
priority: int = 1, priority: int = 1,
block: bool = False, block: bool = False,
state={}) -> Type[Matcher]: state: Optional[dict] = None) -> Type[Matcher]:
matcher = Matcher.new("request", matcher = Matcher.new("request",
Rule() & rule, Rule() & rule,
Permission(), Permission(),
@ -143,7 +143,7 @@ def on_endswith(msg: str,
startswith(msg), permission, **kwargs) startswith(msg), permission, **kwargs)
def on_command(cmd: Union[str, Tuple[str]], def on_command(cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, RuleChecker]] = None, rule: Optional[Union[Rule, RuleChecker]] = None,
permission: Permission = Permission(), permission: Permission = Permission(),
**kwargs) -> Type[Matcher]: **kwargs) -> Type[Matcher]:

View File

@ -110,7 +110,7 @@ def keyword(msg: str) -> Rule:
return Rule(_keyword) return Rule(_keyword)
def command(command: Tuple[str]) -> Rule: def command(command: Tuple[str, ...]) -> Rule:
config = get_driver().config config = get_driver().config
command_start = config.command_start command_start = config.command_start
command_sep = config.command_sep command_sep = config.command_sep

View File

@ -1,19 +1,16 @@
{ {
"name": "nonebot", "name": "nonebot",
"version": "2.0.0", "version": "2.0.0",
"description": "An asynchronous QQ bot framework.", "description": "An asynchronous python bot framework.",
"homepage": "https://nonebot.cqp.moe/", "homepage": "https://docs.nonebot.dev/",
"main": "index.js", "main": "index.js",
"contributors": [{ "contributors": [
"name": "Richard Chien",
"email": "richardchienthebest@gmail.com"
},
{ {
"name": "yanyongyu", "name": "yanyongyu",
"email": "yanyongyu_1@126.com" "email": "yanyongyu_1@126.com"
} }
], ],
"repository": "https://github.com/nonebot/nonebot/nonebot", "repository": "https://github.com/nonebot/nonebot/",
"bugs": { "bugs": {
"url": "https://github.com/nonebot/nonebot/issues" "url": "https://github.com/nonebot/nonebot/issues"
}, },

View File

@ -1,13 +1,13 @@
[tool.poetry] [tool.poetry]
name = "nonebot" name = "nonebot"
version = "2.0.0" version = "2.0.0"
description = "An asynchronous QQ bot framework." description = "An asynchronous python bot framework."
authors = ["Richard Chien <richardchienthebest@gmail.com>", "yanyongyu <yanyongyu_1@126.com>"] authors = ["yanyongyu <yanyongyu_1@126.com>"]
license = "MIT" license = "MIT"
readme = "README.md" readme = "README.md"
homepage = "https://nonebot.cqp.moe/" homepage = "https://docs.nonebot.dev/"
repository = "https://github.com/nonebot/nonebot" repository = "https://github.com/nonebot/nonebot"
documentation = "https://nonebot.cqp.moe/" documentation = "https://docs.nonebot.dev/"
keywords = ["bot", "qq", "qqbot", "mirai", "coolq"] keywords = ["bot", "qq", "qqbot", "mirai", "coolq"]
classifiers = [ classifiers = [
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",

View File

@ -24,6 +24,5 @@ async def test_handler(bot: Bot, event: Event, state: dict):
async def test_handler(bot: Bot, event: Event, state: dict): async def test_handler(bot: Bot, event: Event, state: dict):
print("[!] Command 帮助:", state["help"]) print("[!] Command 帮助:", state["help"])
if state["help"] not in ["test1", "test2"]: if state["help"] not in ["test1", "test2"]:
await bot.send_private_msg(message=f"{state['help']} 不支持,请重新输入!") await test_command.reject(f"{state['help']} 不支持,请重新输入!")
test_command.reject()
await bot.send_private_msg(message=f"{state['help']} 帮助:\n...") await bot.send_private_msg(message=f"{state['help']} 帮助:\n...")