mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🎨 improve typing
This commit is contained in:
parent
e9bc98e74d
commit
ddd96271b0
@ -6,7 +6,9 @@ try:
|
||||
del pkg_resources
|
||||
except ImportError:
|
||||
import pkgutil
|
||||
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__)
|
||||
__path__: Iterable[str] = pkgutil.extend_path(
|
||||
__path__, # type: ignore
|
||||
__name__)
|
||||
del pkgutil
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -7,28 +7,25 @@
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from copy import copy
|
||||
from functools import reduce, partial
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing_extensions import Protocol
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping,
|
||||
Optional, Iterable, Awaitable, TYPE_CHECKING)
|
||||
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
|
||||
Generic, Optional, Iterable)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.drivers import HTTPConnection, HTTPResponse
|
||||
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
|
||||
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
|
||||
|
||||
class _ApiCall(Protocol):
|
||||
|
||||
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
|
||||
async def __call__(self, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@ -37,9 +34,9 @@ class Bot(abc.ABC):
|
||||
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
||||
"""
|
||||
|
||||
driver: "Driver"
|
||||
driver: Driver
|
||||
"""Driver 对象"""
|
||||
config: "Config"
|
||||
config: Config
|
||||
"""Config 配置对象"""
|
||||
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
||||
"""
|
||||
@ -56,9 +53,8 @@ class Bot(abc.ABC):
|
||||
"""
|
||||
:参数:
|
||||
|
||||
* ``connection_type: str``: http 或者 websocket
|
||||
* ``self_id: str``: 机器人 ID
|
||||
* ``websocket: Optional[WebSocket]``: Websocket 连接对象
|
||||
* ``request: HTTPConnection``: request 连接对象
|
||||
"""
|
||||
self.self_id: str = self_id
|
||||
"""机器人 ID"""
|
||||
@ -75,7 +71,7 @@ class Bot(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def register(cls, driver: "Driver", config: "Config"):
|
||||
def register(cls, driver: Driver, config: Config):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -87,7 +83,7 @@ class Bot(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def check_permission(
|
||||
cls, driver: "Driver", request: HTTPConnection
|
||||
cls, driver: Driver, request: HTTPConnection
|
||||
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -97,18 +93,12 @@ class Bot(abc.ABC):
|
||||
:参数:
|
||||
|
||||
* ``driver: Driver``: Driver 对象
|
||||
* ``connection_type: str``: 连接类型
|
||||
* ``headers: dict``: 请求头
|
||||
* ``body: Optional[bytes]``: 请求数据,WebSocket 连接该部分为 None
|
||||
* ``request: HTTPConnection``: request 请求详情
|
||||
|
||||
:返回:
|
||||
|
||||
- ``str``: 连接唯一标识符,``None`` 代表连接不合法
|
||||
- ``HTTPResponse``: HTTP 上报响应
|
||||
|
||||
:异常:
|
||||
|
||||
- ``RequestDenied``: 请求非法
|
||||
- ``Optional[str]``: 连接唯一标识符,``None`` 代表连接不合法
|
||||
- ``Optional[HTTPResponse]``: HTTP 上报响应
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -210,21 +200,45 @@ class Bot(abc.ABC):
|
||||
|
||||
@classmethod
|
||||
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 api 预处理。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``bot: Bot``: 当前 bot 对象
|
||||
* ``api: str``: 调用的 api 名称
|
||||
* ``data: Dict[str, Any]``: api 调用的参数字典
|
||||
"""
|
||||
cls._calling_api_hook.add(func)
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用 api 后处理。
|
||||
|
||||
:参数:
|
||||
|
||||
* ``bot: Bot``: 当前 bot 对象
|
||||
* ``exception: Optional[Exception]``: 调用 api 时发生的错误
|
||||
* ``api: str``: 调用的 api 名称
|
||||
* ``data: Dict[str, Any]``: api 调用的参数字典
|
||||
* ``result: Any``: api 调用的返回
|
||||
"""
|
||||
cls._called_api_hook.add(func)
|
||||
return func
|
||||
|
||||
|
||||
T_Message = TypeVar("T_Message", bound="Message")
|
||||
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
|
||||
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSegment(abc.ABC, Mapping):
|
||||
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||
"""消息段基类"""
|
||||
type: str
|
||||
"""
|
||||
@ -237,6 +251,11 @@ class MessageSegment(abc.ABC, Mapping):
|
||||
- 说明: 消息段数据
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@classmethod
|
||||
def get_message_class(cls) -> Type[T_Message]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __str__(self) -> str:
|
||||
"""该消息段所代表的 str,在命令匹配部分使用"""
|
||||
@ -248,46 +267,27 @@ class MessageSegment(abc.ABC, Mapping):
|
||||
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
|
||||
return not self == other
|
||||
|
||||
@abc.abstractmethod
|
||||
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
|
||||
T_Message]) -> T_Message:
|
||||
"""你需要在这里实现不同消息段的合并:
|
||||
比如:
|
||||
if isinstance(other, str):
|
||||
...
|
||||
elif isinstance(other, MessageSegment):
|
||||
...
|
||||
注意:需要返回一个新生成的对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def __add__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> T_Message:
|
||||
return self.get_message_class()(self) + other
|
||||
|
||||
@abc.abstractmethod
|
||||
def __radd__(
|
||||
self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment,
|
||||
T_Message]) -> "T_Message":
|
||||
"""你需要在这里实现不同消息段的合并:
|
||||
比如:
|
||||
if isinstance(other, str):
|
||||
...
|
||||
elif isinstance(other, MessageSegment):
|
||||
...
|
||||
注意:需要返回一个新生成的对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def __radd__(self, other: Union[str, Mapping,
|
||||
Iterable[Mapping]]) -> T_Message:
|
||||
return self.get_message_class()(other) + self
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
def __getitem__(self, key: str):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return setattr(self, key, value)
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
self.data[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.data.__iter__()
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
return key in self.data
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
def get(self, key: str, default: Any = None):
|
||||
return getattr(self, key, default)
|
||||
|
||||
def keys(self):
|
||||
@ -300,7 +300,7 @@ class MessageSegment(abc.ABC, Mapping):
|
||||
return self.data.items()
|
||||
|
||||
def copy(self: T_MessageSegment) -> T_MessageSegment:
|
||||
return copy(self)
|
||||
return deepcopy(self)
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_text(self) -> bool:
|
||||
@ -310,7 +310,7 @@ class MessageSegment(abc.ABC, Mapping):
|
||||
class Message(List[T_MessageSegment], abc.ABC):
|
||||
"""消息数组"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(self: T_Message,
|
||||
message: Union[str, None, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message, Any] = None,
|
||||
*args,
|
||||
@ -330,8 +330,13 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
else:
|
||||
self.extend(self._construct(message))
|
||||
|
||||
@abc.abstractmethod
|
||||
@classmethod
|
||||
def get_segment_class(cls) -> Type[T_MessageSegment]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __str__(self):
|
||||
return ''.join((str(seg) for seg in self))
|
||||
return "".join(str(seg) for seg in self)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
@ -348,30 +353,31 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
) -> Iterable[T_MessageSegment]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(self: T_Message, other: Union[str, T_MessageSegment,
|
||||
T_Message]) -> T_Message:
|
||||
result = self.__class__(self)
|
||||
if isinstance(other, str):
|
||||
result.extend(self._construct(other))
|
||||
elif isinstance(other, MessageSegment):
|
||||
result.append(other)
|
||||
elif isinstance(other, Message):
|
||||
result.extend(other)
|
||||
def __add__(
|
||||
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message]
|
||||
) -> T_Message:
|
||||
result = self.copy()
|
||||
result += other
|
||||
return result
|
||||
|
||||
def __radd__(self: T_Message, other: Union[str, T_MessageSegment,
|
||||
T_Message]) -> T_Message:
|
||||
def __radd__(
|
||||
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message]
|
||||
) -> T_Message:
|
||||
result = self.__class__(other)
|
||||
return result.__add__(self)
|
||||
return result + self
|
||||
|
||||
def __iadd__(self: T_Message, other: Union[str, T_MessageSegment,
|
||||
T_Message]) -> T_Message:
|
||||
if isinstance(other, str):
|
||||
self.extend(self._construct(other))
|
||||
elif isinstance(other, MessageSegment):
|
||||
def __iadd__(
|
||||
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||
T_MessageSegment, T_Message]
|
||||
) -> T_Message:
|
||||
if isinstance(other, MessageSegment):
|
||||
self.append(other)
|
||||
elif isinstance(other, Message):
|
||||
self.extend(other)
|
||||
else:
|
||||
self.extend(self._construct(other))
|
||||
return self
|
||||
|
||||
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
|
||||
@ -385,7 +391,7 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
* ``obj: Union[str, MessageSegment]``: 要添加的消息段
|
||||
"""
|
||||
if isinstance(obj, MessageSegment):
|
||||
super().append(obj)
|
||||
super(Message, self).append(obj)
|
||||
elif isinstance(obj, str):
|
||||
self.extend(self._construct(obj))
|
||||
else:
|
||||
@ -407,33 +413,17 @@ class Message(List[T_MessageSegment], abc.ABC):
|
||||
self.append(segment)
|
||||
return self
|
||||
|
||||
def reduce(self: T_Message) -> None:
|
||||
"""
|
||||
:说明:
|
||||
def copy(self: T_Message) -> T_Message:
|
||||
return deepcopy(self)
|
||||
|
||||
缩减消息数组,即按 MessageSegment 的实现拼接相邻消息段
|
||||
"""
|
||||
index = 0
|
||||
while index < len(self):
|
||||
if index > 0 and self[index -
|
||||
1].is_text() and self[index].is_text():
|
||||
self[index - 1] += self[index]
|
||||
del self[index]
|
||||
else:
|
||||
index += 1
|
||||
|
||||
def extract_plain_text(self: T_Message) -> str:
|
||||
def extract_plain_text(self) -> str:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
提取消息内纯文本消息
|
||||
"""
|
||||
|
||||
def _concat(x: str, y: T_MessageSegment) -> str:
|
||||
return f"{x} {y}" if y.is_text() else x
|
||||
|
||||
plain_text = reduce(_concat, self, "")
|
||||
return plain_text[1:] if plain_text else plain_text
|
||||
return "".join(str(seg) for seg in self if seg.is_text())
|
||||
|
||||
|
||||
class Event(abc.ABC, BaseModel):
|
||||
|
@ -50,8 +50,8 @@ class Filter:
|
||||
def __call__(self, record):
|
||||
module = sys.modules.get(record["name"])
|
||||
if module:
|
||||
plugin_name = getattr(module, "__plugin_name__", record["name"])
|
||||
record["name"] = plugin_name
|
||||
module_name = getattr(module, "__module_name__", record["name"])
|
||||
record["name"] = module_name
|
||||
record["name"] = record["name"].split(".")[0]
|
||||
levelno = logger.level(self.level).no if isinstance(self.level,
|
||||
str) else self.level
|
||||
|
@ -11,14 +11,14 @@ from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessa
|
||||
from .utils import log, escape, unescape, _b2s
|
||||
|
||||
|
||||
class MessageSegment(BaseMessageSegment):
|
||||
class MessageSegment(BaseMessageSegment["Message"]):
|
||||
"""
|
||||
CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
||||
"""
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __init__(self, type: str, data: Dict[str, Any]) -> None:
|
||||
super().__init__(type=type, data=data)
|
||||
@classmethod
|
||||
def get_message_class(cls):
|
||||
return Message
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __str__(self) -> str:
|
||||
@ -37,7 +37,8 @@ class MessageSegment(BaseMessageSegment):
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __add__(self, other) -> "Message":
|
||||
return Message(self) + other
|
||||
return Message(self) + (MessageSegment.text(other) if isinstance(
|
||||
other, str) else other)
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __radd__(self, other) -> "Message":
|
||||
@ -234,10 +235,25 @@ class Message(BaseMessage[MessageSegment]):
|
||||
CQHTTP 协议 Message 适配。
|
||||
"""
|
||||
|
||||
def __radd__(self, other: Union[str, MessageSegment,
|
||||
"Message"]) -> "Message":
|
||||
result = MessageSegment.text(other) if isinstance(other, str) else other
|
||||
return super(Message, self).__radd__(result)
|
||||
@classmethod
|
||||
def get_segment_class(cls):
|
||||
return MessageSegment
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __add__(
|
||||
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
|
||||
"Message"]
|
||||
) -> "Message":
|
||||
return super(Message, self).__add__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __radd__(
|
||||
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
|
||||
"Message"]
|
||||
) -> "Message":
|
||||
return super(Message, self).__radd__(
|
||||
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||
|
||||
@staticmethod
|
||||
@overrides(BaseMessage)
|
||||
@ -280,10 +296,6 @@ class Message(BaseMessage[MessageSegment]):
|
||||
}
|
||||
yield MessageSegment(type_, data)
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def extract_plain_text(self) -> str:
|
||||
|
||||
def _concat(x: str, y: MessageSegment) -> str:
|
||||
return f"{x} {y.data['text']}" if y.is_text() else x
|
||||
|
||||
plain_text = reduce(_concat, self, "")
|
||||
return plain_text[1:] if plain_text else plain_text
|
||||
return "".join(seg.data["text"] for seg in self if seg.is_text())
|
||||
|
Loading…
Reference in New Issue
Block a user