add typings

This commit is contained in:
yanyongyu 2020-08-08 23:08:01 +08:00
parent 332aac6497
commit 00913f1a8f
4 changed files with 66 additions and 43 deletions

View File

@ -3,6 +3,9 @@
import abc import abc
from functools import reduce from functools import reduce
from dataclasses import dataclass
# from pydantic.dataclasses import dataclass # dataclass with validation
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import BaseWebSocket from nonebot.drivers import BaseWebSocket
@ -37,51 +40,65 @@ class BaseBot(abc.ABC):
raise NotImplementedError raise NotImplementedError
class BaseMessageSegment(dict): @dataclass
class BaseMessageSegment(abc.ABC):
def __init__(self, type: str
type_: Optional[str] = None, data: Dict[str, str] = {}
data: Optional[Dict[str, str]] = None):
super().__init__()
if type_:
self.type = type_
self.data = data
else:
raise ValueError('The "type" field cannot be empty')
@abc.abstractmethod
def __str__(self): def __str__(self):
raise NotImplementedError raise NotImplementedError
def __getitem__(self, item): @abc.abstractmethod
if item not in ("type", "data"): def __add__(self, other):
raise KeyError(f'Key "{item}" is not allowed') raise NotImplementedError
return super().__getitem__(item)
def __setitem__(self, key, value):
if key not in ("type", "data"):
raise KeyError(f'Key "{key}" is not allowed')
return super().__setitem__(key, value)
# TODO: __eq__ __add__
@property
def type(self) -> str:
return self["type"]
@type.setter
def type(self, value: str):
self["type"] = value
@property
def data(self) -> Dict[str, str]:
return self["data"]
@data.setter
def data(self, data: Optional[Dict[str, str]]):
self["data"] = data or {}
class BaseMessage(list): # class BaseMessageSegment(dict):
# def __init__(self,
# type_: Optional[str] = None,
# data: Optional[Dict[str, str]] = None):
# super().__init__()
# if type_:
# self.type = type_
# self.data = data
# else:
# raise ValueError('The "type" field cannot be empty')
# def __str__(self):
# raise NotImplementedError
# def __getitem__(self, item):
# if item not in ("type", "data"):
# raise KeyError(f'Key "{item}" is not allowed')
# return super().__getitem__(item)
# def __setitem__(self, key, value):
# if key not in ("type", "data"):
# raise KeyError(f'Key "{key}" is not allowed')
# return super().__setitem__(key, value)
# # TODO: __eq__ __add__
# @property
# def type(self) -> str:
# return self["type"]
# @type.setter
# def type(self, value: str):
# self["type"] = value
# @property
# def data(self) -> Dict[str, str]:
# return self["data"]
# @data.setter
# def data(self, data: Optional[Dict[str, str]]):
# self["data"] = data or {}
class BaseMessage(list, abc.ABC):
def __init__(self, def __init__(self,
message: Union[str, BaseMessageSegment, "BaseMessage"] = None, message: Union[str, BaseMessageSegment, "BaseMessage"] = None,
@ -99,6 +116,7 @@ class BaseMessage(list):
return ''.join((str(seg) for seg in self)) return ''.join((str(seg) for seg in self))
@staticmethod @staticmethod
@abc.abstractmethod
def _construct(msg: str) -> Iterable[BaseMessageSegment]: def _construct(msg: str) -> Iterable[BaseMessageSegment]:
raise NotImplementedError raise NotImplementedError

View File

@ -99,6 +99,7 @@ class Bot(BaseBot):
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment)
def __str__(self): def __str__(self):
type_ = self.type type_ = self.type
data = self.data.copy() data = self.data.copy()
@ -116,6 +117,10 @@ class MessageSegment(BaseMessageSegment):
params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()]) params = ",".join([f"{k}={escape(str(v))}" for k, v in data.items()])
return f"[CQ:{type_}{',' if params else ''}{params}]" return f"[CQ:{type_}{',' if params else ''}{params}]"
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + other
@staticmethod @staticmethod
def anonymous(ignore_failure: bool = False) -> "MessageSegment": def anonymous(ignore_failure: bool = False) -> "MessageSegment":
return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)}) return MessageSegment("anonymous", {"ignore": _b2s(ignore_failure)})
@ -248,6 +253,7 @@ class MessageSegment(BaseMessageSegment):
class Message(BaseMessage): class Message(BaseMessage):
@staticmethod @staticmethod
@overrides(BaseMessage)
def _construct(msg: str) -> Iterable[MessageSegment]: def _construct(msg: str) -> Iterable[MessageSegment]:
def _iter_message() -> Iterable[Tuple[str, str]]: def _iter_message() -> Iterable[Tuple[str, str]]:

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import ABC
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Any, Set, List, Dict, Type, Tuple, Mapping from typing import Any, Set, List, Dict, Type, Tuple, Mapping
@ -13,9 +12,9 @@ if TYPE_CHECKING:
from nonebot.event import Event from nonebot.event import Event
def overrides(InterfaceClass: ABC): def overrides(InterfaceClass: object):
def overrider(func): def overrider(func: Callable) -> Callable:
assert func.__name__ in dir( assert func.__name__ in dir(
InterfaceClass), f"Error method: {func.__name__}" InterfaceClass), f"Error method: {func.__name__}"
return func return func

View File

@ -25,7 +25,7 @@ fastapi = "^0.58.1"
uvicorn = "^0.11.5" uvicorn = "^0.11.5"
pydantic = { extras = ["dotenv"], version = "^1.5.1" } pydantic = { extras = ["dotenv"], version = "^1.5.1" }
apscheduler = { version = "^3.6.3", optional = true } apscheduler = { version = "^3.6.3", optional = true }
nonebot-test = { version = "^0.1.0", optional = true } # nonebot-test = { version = "^0.1.0", optional = true }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
yapf = "^0.30.0" yapf = "^0.30.0"