diff --git a/nonebot/adapters/_adapter.py b/nonebot/adapters/_adapter.py index b373dced..992ee52a 100644 --- a/nonebot/adapters/_adapter.py +++ b/nonebot/adapters/_adapter.py @@ -29,11 +29,11 @@ class Adapter(abc.ABC): def config(self) -> Config: return self.driver.config - def bot_connect(self, bot: Bot): + def bot_connect(self, bot: Bot) -> None: self.driver._bot_connect(bot) self.bots[bot.self_id] = bot - def bot_disconnect(self, bot: Bot): + def bot_disconnect(self, bot: Bot) -> None: self.driver._bot_disconnect(bot) self.bots.pop(bot.self_id, None) @@ -58,7 +58,7 @@ class Adapter(abc.ABC): return await self.driver.websocket(setup) @abc.abstractmethod - async def _call_api(self, api: str, **data) -> Any: + async def _call_api(self, bot: Bot, api: str, **data) -> Any: """ :说明: diff --git a/nonebot/adapters/_bot.py b/nonebot/adapters/_bot.py index cfaae3a8..e72a0d26 100644 --- a/nonebot/adapters/_bot.py +++ b/nonebot/adapters/_bot.py @@ -100,7 +100,7 @@ class Bot(abc.ABC): if not skip_calling_api: try: - result = await self.adapter._call_api(api, **data) + result = await self.adapter._call_api(self, api, **data) except Exception as e: exception = e diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index daca96b4..8bb3951a 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -168,6 +168,8 @@ class Driver(abc.ABC): def _bot_connect(self, bot: "Bot") -> None: """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" + if bot.self_id in self._clients: + raise RuntimeError(f"Duplicate bot connection with id {bot.self_id}") self._clients[bot.self_id] = bot async def _run_hook(bot: "Bot") -> None: diff --git a/nonebot/drivers/_model.py b/nonebot/drivers/_model.py index 05d59c77..d217ab33 100644 --- a/nonebot/drivers/_model.py +++ b/nonebot/drivers/_model.py @@ -31,7 +31,7 @@ HeaderTypes = Union[ Sequence[Tuple[str, str]], ] -ContentTypes = Union[str, bytes] +ContentTypes = Union[str, bytes, None] CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] @@ -55,15 +55,15 @@ class Request: timeout: Optional[float] = None, ): # method - self.method = ( + self.method: str = ( method.decode("ascii").upper() if isinstance(method, bytes) else method.upper() ) # http version - self.version = HTTPVersion(version) + self.version: HTTPVersion = HTTPVersion(version) # timeout - self.timeout = timeout + self.timeout: Optional[float] = timeout # url if isinstance(url, tuple): @@ -79,7 +79,7 @@ class Request: if params is not None: url = url.update_query(params) - self.url = url + self.url: URL = url # headers self.headers: CIMultiDict[str] @@ -92,7 +92,7 @@ class Request: self.cookies = Cookies(cookies) # body - self.content = content + self.content: ContentTypes = content def __repr__(self) -> str: class_name = self.__class__.__name__ @@ -110,7 +110,7 @@ class Response: request: Optional[Request] = None, ): # status code - self.status_code = status_code + self.status_code: int = status_code # headers self.headers: CIMultiDict[str] @@ -120,16 +120,16 @@ class Response: self.headers = CIMultiDict() # body - self.content = content + self.content: ContentTypes = content # request - self.request = request + self.request: Optional[Request] = request class WebSocket(abc.ABC): def __init__(self, *, request: Request): # request - self.request = request + self.request: Request = request @property @abc.abstractmethod @@ -141,12 +141,12 @@ class WebSocket(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def accept(self): + async def accept(self) -> None: """接受 WebSocket 连接请求""" raise NotImplementedError @abc.abstractmethod - async def close(self, code: int = 1000): + async def close(self, code: int = 1000, reason: str = "") -> None: """关闭 WebSocket 连接请求""" raise NotImplementedError @@ -161,19 +161,19 @@ class WebSocket(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def send(self, data: str): + async def send(self, data: str) -> None: """发送一条 WebSocket text 信息""" raise NotImplementedError @abc.abstractmethod - async def send_bytes(self, data: bytes): + async def send_bytes(self, data: bytes) -> None: """发送一条 WebSocket binary 信息""" raise NotImplementedError class Cookies(MutableMapping): def __init__(self, cookies: CookieTypes = None) -> None: - self.jar = cookies if isinstance(cookies, CookieJar) else CookieJar() + self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar() if cookies is not None and not isinstance(cookies, CookieJar): if isinstance(cookies, dict): for key, value in cookies.items(): diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 8efdf36e..9752eb26 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -333,8 +333,8 @@ class WebSocketsWS(BaseWebSocket): raise NotImplementedError @overrides(BaseWebSocket) - async def close(self, code: int = 1000): - await self.websocket.close(code) + async def close(self, code: int = 1000, reason: str = ""): + await self.websocket.close(code, reason) @overrides(BaseWebSocket) async def receive(self) -> str: @@ -374,11 +374,13 @@ class FastAPIWebSocket(BaseWebSocket): ) @overrides(BaseWebSocket) - async def accept(self): + async def accept(self) -> None: await self.websocket.accept() @overrides(BaseWebSocket) - async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): + async def close( + self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = "" + ) -> None: await self.websocket.close(code) @overrides(BaseWebSocket) diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 7527957e..9d32463f 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -249,15 +249,16 @@ class WebSocket(BaseWebSocket): @property @overrides(BaseWebSocket) def closed(self): - raise NotImplementedError + # FIXME + return True @overrides(BaseWebSocket) async def accept(self): await self.websocket.accept() @overrides(BaseWebSocket) - async def close(self, code: int = 1000): - await self.websocket.close(code) + async def close(self, code: int = 1000, reason: str = ""): + await self.websocket.close(code, reason) @overrides(BaseWebSocket) async def receive(self) -> str: diff --git a/nonebot/typing.py b/nonebot/typing.py index e5a275f8..3e28a3ea 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -17,7 +17,6 @@ .. _typing: https://docs.python.org/3/library/typing.html """ -from asyncio import Task from typing import ( TYPE_CHECKING, Any, @@ -32,6 +31,8 @@ from typing import ( ) if TYPE_CHECKING: + from asyncio import Task + from nonebot.adapters import Bot, Event from nonebot.permission import Permission @@ -249,7 +250,7 @@ T_PermissionUpdater = Callable[..., Union["Permission", Awaitable["Permission"]] PermissionUpdater 在 Matcher.pause, Matcher.reject 时被运行,用于更新会话对象权限。默认会更新为当前事件的触发对象。 """ -T_DependencyCache = Dict[Callable[..., Any], Task[Any]] +T_DependencyCache = Dict[Callable[..., Any], "Task[Any]"] """ :类型: ``Dict[Callable[..., Any], Task[Any]]`` :说明: diff --git a/nonebot/utils.py b/nonebot/utils.py index bd4a92a0..5bbcbee1 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -163,7 +163,7 @@ def logger_wrapper(logger_name: str): """ def log(level: str, message: str, exception: Optional[Exception] = None): - return logger.opt(colors=True, exception=exception).log( + logger.opt(colors=True, exception=exception).log( level, f"{escape_tag(logger_name)} | " + message )