change typing

This commit is contained in:
yanyongyu 2020-08-07 17:51:57 +08:00
parent f2b6f08599
commit 332aac6497
5 changed files with 56 additions and 19 deletions

View File

@ -5,6 +5,7 @@ import abc
from functools import reduce from functools import reduce
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import BaseWebSocket
from nonebot.typing import Dict, Union, Iterable, Optional from nonebot.typing import Dict, Union, Iterable, Optional
@ -12,11 +13,19 @@ class BaseBot(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __init__(self, def __init__(self,
type: str, connection_type: str,
config: Config, config: Config,
self_id: int, self_id: int,
*, *,
websocket=None): websocket: BaseWebSocket = None):
self.connection_type = connection_type
self.config = config
self.self_id = self_id
self.websocket = websocket
@property
@abc.abstractmethod
def type(self) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod

View File

@ -10,7 +10,7 @@ from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.drivers import BaseWebSocket from nonebot.drivers import BaseWebSocket
from nonebot.exception import ApiNotAvailable from nonebot.exception import ApiNotAvailable
from nonebot.typing import Tuple, Iterable, Optional from nonebot.typing import Tuple, Iterable, Optional, overrides
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
@ -50,12 +50,15 @@ class Bot(BaseBot):
websocket: BaseWebSocket = None): websocket: BaseWebSocket = None):
if connection_type not in ["http", "websocket"]: if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type") raise ValueError("Unsupported connection type")
self.type = "coolq"
self.connection_type = connection_type
self.config = config
self.self_id = self_id
self.websocket = websocket
super().__init__(connection_type, config, self_id, websocket=websocket)
@property
@overrides(BaseBot)
def type(self) -> str:
return "cqhttp"
@overrides(BaseBot)
async def handle_message(self, message: dict): async def handle_message(self, message: dict):
# TODO: convert message into event # TODO: convert message into event
event = Event.from_payload(message) event = Event.from_payload(message)
@ -68,6 +71,7 @@ class Bot(BaseBot):
await handle_event(self, event) await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, data: dict):
# TODO: Call API # TODO: Call API
if self.type == "websocket": if self.type == "websocket":

View File

@ -5,14 +5,16 @@ import abc
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.config import Config from nonebot.config import Config
from nonebot.typing import Optional from nonebot.adapters import BaseBot
from nonebot.typing import Dict, Optional
class BaseDriver(abc.ABC): class BaseDriver(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __init__(self, config: Config): def __init__(self, config: Config):
raise NotImplementedError self.config = config
self._clients: Dict[int, BaseBot] = {}
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -29,6 +31,10 @@ class BaseDriver(abc.ABC):
def logger(self): def logger(self):
raise NotImplementedError raise NotImplementedError
@property
def bots(self) -> Dict[int, BaseBot]:
return self._clients
@abc.abstractmethod @abc.abstractmethod
def run(self, def run(self,
host: Optional[IPv4Address] = None, host: Optional[IPv4Address] = None,

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import logging import logging
from ipaddress import IPv4Address from ipaddress import IPv4Address
@ -13,14 +12,16 @@ from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.adapters import BaseBot from nonebot.adapters import BaseBot
from nonebot.typing import Dict, Optional
from nonebot.adapters.cqhttp import Bot as CQBot from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.typing import Dict, Optional, overrides
from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.drivers import BaseDriver, BaseWebSocket
class Driver(BaseDriver): class Driver(BaseDriver):
def __init__(self, config: Config): def __init__(self, config: Config):
super().__init__(config)
self._server_app = FastAPI( self._server_app = FastAPI(
debug=config.debug, debug=config.debug,
openapi_url=None, openapi_url=None,
@ -28,30 +29,27 @@ class Driver(BaseDriver):
redoc_url=None, redoc_url=None,
) )
self.config = config
self._clients: Dict[int, BaseBot] = {}
self._server_app.post("/{adapter}/")(self._handle_http) self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http) self._server_app.post("/{adapter}/http")(self._handle_http)
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
@property @property
@overrides(BaseDriver)
def server_app(self): def server_app(self):
return self._server_app return self._server_app
@property @property
@overrides(BaseDriver)
def asgi(self): def asgi(self):
return self._server_app return self._server_app
@property @property
@overrides(BaseDriver)
def logger(self) -> logging.Logger: def logger(self) -> logging.Logger:
return logging.getLogger("fastapi") return logging.getLogger("fastapi")
@property @overrides(BaseDriver)
def bots(self) -> Dict[int, BaseBot]:
return self._clients
def run(self, def run(self,
host: Optional[IPv4Address] = None, host: Optional[IPv4Address] = None,
port: Optional[int] = None, port: Optional[int] = None,
@ -93,6 +91,7 @@ class Driver(BaseDriver):
log_config=LOGGING_CONFIG, log_config=LOGGING_CONFIG,
**kwargs) **kwargs)
@overrides(BaseDriver)
async def _handle_http(self, async def _handle_http(self,
adapter: str, adapter: str,
data: dict = Body(...), data: dict = Body(...),
@ -105,6 +104,7 @@ class Driver(BaseDriver):
await bot.handle_message(data) await bot.handle_message(data)
return {"status": 200, "message": "success"} return {"status": 200, "message": "success"}
@overrides(BaseDriver)
async def _handle_ws_reverse(self, async def _handle_ws_reverse(self,
adapter: str, adapter: str,
websocket: FastAPIWebSocket, websocket: FastAPIWebSocket,
@ -143,17 +143,21 @@ class WebSocket(BaseWebSocket):
self._closed = None self._closed = None
@property @property
@overrides(BaseWebSocket)
def closed(self): def closed(self):
return self._closed return self._closed
@overrides(BaseWebSocket)
async def accept(self): async def accept(self):
await self.websocket.accept() await self.websocket.accept()
self._closed = False self._closed = False
@overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code=code) await self.websocket.close(code=code)
self._closed = True self._closed = True
@overrides(BaseWebSocket)
async def receive(self) -> Optional[dict]: async def receive(self) -> Optional[dict]:
data = None data = None
try: try:
@ -166,5 +170,6 @@ class WebSocket(BaseWebSocket):
return data return data
@overrides(BaseWebSocket)
async def send(self, data: dict) -> None: async def send(self, data: dict) -> None:
await self.websocket.send_json(data) await self.websocket.send_json(data)

View File

@ -1,13 +1,26 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import ABC
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Any, Set, List, Dict, Type, Tuple, Mapping from typing import Any, Set, List, Dict, Type, Tuple, Mapping
from typing import Union, Optional, Iterable, Callable, Awaitable from typing import Union, Optional, Iterable, Callable, Awaitable
# import some modules needed when checking types
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import BaseBot as Bot from nonebot.adapters import BaseBot as Bot
from nonebot.event import Event from nonebot.event import Event
def overrides(InterfaceClass: ABC):
def overrider(func):
assert func.__name__ in dir(
InterfaceClass), f"Error method: {func.__name__}"
return func
return overrider
Handler = Callable[["Bot", "Event", dict], Awaitable[None]] Handler = Callable[["Bot", "Event", dict], Awaitable[None]]