✏️ add typing

This commit is contained in:
yanyongyu 2021-12-21 18:22:14 +08:00
parent b900133ab4
commit 9b2fa46921
8 changed files with 35 additions and 29 deletions

View File

@ -29,11 +29,11 @@ class Adapter(abc.ABC):
def config(self) -> Config: def config(self) -> Config:
return self.driver.config return self.driver.config
def bot_connect(self, bot: Bot): def bot_connect(self, bot: Bot) -> None:
self.driver._bot_connect(bot) self.driver._bot_connect(bot)
self.bots[bot.self_id] = bot self.bots[bot.self_id] = bot
def bot_disconnect(self, bot: Bot): def bot_disconnect(self, bot: Bot) -> None:
self.driver._bot_disconnect(bot) self.driver._bot_disconnect(bot)
self.bots.pop(bot.self_id, None) self.bots.pop(bot.self_id, None)
@ -58,7 +58,7 @@ class Adapter(abc.ABC):
return await self.driver.websocket(setup) return await self.driver.websocket(setup)
@abc.abstractmethod @abc.abstractmethod
async def _call_api(self, api: str, **data) -> Any: async def _call_api(self, bot: Bot, api: str, **data) -> Any:
""" """
:说明: :说明:

View File

@ -100,7 +100,7 @@ class Bot(abc.ABC):
if not skip_calling_api: if not skip_calling_api:
try: try:
result = await self.adapter._call_api(api, **data) result = await self.adapter._call_api(self, api, **data)
except Exception as e: except Exception as e:
exception = e exception = e

View File

@ -168,6 +168,8 @@ class Driver(abc.ABC):
def _bot_connect(self, bot: "Bot") -> None: def _bot_connect(self, bot: "Bot") -> None:
"""在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" """在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
if bot.self_id in self._clients:
raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}")
self._clients[bot.self_id] = bot self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None: async def _run_hook(bot: "Bot") -> None:

View File

@ -31,7 +31,7 @@ HeaderTypes = Union[
Sequence[Tuple[str, str]], Sequence[Tuple[str, str]],
] ]
ContentTypes = Union[str, bytes] ContentTypes = Union[str, bytes, None]
CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
@ -55,15 +55,15 @@ class Request:
timeout: Optional[float] = None, timeout: Optional[float] = None,
): ):
# method # method
self.method = ( self.method: str = (
method.decode("ascii").upper() method.decode("ascii").upper()
if isinstance(method, bytes) if isinstance(method, bytes)
else method.upper() else method.upper()
) )
# http version # http version
self.version = HTTPVersion(version) self.version: HTTPVersion = HTTPVersion(version)
# timeout # timeout
self.timeout = timeout self.timeout: Optional[float] = timeout
# url # url
if isinstance(url, tuple): if isinstance(url, tuple):
@ -79,7 +79,7 @@ class Request:
if params is not None: if params is not None:
url = url.update_query(params) url = url.update_query(params)
self.url = url self.url: URL = url
# headers # headers
self.headers: CIMultiDict[str] self.headers: CIMultiDict[str]
@ -92,7 +92,7 @@ class Request:
self.cookies = Cookies(cookies) self.cookies = Cookies(cookies)
# body # body
self.content = content self.content: ContentTypes = content
def __repr__(self) -> str: def __repr__(self) -> str:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -110,7 +110,7 @@ class Response:
request: Optional[Request] = None, request: Optional[Request] = None,
): ):
# status code # status code
self.status_code = status_code self.status_code: int = status_code
# headers # headers
self.headers: CIMultiDict[str] self.headers: CIMultiDict[str]
@ -120,16 +120,16 @@ class Response:
self.headers = CIMultiDict() self.headers = CIMultiDict()
# body # body
self.content = content self.content: ContentTypes = content
# request # request
self.request = request self.request: Optional[Request] = request
class WebSocket(abc.ABC): class WebSocket(abc.ABC):
def __init__(self, *, request: Request): def __init__(self, *, request: Request):
# request # request
self.request = request self.request: Request = request
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -141,12 +141,12 @@ class WebSocket(abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def accept(self): async def accept(self) -> None:
"""接受 WebSocket 连接请求""" """接受 WebSocket 连接请求"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def close(self, code: int = 1000): async def close(self, code: int = 1000, reason: str = "") -> None:
"""关闭 WebSocket 连接请求""" """关闭 WebSocket 连接请求"""
raise NotImplementedError raise NotImplementedError
@ -161,19 +161,19 @@ class WebSocket(abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def send(self, data: str): async def send(self, data: str) -> None:
"""发送一条 WebSocket text 信息""" """发送一条 WebSocket text 信息"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def send_bytes(self, data: bytes): async def send_bytes(self, data: bytes) -> None:
"""发送一条 WebSocket binary 信息""" """发送一条 WebSocket binary 信息"""
raise NotImplementedError raise NotImplementedError
class Cookies(MutableMapping): class Cookies(MutableMapping):
def __init__(self, cookies: CookieTypes = None) -> None: def __init__(self, cookies: CookieTypes = None) -> None:
self.jar = cookies if isinstance(cookies, CookieJar) else CookieJar() self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar()
if cookies is not None and not isinstance(cookies, CookieJar): if cookies is not None and not isinstance(cookies, CookieJar):
if isinstance(cookies, dict): if isinstance(cookies, dict):
for key, value in cookies.items(): for key, value in cookies.items():

View File

@ -333,8 +333,8 @@ class WebSocketsWS(BaseWebSocket):
raise NotImplementedError raise NotImplementedError
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self, code: int = 1000): async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code) await self.websocket.close(code, reason)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> str: async def receive(self) -> str:
@ -374,11 +374,13 @@ class FastAPIWebSocket(BaseWebSocket):
) )
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self) -> None:
await self.websocket.accept() await self.websocket.accept()
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): async def close(
self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = ""
) -> None:
await self.websocket.close(code) await self.websocket.close(code)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)

View File

@ -249,15 +249,16 @@ class WebSocket(BaseWebSocket):
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self): def closed(self):
raise NotImplementedError # FIXME
return True
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
await self.websocket.accept() await self.websocket.accept()
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self, code: int = 1000): async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code) await self.websocket.close(code, reason)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> str: async def receive(self) -> str:

View File

@ -17,7 +17,6 @@
.. _typing: .. _typing:
https://docs.python.org/3/library/typing.html https://docs.python.org/3/library/typing.html
""" """
from asyncio import Task
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -32,6 +31,8 @@ from typing import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from asyncio import Task
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
@ -249,7 +250,7 @@ T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]]
PermissionUpdater Matcher.pause, Matcher.reject 时被运行用于更新会话对象权限默认会更新为当前事件的触发对象 PermissionUpdater Matcher.pause, Matcher.reject 时被运行用于更新会话对象权限默认会更新为当前事件的触发对象
""" """
T_DependencyCache = Dict[Callable[..., Any], Task[Any]] T_DependencyCache = Dict[Callable[..., Any], "Task[Any]"]
""" """
:类型: ``Dict[Callable[..., Any], Task[Any]]`` :类型: ``Dict[Callable[..., Any], Task[Any]]``
:说明: :说明:

View File

@ -163,7 +163,7 @@ def logger_wrapper(logger_name: str):
""" """
def log(level: str, message: str, exception: Optional[Exception] = None): def log(level: str, message: str, exception: Optional[Exception] = None):
return logger.opt(colors=True, exception=exception).log( logger.opt(colors=True, exception=exception).log(
level, f"<m>{escape_tag(logger_name)}</m> | " + message level, f"<m>{escape_tag(logger_name)}</m> | " + message
) )