merge changes

This commit is contained in:
yanyongyu 2020-08-06 17:22:56 +08:00
commit f1e62feb26
11 changed files with 207 additions and 47 deletions

View File

@ -11,7 +11,12 @@ from nonebot.config import Config
class BaseBot(abc.ABC): class BaseBot(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __init__(self, type: str, config: Config, *, websocket=None): def __init__(self,
type: str,
config: Config,
self_id: int,
*,
websocket=None):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod

View File

@ -10,6 +10,7 @@ from nonebot.event import Event
from nonebot.config import Config from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.drivers import BaseWebSocket from nonebot.drivers import BaseWebSocket
from nonebot.exception import ApiNotAvailable
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
@ -44,6 +45,7 @@ class Bot(BaseBot):
def __init__(self, def __init__(self,
connection_type: str, connection_type: str,
config: Config, config: Config,
self_id: int,
*, *,
websocket: BaseWebSocket = None): websocket: BaseWebSocket = None):
if connection_type not in ["http", "websocket"]: if connection_type not in ["http", "websocket"]:
@ -51,6 +53,7 @@ class Bot(BaseBot):
self.type = "coolq" self.type = "coolq"
self.connection_type = connection_type self.connection_type = connection_type
self.config = config self.config = config
self.self_id = self_id
self.websocket = websocket self.websocket = websocket
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
@ -63,10 +66,6 @@ class Bot(BaseBot):
if "message" in event.keys(): if "message" in event.keys():
event["message"] = Message(event["message"]) event["message"] = Message(event["message"])
# TODO: Handle Meta Event
if event.type == "meta_event":
pass
else:
await handle_event(self, event) await handle_event(self, event)
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, data: dict):
@ -74,7 +73,24 @@ class Bot(BaseBot):
if self.type == "websocket": if self.type == "websocket":
pass pass
elif self.type == "http": elif self.type == "http":
pass api_root = self.config.api_root.get(self.self_id)
if not api_root:
raise ApiNotAvailable
elif not api_root.endswith("/"):
api_root += "/"
headers = {}
if self.config.access_token:
headers["Authorization"] = "Bearer " + self.config.access_token
async with httpx.AsyncClient() as client:
response = await client.post(api_root + api)
if 200 <= response.status_code < 300:
# TODO: handle http api response
return ...
raise httpx.HTTPError(
"<HttpFailed {0.status_code} for url: {0.url}>", response)
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Set, Union
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Set, Dict, Union, Optional
from pydantic import BaseSettings from pydantic import BaseSettings
@ -15,14 +15,22 @@ class Env(BaseSettings):
class Config(BaseSettings): class Config(BaseSettings):
# nonebot configs
driver: str = "nonebot.drivers.fastapi" driver: str = "nonebot.drivers.fastapi"
host: IPv4Address = IPv4Address("127.0.0.1") host: IPv4Address = IPv4Address("127.0.0.1")
port: int = 8080 port: int = 8080
secret: Optional[str] = None
debug: bool = False debug: bool = False
# bot connection configs
api_root: Dict[int, str] = {}
access_token: Optional[str] = None
# bot runtime configs
superusers: Set[int] = set() superusers: Set[int] = set()
nickname: Union[str, Set[str]] = "" nickname: Union[str, Set[str]] = ""
# custom configs
custom_config: dict = {} custom_config: dict = {}
class Config: class Config:

View File

