add types

This commit is contained in:
yanyongyu 2020-08-10 13:06:02 +08:00
parent 00913f1a8f
commit 9e33a605a6
9 changed files with 87 additions and 50 deletions

View File

@ -6,14 +6,14 @@ import importlib
from ipaddress import IPv4Address
from nonebot.log import logger
from nonebot.typing import Optional
from nonebot.typing import Union, Optional, NoReturn
from nonebot.config import Env, Config
from nonebot.drivers import BaseDriver
_driver: Optional[BaseDriver] = None
def get_driver() -> BaseDriver:
def get_driver() -> Union[NoReturn, BaseDriver]:
if _driver is None:
raise ValueError("NoneBot has not been initialized.")
return _driver
@ -38,7 +38,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
logger.debug(f"Loaded config: {config.dict()}")
Driver = getattr(importlib.import_module(config.driver), "Driver")
_driver = Driver(config)
_driver = Driver(env, config)
def run(host: Optional[IPv4Address] = None,

View File

@ -3,13 +3,10 @@
import abc
from functools import reduce
from dataclasses import dataclass
# from pydantic.dataclasses import dataclass # dataclass with validation
from dataclasses import dataclass, field
from nonebot.config import Config
from nonebot.drivers import BaseWebSocket
from nonebot.typing import Dict, Union, Iterable, Optional
from nonebot.typing import Dict, Union, Iterable, WebSocket
class BaseBot(abc.ABC):
@ -20,7 +17,7 @@ class BaseBot(abc.ABC):
config: Config,
self_id: int,
*,
websocket: BaseWebSocket = None):
websocket: WebSocket = None):
self.connection_type = connection_type
self.config = config
self.self_id = self_id
@ -43,7 +40,7 @@ class BaseBot(abc.ABC):
@dataclass
class BaseMessageSegment(abc.ABC):
type: str
data: Dict[str, str] = {}
data: Dict[str, str] = field(default_factory=lambda: {})
@abc.abstractmethod
def __str__(self):

View File

@ -8,10 +8,9 @@ import httpx
from nonebot.event import Event
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.drivers import BaseWebSocket
from nonebot.exception import ApiNotAvailable
from nonebot.typing import Tuple, Iterable, Optional, overrides
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str:
@ -47,7 +46,7 @@ class Bot(BaseBot):
config: Config,
self_id: int,
*,
websocket: BaseWebSocket = None):
websocket: WebSocket = None):
if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")

View File

@ -4,17 +4,17 @@
import abc
from ipaddress import IPv4Address
from nonebot.config import Config
from nonebot.adapters import BaseBot
from nonebot.typing import Dict, Optional
from nonebot.config import Env, Config
from nonebot.typing import Bot, Dict, Optional
class BaseDriver(abc.ABC):
@abc.abstractmethod
def __init__(self, config: Config):
def __init__(self, env: Env, config: Config):
self.env = env.environment
self.config = config
self._clients: Dict[int, BaseBot] = {}
self._clients: Dict[int, Bot] = {}
@property
@abc.abstractmethod
@ -32,7 +32,7 @@ class BaseDriver(abc.ABC):
raise NotImplementedError
@property
def bots(self) -> Dict[int, BaseBot]:
def bots(self) -> Dict[int, Bot]:
return self._clients
@abc.abstractmethod
@ -59,7 +59,6 @@ class BaseWebSocket(object):
self._websocket = websocket
@property
@abc.abstractmethod
def websocket(self):
return self._websocket

View File

