merge changes

This commit is contained in:
yanyongyu 2020-08-06 17:22:56 +08:00
commit f1e62feb26
11 changed files with 207 additions and 47 deletions

View File

@ -11,7 +11,12 @@ from nonebot.config import Config
class BaseBot(abc.ABC):
@abc.abstractmethod
def __init__(self, type: str, config: Config, *, websocket=None):
def __init__(self,
type: str,
config: Config,
self_id: int,
*,
websocket=None):
raise NotImplementedError
@abc.abstractmethod

View File

@ -10,6 +10,7 @@ from nonebot.event import Event
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.drivers import BaseWebSocket
from nonebot.exception import ApiNotAvailable
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
@ -44,6 +45,7 @@ class Bot(BaseBot):
def __init__(self,
connection_type: str,
config: Config,
self_id: int,
*,
websocket: BaseWebSocket = None):
if connection_type not in ["http", "websocket"]:
@ -51,6 +53,7 @@ class Bot(BaseBot):
self.type = "coolq"
self.connection_type = connection_type
self.config = config
self.self_id = self_id
self.websocket = websocket
async def handle_message(self, message: dict):
@ -63,18 +66,31 @@ class Bot(BaseBot):
if "message" in event.keys():
event["message"] = Message(event["message"])
# TODO: Handle Meta Event
if event.type == "meta_event":
pass
else:
await handle_event(self, event)
await handle_event(self, event)
async def call_api(self, api: str, data: dict):
# TODO: Call API
if self.type == "websocket":
pass
elif self.type == "http":
pass
api_root = self.config.api_root.get(self.self_id)
if not api_root:
raise ApiNotAvailable
elif not api_root.endswith("/"):
api_root += "/"
headers = {}
if self.config.access_token:
headers["Authorization"] = "Bearer " + self.config.access_token
async with httpx.AsyncClient() as client:
response = await client.post(api_root + api)
if 200 <= response.status_code < 300:
# TODO: handle http api response
return ...
raise httpx.HTTPError(
"<HttpFailed {0.status_code} for url: {0.url}>", response)
class MessageSegment(BaseMessageSegment):

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Set, Union
from ipaddress import IPv4Address
from typing import Set, Dict, Union, Optional
from pydantic import BaseSettings
@ -15,14 +15,22 @@ class Env(BaseSettings):
class Config(BaseSettings):
# nonebot configs
driver: str = "nonebot.drivers.fastapi"
host: IPv4Address = IPv4Address("127.0.0.1")
port: int = 8080
secret: Optional[str] = None
debug: bool = False
# bot connection configs
api_root: Dict[int, str] = {}
access_token: Optional[str] = None
# bot runtime configs
superusers: Set[int] = set()
nickname: Union[str, Set[str]] = ""
# custom configs
custom_config: dict = {}
class Config:

View File

@ -45,10 +45,6 @@ class BaseDriver(abc.ABC):
async def _handle_ws_reverse(self):
raise NotImplementedError
@abc.abstractmethod
async def _handle_http_api(self):
raise NotImplementedError
class BaseWebSocket(object):
@ -71,7 +67,7 @@ class BaseWebSocket(object):
raise NotImplementedError
@abc.abstractmethod
async def close(self):
async def close(self, code: int):
raise NotImplementedError
@abc.abstractmethod

View File

