From ddd96271b0c18b1831bf4126b9bce942eca9513c Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 14 Jun 2021 19:52:35 +0800 Subject: [PATCH] :art: improve typing --- nonebot/adapters/__init__.py | 4 +- nonebot/adapters/_base.py | 188 +++++++++--------- nonebot/log.py | 4 +- .../nonebot/adapters/cqhttp/message.py | 42 ++-- 4 files changed, 121 insertions(+), 117 deletions(-) diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 5bcc2b02..a424855c 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -6,7 +6,9 @@ try: del pkg_resources except ImportError: import pkgutil - __path__: Iterable[str] = pkgutil.extend_path(__path__, __name__) + __path__: Iterable[str] = pkgutil.extend_path( + __path__, # type: ignore + __name__) del pkgutil except Exception: pass diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 39ff913d..c4846088 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -7,28 +7,25 @@ import abc import asyncio -from copy import copy -from functools import reduce, partial +from copy import deepcopy +from functools import partial from typing_extensions import Protocol from dataclasses import dataclass, field -from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping, - Optional, Iterable, Awaitable, TYPE_CHECKING) +from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping, + Generic, Optional, Iterable) from pydantic import BaseModel from nonebot.log import logger +from nonebot.config import Config from nonebot.utils import DataclassEncoder -from nonebot.drivers import HTTPConnection, HTTPResponse +from nonebot.drivers import Driver, HTTPConnection, HTTPResponse from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook -if TYPE_CHECKING: - from nonebot.config import Config - from nonebot.drivers import Driver, WebSocket - class _ApiCall(Protocol): - def __call__(self, **kwargs: Any) -> Awaitable[Any]: + async def __call__(self, **kwargs: Any) -> Any: ... @@ -37,9 +34,9 @@ class Bot(abc.ABC): Bot 基类。用于处理上报消息,并提供 API 调用接口。 """ - driver: "Driver" + driver: Driver """Driver 对象""" - config: "Config" + config: Config """Config 配置对象""" _calling_api_hook: Set[T_CallingAPIHook] = set() """ @@ -56,9 +53,8 @@ class Bot(abc.ABC): """ :参数: - * ``connection_type: str``: http 或者 websocket * ``self_id: str``: 机器人 ID - * ``websocket: Optional[WebSocket]``: Websocket 连接对象 + * ``request: HTTPConnection``: request 连接对象 """ self.self_id: str = self_id """机器人 ID""" @@ -75,7 +71,7 @@ class Bot(abc.ABC): raise NotImplementedError @classmethod - def register(cls, driver: "Driver", config: "Config"): + def register(cls, driver: Driver, config: Config): """ :说明: @@ -87,7 +83,7 @@ class Bot(abc.ABC): @classmethod @abc.abstractmethod async def check_permission( - cls, driver: "Driver", request: HTTPConnection + cls, driver: Driver, request: HTTPConnection ) -> Tuple[Optional[str], Optional[HTTPResponse]]: """ :说明: @@ -97,18 +93,12 @@ class Bot(abc.ABC): :参数: * ``driver: Driver``: Driver 对象 - * ``connection_type: str``: 连接类型 - * ``headers: dict``: 请求头 - * ``body: Optional[bytes]``: 请求数据,WebSocket 连接该部分为 None + * ``request: HTTPConnection``: request 请求详情 :返回: - - ``str``: 连接唯一标识符,``None`` 代表连接不合法 - - ``HTTPResponse``: HTTP 上报响应 - - :异常: - - - ``RequestDenied``: 请求非法 + - ``Optional[str]``: 连接唯一标识符,``None`` 代表连接不合法 + - ``Optional[HTTPResponse]``: HTTP 上报响应 """ raise NotImplementedError @@ -210,21 +200,45 @@ class Bot(abc.ABC): @classmethod def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook: + """ + :说明: + + 调用 api 预处理。 + + :参数: + + * ``bot: Bot``: 当前 bot 对象 + * ``api: str``: 调用的 api 名称 + * ``data: Dict[str, Any]``: api 调用的参数字典 + """ cls._calling_api_hook.add(func) return func @classmethod def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook: + """ + :说明: + + 调用 api 后处理。 + + :参数: + + * ``bot: Bot``: 当前 bot 对象 + * ``exception: Optional[Exception]``: 调用 api 时发生的错误 + * ``api: str``: 调用的 api 名称 + * ``data: Dict[str, Any]``: api 调用的参数字典 + * ``result: Any``: api 调用的返回 + """ cls._called_api_hook.add(func) return func T_Message = TypeVar("T_Message", bound="Message") -T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment") +T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]") @dataclass -class MessageSegment(abc.ABC, Mapping): +class MessageSegment(Mapping, abc.ABC, Generic[T_Message]): """消息段基类""" type: str """ @@ -237,6 +251,11 @@ class MessageSegment(abc.ABC, Mapping): - 说明: 消息段数据 """ + @abc.abstractmethod + @classmethod + def get_message_class(cls) -> Type[T_Message]: + raise NotImplementedError + @abc.abstractmethod def __str__(self) -> str: """该消息段所代表的 str,在命令匹配部分使用""" @@ -248,46 +267,27 @@ class MessageSegment(abc.ABC, Mapping): def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool: return not self == other - @abc.abstractmethod - def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, - T_Message]) -> T_Message: - """你需要在这里实现不同消息段的合并: - 比如: - if isinstance(other, str): - ... - elif isinstance(other, MessageSegment): - ... - 注意:需要返回一个新生成的对象 - """ - raise NotImplementedError + def __add__(self, other: Union[str, Mapping, + Iterable[Mapping]]) -> T_Message: + return self.get_message_class()(self) + other - @abc.abstractmethod - def __radd__( - self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment, - T_Message]) -> "T_Message": - """你需要在这里实现不同消息段的合并: - 比如: - if isinstance(other, str): - ... - elif isinstance(other, MessageSegment): - ... - 注意:需要返回一个新生成的对象 - """ - raise NotImplementedError + def __radd__(self, other: Union[str, Mapping, + Iterable[Mapping]]) -> T_Message: + return self.get_message_class()(other) + self - def __getitem__(self, key): - return getattr(self, key) + def __getitem__(self, key: str): + return self.data[key] - def __setitem__(self, key, value): - return setattr(self, key, value) + def __setitem__(self, key: str, value: Any): + self.data[key] = value def __iter__(self): yield from self.data.__iter__() - def __contains__(self, key: object) -> bool: + def __contains__(self, key: Any) -> bool: return key in self.data - def get(self, key: str, default=None): + def get(self, key: str, default: Any = None): return getattr(self, key, default) def keys(self): @@ -300,7 +300,7 @@ class MessageSegment(abc.ABC, Mapping): return self.data.items() def copy(self: T_MessageSegment) -> T_MessageSegment: - return copy(self) + return deepcopy(self) @abc.abstractmethod def is_text(self) -> bool: @@ -310,7 +310,7 @@ class MessageSegment(abc.ABC, Mapping): class Message(List[T_MessageSegment], abc.ABC): """消息数组""" - def __init__(self, + def __init__(self: T_Message, message: Union[str, None, Mapping, Iterable[Mapping], T_MessageSegment, T_Message, Any] = None, *args, @@ -330,8 +330,13 @@ class Message(List[T_MessageSegment], abc.ABC): else: self.extend(self._construct(message)) + @abc.abstractmethod + @classmethod + def get_segment_class(cls) -> Type[T_MessageSegment]: + raise NotImplementedError + def __str__(self): - return ''.join((str(seg) for seg in self)) + return "".join(str(seg) for seg in self) @classmethod def __get_validators__(cls): @@ -348,30 +353,31 @@ class Message(List[T_MessageSegment], abc.ABC): ) -> Iterable[T_MessageSegment]: raise NotImplementedError - def __add__(self: T_Message, other: Union[str, T_MessageSegment, - T_Message]) -> T_Message: - result = self.__class__(self) - if isinstance(other, str): - result.extend(self._construct(other)) - elif isinstance(other, MessageSegment): - result.append(other) - elif isinstance(other, Message): - result.extend(other) + def __add__( + self: T_Message, other: Union[str, Mapping, Iterable[Mapping], + T_MessageSegment, T_Message] + ) -> T_Message: + result = self.copy() + result += other return result - def __radd__(self: T_Message, other: Union[str, T_MessageSegment, - T_Message]) -> T_Message: + def __radd__( + self: T_Message, other: Union[str, Mapping, Iterable[Mapping], + T_MessageSegment, T_Message] + ) -> T_Message: result = self.__class__(other) - return result.__add__(self) + return result + self - def __iadd__(self: T_Message, other: Union[str, T_MessageSegment, - T_Message]) -> T_Message: - if isinstance(other, str): - self.extend(self._construct(other)) - elif isinstance(other, MessageSegment): + def __iadd__( + self: T_Message, other: Union[str, Mapping, Iterable[Mapping], + T_MessageSegment, T_Message] + ) -> T_Message: + if isinstance(other, MessageSegment): self.append(other) elif isinstance(other, Message): self.extend(other) + else: + self.extend(self._construct(other)) return self def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message: @@ -385,7 +391,7 @@ class Message(List[T_MessageSegment], abc.ABC): * ``obj: Union[str, MessageSegment]``: 要添加的消息段 """ if isinstance(obj, MessageSegment): - super().append(obj) + super(Message, self).append(obj) elif isinstance(obj, str): self.extend(self._construct(obj)) else: @@ -407,33 +413,17 @@ class Message(List[T_MessageSegment], abc.ABC): self.append(segment) return self - def reduce(self: T_Message) -> None: - """ - :说明: + def copy(self: T_Message) -> T_Message: + return deepcopy(self) - 缩减消息数组,即按 MessageSegment 的实现拼接相邻消息段 - """ - index = 0 - while index < len(self): - if index > 0 and self[index - - 1].is_text() and self[index].is_text(): - self[index - 1] += self[index] - del self[index] - else: - index += 1 - - def extract_plain_text(self: T_Message) -> str: + def extract_plain_text(self) -> str: """ :说明: 提取消息内纯文本消息 """ - def _concat(x: str, y: T_MessageSegment) -> str: - return f"{x} {y}" if y.is_text() else x - - plain_text = reduce(_concat, self, "") - return plain_text[1:] if plain_text else plain_text + return "".join(str(seg) for seg in self if seg.is_text()) class Event(abc.ABC, BaseModel): diff --git a/nonebot/log.py b/nonebot/log.py index 7151041c..0acb8a20 100644 --- a/nonebot/log.py +++ b/nonebot/log.py @@ -50,8 +50,8 @@ class Filter: def __call__(self, record): module = sys.modules.get(record["name"]) if module: - plugin_name = getattr(module, "__plugin_name__", record["name"]) - record["name"] = plugin_name + module_name = getattr(module, "__module_name__", record["name"]) + record["name"] = module_name record["name"] = record["name"].split(".")[0] levelno = logger.level(self.level).no if isinstance(self.level, str) else self.level diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py index df1a8b89..48affc8a 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py @@ -11,14 +11,14 @@ from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessa from .utils import log, escape, unescape, _b2s -class MessageSegment(BaseMessageSegment): +class MessageSegment(BaseMessageSegment["Message"]): """ CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 """ - @overrides(BaseMessageSegment) - def __init__(self, type: str, data: Dict[str, Any]) -> None: - super().__init__(type=type, data=data) + @classmethod + def get_message_class(cls): + return Message @overrides(BaseMessageSegment) def __str__(self) -> str: @@ -37,7 +37,8 @@ class MessageSegment(BaseMessageSegment): @overrides(BaseMessageSegment) def __add__(self, other) -> "Message": - return Message(self) + other + return Message(self) + (MessageSegment.text(other) if isinstance( + other, str) else other) @overrides(BaseMessageSegment) def __radd__(self, other) -> "Message": @@ -234,10 +235,25 @@ class Message(BaseMessage[MessageSegment]): CQHTTP 协议 Message 适配。 """ - def __radd__(self, other: Union[str, MessageSegment, - "Message"]) -> "Message": - result = MessageSegment.text(other) if isinstance(other, str) else other - return super(Message, self).__radd__(result) + @classmethod + def get_segment_class(cls): + return MessageSegment + + @overrides(BaseMessage) + def __add__( + self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment, + "Message"] + ) -> "Message": + return super(Message, self).__add__( + MessageSegment.text(other) if isinstance(other, str) else other) + + @overrides(BaseMessage) + def __radd__( + self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment, + "Message"] + ) -> "Message": + return super(Message, self).__radd__( + MessageSegment.text(other) if isinstance(other, str) else other) @staticmethod @overrides(BaseMessage) @@ -280,10 +296,6 @@ class Message(BaseMessage[MessageSegment]): } yield MessageSegment(type_, data) + @overrides(BaseMessage) def extract_plain_text(self) -> str: - - def _concat(x: str, y: MessageSegment) -> str: - return f"{x} {y.data['text']}" if y.is_text() else x - - plain_text = reduce(_concat, self, "") - return plain_text[1:] if plain_text else plain_text + return "".join(seg.data["text"] for seg in self if seg.is_text())