@ -1,26 +1,28 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import logging
from ipaddress import IPv4Address
import uvicorn
from fastapi import FastAPI, status
from fastapi.security import OAuth2PasswordBearer
from starlette.websockets import WebSocketDisconnect
from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket
from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Config
from nonebot.adapters import BaseBot
from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder
from nonebot.typing import Optional, overrides
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.typing import Dict, Optional, overrides
from nonebot.drivers import BaseDriver, BaseWebSocket
class Driver(BaseDriver):
def __init__(self, config: Config):
super().__init__(config)
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self._server_app = FastAPI(
debug=config.debug,
@ -94,21 +96,28 @@ class Driver(BaseDriver):
@overrides(BaseDriver)
async def _handle_http(self,
adapter: str,
response: Response,
data: dict = Body(...),
x_self_id: int = Header(None),
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
# TODO: Check authorization
logger.debug(f"Received message: {data}")
# Create Bot Object
if adapter == "cqhttp":
bot = CQBot("http", self.config)
await bot.handle_message(data)
bot = CQBot("http", self.config, x_self_id)
else:
response.status_code = status.HTTP_404_NOT_FOUND
return {"status": 404, "message": "adapter not found"}
await bot.handle_message(data)
return {"status": 200, "message": "success"}
@overrides(BaseDriver)
async def _handle_ws_reverse(self,
adapter: str,
websocket: FastAPIWebSocket,
self_id: int = Header(None),
x_self_id: int = Header(None),
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
websocket = WebSocket(websocket)
@ -117,13 +126,16 @@ class Driver(BaseDriver):
# Create Bot Object
if adapter == "coolq":
bot = CQBot("websocket", self.config, self_id, websocket=websocket)
bot = CQBot("websocket",
self.config,
x_self_id,
websocket=websocket)
else:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
return
await websocket.accept()
self._clients[self_id] = bot
self._clients[x_self_id] = bot
while not websocket.closed:
data = await websocket.receive()
@ -133,7 +145,7 @@ class Driver(BaseDriver):
await bot.handle_message(data)
del self._clients[self_id]
del self._clients[x_self_id]
class WebSocket(BaseWebSocket):
@ -172,4 +184,5 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def send(self, data: dict) -> None:
await self.websocket.send_json(data)
text = json.dumps(data, cls=DataclassEncoder)
await self.websocket.send({"type": "websocket.send", "text": text})

View File

@ -5,10 +5,9 @@ from functools import wraps
from datetime import datetime
from collections import defaultdict
from nonebot.event import Event
from nonebot.typing import Handler
from nonebot.rule import Rule, user
from nonebot.typing import Type, List, Dict, Optional, Callable
from nonebot.typing import Bot, Event, Handler
from nonebot.typing import Type, List, Dict, Optional, NoReturn
from nonebot.exception import PausedException, RejectedException, FinishedException
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
@ -66,7 +65,7 @@ class Matcher:
return NewMatcher
@classmethod
def check_rule(cls, bot, event: Event) -> bool:
def check_rule(cls, bot: Bot, event: Event) -> bool:
"""检查 Matcher 的 Rule 是否成立
Args:
@ -98,7 +97,7 @@ class Matcher:
def _decorator(func: Handler) -> Handler:
async def _handler(bot, event: Event, state: dict):
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn:
raise PausedException
cls.handlers.append(_handler)
@ -144,7 +143,7 @@ class Matcher:
# raise RejectedException
# 运行handlers
async def run(self, bot, event):
async def run(self, bot: Bot, event: Event):
try:
# if self.parser:
# await self.parser(event, state) # type: ignore

View File

@ -5,20 +5,19 @@ import asyncio
from datetime import datetime
from nonebot.log import logger
from nonebot.event import Event
from nonebot.matcher import matchers
from nonebot.typing import Set, Callable
from nonebot.exception import IgnoredException
from nonebot.typing import Bot, Set, Event, PreProcessor
_event_preprocessors: Set[Callable] = set()
_event_preprocessors: Set[PreProcessor] = set()
def event_preprocessor(func: Callable) -> Callable:
def event_preprocessor(func: PreProcessor) -> PreProcessor:
_event_preprocessors.add(func)
return func
async def handle_event(bot, event: Event):
async def handle_event(bot: Bot, event: Event):
# TODO: PreProcess
coros = []
for preprocessor in _event_preprocessors:

View File

@ -2,14 +2,16 @@
# -*- coding: utf-8 -*-
from types import ModuleType
from typing import TYPE_CHECKING
from typing import NoReturn, TYPE_CHECKING
from typing import Any, Set, List, Dict, Type, Tuple, Mapping
from typing import Union, Optional, Iterable, Callable, Awaitable
from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable
# import some modules needed when checking types
if TYPE_CHECKING:
from nonebot.adapters import BaseBot as Bot
from nonebot.event import Event
from nonebot.event import Event as EventClass
from nonebot.matcher import Matcher as MatcherClass
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
def overrides(InterfaceClass: object):
@ -22,4 +24,17 @@ def overrides(InterfaceClass: object):
return overrider
Handler = Callable[["Bot", "Event", dict], Awaitable[None]]
Driver = TypeVar("Driver", bound="BaseDriver")
WebSocket = TypeVar("WebSocket", bound="BaseWebSocket")
Bot = TypeVar("Bot", bound="BaseBot")
Event = TypeVar("Event", bound="EventClass")
Message = TypeVar("Message", bound="BaseMessage")
MessageSegment = TypeVar("MessageSegment", bound="BaseMessageSegment")
PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
Awaitable[NoReturn]]]
Matcher = TypeVar("Matcher", bound="MatcherClass")
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None],
Awaitable[NoReturn]]]

16
nonebot/utils.py Normal file
View File

@ -0,0 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import dataclasses
from nonebot.typing import overrides
class DataclassEncoder(json.JSONEncoder):
@overrides(json.JSONEncoder)
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)