diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 274c82c3..e8e5f2c3 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -5,6 +5,7 @@ import abc from functools import reduce from nonebot.config import Config +from nonebot.drivers import BaseWebSocket from nonebot.typing import Dict, Union, Iterable, Optional @@ -12,11 +13,19 @@ class BaseBot(abc.ABC): @abc.abstractmethod def __init__(self, - type: str, + connection_type: str, config: Config, self_id: int, *, - websocket=None): + websocket: BaseWebSocket = None): + self.connection_type = connection_type + self.config = config + self.self_id = self_id + self.websocket = websocket + + @property + @abc.abstractmethod + def type(self) -> str: raise NotImplementedError @abc.abstractmethod diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 9efeff44..7e2fc827 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -10,7 +10,7 @@ from nonebot.config import Config from nonebot.message import handle_event from nonebot.drivers import BaseWebSocket from nonebot.exception import ApiNotAvailable -from nonebot.typing import Tuple, Iterable, Optional +from nonebot.typing import Tuple, Iterable, Optional, overrides from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment @@ -50,12 +50,15 @@ class Bot(BaseBot): websocket: BaseWebSocket = None): if connection_type not in ["http", "websocket"]: raise ValueError("Unsupported connection type") - self.type = "coolq" - self.connection_type = connection_type - self.config = config - self.self_id = self_id - self.websocket = websocket + super().__init__(connection_type, config, self_id, websocket=websocket) + + @property + @overrides(BaseBot) + def type(self) -> str: + return "cqhttp" + + @overrides(BaseBot) async def handle_message(self, message: dict): # TODO: convert message into event event = Event.from_payload(message) @@ -68,6 +71,7 @@ class Bot(BaseBot): await handle_event(self, event) + @overrides(BaseBot) async def call_api(self, api: str, data: dict): # TODO: Call API if self.type == "websocket": diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 70ab96d4..75fe6ba3 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -5,14 +5,16 @@ import abc from ipaddress import IPv4Address from nonebot.config import Config -from nonebot.typing import Optional +from nonebot.adapters import BaseBot +from nonebot.typing import Dict, Optional class BaseDriver(abc.ABC): @abc.abstractmethod def __init__(self, config: Config): - raise NotImplementedError + self.config = config + self._clients: Dict[int, BaseBot] = {} @property @abc.abstractmethod @@ -29,6 +31,10 @@ class BaseDriver(abc.ABC): def logger(self): raise NotImplementedError + @property + def bots(self) -> Dict[int, BaseBot]: + return self._clients + @abc.abstractmethod def run(self, host: Optional[IPv4Address] = None, diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 7a6ba0ef..c0ee663d 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import json import logging from ipaddress import IPv4Address @@ -13,14 +12,16 @@ from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket from nonebot.log import logger from nonebot.config import Config from nonebot.adapters import BaseBot -from nonebot.typing import Dict, Optional from nonebot.adapters.cqhttp import Bot as CQBot +from nonebot.typing import Dict, Optional, overrides from nonebot.drivers import BaseDriver, BaseWebSocket class Driver(BaseDriver): def __init__(self, config: Config): + super().__init__(config) + self._server_app = FastAPI( debug=config.debug, openapi_url=None, @@ -28,30 +29,27 @@ class Driver(BaseDriver): redoc_url=None, ) - self.config = config - self._clients: Dict[int, BaseBot] = {} - self._server_app.post("/{adapter}/")(self._handle_http) self._server_app.post("/{adapter}/http")(self._handle_http) self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) @property + @overrides(BaseDriver) def server_app(self): return self._server_app @property + @overrides(BaseDriver) def asgi(self): return self._server_app @property + @overrides(BaseDriver) def logger(self) -> logging.Logger: return logging.getLogger("fastapi") - @property - def bots(self) -> Dict[int, BaseBot]: - return self._clients - + @overrides(BaseDriver) def run(self, host: Optional[IPv4Address] = None, port: Optional[int] = None, @@ -93,6 +91,7 @@ class Driver(BaseDriver): log_config=LOGGING_CONFIG, **kwargs) + @overrides(BaseDriver) async def _handle_http(self, adapter: str, data: dict = Body(...), @@ -105,6 +104,7 @@ class Driver(BaseDriver): await bot.handle_message(data) return {"status": 200, "message": "success"} + @overrides(BaseDriver) async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket, @@ -143,17 +143,21 @@ class WebSocket(BaseWebSocket): self._closed = None @property + @overrides(BaseWebSocket) def closed(self): return self._closed + @overrides(BaseWebSocket) async def accept(self): await self.websocket.accept() self._closed = False + @overrides(BaseWebSocket) async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): await self.websocket.close(code=code) self._closed = True + @overrides(BaseWebSocket) async def receive(self) -> Optional[dict]: data = None try: @@ -166,5 +170,6 @@ class WebSocket(BaseWebSocket): return data + @overrides(BaseWebSocket) async def send(self, data: dict) -> None: await self.websocket.send_json(data) diff --git a/nonebot/typing.py b/nonebot/typing.py index 400666fe..cb798274 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -1,13 +1,26 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from abc import ABC from types import ModuleType from typing import TYPE_CHECKING from typing import Any, Set, List, Dict, Type, Tuple, Mapping from typing import Union, Optional, Iterable, Callable, Awaitable +# import some modules needed when checking types if TYPE_CHECKING: from nonebot.adapters import BaseBot as Bot from nonebot.event import Event + +def overrides(InterfaceClass: ABC): + + def overrider(func): + assert func.__name__ in dir( + InterfaceClass), f"Error method: {func.__name__}" + return func + + return overrider + + Handler = Callable[["Bot", "Event", dict], Awaitable[None]]