@ -45,10 +45,6 @@ class BaseDriver(abc.ABC):
async def _handle_ws_reverse(self): async def _handle_ws_reverse(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
async def _handle_http_api(self):
raise NotImplementedError
class BaseWebSocket(object): class BaseWebSocket(object):
@ -71,7 +67,7 @@ class BaseWebSocket(object):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def close(self): async def close(self, code: int):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod

View File

@ -3,18 +3,19 @@
import json import json
import logging import logging
from typing import Optional from typing import Dict, Optional
from ipaddress import IPv4Address from ipaddress import IPv4Address
import uvicorn import uvicorn
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.adapters import BaseBot
from nonebot.adapters.cqhttp import Bot as CQBot from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.drivers import BaseDriver, BaseWebSocket
class Driver(BaseDriver): class Driver(BaseDriver):
@ -28,6 +29,7 @@ class Driver(BaseDriver):
) )
self.config = config self.config = config
self._clients: Dict[int, BaseBot] = {}
self._server_app.post("/{adapter}/")(self._handle_http) self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http) self._server_app.post("/{adapter}/http")(self._handle_http)
@ -43,9 +45,13 @@ class Driver(BaseDriver):
return self._server_app return self._server_app
@property @property
def logger(self): def logger(self) -> logging.Logger:
return logging.getLogger("fastapi") return logging.getLogger("fastapi")
@property
def bots(self) -> Dict[int, BaseBot]:
return self._clients
def run(self, def run(self,
host: Optional[IPv4Address] = None, host: Optional[IPv4Address] = None,
port: Optional[int] = None, port: Optional[int] = None,
@ -102,12 +108,22 @@ class Driver(BaseDriver):
async def _handle_ws_reverse(self, async def _handle_ws_reverse(self,
adapter: str, adapter: str,
websocket: FastAPIWebSocket, websocket: FastAPIWebSocket,
self_id: int = Header(None),
access_token: str = OAuth2PasswordBearer( access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)): "/", auto_error=False)):
websocket = WebSocket(websocket) websocket = WebSocket(websocket)
# TODO: Check authorization # TODO: Check authorization
# Create Bot Object
if adapter == "coolq":
bot = CQBot("websocket", self.config, self_id, websocket=websocket)
else:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
return
await websocket.accept() await websocket.accept()
self._clients[self_id] = bot
while not websocket.closed: while not websocket.closed:
data = await websocket.receive() data = await websocket.receive()
@ -115,11 +131,10 @@ class Driver(BaseDriver):
if not data: if not data:
continue continue
logger.debug(f"Received message: {data}")
if adapter == "cqhttp":
bot = CQBot("websocket", self.config, websocket=websocket)
await bot.handle_message(data) await bot.handle_message(data)
del self._clients[self_id]
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
@ -135,8 +150,8 @@ class WebSocket(BaseWebSocket):
await self.websocket.accept() await self.websocket.accept()
self._closed = False self._closed = False
async def close(self): async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close() await self.websocket.close(code=code)
self._closed = True self._closed = True
async def receive(self) -> Optional[dict]: async def receive(self) -> Optional[dict]:

View File

@ -2,6 +2,19 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class IgnoredException(Exception):
"""
Raised by event_preprocessor indicating that
the bot should ignore the event
"""
def __init__(self, reason):
"""
:param reason: reason to ignore the event
"""
self.reason = reason
class PausedException(Exception): class PausedException(Exception):
"""Block a message from further handling and try to receive a new message""" """Block a message from further handling and try to receive a new message"""
pass pass
@ -15,3 +28,8 @@ class RejectedException(Exception):
class FinishedException(Exception): class FinishedException(Exception):
"""Finish handling a message""" """Finish handling a message"""
pass pass
class ApiNotAvailable(Exception):
"""Api is not available"""
pass

View File

@ -62,7 +62,7 @@ class Matcher:
return NewMatcher return NewMatcher
@classmethod @classmethod
def check_rule(cls, event: Event) -> bool: def check_rule(cls, bot, event: Event) -> bool:
"""检查 Matcher 的 Rule 是否成立 """检查 Matcher 的 Rule 是否成立
Args: Args:
@ -71,7 +71,7 @@ class Matcher:
Returns: Returns:
bool: 条件成立与否 bool: 条件成立与否
""" """
return cls.rule(event) return cls.rule(bot, event)
# @classmethod # @classmethod
# def args_parser(cls, func: Callable[[Event, dict], None]): # def args_parser(cls, func: Callable[[Event, dict], None]):
@ -141,9 +141,6 @@ class Matcher:
# 运行handlers # 运行handlers
async def run(self, bot, event): async def run(self, bot, event):
if not self.rule(event):
return
try: try:
# if self.parser: # if self.parser:
# await self.parser(event, state) # type: ignore # await self.parser(event, state) # type: ignore

View File

@ -1,19 +1,39 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import asyncio
from typing import Set, Callable
from nonebot.log import logger from nonebot.log import logger
from nonebot.event import Event from nonebot.event import Event
from nonebot.matcher import matchers from nonebot.matcher import matchers
from nonebot.exception import IgnoredException
_event_preprocessors: Set[Callable] = set()
def event_preprocessor(func: Callable) -> Callable:
_event_preprocessors.add(func)
return func
async def handle_event(bot, event: Event): async def handle_event(bot, event: Event):
# TODO: PreProcess # TODO: PreProcess
coros = []
for preprocessor in _event_preprocessors:
coros.append(preprocessor(bot, event))
if coros:
try:
await asyncio.gather(*coros)
except IgnoredException:
logger.info(f"Event {event} is ignored")
return
for priority in sorted(matchers.keys()): for priority in sorted(matchers.keys()):
for index in range(len(matchers[priority])): for index in range(len(matchers[priority])):
Matcher = matchers[priority][index] Matcher = matchers[priority][index]
try: try:
if not Matcher.check_rule(event): if not Matcher.check_rule(bot, event):
continue continue
except Exception as e: except Exception as e:
logger.error( logger.error(

View File

@ -7,9 +7,9 @@ import importlib
from types import ModuleType from types import ModuleType
from typing import Set, Dict, Type, Optional from typing import Set, Dict, Type, Optional
from nonebot.rule import Rule
from nonebot.log import logger from nonebot.log import logger
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.rule import Rule, metaevent, message, notice, request
plugins: Dict[str, "Plugin"] = {} plugins: Dict[str, "Plugin"] = {}
@ -26,13 +26,58 @@ class Plugin(object):
self.matchers = matchers self.matchers = matchers
def on_metaevent(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(metaevent() & rule,
temp=temp,
priority=priority,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
return matcher
def on_message(rule: Rule, def on_message(rule: Rule,
*, *,
handlers=[], handlers=[],
temp=False, temp=False,
priority: int = 1, priority: int = 1,
state={}) -> Type[Matcher]: state={}) -> Type[Matcher]:
matcher = Matcher.new(rule, matcher = Matcher.new(message() & rule,
temp=temp,
priority=priority,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
return matcher
def on_notice(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(notice() & rule,
temp=temp,
priority=priority,
handlers=handlers,
default_state=state)
_tmp_matchers.add(matcher)
return matcher
def on_request(rule: Rule,
*,
handlers=[],
temp=False,
priority: int = 1,
state={}) -> Type[Matcher]:
matcher = Matcher.new(request() & rule,
temp=temp, temp=temp,
priority=priority, priority=priority,
handlers=handlers, handlers=handlers,

View File

@ -9,52 +9,74 @@ from nonebot.event import Event
class Rule: class Rule:
def __init__(self, checker: Optional[Callable[[Event], bool]] = None): def __init__(
self.checker = checker or (lambda event: True) self,
checker: Optional[Callable[["BaseBot", Event], # type: ignore
bool]] = None):
self.checker = checker or (lambda bot, event: True)
def __call__(self, event: Event) -> bool: def __call__(self, bot, event: Event) -> bool:
return self.checker(event) return self.checker(bot, event)
def __and__(self, other: "Rule") -> "Rule": def __and__(self, other: "Rule") -> "Rule":
return Rule(lambda event: self.checker(event) and other.checker(event)) return Rule(lambda bot, event: self.checker(bot, event) and other.
checker(bot, event))
def __or__(self, other: "Rule") -> "Rule": def __or__(self, other: "Rule") -> "Rule":
return Rule(lambda event: self.checker(event) or other.checker(event)) return Rule(lambda bot, event: self.checker(bot, event) or other.
checker(bot, event))
def __neg__(self) -> "Rule": def __neg__(self) -> "Rule":
return Rule(lambda event: not self.checker(event)) return Rule(lambda bot, event: not self.checker(bot, event))
def message() -> Rule:
return Rule(lambda bot, event: event.type == "message")
def notice() -> Rule:
return Rule(lambda bot, event: event.type == "notice")
def request() -> Rule:
return Rule(lambda bot, event: event.type == "request")
def metaevent() -> Rule:
return Rule(lambda bot, event: event.type == "meta_event")
def user(*qq: int) -> Rule: def user(*qq: int) -> Rule:
return Rule(lambda event: event.user_id in qq) return Rule(lambda bot, event: event.user_id in qq)
def private() -> Rule: def private() -> Rule:
return Rule(lambda event: event.detail_type == "private") return Rule(lambda bot, event: event.detail_type == "private")
def group(*group: int) -> Rule: def group(*group: int) -> Rule:
return Rule( return Rule(lambda bot, event: event.detail_type == "group" and event.
lambda event: event.detail_type == "group" and event.group_id in group) group_id in group)
def discuss(*discuss: int) -> Rule: def discuss(*discuss: int) -> Rule:
return Rule(lambda event: event.detail_type == "discuss" and event. return Rule(lambda bot, event: event.detail_type == "discuss" and event.
discuss_id in discuss) discuss_id in discuss)
def startswith(msg, start: int = None, end: int = None) -> Rule: def startswith(msg, start: int = None, end: int = None) -> Rule:
return Rule(lambda event: event.message.startswith(msg, start, end)) return Rule(lambda bot, event: event.message.startswith(msg, start, end))
def endswith(msg, start: int = None, end: int = None) -> Rule: def endswith(msg, start: int = None, end: int = None) -> Rule:
return Rule(lambda event: event.message.endswith(msg, start=None, end=None)) return Rule(
lambda bot, event: event.message.endswith(msg, start=None, end=None))
def has(msg: str) -> Rule: def has(msg: str) -> Rule:
return Rule(lambda event: msg in event.message) return Rule(lambda bot, event: msg in event.message)
def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> Rule: def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> Rule:
pattern = re.compile(regex, flags) pattern = re.compile(regex, flags)
return Rule(lambda event: bool(pattern.search(event.message))) return Rule(lambda bot, event: bool(pattern.search(str(event.message))))

View File

@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from nonebot.rule import Rule
from nonebot.event import Event
from nonebot.plugin import on_metaevent
def heartbeat(bot, event: Event) -> bool:
return event.detail_type == "heartbeat"
test_matcher = on_metaevent(Rule(heartbeat))
@test_matcher.handle()
async def handle_heartbeat(bot, event: Event, state: dict):
print("[i] Heartbeat")