mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-02-17 16:20:05 +08:00
🎨 improve typing
This commit is contained in:
parent
e9bc98e74d
commit
ddd96271b0
@ -6,7 +6,9 @@ try:
|
|||||||
del pkg_resources
|
del pkg_resources
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import pkgutil
|
import pkgutil
|
||||||
__path__: Iterable[str] = pkgutil.extend_path(__path__, __name__)
|
__path__: Iterable[str] = pkgutil.extend_path(
|
||||||
|
__path__, # type: ignore
|
||||||
|
__name__)
|
||||||
del pkgutil
|
del pkgutil
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
@ -7,28 +7,25 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from copy import copy
|
from copy import deepcopy
|
||||||
from functools import reduce, partial
|
from functools import partial
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (Any, Set, List, Dict, Tuple, Union, TypeVar, Mapping,
|
from typing import (Any, Set, List, Dict, Type, Tuple, Union, TypeVar, Mapping,
|
||||||
Optional, Iterable, Awaitable, TYPE_CHECKING)
|
Generic, Optional, Iterable)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
|
from nonebot.config import Config
|
||||||
from nonebot.utils import DataclassEncoder
|
from nonebot.utils import DataclassEncoder
|
||||||
from nonebot.drivers import HTTPConnection, HTTPResponse
|
from nonebot.drivers import Driver, HTTPConnection, HTTPResponse
|
||||||
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
|
from nonebot.typing import T_CallingAPIHook, T_CalledAPIHook
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from nonebot.config import Config
|
|
||||||
from nonebot.drivers import Driver, WebSocket
|
|
||||||
|
|
||||||
|
|
||||||
class _ApiCall(Protocol):
|
class _ApiCall(Protocol):
|
||||||
|
|
||||||
def __call__(self, **kwargs: Any) -> Awaitable[Any]:
|
async def __call__(self, **kwargs: Any) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@ -37,9 +34,9 @@ class Bot(abc.ABC):
|
|||||||
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
Bot 基类。用于处理上报消息,并提供 API 调用接口。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
driver: "Driver"
|
driver: Driver
|
||||||
"""Driver 对象"""
|
"""Driver 对象"""
|
||||||
config: "Config"
|
config: Config
|
||||||
"""Config 配置对象"""
|
"""Config 配置对象"""
|
||||||
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
_calling_api_hook: Set[T_CallingAPIHook] = set()
|
||||||
"""
|
"""
|
||||||
@ -56,9 +53,8 @@ class Bot(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``connection_type: str``: http 或者 websocket
|
|
||||||
* ``self_id: str``: 机器人 ID
|
* ``self_id: str``: 机器人 ID
|
||||||
* ``websocket: Optional[WebSocket]``: Websocket 连接对象
|
* ``request: HTTPConnection``: request 连接对象
|
||||||
"""
|
"""
|
||||||
self.self_id: str = self_id
|
self.self_id: str = self_id
|
||||||
"""机器人 ID"""
|
"""机器人 ID"""
|
||||||
@ -75,7 +71,7 @@ class Bot(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, driver: "Driver", config: "Config"):
|
def register(cls, driver: Driver, config: Config):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
@ -87,7 +83,7 @@ class Bot(abc.ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def check_permission(
|
async def check_permission(
|
||||||
cls, driver: "Driver", request: HTTPConnection
|
cls, driver: Driver, request: HTTPConnection
|
||||||
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
|
) -> Tuple[Optional[str], Optional[HTTPResponse]]:
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
@ -97,18 +93,12 @@ class Bot(abc.ABC):
|
|||||||
:参数:
|
:参数:
|
||||||
|
|
||||||
* ``driver: Driver``: Driver 对象
|
* ``driver: Driver``: Driver 对象
|
||||||
* ``connection_type: str``: 连接类型
|
* ``request: HTTPConnection``: request 请求详情
|
||||||
* ``headers: dict``: 请求头
|
|
||||||
* ``body: Optional[bytes]``: 请求数据,WebSocket 连接该部分为 None
|
|
||||||
|
|
||||||
:返回:
|
:返回:
|
||||||
|
|
||||||
- ``str``: 连接唯一标识符,``None`` 代表连接不合法
|
- ``Optional[str]``: 连接唯一标识符,``None`` 代表连接不合法
|
||||||
- ``HTTPResponse``: HTTP 上报响应
|
- ``Optional[HTTPResponse]``: HTTP 上报响应
|
||||||
|
|
||||||
:异常:
|
|
||||||
|
|
||||||
- ``RequestDenied``: 请求非法
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -210,21 +200,45 @@ class Bot(abc.ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
|
def on_calling_api(cls, func: T_CallingAPIHook) -> T_CallingAPIHook:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
调用 api 预处理。
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``bot: Bot``: 当前 bot 对象
|
||||||
|
* ``api: str``: 调用的 api 名称
|
||||||
|
* ``data: Dict[str, Any]``: api 调用的参数字典
|
||||||
|
"""
|
||||||
cls._calling_api_hook.add(func)
|
cls._calling_api_hook.add(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
|
def on_called_api(cls, func: T_CalledAPIHook) -> T_CalledAPIHook:
|
||||||
|
"""
|
||||||
|
:说明:
|
||||||
|
|
||||||
|
调用 api 后处理。
|
||||||
|
|
||||||
|
:参数:
|
||||||
|
|
||||||
|
* ``bot: Bot``: 当前 bot 对象
|
||||||
|
* ``exception: Optional[Exception]``: 调用 api 时发生的错误
|
||||||
|
* ``api: str``: 调用的 api 名称
|
||||||
|
* ``data: Dict[str, Any]``: api 调用的参数字典
|
||||||
|
* ``result: Any``: api 调用的返回
|
||||||
|
"""
|
||||||
cls._called_api_hook.add(func)
|
cls._called_api_hook.add(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
T_Message = TypeVar("T_Message", bound="Message")
|
T_Message = TypeVar("T_Message", bound="Message")
|
||||||
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment")
|
T_MessageSegment = TypeVar("T_MessageSegment", bound="MessageSegment[Message]")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageSegment(abc.ABC, Mapping):
|
class MessageSegment(Mapping, abc.ABC, Generic[T_Message]):
|
||||||
"""消息段基类"""
|
"""消息段基类"""
|
||||||
type: str
|
type: str
|
||||||
"""
|
"""
|
||||||
@ -237,6 +251,11 @@ class MessageSegment(abc.ABC, Mapping):
|
|||||||
- 说明: 消息段数据
|
- 说明: 消息段数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
@classmethod
|
||||||
|
def get_message_class(cls) -> Type[T_Message]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""该消息段所代表的 str,在命令匹配部分使用"""
|
"""该消息段所代表的 str,在命令匹配部分使用"""
|
||||||
@ -248,46 +267,27 @@ class MessageSegment(abc.ABC, Mapping):
|
|||||||
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
|
def __ne__(self: T_MessageSegment, other: T_MessageSegment) -> bool:
|
||||||
return not self == other
|
return not self == other
|
||||||
|
|
||||||
@abc.abstractmethod
|
def __add__(self, other: Union[str, Mapping,
|
||||||
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
|
Iterable[Mapping]]) -> T_Message:
|
||||||
T_Message]) -> T_Message:
|
return self.get_message_class()(self) + other
|
||||||
"""你需要在这里实现不同消息段的合并:
|
|
||||||
比如:
|
|
||||||
if isinstance(other, str):
|
|
||||||
...
|
|
||||||
elif isinstance(other, MessageSegment):
|
|
||||||
...
|
|
||||||
注意:需要返回一个新生成的对象
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
def __radd__(self, other: Union[str, Mapping,
|
||||||
def __radd__(
|
Iterable[Mapping]]) -> T_Message:
|
||||||
self: T_MessageSegment, other: Union[str, dict, list, T_MessageSegment,
|
return self.get_message_class()(other) + self
|
||||||
T_Message]) -> "T_Message":
|
|
||||||
"""你需要在这里实现不同消息段的合并:
|
|
||||||
比如:
|
|
||||||
if isinstance(other, str):
|
|
||||||
...
|
|
||||||
elif isinstance(other, MessageSegment):
|
|
||||||
...
|
|
||||||
注意:需要返回一个新生成的对象
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str):
|
||||||
return getattr(self, key)
|
return self.data[key]
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key: str, value: Any):
|
||||||
return setattr(self, key, value)
|
self.data[key] = value
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
yield from self.data.__iter__()
|
yield from self.data.__iter__()
|
||||||
|
|
||||||
def __contains__(self, key: object) -> bool:
|
def __contains__(self, key: Any) -> bool:
|
||||||
return key in self.data
|
return key in self.data
|
||||||
|
|
||||||
def get(self, key: str, default=None):
|
def get(self, key: str, default: Any = None):
|
||||||
return getattr(self, key, default)
|
return getattr(self, key, default)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
@ -300,7 +300,7 @@ class MessageSegment(abc.ABC, Mapping):
|
|||||||
return self.data.items()
|
return self.data.items()
|
||||||
|
|
||||||
def copy(self: T_MessageSegment) -> T_MessageSegment:
|
def copy(self: T_MessageSegment) -> T_MessageSegment:
|
||||||
return copy(self)
|
return deepcopy(self)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def is_text(self) -> bool:
|
def is_text(self) -> bool:
|
||||||
@ -310,7 +310,7 @@ class MessageSegment(abc.ABC, Mapping):
|
|||||||
class Message(List[T_MessageSegment], abc.ABC):
|
class Message(List[T_MessageSegment], abc.ABC):
|
||||||
"""消息数组"""
|
"""消息数组"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self: T_Message,
|
||||||
message: Union[str, None, Mapping, Iterable[Mapping],
|
message: Union[str, None, Mapping, Iterable[Mapping],
|
||||||
T_MessageSegment, T_Message, Any] = None,
|
T_MessageSegment, T_Message, Any] = None,
|
||||||
*args,
|
*args,
|
||||||
@ -330,8 +330,13 @@ class Message(List[T_MessageSegment], abc.ABC):
|
|||||||
else:
|
else:
|
||||||
self.extend(self._construct(message))
|
self.extend(self._construct(message))
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
@classmethod
|
||||||
|
def get_segment_class(cls) -> Type[T_MessageSegment]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ''.join((str(seg) for seg in self))
|
return "".join(str(seg) for seg in self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_validators__(cls):
|
def __get_validators__(cls):
|
||||||
@ -348,30 +353,31 @@ class Message(List[T_MessageSegment], abc.ABC):
|
|||||||
) -> Iterable[T_MessageSegment]:
|
) -> Iterable[T_MessageSegment]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __add__(self: T_Message, other: Union[str, T_MessageSegment,
|
def __add__(
|
||||||
T_Message]) -> T_Message:
|
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||||
result = self.__class__(self)
|
T_MessageSegment, T_Message]
|
||||||
if isinstance(other, str):
|
) -> T_Message:
|
||||||
result.extend(self._construct(other))
|
result = self.copy()
|
||||||
elif isinstance(other, MessageSegment):
|
result += other
|
||||||
result.append(other)
|
|
||||||
elif isinstance(other, Message):
|
|
||||||
result.extend(other)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __radd__(self: T_Message, other: Union[str, T_MessageSegment,
|
def __radd__(
|
||||||
T_Message]) -> T_Message:
|
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||||
|
T_MessageSegment, T_Message]
|
||||||
|
) -> T_Message:
|
||||||
result = self.__class__(other)
|
result = self.__class__(other)
|
||||||
return result.__add__(self)
|
return result + self
|
||||||
|
|
||||||
def __iadd__(self: T_Message, other: Union[str, T_MessageSegment,
|
def __iadd__(
|
||||||
T_Message]) -> T_Message:
|
self: T_Message, other: Union[str, Mapping, Iterable[Mapping],
|
||||||
if isinstance(other, str):
|
T_MessageSegment, T_Message]
|
||||||
self.extend(self._construct(other))
|
) -> T_Message:
|
||||||
elif isinstance(other, MessageSegment):
|
if isinstance(other, MessageSegment):
|
||||||
self.append(other)
|
self.append(other)
|
||||||
elif isinstance(other, Message):
|
elif isinstance(other, Message):
|
||||||
self.extend(other)
|
self.extend(other)
|
||||||
|
else:
|
||||||
|
self.extend(self._construct(other))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
|
def append(self: T_Message, obj: Union[str, T_MessageSegment]) -> T_Message:
|
||||||
@ -385,7 +391,7 @@ class Message(List[T_MessageSegment], abc.ABC):
|
|||||||
* ``obj: Union[str, MessageSegment]``: 要添加的消息段
|
* ``obj: Union[str, MessageSegment]``: 要添加的消息段
|
||||||
"""
|
"""
|
||||||
if isinstance(obj, MessageSegment):
|
if isinstance(obj, MessageSegment):
|
||||||
super().append(obj)
|
super(Message, self).append(obj)
|
||||||
elif isinstance(obj, str):
|
elif isinstance(obj, str):
|
||||||
self.extend(self._construct(obj))
|
self.extend(self._construct(obj))
|
||||||
else:
|
else:
|
||||||
@ -407,33 +413,17 @@ class Message(List[T_MessageSegment], abc.ABC):
|
|||||||
self.append(segment)
|
self.append(segment)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def reduce(self: T_Message) -> None:
|
def copy(self: T_Message) -> T_Message:
|
||||||
"""
|
return deepcopy(self)
|
||||||
:说明:
|
|
||||||
|
|
||||||
缩减消息数组,即按 MessageSegment 的实现拼接相邻消息段
|
def extract_plain_text(self) -> str:
|
||||||
"""
|
|
||||||
index = 0
|
|
||||||
while index < len(self):
|
|
||||||
if index > 0 and self[index -
|
|
||||||
1].is_text() and self[index].is_text():
|
|
||||||
self[index - 1] += self[index]
|
|
||||||
del self[index]
|
|
||||||
else:
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
def extract_plain_text(self: T_Message) -> str:
|
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
提取消息内纯文本消息
|
提取消息内纯文本消息
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _concat(x: str, y: T_MessageSegment) -> str:
|
return "".join(str(seg) for seg in self if seg.is_text())
|
||||||
return f"{x} {y}" if y.is_text() else x
|
|
||||||
|
|
||||||
plain_text = reduce(_concat, self, "")
|
|
||||||
return plain_text[1:] if plain_text else plain_text
|
|
||||||
|
|
||||||
|
|
||||||
class Event(abc.ABC, BaseModel):
|
class Event(abc.ABC, BaseModel):
|
||||||
|
@ -50,8 +50,8 @@ class Filter:
|
|||||||
def __call__(self, record):
|
def __call__(self, record):
|
||||||
module = sys.modules.get(record["name"])
|
module = sys.modules.get(record["name"])
|
||||||
if module:
|
if module:
|
||||||
plugin_name = getattr(module, "__plugin_name__", record["name"])
|
module_name = getattr(module, "__module_name__", record["name"])
|
||||||
record["name"] = plugin_name
|
record["name"] = module_name
|
||||||
record["name"] = record["name"].split(".")[0]
|
record["name"] = record["name"].split(".")[0]
|
||||||
levelno = logger.level(self.level).no if isinstance(self.level,
|
levelno = logger.level(self.level).no if isinstance(self.level,
|
||||||
str) else self.level
|
str) else self.level
|
||||||
|
@ -11,14 +11,14 @@ from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessa
|
|||||||
from .utils import log, escape, unescape, _b2s
|
from .utils import log, escape, unescape, _b2s
|
||||||
|
|
||||||
|
|
||||||
class MessageSegment(BaseMessageSegment):
|
class MessageSegment(BaseMessageSegment["Message"]):
|
||||||
"""
|
"""
|
||||||
CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
CQHTTP 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@overrides(BaseMessageSegment)
|
@classmethod
|
||||||
def __init__(self, type: str, data: Dict[str, Any]) -> None:
|
def get_message_class(cls):
|
||||||
super().__init__(type=type, data=data)
|
return Message
|
||||||
|
|
||||||
@overrides(BaseMessageSegment)
|
@overrides(BaseMessageSegment)
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -37,7 +37,8 @@ class MessageSegment(BaseMessageSegment):
|
|||||||
|
|
||||||
@overrides(BaseMessageSegment)
|
@overrides(BaseMessageSegment)
|
||||||
def __add__(self, other) -> "Message":
|
def __add__(self, other) -> "Message":
|
||||||
return Message(self) + other
|
return Message(self) + (MessageSegment.text(other) if isinstance(
|
||||||
|
other, str) else other)
|
||||||
|
|
||||||
@overrides(BaseMessageSegment)
|
@overrides(BaseMessageSegment)
|
||||||
def __radd__(self, other) -> "Message":
|
def __radd__(self, other) -> "Message":
|
||||||
@ -234,10 +235,25 @@ class Message(BaseMessage[MessageSegment]):
|
|||||||
CQHTTP 协议 Message 适配。
|
CQHTTP 协议 Message 适配。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __radd__(self, other: Union[str, MessageSegment,
|
@classmethod
|
||||||
"Message"]) -> "Message":
|
def get_segment_class(cls):
|
||||||
result = MessageSegment.text(other) if isinstance(other, str) else other
|
return MessageSegment
|
||||||
return super(Message, self).__radd__(result)
|
|
||||||
|
@overrides(BaseMessage)
|
||||||
|
def __add__(
|
||||||
|
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
|
||||||
|
"Message"]
|
||||||
|
) -> "Message":
|
||||||
|
return super(Message, self).__add__(
|
||||||
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||||
|
|
||||||
|
@overrides(BaseMessage)
|
||||||
|
def __radd__(
|
||||||
|
self, other: Union[str, Mapping, Iterable[Mapping], MessageSegment,
|
||||||
|
"Message"]
|
||||||
|
) -> "Message":
|
||||||
|
return super(Message, self).__radd__(
|
||||||
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@overrides(BaseMessage)
|
@overrides(BaseMessage)
|
||||||
@ -280,10 +296,6 @@ class Message(BaseMessage[MessageSegment]):
|
|||||||
}
|
}
|
||||||
yield MessageSegment(type_, data)
|
yield MessageSegment(type_, data)
|
||||||
|
|
||||||
|
@overrides(BaseMessage)
|
||||||
def extract_plain_text(self) -> str:
|
def extract_plain_text(self) -> str:
|
||||||
|
return "".join(seg.data["text"] for seg in self if seg.is_text())
|
||||||
def _concat(x: str, y: MessageSegment) -> str:
|
|
||||||
return f"{x} {y.data['text']}" if y.is_text() else x
|
|
||||||
|
|
||||||
plain_text = reduce(_concat, self, "")
|
|
||||||
return plain_text[1:] if plain_text else plain_text
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user