@ -3,18 +3,19 @@
import json
import logging
from typing import Optional
from typing import Dict, Optional
from ipaddress import IPv4Address
import uvicorn
from fastapi.security import OAuth2PasswordBearer
from starlette.websockets import WebSocketDisconnect
from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket
from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Config
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.adapters import BaseBot
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.drivers import BaseDriver, BaseWebSocket
class Driver(BaseDriver):
@ -28,6 +29,7 @@ class Driver(BaseDriver):
)
self.config = config
self._clients: Dict[int, BaseBot] = {}
self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http)
@ -43,9 +45,13 @@ class Driver(BaseDriver):
return self._server_app
@property
def logger(self):
def logger(self) -> logging.Logger:
return logging.getLogger("fastapi")
@property
def bots(self) -> Dict[int, BaseBot]:
return self._clients
def run(self,
host: Optional[IPv4Address] = None,
port: Optional[int] = None,
@ -102,12 +108,22 @@ class Driver(BaseDriver):
async def _handle_ws_reverse(self,
adapter: str,
websocket: FastAPIWebSocket,
self_id: int = Header(None),
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
websocket = WebSocket(websocket)
# TODO: Check authorization
# Create Bot Object
if adapter == "coolq":
bot = CQBot("websocket", self.config, self_id, websocket=websocket)
else:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
return
await websocket.accept()
self._clients[self_id] = bot
while not websocket.closed:
data = await websocket.receive()
@ -115,10 +131,9 @@ class Driver(BaseDriver):
if not data:
continue
logger.debug(f"Received message: {data}")
if adapter == "cqhttp":
bot = CQBot("websocket", self.config, websocket=websocket)
await bot.handle_message(data)
await bot.handle_message(data)
del self._clients[self_id]
class WebSocket(BaseWebSocket):
@ -135,8 +150,8 @@ class WebSocket(BaseWebSocket):
await self.websocket.accept()
self._closed = False
async def close(self):
await self.websocket.close()
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code)
self._closed = True
async def receive(self) -> Optional[dict]:

View File

@ -2,6 +2,19 @@
# -*- coding: utf-8 -*-
class IgnoredException(Exception):
"""
Raised by event_preprocessor indicating that
the bot should ignore the event
"""
def __init__(self, reason):
"""
:param reason: reason to ignore the event
"""
self.reason = reason
class PausedException(Exception):
"""Block a message from further handling and try to receive a new message"""
pass
@ -15,3 +28,8 @@ class RejectedException(Exception):
class FinishedException(Exception):
"""Finish handling a message"""
pass
class ApiNotAvailable(Exception):
"""Api is not available"""
pass

View File

@ -62,7 +62,7 @@ class Matcher:
return NewMatcher
@classmethod
def check_rule(cls, event: Event) -> bool:
def check_rule(cls, bot, event: Event) -> bool:
"""检查 Matcher 的 Rule 是否成立
Args:
@ -71,7 +71,7 @@ class Matcher:
Returns:
bool: 条件成立与否
"""
return cls.rule(event)
return cls.rule(bot, event)
# @classmethod
# def args_parser(cls, func: Callable[[Event, dict], None]):
@ -141,9 +141,6 @@ class Matcher:
# 运行handlers
async def run(self, bot, event):
if not self.rule(event):
return
try:
# if self.parser:
# await self.parser(event, state) # type: ignore

View File

@ -1,19 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
from typing import Set, Callable
from nonebot.log import logger
from nonebot.event import Event
from nonebot.matcher import matchers
from nonebot.exception import IgnoredException
_event_preprocessors: Set[Callable] = set()
def event_preprocessor(func: Callable) -> Callable:
_event_preprocessors.add(func)
return func
async def handle_event(bot, event: Event):
# TODO: PreProcess
coros = []
for preprocessor in _event_preprocessors:
coros.append(preprocessor(bot, event))
if coros:
try:
await asyncio.gather(*coros)
except IgnoredException:
logger.info(f"Event {event} is ignored")
return
for priority in sorted(matchers.keys()):
for index in range(len(matchers[priority])):
Matcher = matchers[priority][index]
try:
if not Matcher.check_rule(event):
if not Matcher.check_rule(bot, event):
continue
except Exception as e:
logger.error(

View File

@ -7,9 +7,9 @@ import importlib
from types import ModuleType
from typing import Set, Dict, Type, Optional
from nonebot.rule import Rule
from nonebot.log import logger
from nonebot.matcher import Matcher
from nonebot.rule import Rule, metaevent, message, notice, request
plugins: Dict[str, "Plugin"] = {}
@ -26,13 +26,58 @@ class Plugin(object):
self.matchers = matchers
def on_metaevent(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(metaevent() & rule,
temp=temp,
priority=priority,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
return matcher
def on_message(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(rule,
matcher = Matcher.new(message() & rule,
temp=temp,
priority=priority,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
return matcher
def on_notice(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(notice() & rule,
temp=temp,
priority=priority,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
return matcher
def on_request(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(request() & rule,
temp=temp,
priority=priority,
handlers=handlers,

View File

@ -9,52 +9,74 @@ from nonebot.event import Event
class Rule:
def __init__(self, checker: Optional[Callable[[Event], bool]] = None):
self.checker = checker or (lambda event: True)
def __init__(
self,
checker: Optional[Callable[["BaseBot", Event], # type: ignore
bool]] = None):
self.checker = checker or (lambda bot, event: True)
def __call__(self, event: Event) -> bool:
return self.checker(event)
def __call__(self, bot, event: Event) -> bool:
return self.checker(bot, event)
def __and__(self, other: "Rule") -> "Rule":
return Rule(lambda event: self.checker(event) and other.checker(event))
return Rule(lambda bot, event: self.checker(bot, event) and other.
checker(bot, event))
def __or__(self, other: "Rule") -> "Rule":
return Rule(lambda event: self.checker(event) or other.checker(event))
return Rule(lambda bot, event: self.checker(bot, event) or other.
checker(bot, event))
def __neg__(self) -> "Rule":
return Rule(lambda event: not self.checker(event))
return Rule(lambda bot, event: not self.checker(bot, event))
def message() -> Rule:
return Rule(lambda bot, event: event.type == "message")
def notice() -> Rule:
return Rule(lambda bot, event: event.type == "notice")
def request() -> Rule:
return Rule(lambda bot, event: event.type == "request")
def metaevent() -> Rule:
return Rule(lambda bot, event: event.type == "meta_event")
def user(*qq: int) -> Rule:
return Rule(lambda event: event.user_id in qq)
return Rule(lambda bot, event: event.user_id in qq)
def private() -> Rule:
return Rule(lambda event: event.detail_type == "private")
return Rule(lambda bot, event: event.detail_type == "private")
def group(*group: int) -> Rule:
return Rule(
lambda event: event.detail_type == "group" and event.group_id in group)
return Rule(lambda bot, event: event.detail_type == "group" and event.
group_id in group)
def discuss(*discuss: int) -> Rule:
return Rule(lambda event: event.detail_type == "discuss" and event.
return Rule(lambda bot, event: event.detail_type == "discuss" and event.
discuss_id in discuss)
def startswith(msg, start: int = None, end: int = None) -> Rule:
return Rule(lambda event: event.message.startswith(msg, start, end))
return Rule(lambda bot, event: event.message.startswith(msg, start, end))
def endswith(msg, start: int = None, end: int = None) -> Rule:
return Rule(lambda event: event.message.endswith(msg, start=None, end=None))
return Rule(
lambda bot, event: event.message.endswith(msg, start=None, end=None))
def has(msg: str) -> Rule:
return Rule(lambda event: msg in event.message)
return Rule(lambda bot, event: msg in event.message)
def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags)
return Rule(lambda event: bool(pattern.search(event.message)))
return Rule(lambda bot, event: bool(pattern.search(str(event.message))))

View File

@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from nonebot.rule import Rule
from nonebot.event import Event
from nonebot.plugin import on_metaevent
def heartbeat(bot, event: Event) -> bool:
return event.detail_type == "heartbeat"
test_matcher = on_metaevent(Rule(heartbeat))
@test_matcher.handle()
async def handle_heartbeat(bot, event: Event, state: dict):
print("[i] Heartbeat")