diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 9f7c69b7..ef6f4cc5 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -11,7 +11,12 @@ from nonebot.config import Config class BaseBot(abc.ABC): @abc.abstractmethod - def __init__(self, type: str, config: Config, *, websocket=None): + def __init__(self, + type: str, + config: Config, + self_id: int, + *, + websocket=None): raise NotImplementedError @abc.abstractmethod diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 376e27f7..2bc46b75 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -10,6 +10,7 @@ from nonebot.event import Event from nonebot.config import Config from nonebot.message import handle_event from nonebot.drivers import BaseWebSocket +from nonebot.exception import ApiNotAvailable from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment @@ -44,6 +45,7 @@ class Bot(BaseBot): def __init__(self, connection_type: str, config: Config, + self_id: int, *, websocket: BaseWebSocket = None): if connection_type not in ["http", "websocket"]: @@ -51,6 +53,7 @@ class Bot(BaseBot): self.type = "coolq" self.connection_type = connection_type self.config = config + self.self_id = self_id self.websocket = websocket async def handle_message(self, message: dict): @@ -63,18 +66,31 @@ class Bot(BaseBot): if "message" in event.keys(): event["message"] = Message(event["message"]) - # TODO: Handle Meta Event - if event.type == "meta_event": - pass - else: - await handle_event(self, event) + await handle_event(self, event) async def call_api(self, api: str, data: dict): # TODO: Call API if self.type == "websocket": pass elif self.type == "http": - pass + api_root = self.config.api_root.get(self.self_id) + if not api_root: + raise ApiNotAvailable + elif not api_root.endswith("/"): + api_root += "/" + + headers = {} + if self.config.access_token: + headers["Authorization"] = "Bearer " + self.config.access_token + + async with httpx.AsyncClient() as client: + response = await client.post(api_root + api) + + if 200 <= response.status_code < 300: + # TODO: handle http api response + return ... + raise httpx.HTTPError( + "", response) class MessageSegment(BaseMessageSegment): diff --git a/nonebot/config.py b/nonebot/config.py index b06c0625..8e4d8efb 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Set, Union from ipaddress import IPv4Address +from typing import Set, Dict, Union, Optional from pydantic import BaseSettings @@ -15,14 +15,22 @@ class Env(BaseSettings): class Config(BaseSettings): + # nonebot configs driver: str = "nonebot.drivers.fastapi" host: IPv4Address = IPv4Address("127.0.0.1") port: int = 8080 + secret: Optional[str] = None debug: bool = False + # bot connection configs + api_root: Dict[int, str] = {} + access_token: Optional[str] = None + + # bot runtime configs superusers: Set[int] = set() nickname: Union[str, Set[str]] = "" + # custom configs custom_config: dict = {} class Config: diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index e4b88764..0028a827 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -45,10 +45,6 @@ class BaseDriver(abc.ABC): async def _handle_ws_reverse(self): raise NotImplementedError - @abc.abstractmethod - async def _handle_http_api(self): - raise NotImplementedError - class BaseWebSocket(object): @@ -71,7 +67,7 @@ class BaseWebSocket(object): raise NotImplementedError @abc.abstractmethod - async def close(self): + async def close(self, code: int): raise NotImplementedError @abc.abstractmethod diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index a2774d8d..5c22c94e 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -3,18 +3,19 @@ import json import logging -from typing import Optional +from typing import Dict, Optional from ipaddress import IPv4Address import uvicorn from fastapi.security import OAuth2PasswordBearer from starlette.websockets import WebSocketDisconnect -from fastapi import Body, FastAPI, WebSocket as FastAPIWebSocket +from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket from nonebot.log import logger from nonebot.config import Config -from nonebot.drivers import BaseDriver, BaseWebSocket +from nonebot.adapters import BaseBot from nonebot.adapters.cqhttp import Bot as CQBot +from nonebot.drivers import BaseDriver, BaseWebSocket class Driver(BaseDriver): @@ -28,6 +29,7 @@ class Driver(BaseDriver): ) self.config = config + self._clients: Dict[int, BaseBot] = {} self._server_app.post("/{adapter}/")(self._handle_http) self._server_app.post("/{adapter}/http")(self._handle_http) @@ -43,9 +45,13 @@ class Driver(BaseDriver): return self._server_app @property - def logger(self): + def logger(self) -> logging.Logger: return logging.getLogger("fastapi") + @property + def bots(self) -> Dict[int, BaseBot]: + return self._clients + def run(self, host: Optional[IPv4Address] = None, port: Optional[int] = None, @@ -102,12 +108,22 @@ class Driver(BaseDriver): async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket, + self_id: int = Header(None), access_token: str = OAuth2PasswordBearer( "/", auto_error=False)): websocket = WebSocket(websocket) # TODO: Check authorization + + # Create Bot Object + if adapter == "coolq": + bot = CQBot("websocket", self.config, self_id, websocket=websocket) + else: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + return + await websocket.accept() + self._clients[self_id] = bot while not websocket.closed: data = await websocket.receive() @@ -115,10 +131,9 @@ class Driver(BaseDriver): if not data: continue - logger.debug(f"Received message: {data}") - if adapter == "cqhttp": - bot = CQBot("websocket", self.config, websocket=websocket) - await bot.handle_message(data) + await bot.handle_message(data) + + del self._clients[self_id] class WebSocket(BaseWebSocket): @@ -135,8 +150,8 @@ class WebSocket(BaseWebSocket): await self.websocket.accept() self._closed = False - async def close(self): - await self.websocket.close() + async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): + await self.websocket.close(code=code) self._closed = True async def receive(self) -> Optional[dict]: diff --git a/nonebot/exception.py b/nonebot/exception.py index a0a5ca59..de001dad 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -2,6 +2,19 @@ # -*- coding: utf-8 -*- +class IgnoredException(Exception): + """ + Raised by event_preprocessor indicating that + the bot should ignore the event + """ + + def __init__(self, reason): + """ + :param reason: reason to ignore the event + """ + self.reason = reason + + class PausedException(Exception): """Block a message from further handling and try to receive a new message""" pass @@ -15,3 +28,8 @@ class RejectedException(Exception): class FinishedException(Exception): """Finish handling a message""" pass + + +class ApiNotAvailable(Exception): + """Api is not available""" + pass diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 89e22b56..a65ba161 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -62,7 +62,7 @@ class Matcher: return NewMatcher @classmethod - def check_rule(cls, event: Event) -> bool: + def check_rule(cls, bot, event: Event) -> bool: """检查 Matcher 的 Rule 是否成立 Args: @@ -71,7 +71,7 @@ class Matcher: Returns: bool: 条件成立与否 """ - return cls.rule(event) + return cls.rule(bot, event) # @classmethod # def args_parser(cls, func: Callable[[Event, dict], None]): @@ -141,9 +141,6 @@ class Matcher: # 运行handlers async def run(self, bot, event): - if not self.rule(event): - return - try: # if self.parser: # await self.parser(event, state) # type: ignore diff --git a/nonebot/message.py b/nonebot/message.py index 032615b8..81a5778b 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -1,19 +1,39 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import asyncio +from typing import Set, Callable + from nonebot.log import logger from nonebot.event import Event from nonebot.matcher import matchers +from nonebot.exception import IgnoredException + +_event_preprocessors: Set[Callable] = set() + + +def event_preprocessor(func: Callable) -> Callable: + _event_preprocessors.add(func) + return func async def handle_event(bot, event: Event): # TODO: PreProcess + coros = [] + for preprocessor in _event_preprocessors: + coros.append(preprocessor(bot, event)) + if coros: + try: + await asyncio.gather(*coros) + except IgnoredException: + logger.info(f"Event {event} is ignored") + return for priority in sorted(matchers.keys()): for index in range(len(matchers[priority])): Matcher = matchers[priority][index] try: - if not Matcher.check_rule(event): + if not Matcher.check_rule(bot, event): continue except Exception as e: logger.error( diff --git a/nonebot/plugin.py b/nonebot/plugin.py index 5b724763..50576842 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -7,9 +7,9 @@ import importlib from types import ModuleType from typing import Set, Dict, Type, Optional -from nonebot.rule import Rule from nonebot.log import logger from nonebot.matcher import Matcher +from nonebot.rule import Rule, metaevent, message, notice, request plugins: Dict[str, "Plugin"] = {} @@ -26,13 +26,58 @@ class Plugin(object): self.matchers = matchers +def on_metaevent(rule: Rule, + *, + handlers=[], + temp=False, + priority: int = 1, + state={}) -> Type[Matcher]: + matcher = Matcher.new(metaevent() & rule, + temp=temp, + priority=priority, + handlers=handlers, + default_state=state) + _tmp_matchers.add(matcher) + return matcher + + def on_message(rule: Rule, *, handlers=[], temp=False, priority: int = 1, state={}) -> Type[Matcher]: - matcher = Matcher.new(rule, + matcher = Matcher.new(message() & rule, + temp=temp, + priority=priority, + handlers=handlers, + default_state=state) + _tmp_matchers.add(matcher) + return matcher + + +def on_notice(rule: Rule, + *, + handlers=[], + temp=False, + priority: int = 1, + state={}) -> Type[Matcher]: + matcher = Matcher.new(notice() & rule, + temp=temp, + priority=priority, + handlers=handlers, + default_state=state) + _tmp_matchers.add(matcher) + return matcher + + +def on_request(rule: Rule, + *, + handlers=[], + temp=False, + priority: int = 1, + state={}) -> Type[Matcher]: + matcher = Matcher.new(request() & rule, temp=temp, priority=priority, handlers=handlers, diff --git a/nonebot/rule.py b/nonebot/rule.py index 814069b9..7f3f68fa 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -9,52 +9,74 @@ from nonebot.event import Event class Rule: - def __init__(self, checker: Optional[Callable[[Event], bool]] = None): - self.checker = checker or (lambda event: True) + def __init__( + self, + checker: Optional[Callable[["BaseBot", Event], # type: ignore + bool]] = None): + self.checker = checker or (lambda bot, event: True) - def __call__(self, event: Event) -> bool: - return self.checker(event) + def __call__(self, bot, event: Event) -> bool: + return self.checker(bot, event) def __and__(self, other: "Rule") -> "Rule": - return Rule(lambda event: self.checker(event) and other.checker(event)) + return Rule(lambda bot, event: self.checker(bot, event) and other. + checker(bot, event)) def __or__(self, other: "Rule") -> "Rule": - return Rule(lambda event: self.checker(event) or other.checker(event)) + return Rule(lambda bot, event: self.checker(bot, event) or other. + checker(bot, event)) def __neg__(self) -> "Rule": - return Rule(lambda event: not self.checker(event)) + return Rule(lambda bot, event: not self.checker(bot, event)) + + +def message() -> Rule: + return Rule(lambda bot, event: event.type == "message") + + +def notice() -> Rule: + return Rule(lambda bot, event: event.type == "notice") + + +def request() -> Rule: + return Rule(lambda bot, event: event.type == "request") + + +def metaevent() -> Rule: + return Rule(lambda bot, event: event.type == "meta_event") def user(*qq: int) -> Rule: - return Rule(lambda event: event.user_id in qq) + return Rule(lambda bot, event: event.user_id in qq) def private() -> Rule: - return Rule(lambda event: event.detail_type == "private") + return Rule(lambda bot, event: event.detail_type == "private") def group(*group: int) -> Rule: - return Rule( - lambda event: event.detail_type == "group" and event.group_id in group) + return Rule(lambda bot, event: event.detail_type == "group" and event. + group_id in group) def discuss(*discuss: int) -> Rule: - return Rule(lambda event: event.detail_type == "discuss" and event. + return Rule(lambda bot, event: event.detail_type == "discuss" and event. discuss_id in discuss) def startswith(msg, start: int = None, end: int = None) -> Rule: - return Rule(lambda event: event.message.startswith(msg, start, end)) + return Rule(lambda bot, event: event.message.startswith(msg, start, end)) def endswith(msg, start: int = None, end: int = None) -> Rule: - return Rule(lambda event: event.message.endswith(msg, start=None, end=None)) + return Rule( + lambda bot, event: event.message.endswith(msg, start=None, end=None)) def has(msg: str) -> Rule: - return Rule(lambda event: msg in event.message) + return Rule(lambda bot, event: msg in event.message) def regex(regex, flags: Union[int, re.RegexFlag] = 0) -> Rule: pattern = re.compile(regex, flags) - return Rule(lambda event: bool(pattern.search(event.message))) + return Rule(lambda bot, event: bool(pattern.search(str(event.message)))) diff --git a/tests/test_plugins/test_metaevent.py b/tests/test_plugins/test_metaevent.py new file mode 100644 index 00000000..3fb7360e --- /dev/null +++ b/tests/test_plugins/test_metaevent.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from nonebot.rule import Rule +from nonebot.event import Event +from nonebot.plugin import on_metaevent + + +def heartbeat(bot, event: Event) -> bool: + return event.detail_type == "heartbeat" + + +test_matcher = on_metaevent(Rule(heartbeat)) + + +@test_matcher.handle() +async def handle_heartbeat(bot, event: Event, state: dict): + print("[i] Heartbeat")