add types

This commit is contained in:
yanyongyu 2020-08-10 13:06:02 +08:00
parent 00913f1a8f
commit 9e33a605a6
9 changed files with 87 additions and 50 deletions

View File

@ -6,14 +6,14 @@ import importlib
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import Optional from nonebot.typing import Union, Optional, NoReturn
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import BaseDriver from nonebot.drivers import BaseDriver
_driver: Optional[BaseDriver] = None _driver: Optional[BaseDriver] = None
def get_driver() -> BaseDriver: def get_driver() -> Union[NoReturn, BaseDriver]:
if _driver is None: if _driver is None:
raise ValueError("NoneBot has not been initialized.") raise ValueError("NoneBot has not been initialized.")
return _driver return _driver
@ -38,7 +38,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
logger.debug(f"Loaded config: {config.dict()}") logger.debug(f"Loaded config: {config.dict()}")
Driver = getattr(importlib.import_module(config.driver), "Driver") Driver = getattr(importlib.import_module(config.driver), "Driver")
_driver = Driver(config) _driver = Driver(env, config)
def run(host: Optional[IPv4Address] = None, def run(host: Optional[IPv4Address] = None,

View File

@ -3,13 +3,10 @@
import abc import abc
from functools import reduce from functools import reduce
from dataclasses import dataclass from dataclasses import dataclass, field
# from pydantic.dataclasses import dataclass # dataclass with validation
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import BaseWebSocket from nonebot.typing import Dict, Union, Iterable, WebSocket
from nonebot.typing import Dict, Union, Iterable, Optional
class BaseBot(abc.ABC): class BaseBot(abc.ABC):
@ -20,7 +17,7 @@ class BaseBot(abc.ABC):
config: Config, config: Config,
self_id: int, self_id: int,
*, *,
websocket: BaseWebSocket = None): websocket: WebSocket = None):
self.connection_type = connection_type self.connection_type = connection_type
self.config = config self.config = config
self.self_id = self_id self.self_id = self_id
@ -43,7 +40,7 @@ class BaseBot(abc.ABC):
@dataclass @dataclass
class BaseMessageSegment(abc.ABC): class BaseMessageSegment(abc.ABC):
type: str type: str
data: Dict[str, str] = {} data: Dict[str, str] = field(default_factory=lambda: {})
@abc.abstractmethod @abc.abstractmethod
def __str__(self): def __str__(self):

View File

@ -8,10 +8,9 @@ import httpx
from nonebot.event import Event from nonebot.event import Event
from nonebot.config import Config from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.drivers import BaseWebSocket
from nonebot.exception import ApiNotAvailable from nonebot.exception import ApiNotAvailable
from nonebot.typing import Tuple, Iterable, Optional, overrides
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str: def escape(s: str, *, escape_comma: bool = True) -> str:
@ -47,7 +46,7 @@ class Bot(BaseBot):
config: Config, config: Config,
self_id: int, self_id: int,
*, *,
websocket: BaseWebSocket = None): websocket: WebSocket = None):
if connection_type not in ["http", "websocket"]: if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type") raise ValueError("Unsupported connection type")

View File

