🎨 improve typing

This commit is contained in:
yanyongyu 2021-06-14 19:52:35 +08:00
parent e9bc98e74d
commit ddd96271b0
4 changed files with 121 additions and 117 deletions

View File

@ -6,7 +6,9 @@ try:
del pkg_resources del pkg_resources
except ImportError: except ImportError:
import pkgutil import pkgutil
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__) __path__: Iterable[str] = pkgutil.extend_path(
__path__, # type: ignore
__name__)
del pkgutil del pkgutil
except Exception: except Exception:
pass pass

View File

@ -7,28 +7,25 @@
import abc import abc
import asyncio import asyncio
from copy import copy from copy import deepcopy
from functools import reduce, partial from functools import partial
from typing_extensions import Protocol from typing_extensions import Protocol
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping, from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
Optional, Iterable, Awaitable, TYPE_CHECKING) Generic, Optional, Iterable)
from pydantic import BaseModel from pydantic import BaseModel
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config
from nonebot.utils import DataclassEncoder from nonebot.utils import DataclassEncoder
from nonebot.drivers import HTTPConnection, HTTPResponse from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
if TYPE_CHECKING:
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
class _ApiCall(Protocol): class _ApiCall(Protocol):
def __call__(self, **kwargs: Any) -> Awaitable[Any]: async def __call__(self, **kwargs: Any) -> Any:
... ...
@ -37,9 +34,9 @@ class Bot(abc.ABC):
Bot 基类用于处理上报消息并提供 API 调用接口 Bot 基类用于处理上报消息并提供 API 调用接口
""" """
driver: "Driver" driver: Driver
"""Driver 对象""" """Driver 对象"""
config: "Config" config: Config
"""Config 配置对象""" """Config 配置对象"""
_calling_api_hook: Set[T_CallingAPIHook] = set() _calling_api_hook: Set[T_CallingAPIHook] = set()
""" """
@ -56,9 +53,8 @@ class Bot(abc.ABC):
""" """
:参数: :参数:
* ``connection_type: str``: http 或者 websocket
* ``self_id: str``: 机器人 ID * ``self_id: str``: 机器人 ID
* ``websocket: Optional[WebSocket]``: Websocket 连接对象 * ``request: HTTPConnection``: request 连接对象
""" """
self.self_id: str = self_id self.self_id: str = self_id
"""机器人 ID""" """机器人 ID"""
@ -75,7 +71,7 @@ class Bot(abc.ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def register(cls, driver: "Driver", config: "Config"): def register(cls, driver: Driver, config: Config):
""" """
:说明: :说明:
@ -87,7 +83,7 @@ class Bot(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
async def check_permission( async def check_permission(
cls, driver: "Driver", request: HTTPConnection cls, driver: Driver, request: HTTPConnection
) -> Tuple[Optional[str], Optional[HTTPResponse]]: ) -> Tuple[Optional[str], Optional[HTTPResponse]]:
""" """
:说明: :说明:
@ -97,18 +93,12 @@ class Bot(abc.ABC):
:参数: :参数:
* ``driver: Driver``: Driver 对象 * ``driver: Driver``: Driver 对象
* ``connection_type: str``: 连接类型 * ``request: HTTPConnection``: request 请求详情
* ``headers: dict``: 请求头
* ``body: Optional[bytes]``: 请求数据WebSocket 连接该部分为 None
:返回: :返回:
- ``str``: 连接唯一标识符``None`` 代表连接不合法 - ``Optional[str]``: 连接唯一标识符``None`` 代表连接不合法
- ``HTTPResponse``: HTTP 上报响应 - ``Optional[HTTPResponse]``: HTTP 上报响应
:异常:
- ``RequestDenied``: 请求非法
""" """
raise NotImplementedError raise NotImplementedError
@ -210,21 +200,45 @@ class Bot(abc.ABC):
@classmethod @classmethod
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook: def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
"""
:说明:
调用 api 预处理
:参数:
* ``bot: Bot``: 当前 bot 对象
* ``api: str``: 调用的 api 名称
* ``data: Dict[str, Any]``: api 调用的参数字典
"""
cls._calling_api_hook.add(func) cls._calling_api_hook.add(func)
return func return func
@classmethod @classmethod
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook: def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
"""
:说明:
调用 api 后处理
:参数:
* ``bot: Bot``: 当前 bot 对象
* ``exception: Optional[Exception]``: 调用 api 时发生的错误
* ``api: str``: 调用的 api 名称
* ``data: Dict[str, Any]``: api 调用的参数字典
* ``result: Any``: api 调用的返回
"""
cls._called_api_hook.add(func) cls._called_api_hook.add(func)
return func return func
T_Message = TypeVar("T_Message", bound="Message") T_Message = TypeVar("T_Message", bound="Message")
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
@dataclass @dataclass
class MessageSegment(abc.ABC, Mapping): class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
"""消息段基类""" """消息段基类"""
type: str type: str
""" """
@ -237,6 +251,11 @@ class MessageSegment(abc.ABC, Mapping):
- 说明: 消息段数据 - 说明: 消息段数据
""" """
@abc.abstractmethod
@classmethod
def get_message_class(cls) -> Type[T_Message]:
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def __str__(self) -> str: def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用""" """该消息段所代表的 str在命令匹配部分使用"""
@ -248,46 +267,27 @@ class MessageSegment(abc.ABC, Mapping):
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool: def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
return not self == other return not self == other
@abc.abstractmethod def __add__(self, other: Union[str, Mapping,
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, Iterable[Mapping]]) -> T_Message:
T_Message]) -> T_Message: return self.get_message_class()(self) + other
"""你需要在这里实现不同消息段的合并:
比如
if isinstance(other, str):
...
elif isinstance(other, MessageSegment):
...
注意需要返回一个新生成的对象
"""
raise NotImplementedError
@abc.abstractmethod def __radd__(self, other: Union[str, Mapping,
def __radd__( Iterable[Mapping]]) -> T_Message:
self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment, return self.get_message_class()(other) + self
T_Message]) -> "T_Message":
"""你需要在这里实现不同消息段的合并:
比如
if isinstance(other, str):
...
elif isinstance(other, MessageSegment):
...
注意需要返回一个新生成的对象
"""
raise NotImplementedError
def __getitem__(self, key): def __getitem__(self, key: str):
return getattr(self, key) return self.data[key]
def __setitem__(self, key, value): def __setitem__(self, key: str, value: Any):
return setattr(self, key, value) self.data[key] = value
def __iter__(self): def __iter__(self):
yield from self.data.__iter__() yield from self.data.__iter__()
def __contains__(self, key: object) -> bool: def __contains__(self, key: Any) -> bool:
return key in self.data return key in self.data
def get(self, key: str, default=None): def get(self, key: str, default: Any = None):
return getattr(self, key, default) return getattr(self, key, default)
def keys(self): def keys(self):
@ -300,7 +300,7 @@ class MessageSegment(abc.ABC, Mapping):
return self.data.items() return self.data.items()
def copy(self: T_MessageSegment) -> T_MessageSegment: def copy(self: T_MessageSegment) -> T_MessageSegment:
return copy(self) return deepcopy(self)
@abc.abstractmethod @abc.abstractmethod
def is_text(self) -> bool: def is_text(self) -> bool:
@ -310,7 +310,7 @@ class MessageSegment(abc.ABC, Mapping):
class Message(List[T_MessageSegment], abc.ABC): class Message(List[T_MessageSegment], abc.ABC):
"""消息数组""" """消息数组"""
def __init__(self, def __init__(self: T_Message,
message: Union[str, None, Mapping, Iterable[Mapping], message: Union[str, None, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message, Any] = None, T_MessageSegment, T_Message, Any] = None,
*args, *args,
@ -330,8 +330,13 @@ class Message(List[T_MessageSegment], abc.ABC):
else: else:
self.extend(self._construct(message)) self.extend(self._construct(message))
@abc.abstractmethod
@classmethod
def get_segment_class(cls) -> Type[T_MessageSegment]:
raise NotImplementedError
def __str__(self): def __str__(self):
return ''.join((str(seg) for seg in self)) return "".join(str(seg) for seg in self)
@classmethod @classmethod
def __get_validators__(cls): def __get_validators__(cls):
@ -348,30 +353,31 @@ class Message(List[T_MessageSegment], abc.ABC):
) -> Iterable[T_MessageSegment]: ) -> Iterable[T_MessageSegment]:
raise NotImplementedError raise NotImplementedError
def __add__(self: T_Message, other: Union[str, T_MessageSegment, def __add__(
T_Message]) -> T_Message: self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
result = self.__class__(self) T_MessageSegment, T_Message]
if isinstance(other, str): ) -> T_Message:
result.extend(self._construct(other)) result = self.copy()
elif isinstance(other, MessageSegment): result += other
result.append(other)
elif isinstance(other, Message):
result.extend(other)
return result return result
def __radd__(self: T_Message, other: Union[str, T_MessageSegment, def __radd__(
T_Message]) -> T_Message: self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
T_MessageSegment, T_Message]
) -> T_Message:
result = self.__class__(other) result = self.__class__(other)
return result.__add__(self) return result + self
def __iadd__(self: T_Message, other: Union[str, T_MessageSegment, def __iadd__(
T_Message]) -> T_Message: self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
if isinstance(other, str): T_MessageSegment, T_Message]
self.extend(self._construct(other)) ) -> T_Message:
elif isinstance(other, MessageSegment): if isinstance(other, MessageSegment):
self.append(other) self.append(other)
elif isinstance(other, Message): elif isinstance(other, Message):
self.extend(other) self.extend(other)
else:
self.extend(self._construct(other))
return self return self
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message: def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
@ -385,7 +391,7 @@ class Message(List[T_MessageSegment], abc.ABC):
* ``obj: Union[str, MessageSegment]``: 要添加的消息段 * ``obj: Union[str, MessageSegment]``: 要添加的消息段
""" """
if isinstance(obj, MessageSegment): if isinstance(obj, MessageSegment):
super().append(obj) super(Message, self).append(obj)
elif isinstance(obj, str): elif isinstance(obj, str):
self.extend(self._construct(obj)) self.extend(self._construct(obj))
else: else:
@ -407,33 +413,17 @@ class Message(List[T_MessageSegment], abc.ABC):
self.append(segment) self.append(segment)
return self return self
def reduce(self: T_Message) -> None: def copy(self: T_Message) -> T_Message:
""" return deepcopy(self)
:说明:
缩减消息数组即按 MessageSegment 的实现拼接相邻消息段 def extract_plain_text(self) -> str:
"""
index = 0
while index < len(self):
if index > 0 and self[index -
1].is_text() and self[index].is_text():
self[index - 1] += self[index]
del self[index]
else:
index += 1
def extract_plain_text(self: T_Message) -> str:
""" """
:说明: :说明:
提取消息内纯文本消息 提取消息内纯文本消息
""" """
def _concat(x: str, y: T_MessageSegment) -> str: return "".join(str(seg) for seg in self if seg.is_text())
return f"{x} {y}" if y.is_text() else x
plain_text = reduce(_concat, self, "")
return plain_text[1:] if plain_text else plain_text
class Event(abc.ABC, BaseModel): class Event(abc.ABC, BaseModel):

View File

@ -50,8 +50,8 @@ class Filter:
def __call__(self, record): def __call__(self, record):
module = sys.modules.get(record["name"]) module = sys.modules.get(record["name"])
if module: if module:
plugin_name = getattr(module, "__plugin_name__", record["name"]) module_name = getattr(module, "__module_name__", record["name"])
record["name"] = plugin_name record["name"] = module_name
record["name"] = record["name"].split(".")[0] record["name"] = record["name"].split(".")[0]
levelno = logger.level(self.level).no if isinstance(self.level, levelno = logger.level(self.level).no if isinstance(self.level,
str) else self.level str) else self.level

View File

@ -11,14 +11,14 @@ from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessa
from .utils import log, escape, unescape, _b2s from .utils import log, escape, unescape, _b2s
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment["Message"]):
""" """
CQHTTP 协议 MessageSegment 适配具体方法参考协议消息段类型或源码 CQHTTP 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
""" """
@overrides(BaseMessageSegment) @classmethod
def __init__(self, type: str, data: Dict[str, Any]) -> None: def get_message_class(cls):
super().__init__(type=type, data=data) return Message
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __str__(self) -> str: def __str__(self) -> str:
@ -37,7 +37,8 @@ class MessageSegment(BaseMessageSegment):
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
return Message(self) + other return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message": def __radd__(self, other) -> "Message":
@ -234,10 +235,25 @@ class Message(BaseMessage[MessageSegment]):
CQHTTP 协议 Message 适配 CQHTTP 协议 Message 适配
""" """
def __radd__(self, other: Union[str, MessageSegment, @classmethod
"Message"]) -> "Message": def get_segment_class(cls):
result = MessageSegment.text(other) if isinstance(other, str) else other return MessageSegment
return super(Message, self).__radd__(result)
@overrides(BaseMessage)
def __add__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage)
def __radd__(
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
"Message"]
) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
@staticmethod @staticmethod
@overrides(BaseMessage) @overrides(BaseMessage)
@ -280,10 +296,6 @@ class Message(BaseMessage[MessageSegment]):
} }
yield MessageSegment(type_, data) yield MessageSegment(type_, data)
@overrides(BaseMessage)
def extract_plain_text(self) -> str: def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text())
def _concat(x: str, y: MessageSegment) -> str:
return f"{x} {y.data['text']}" if y.is_text() else x
plain_text = reduce(_concat, self, "")
return plain_text[1:] if plain_text else plain_text