mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 09:05:04 +08:00
Merge pull request #5 from nonebot/dev
This commit is contained in:
commit
310aeb8447
@ -147,6 +147,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
global _driver
|
||||
env = Env()
|
||||
logger.debug(f"Current Env: {env.environment}")
|
||||
config = Config(**kwargs, _env_file=_env_file or f".env.{env.environment}")
|
||||
|
||||
logger.setLevel(logging.DEBUG if config.debug else logging.INFO)
|
||||
|
@ -42,6 +42,10 @@ class BaseBot(abc.ABC):
|
||||
async def call_api(self, api: str, data: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: improve event
|
||||
class BaseEvent(abc.ABC):
|
||||
@ -50,18 +54,32 @@ class BaseEvent(abc.ABC):
|
||||
self._raw_event = raw_event
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO: pretty print
|
||||
return f"<Event: {self.type}/{self.detail_type} {self.raw_message}>"
|
||||
return f"<Event {self.self_id}: {self.name} {self.time}>"
|
||||
|
||||
@property
|
||||
def raw_event(self) -> dict:
|
||||
return self._raw_event
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def self_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def time(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
@ -102,6 +120,16 @@ class BaseEvent(abc.ABC):
|
||||
def user_id(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def group_id(self) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@group_id.setter
|
||||
@abc.abstractmethod
|
||||
def group_id(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def to_me(self) -> Optional[bool]:
|
||||
@ -151,7 +179,7 @@ class BaseEvent(abc.ABC):
|
||||
@dataclass
|
||||
class BaseMessageSegment(abc.ABC):
|
||||
type: str
|
||||
data: Dict[str, str] = field(default_factory=lambda: {})
|
||||
data: Dict[str, Union[str, list]] = field(default_factory=lambda: {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def __str__(self):
|
||||
|
@ -1,5 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CQHTTP (OneBot) v11 协议适配
|
||||
============================
|
||||
|
||||
协议详情请看: `CQHTTP`_ | `OneBot`_
|
||||
|
||||
.. _CQHTTP:
|
||||
http://cqhttp.cc/
|
||||
.. _OneBot:
|
||||
https://github.com/howmanybots/onebot
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
@ -38,8 +49,8 @@ def unescape(s: str) -> str:
|
||||
.replace("&", "&")
|
||||
|
||||
|
||||
def _b2s(b: bool) -> str:
|
||||
return str(b).lower()
|
||||
def _b2s(b: Optional[bool]) -> Optional[str]:
|
||||
return b if b is None else str(b).lower()
|
||||
|
||||
|
||||
def _check_at_me(bot: "Bot", event: "Event"):
|
||||
@ -168,9 +179,14 @@ class Bot(BaseBot):
|
||||
if not message:
|
||||
return
|
||||
|
||||
if "post_type" not in message:
|
||||
ResultStore.add_result(message)
|
||||
return
|
||||
|
||||
event = Event(message)
|
||||
|
||||
# Check whether user is calling me
|
||||
# TODO: Check reply
|
||||
_check_at_me(self, event)
|
||||
_check_nickname(self, event)
|
||||
|
||||
@ -223,6 +239,36 @@ class Bot(BaseBot):
|
||||
except httpx.HTTPError:
|
||||
raise NetworkError("HTTP request failed")
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def send(self, event: "Event", message: Union[str, "Message",
|
||||
"MessageSegment"],
|
||||
**kwargs) -> Union[Any, NoReturn]:
|
||||
msg = message if isinstance(message, Message) else Message(message)
|
||||
|
||||
at_sender = kwargs.pop("at_sender", False) and bool(event.user_id)
|
||||
|
||||
params = {}
|
||||
if event.user_id:
|
||||
params["user_id"] = event.user_id
|
||||
if event.group_id:
|
||||
params["group_id"] = event.group_id
|
||||
params.update(kwargs)
|
||||
|
||||
if "message_type" not in params:
|
||||
if "group_id" in params:
|
||||
params["message_type"] = "group"
|
||||
elif "user_id" in params:
|
||||
params["message_type"] = "private"
|
||||
else:
|
||||
raise ValueError("Cannot guess message type to reply!")
|
||||
|
||||
if at_sender and params["message_type"] != "private":
|
||||
params["message"] = MessageSegment.at(params["user_id"]) + \
|
||||
MessageSegment.text(" ") + msg
|
||||
else:
|
||||
params["message"] = msg
|
||||
return await self.send_msg(**params)
|
||||
|
||||
|
||||
class Event(BaseEvent):
|
||||
|
||||
@ -232,11 +278,29 @@ class Event(BaseEvent):
|
||||
|
||||
super().__init__(raw_event)
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def id(self) -> Optional[int]:
|
||||
return self._raw_event.get("message_id") or self._raw_event.get("flag")
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def name(self) -> str:
|
||||
n = self.type + "." + self.detail_type
|
||||
if self.sub_type:
|
||||
n += "." + self.sub_type
|
||||
return n
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def self_id(self) -> str:
|
||||
return str(self._raw_event["self_id"])
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def time(self) -> int:
|
||||
return self._raw_event["time"]
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def type(self) -> str:
|
||||
@ -277,6 +341,16 @@ class Event(BaseEvent):
|
||||
def user_id(self, value) -> None:
|
||||
self._raw_event["user_id"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def group_id(self) -> Optional[int]:
|
||||
return self._raw_event.get("group_id")
|
||||
|
||||
@group_id.setter
|
||||
@overrides(BaseEvent)
|
||||
def group_id(self, value) -> None:
|
||||
self._raw_event["group_id"] = value
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def to_me(self) -> Optional[bool]:
|
||||
@ -326,14 +400,8 @@ class Event(BaseEvent):
|
||||
class MessageSegment(BaseMessageSegment):
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __init__(self, type: str, data: Dict[str, str]) -> None:
|
||||
if type == "at" and data.get("qq") == "all":
|
||||
type = "at_all"
|
||||
data.clear()
|
||||
elif type == "shake":
|
||||
type = "poke"
|
||||
data = {"type": "Poke"}
|
||||
elif type == "text":
|
||||
def __init__(self, type: str, data: Dict[str, Union[str, list]]) -> None:
|
||||
if type == "text":
|
||||
data["text"] = unescape(data["text"])
|
||||
super().__init__(type=type, data=data)
|
||||
|
||||
@ -343,16 +411,11 @@ class MessageSegment(BaseMessageSegment):
|
||||
data = self.data.copy()
|
||||
|
||||
# process special types
|
||||
if type_ == "at_all":
|
||||
type_ = "at"
|
||||
data = {"qq": "all"}
|
||||
elif type_ == "poke":
|
||||
type_ = "shake"
|
||||
data.clear()
|
||||
elif type_ == "text":
|
||||
if type_ == "text":
|
||||
return escape(data.get("text", ""), escape_comma=False)
|
||||
|
||||
params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()])
|
||||
params = ",".join(
|
||||
[f"{k}={escape(str(v))}" for k, v in data.items() if v is not None])
|
||||
return f"[CQ:{type_}{',' if params else ''}{params}]"
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
@ -360,17 +423,13 @@ class MessageSegment(BaseMessageSegment):
|
||||
return Message(self) + other
|
||||
|
||||
@staticmethod
|
||||
def anonymous(ignore_failure: bool = False) -> "MessageSegment":
|
||||
def anonymous(ignore_failure: Optional[bool] = None) -> "MessageSegment":
|
||||
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
|
||||
|
||||
@staticmethod
|
||||
def at(user_id: Union[int, str]) -> "MessageSegment":
|
||||
return MessageSegment("at", {"qq": str(user_id)})
|
||||
|
||||
@staticmethod
|
||||
def at_all() -> "MessageSegment":
|
||||
return MessageSegment("at_all")
|
||||
|
||||
@staticmethod
|
||||
def contact_group(group_id: int) -> "MessageSegment":
|
||||
return MessageSegment("contact", {"type": "group", "id": str(group_id)})
|
||||
@ -379,23 +438,43 @@ class MessageSegment(BaseMessageSegment):
|
||||
def contact_user(user_id: int) -> "MessageSegment":
|
||||
return MessageSegment("contact", {"type": "qq", "id": str(user_id)})
|
||||
|
||||
@staticmethod
|
||||
def dice() -> "MessageSegment":
|
||||
return MessageSegment("dice", {})
|
||||
|
||||
@staticmethod
|
||||
def face(id_: int) -> "MessageSegment":
|
||||
return MessageSegment("face", {"id": str(id_)})
|
||||
|
||||
@staticmethod
|
||||
def forward(id_: str) -> "MessageSegment":
|
||||
logger.warning("Forward Message only can be received!")
|
||||
return MessageSegment("forward", {"id": id_})
|
||||
|
||||
@staticmethod
|
||||
def image(file: str) -> "MessageSegment":
|
||||
return MessageSegment("image", {"file": file})
|
||||
def image(file: str,
|
||||
type_: Optional[str] = None,
|
||||
cache: bool = True,
|
||||
proxy: bool = True,
|
||||
timeout: Optional[int] = None) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"image", {
|
||||
"file": file,
|
||||
"type": type_,
|
||||
"cache": cache,
|
||||
"proxy": proxy,
|
||||
"timeout": timeout
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def json(data: str) -> "MessageSegment":
|
||||
return MessageSegment("json", {"data": data})
|
||||
|
||||
@staticmethod
|
||||
def location(latitude: float,
|
||||
longitude: float,
|
||||
title: str = "",
|
||||
content: str = "") -> "MessageSegment":
|
||||
title: Optional[str] = None,
|
||||
content: Optional[str] = None) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"location", {
|
||||
"lat": str(latitude),
|
||||
@ -405,36 +484,18 @@ class MessageSegment(BaseMessageSegment):
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def magic_face(type_: str) -> "MessageSegment":
|
||||
if type_ not in ["dice", "rpc"]:
|
||||
raise ValueError(
|
||||
f"Coolq doesn't support magic face type {type_}. Supported types: dice, rpc."
|
||||
)
|
||||
return MessageSegment("magic_face", {"type": type_})
|
||||
def music(type_: str, id_: int) -> "MessageSegment":
|
||||
return MessageSegment("music", {"type": type_, "id": id_})
|
||||
|
||||
@staticmethod
|
||||
def music(type_: str,
|
||||
id_: int,
|
||||
style: Optional[int] = None) -> "MessageSegment":
|
||||
if style is None:
|
||||
return MessageSegment("music", {"type": type_, "id": id_})
|
||||
else:
|
||||
return MessageSegment("music", {
|
||||
"type": type_,
|
||||
"id": id_,
|
||||
"style": style
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def music_custom(type_: str,
|
||||
url: str,
|
||||
def music_custom(url: str,
|
||||
audio: str,
|
||||
title: str,
|
||||
content: str = "",
|
||||
img_url: str = "") -> "MessageSegment":
|
||||
content: Optional[str] = None,
|
||||
img_url: Optional[str] = None) -> "MessageSegment":
|
||||
return MessageSegment(
|
||||
"music", {
|
||||
"type": type_,
|
||||
"type": "custom",
|
||||
"url": url,
|
||||
"audio": audio,
|
||||
"title": title,
|
||||
@ -447,35 +508,43 @@ class MessageSegment(BaseMessageSegment):
|
||||
return MessageSegment("node", {"id": str(id_)})
|
||||
|
||||
@staticmethod
|
||||
def node_custom(name: str, uin: int,
|
||||
content: "Message") -> "MessageSegment":
|
||||
def node_custom(user_id: int, nickname: str,
|
||||
content: Union[str, "Message"]) -> "MessageSegment":
|
||||
return MessageSegment("node", {
|
||||
"name": name,
|
||||
"uin": str(uin),
|
||||
"content": str(content)
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"content": content
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def poke(type_: str = "Poke") -> "MessageSegment":
|
||||
if type_ not in ["Poke"]:
|
||||
raise ValueError(
|
||||
f"Coolq doesn't support poke type {type_}. Supported types: Poke."
|
||||
)
|
||||
return MessageSegment("poke", {"type": type_})
|
||||
def poke(type_: str, id_: str) -> "MessageSegment":
|
||||
return MessageSegment("poke", {"type": type_, "id": id_})
|
||||
|
||||
@staticmethod
|
||||
def record(file: str, magic: bool = False) -> "MessageSegment":
|
||||
def record(file: str,
|
||||
magic: Optional[bool] = None,
|
||||
cache: Optional[bool] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
timeout: Optional[int] = None) -> "MessageSegment":
|
||||
return MessageSegment("record", {"file": file, "magic": _b2s(magic)})
|
||||
|
||||
@staticmethod
|
||||
def replay(id_: int) -> "MessageSegment":
|
||||
return MessageSegment("replay", {"id": str(id_)})
|
||||
def reply(id_: int) -> "MessageSegment":
|
||||
return MessageSegment("reply", {"id": str(id_)})
|
||||
|
||||
@staticmethod
|
||||
def rps() -> "MessageSegment":
|
||||
return MessageSegment("rps", {})
|
||||
|
||||
@staticmethod
|
||||
def shake() -> "MessageSegment":
|
||||
return MessageSegment("shake", {})
|
||||
|
||||
@staticmethod
|
||||
def share(url: str = "",
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
img_url: str = "") -> "MessageSegment":
|
||||
content: Optional[str] = None,
|
||||
img_url: Optional[str] = None) -> "MessageSegment":
|
||||
return MessageSegment("share", {
|
||||
"url": url,
|
||||
"title": title,
|
||||
@ -487,6 +556,22 @@ class MessageSegment(BaseMessageSegment):
|
||||
def text(text: str) -> "MessageSegment":
|
||||
return MessageSegment("text", {"text": text})
|
||||
|
||||
@staticmethod
|
||||
def video(file: str,
|
||||
cache: Optional[bool] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
timeout: Optional[int] = None) -> "MessageSegment":
|
||||
return MessageSegment("video", {
|
||||
"file": file,
|
||||
"cache": cache,
|
||||
"proxy": proxy,
|
||||
"timeout": timeout
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def xml(data: str) -> "MessageSegment":
|
||||
return MessageSegment("xml", {"data": data})
|
||||
|
||||
|
||||
class Message(BaseMessage):
|
||||
|
||||
@ -501,7 +586,7 @@ class Message(BaseMessage):
|
||||
yield MessageSegment(seg["type"], seg.get("data") or {})
|
||||
return
|
||||
|
||||
def _iter_message() -> Iterable[Tuple[str, str]]:
|
||||
def _iter_message(msg: str) -> Iterable[Tuple[str, str]]:
|
||||
text_begin = 0
|
||||
for cqcode in re.finditer(
|
||||
r"\[CQ:(?P<type>[a-zA-Z0-9-_.]+)"
|
||||
@ -514,7 +599,7 @@ class Message(BaseMessage):
|
||||
yield cqcode.group("type"), cqcode.group("params").lstrip(",")
|
||||
yield "text", unescape(msg[text_begin:])
|
||||
|
||||
for type_, data in _iter_message():
|
||||
for type_, data in _iter_message(msg):
|
||||
if type_ == "text":
|
||||
if data:
|
||||
# only yield non-empty text segment
|
||||
@ -526,13 +611,4 @@ class Message(BaseMessage):
|
||||
filter(lambda x: x, (
|
||||
x.lstrip() for x in data.split(","))))
|
||||
}
|
||||
if type_ == "at" and data["qq"] == "all":
|
||||
type_ = "at_all"
|
||||
data.clear()
|
||||
elif type_ in ["dice", "rpc"]:
|
||||
type_ = "magic_face"
|
||||
data["type"] = type_
|
||||
elif type_ == "shake":
|
||||
type_ = "poke"
|
||||
data["type"] = "Poke"
|
||||
yield MessageSegment(type_, data)
|
||||
|
@ -1,19 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import Response
|
||||
from fastapi import Body, status, Header, FastAPI, Depends, HTTPException
|
||||
from starlette.websockets import WebSocketDisconnect, WebSocket as FastAPIWebSocket
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||
from nonebot.typing import Union, Optional, Callable, overrides
|
||||
from nonebot.typing import Optional, Callable, overrides
|
||||
|
||||
|
||||
def get_auth_bearer(access_token: Optional[str] = Header(
|
||||
@ -116,28 +117,50 @@ class Driver(BaseDriver):
|
||||
**kwargs)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_http(
|
||||
self,
|
||||
adapter: str,
|
||||
data: dict = Body(...),
|
||||
x_self_id: str = Header(None),
|
||||
access_token: Optional[str] = Depends(get_auth_bearer)):
|
||||
secret = self.config.secret
|
||||
if secret is not None and secret != access_token:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"})
|
||||
async def _handle_http(self,
|
||||
adapter: str,
|
||||
data: dict = Body(...),
|
||||
x_self_id: Optional[str] = Header(None),
|
||||
x_signature: Optional[str] = Header(None)):
|
||||
# 检查self_id
|
||||
if not x_self_id:
|
||||
logger.warning("Missing X-Self-ID Header")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing X-Self-ID Header")
|
||||
|
||||
# Create Bot Object
|
||||
# 检查签名
|
||||
secret = self.config.secret
|
||||
if secret:
|
||||
if not x_signature:
|
||||
logger.warning("Missing Signature Header")
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Signature")
|
||||
sig = hmac.new(secret.encode("utf-8"),
|
||||
json.dumps(data).encode(), "sha1").hexdigest()
|
||||
if x_signature != "sha1=" + sig:
|
||||
logger.warning("Signature Header is invalid")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Signature is invalid")
|
||||
|
||||
if not isinstance(data, dict):
|
||||
logger.warning("Data received is invalid")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.warning("There's already a reverse websocket api connection,"
|
||||
"so the event may be handled twice.")
|
||||
|
||||
# 创建 Bot 对象
|
||||
if adapter in self._adapters:
|
||||
BotClass = self._adapters[adapter]
|
||||
bot = BotClass(self, "http", self.config, x_self_id)
|
||||
else:
|
||||
logger.warning("Unknown adapter")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="adapter not found")
|
||||
|
||||
await bot.handle_message(data)
|
||||
return {"status": 200, "message": "success"}
|
||||
return Response("", 204)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
async def _handle_ws_reverse(
|
||||
@ -146,19 +169,21 @@ class Driver(BaseDriver):
|
||||
websocket: FastAPIWebSocket,
|
||||
x_self_id: str = Header(None),
|
||||
access_token: Optional[str] = Depends(get_auth_bearer)):
|
||||
ws = WebSocket(websocket)
|
||||
|
||||
secret = self.config.secret
|
||||
if secret is not None and secret != access_token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
websocket = WebSocket(websocket)
|
||||
logger.warning("Authorization Header is invalid"
|
||||
if access_token else "Missing Authorization Header")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
if not x_self_id:
|
||||
logger.error(f"Error Connection Unkown: self_id {x_self_id}")
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
logger.warning(f"Missing X-Self-ID Header")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
if x_self_id in self._clients:
|
||||
logger.error(f"Error Connection Conflict: self_id {x_self_id}")
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
logger.warning(f"Connection Conflict: self_id {x_self_id}")
|
||||
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
# Create Bot Object
|
||||
if adapter in self._adapters:
|
||||
@ -167,17 +192,18 @@ class Driver(BaseDriver):
|
||||
"websocket",
|
||||
self.config,
|
||||
x_self_id,
|
||||
websocket=websocket)
|
||||
websocket=ws)
|
||||
else:
|
||||
logger.warning("Unknown adapter")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="adapter not found")
|
||||
|
||||
await websocket.accept()
|
||||
await ws.accept()
|
||||
self._clients[x_self_id] = bot
|
||||
|
||||
try:
|
||||
while not websocket.closed:
|
||||
data = await websocket.receive()
|
||||
while not ws.closed:
|
||||
data = await ws.receive()
|
||||
|
||||
if not data:
|
||||
continue
|
||||
@ -213,8 +239,11 @@ class WebSocket(BaseWebSocket):
|
||||
data = None
|
||||
try:
|
||||
data = await self.websocket.receive_json()
|
||||
if not isinstance(data, dict):
|
||||
data = None
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
logger.debug("Received an invalid json message.")
|
||||
logger.warning("Received an invalid json message.")
|
||||
except WebSocketDisconnect:
|
||||
self._closed = True
|
||||
logger.error("WebSocket disconnected by peer.")
|
||||
|
@ -6,14 +6,17 @@ import inspect
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from contextvars import Context, ContextVar, copy_context
|
||||
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.permission import Permission, USER
|
||||
from nonebot.typing import Bot, Event, Handler, ArgsParser
|
||||
from nonebot.typing import Type, List, Dict, Callable, Optional, NoReturn
|
||||
from nonebot.typing import Type, List, Dict, Union, Callable, Optional, NoReturn
|
||||
from nonebot.typing import Bot, Event, Handler, Message, ArgsParser, MessageSegment
|
||||
from nonebot.exception import PausedException, RejectedException, FinishedException
|
||||
|
||||
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
|
||||
current_bot: ContextVar = ContextVar("current_bot")
|
||||
current_event: ContextVar = ContextVar("current_event")
|
||||
|
||||
|
||||
class Matcher:
|
||||
@ -51,12 +54,12 @@ class Matcher:
|
||||
type_: str = "",
|
||||
rule: Rule = Rule(),
|
||||
permission: Permission = Permission(),
|
||||
handlers: list = [],
|
||||
handlers: Optional[list] = None,
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
*,
|
||||
default_state: dict = {},
|
||||
default_state: Optional[dict] = None,
|
||||
expire_time: Optional[datetime] = None) -> Type["Matcher"]:
|
||||
"""创建新的 Matcher
|
||||
|
||||
@ -69,12 +72,12 @@ class Matcher:
|
||||
"type": type_,
|
||||
"rule": rule,
|
||||
"permission": permission,
|
||||
"handlers": handlers,
|
||||
"handlers": handlers or [],
|
||||
"temp": temp,
|
||||
"expire_time": expire_time,
|
||||
"priority": priority,
|
||||
"block": block,
|
||||
"_default_state": default_state
|
||||
"_default_state": default_state or {}
|
||||
})
|
||||
|
||||
matchers[priority].append(NewMatcher)
|
||||
@ -117,12 +120,12 @@ class Matcher:
|
||||
def receive(cls) -> Callable[[Handler], Handler]:
|
||||
"""接收一条新消息并处理"""
|
||||
|
||||
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn:
|
||||
async def _receive(bot: Bot, event: Event, state: dict) -> NoReturn:
|
||||
raise PausedException
|
||||
|
||||
if cls.handlers:
|
||||
# 已有前置handlers则接受一条新的消息,否则视为接收初始消息
|
||||
cls.handlers.append(_handler)
|
||||
cls.handlers.append(_receive)
|
||||
|
||||
def _decorator(func: Handler) -> Handler:
|
||||
if not cls.handlers or cls.handlers[-1] is not func:
|
||||
@ -144,8 +147,7 @@ class Matcher:
|
||||
if key not in state:
|
||||
state["_current_key"] = key
|
||||
if prompt:
|
||||
await bot.send_private_msg(user_id=event.user_id,
|
||||
message=prompt)
|
||||
await bot.send(event=event, message=prompt)
|
||||
raise PausedException
|
||||
|
||||
async def _key_parser(bot: Bot, event: Event, state: dict):
|
||||
@ -176,19 +178,42 @@ class Matcher:
|
||||
return _decorator
|
||||
|
||||
@classmethod
|
||||
def finish(cls) -> NoReturn:
|
||||
async def finish(
|
||||
cls,
|
||||
prompt: Optional[Union[str, Message,
|
||||
MessageSegment]] = None) -> NoReturn:
|
||||
bot: Bot = current_bot.get()
|
||||
event: Event = current_event.get()
|
||||
if prompt:
|
||||
await bot.send(event=event, message=prompt)
|
||||
raise FinishedException
|
||||
|
||||
@classmethod
|
||||
def pause(cls) -> NoReturn:
|
||||
async def pause(
|
||||
cls,
|
||||
prompt: Optional[Union[str, Message,
|
||||
MessageSegment]] = None) -> NoReturn:
|
||||
bot: Bot = current_bot.get()
|
||||
event: Event = current_event.get()
|
||||
if prompt:
|
||||
await bot.send(event=event, message=prompt)
|
||||
raise PausedException
|
||||
|
||||
@classmethod
|
||||
def reject(cls) -> NoReturn:
|
||||
async def reject(
|
||||
cls,
|
||||
prompt: Optional[Union[str, Message,
|
||||
MessageSegment]] = None) -> NoReturn:
|
||||
bot: Bot = current_bot.get()
|
||||
event: Event = current_event.get()
|
||||
if prompt:
|
||||
await bot.send(event=event, message=prompt)
|
||||
raise RejectedException
|
||||
|
||||
# 运行handlers
|
||||
async def run(self, bot: Bot, event: Event, state: dict):
|
||||
b_t = current_bot.set(bot)
|
||||
e_t = current_event.set(event)
|
||||
try:
|
||||
# Refresh preprocess state
|
||||
self.state.update(state)
|
||||
@ -214,7 +239,6 @@ class Matcher:
|
||||
block=True,
|
||||
default_state=self.state,
|
||||
expire_time=datetime.now() + bot.config.session_expire_timeout)
|
||||
return
|
||||
except PausedException:
|
||||
Matcher.new(
|
||||
self.type,
|
||||
@ -226,6 +250,8 @@ class Matcher:
|
||||
block=True,
|
||||
default_state=self.state,
|
||||
expire_time=datetime.now() + bot.config.session_expire_timeout)
|
||||
return
|
||||
except FinishedException:
|
||||
return
|
||||
pass
|
||||
finally:
|
||||
current_bot.reset(b_t)
|
||||
current_event.reset(e_t)
|
||||
|
@ -53,15 +53,31 @@ async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event,
|
||||
|
||||
|
||||
async def handle_event(bot: Bot, event: Event):
|
||||
log_msg = f"{bot.type.upper()} Bot {event.self_id} [{event.name}]: "
|
||||
if event.type == "message":
|
||||
log_msg += f"Message {event.id} from "
|
||||
log_msg += str(event.user_id)
|
||||
if event.detail_type == "group":
|
||||
log_msg += f"@[群:{event.group_id}]: "
|
||||
log_msg += repr(str(event.message))
|
||||
elif event.type == "notice":
|
||||
log_msg += f"Notice {event.raw_event}"
|
||||
elif event.type == "request":
|
||||
log_msg += f"Request {event.raw_event}"
|
||||
elif event.type == "meta_event":
|
||||
log_msg += f"MetaEvent {event.raw_event}"
|
||||
logger.info(log_msg)
|
||||
|
||||
coros = []
|
||||
state = {}
|
||||
for preprocessor in _event_preprocessors:
|
||||
coros.append(preprocessor(bot, event, state))
|
||||
if coros:
|
||||
try:
|
||||
logger.debug("Running PreProcessors...")
|
||||
await asyncio.gather(*coros)
|
||||
except IgnoredException:
|
||||
logger.info(f"Event {event} is ignored")
|
||||
logger.info(f"Event {event.name} is ignored")
|
||||
return
|
||||
|
||||
# Trie Match
|
||||
@ -77,6 +93,7 @@ async def handle_event(bot: Bot, event: Event):
|
||||
for matcher in matchers[priority]
|
||||
]
|
||||
|
||||
logger.debug(f"Checking for all matchers in priority {priority}...")
|
||||
results = await asyncio.gather(*pending_tasks, return_exceptions=True)
|
||||
|
||||
i = 0
|
||||
@ -85,6 +102,7 @@ async def handle_event(bot: Bot, event: Event):
|
||||
e_list = result.exceptions
|
||||
if StopPropagation in e_list:
|
||||
break_flag = True
|
||||
logger.debug("Stop event propafation")
|
||||
if ExpiredException in e_list:
|
||||
del matchers[priority][index - i]
|
||||
i += 1
|
||||
|
@ -31,11 +31,11 @@ class Plugin(object):
|
||||
def on(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
permission: Permission = Permission(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
handlers: Optional[list] = None,
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
state: Optional[dict] = None) -> Type[Matcher]:
|
||||
matcher = Matcher.new("",
|
||||
Rule() & rule,
|
||||
permission,
|
||||
@ -50,11 +50,11 @@ def on(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
|
||||
def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
handlers: Optional[list] = None,
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
state: Optional[dict] = None) -> Type[Matcher]:
|
||||
matcher = Matcher.new("meta_event",
|
||||
Rule() & rule,
|
||||
Permission(),
|
||||
@ -70,11 +70,11 @@ def on_metaevent(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
def on_message(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
permission: Permission = Permission(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
handlers: Optional[list] = None,
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = True,
|
||||
state={}) -> Type[Matcher]:
|
||||
state: Optional[dict] = None) -> Type[Matcher]:
|
||||
matcher = Matcher.new("message",
|
||||
Rule() & rule,
|
||||
permission,
|
||||
@ -89,11 +89,11 @@ def on_message(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
|
||||
def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
handlers: Optional[list] = None,
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
state: Optional[dict] = None) -> Type[Matcher]:
|
||||
matcher = Matcher.new("notice",
|
||||
Rule() & rule,
|
||||
Permission(),
|
||||
@ -108,11 +108,11 @@ def on_notice(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
|
||||
def on_request(rule: Union[Rule, RuleChecker] = Rule(),
|
||||
*,
|
||||
handlers=[],
|
||||
temp=False,
|
||||
handlers: Optional[list] = None,
|
||||
temp: bool = False,
|
||||
priority: int = 1,
|
||||
block: bool = False,
|
||||
state={}) -> Type[Matcher]:
|
||||
state: Optional[dict] = None) -> Type[Matcher]:
|
||||
matcher = Matcher.new("request",
|
||||
Rule() & rule,
|
||||
Permission(),
|
||||
@ -143,7 +143,7 @@ def on_endswith(msg: str,
|
||||
startswith(msg), permission, **kwargs)
|
||||
|
||||
|
||||
def on_command(cmd: Union[str, Tuple[str]],
|
||||
def on_command(cmd: Union[str, Tuple[str, ...]],
|
||||
rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
permission: Permission = Permission(),
|
||||
**kwargs) -> Type[Matcher]:
|
||||
|
@ -110,7 +110,7 @@ def keyword(msg: str) -> Rule:
|
||||
return Rule(_keyword)
|
||||
|
||||
|
||||
def command(command: Tuple[str]) -> Rule:
|
||||
def command(command: Tuple[str, ...]) -> Rule:
|
||||
config = get_driver().config
|
||||
command_start = config.command_start
|
||||
command_sep = config.command_sep
|
||||
|
11
package.json
11
package.json
@ -1,19 +1,16 @@
|
||||
{
|
||||
"name": "nonebot",
|
||||
"version": "2.0.0",
|
||||
"description": "An asynchronous QQ bot framework.",
|
||||
"homepage": "https://nonebot.cqp.moe/",
|
||||
"description": "An asynchronous python bot framework.",
|
||||
"homepage": "https://docs.nonebot.dev/",
|
||||
"main": "index.js",
|
||||
"contributors": [{
|
||||
"name": "Richard Chien",
|
||||
"email": "richardchienthebest@gmail.com"
|
||||
},
|
||||
"contributors": [
|
||||
{
|
||||
"name": "yanyongyu",
|
||||
"email": "yanyongyu_1@126.com"
|
||||
}
|
||||
],
|
||||
"repository": "https://github.com/nonebot/nonebot/nonebot",
|
||||
"repository": "https://github.com/nonebot/nonebot/",
|
||||
"bugs": {
|
||||
"url": "https://github.com/nonebot/nonebot/issues"
|
||||
},
|
||||
|
@ -1,13 +1,13 @@
|
||||
[tool.poetry]
|
||||
name = "nonebot"
|
||||
version = "2.0.0"
|
||||
description = "An asynchronous QQ bot framework."
|
||||
authors = ["Richard Chien <richardchienthebest@gmail.com>", "yanyongyu <yanyongyu_1@126.com>"]
|
||||
description = "An asynchronous python bot framework."
|
||||
authors = ["yanyongyu <yanyongyu_1@126.com>"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
homepage = "https://nonebot.cqp.moe/"
|
||||
homepage = "https://docs.nonebot.dev/"
|
||||
repository = "https://github.com/nonebot/nonebot"
|
||||
documentation = "https://nonebot.cqp.moe/"
|
||||
documentation = "https://docs.nonebot.dev/"
|
||||
keywords = ["bot", "qq", "qqbot", "mirai", "coolq"]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
|
@ -24,6 +24,5 @@ async def test_handler(bot: Bot, event: Event, state: dict):
|
||||
async def test_handler(bot: Bot, event: Event, state: dict):
|
||||
print("[!] Command 帮助:", state["help"])
|
||||
if state["help"] not in ["test1", "test2"]:
|
||||
await bot.send_private_msg(message=f"{state['help']} 不支持,请重新输入!")
|
||||
test_command.reject()
|
||||
await test_command.reject(f"{state['help']} 不支持,请重新输入!")
|
||||
await bot.send_private_msg(message=f"{state['help']} 帮助:\n...")
|
||||
|
Loading…
Reference in New Issue
Block a user