@ -4,17 +4,17 @@
import abc import abc
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.config import Config from nonebot.config import Env, Config
from nonebot.adapters import BaseBot from nonebot.typing import Bot, Dict, Optional
from nonebot.typing import Dict, Optional
class BaseDriver(abc.ABC): class BaseDriver(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __init__(self, config: Config): def __init__(self, env: Env, config: Config):
self.env = env.environment
self.config = config self.config = config
self._clients: Dict[int, BaseBot] = {} self._clients: Dict[int, Bot] = {}
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -32,7 +32,7 @@ class BaseDriver(abc.ABC):
raise NotImplementedError raise NotImplementedError
@property @property
def bots(self) -> Dict[int, BaseBot]: def bots(self) -> Dict[int, Bot]:
return self._clients return self._clients
@abc.abstractmethod @abc.abstractmethod
@ -59,7 +59,6 @@ class BaseWebSocket(object):
self._websocket = websocket self._websocket = websocket
@property @property
@abc.abstractmethod
def websocket(self): def websocket(self):
return self._websocket return self._websocket

View File

@ -1,26 +1,28 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import logging import logging
from ipaddress import IPv4Address from ipaddress import IPv4Address
import uvicorn import uvicorn
from fastapi import FastAPI, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from starlette.websockets import WebSocketDisconnect from starlette.websockets import WebSocketDisconnect
from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Env, Config
from nonebot.adapters import BaseBot from nonebot.utils import DataclassEncoder
from nonebot.typing import Optional, overrides
from nonebot.adapters.cqhttp import Bot as CQBot from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.typing import Dict, Optional, overrides
from nonebot.drivers import BaseDriver, BaseWebSocket from nonebot.drivers import BaseDriver, BaseWebSocket
class Driver(BaseDriver): class Driver(BaseDriver):
def __init__(self, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(config) super().__init__(env, config)
self._server_app = FastAPI( self._server_app = FastAPI(
debug=config.debug, debug=config.debug,
@ -94,21 +96,28 @@ class Driver(BaseDriver):
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_http(self, async def _handle_http(self,
adapter: str, adapter: str,
response: Response,
data: dict = Body(...), data: dict = Body(...),
x_self_id: int = Header(None),
access_token: str = OAuth2PasswordBearer( access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)): "/", auto_error=False)):
# TODO: Check authorization # TODO: Check authorization
logger.debug(f"Received message: {data}")
# Create Bot Object
if adapter == "cqhttp": if adapter == "cqhttp":
bot = CQBot("http", self.config) bot = CQBot("http", self.config, x_self_id)
await bot.handle_message(data) else:
response.status_code = status.HTTP_404_NOT_FOUND
return {"status": 404, "message": "adapter not found"}
await bot.handle_message(data)
return {"status": 200, "message": "success"} return {"status": 200, "message": "success"}
@overrides(BaseDriver) @overrides(BaseDriver)
async def _handle_ws_reverse(self, async def _handle_ws_reverse(self,
adapter: str, adapter: str,
websocket: FastAPIWebSocket, websocket: FastAPIWebSocket,
self_id: int = Header(None), x_self_id: int = Header(None),
access_token: str = OAuth2PasswordBearer( access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)): "/", auto_error=False)):
websocket = WebSocket(websocket) websocket = WebSocket(websocket)
@ -117,13 +126,16 @@ class Driver(BaseDriver):
# Create Bot Object # Create Bot Object
if adapter == "coolq": if adapter == "coolq":
bot = CQBot("websocket", self.config, self_id, websocket=websocket) bot = CQBot("websocket",
self.config,
x_self_id,
websocket=websocket)
else: else:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
return return
await websocket.accept() await websocket.accept()
self._clients[self_id] = bot self._clients[x_self_id] = bot
while not websocket.closed: while not websocket.closed:
data = await websocket.receive() data = await websocket.receive()
@ -133,7 +145,7 @@ class Driver(BaseDriver):
await bot.handle_message(data) await bot.handle_message(data)
del self._clients[self_id] del self._clients[x_self_id]
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
@ -172,4 +184,5 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: dict) -> None: async def send(self, data: dict) -> None:
await self.websocket.send_json(data) text = json.dumps(data, cls=DataclassEncoder)
await self.websocket.send({"type": "websocket.send", "text": text})

View File

@ -5,10 +5,9 @@ from functools import wraps
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
from nonebot.event import Event
from nonebot.typing import Handler
from nonebot.rule import Rule, user from nonebot.rule import Rule, user
from nonebot.typing import Type, List, Dict, Optional, Callable from nonebot.typing import Bot, Event, Handler
from nonebot.typing import Type, List, Dict, Optional, NoReturn
from nonebot.exception import PausedException, RejectedException, FinishedException from nonebot.exception import PausedException, RejectedException, FinishedException
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
@ -66,7 +65,7 @@ class Matcher:
return NewMatcher return NewMatcher
@classmethod @classmethod
def check_rule(cls, bot, event: Event) -> bool: def check_rule(cls, bot: Bot, event: Event) -> bool:
"""检查 Matcher 的 Rule 是否成立 """检查 Matcher 的 Rule 是否成立
Args: Args:
@ -98,7 +97,7 @@ class Matcher:
def _decorator(func: Handler) -> Handler: def _decorator(func: Handler) -> Handler:
async def _handler(bot, event: Event, state: dict): async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn:
raise PausedException raise PausedException
cls.handlers.append(_handler) cls.handlers.append(_handler)
@ -144,7 +143,7 @@ class Matcher:
# raise RejectedException # raise RejectedException
# 运行handlers # 运行handlers
async def run(self, bot, event): async def run(self, bot: Bot, event: Event):
try: try:
# if self.parser: # if self.parser:
# await self.parser(event, state) # type: ignore # await self.parser(event, state) # type: ignore

View File

@ -5,20 +5,19 @@ import asyncio
from datetime import datetime from datetime import datetime
from nonebot.log import logger from nonebot.log import logger
from nonebot.event import Event
from nonebot.matcher import matchers from nonebot.matcher import matchers
from nonebot.typing import Set, Callable
from nonebot.exception import IgnoredException from nonebot.exception import IgnoredException
from nonebot.typing import Bot, Set, Event, PreProcessor
_event_preprocessors: Set[Callable] = set() _event_preprocessors: Set[PreProcessor] = set()
def event_preprocessor(func: Callable) -> Callable: def event_preprocessor(func: PreProcessor) -> PreProcessor:
_event_preprocessors.add(func) _event_preprocessors.add(func)
return func return func
async def handle_event(bot, event: Event): async def handle_event(bot: Bot, event: Event):
# TODO: PreProcess # TODO: PreProcess
coros = [] coros = []
for preprocessor in _event_preprocessors: for preprocessor in _event_preprocessors:

View File

@ -2,14 +2,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING from typing import NoReturn, TYPE_CHECKING
from typing import Any, Set, List, Dict, Type, Tuple, Mapping from typing import Any, Set, List, Dict, Type, Tuple, Mapping
from typing import Union, Optional, Iterable, Callable, Awaitable from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable
# import some modules needed when checking types # import some modules needed when checking types
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import BaseBot as Bot from nonebot.event import Event as EventClass
from nonebot.event import Event from nonebot.matcher import Matcher as MatcherClass
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
def overrides(InterfaceClass: object): def overrides(InterfaceClass: object):
@ -22,4 +24,17 @@ def overrides(InterfaceClass: object):
return overrider return overrider
Handler = Callable[["Bot", "Event", dict], Awaitable[None]] Driver = TypeVar("Driver", bound="BaseDriver")
WebSocket = TypeVar("WebSocket", bound="BaseWebSocket")
Bot = TypeVar("Bot", bound="BaseBot")
Event = TypeVar("Event", bound="EventClass")
Message = TypeVar("Message", bound="BaseMessage")
MessageSegment = TypeVar("MessageSegment", bound="BaseMessageSegment")
PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
Awaitable[NoReturn]]]
Matcher = TypeVar("Matcher", bound="MatcherClass")
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None],
Awaitable[NoReturn]]]

16
nonebot/utils.py Normal file
View File

@ -0,0 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import dataclasses
from nonebot.typing import overrides
class DataclassEncoder(json.JSONEncoder):
@overrides(json.JSONEncoder)
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)