use typing.override instead (#2193)

This commit is contained in:
Ju4tCode 2023-07-17 15:56:27 +08:00 committed by GitHub
parent 7dd7c927bf
commit 6dc87a9455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 84 additions and 83 deletions

View File

@ -15,10 +15,10 @@ FrontMatter:
description: nonebot.drivers.aiohttp 模块
"""
from typing_extensions import override
from typing import Type, AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers.none import Driver as NoneDriver
@ -38,11 +38,11 @@ class Mixin(ForwardMixin):
"""AIOHTTP Mixin"""
@property
@overrides(ForwardMixin)
@override
def type(self) -> str:
return "aiohttp"
@overrides(ForwardMixin)
@override
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
@ -81,7 +81,7 @@ class Mixin(ForwardMixin):
request=setup,
)
@overrides(ForwardMixin)
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
if setup.version == HTTPVersion.H10:
@ -117,15 +117,15 @@ class WebSocket(BaseWebSocket):
self.websocket = websocket
@property
@overrides(BaseWebSocket)
@override
def closed(self):
return self.websocket.closed
@overrides(BaseWebSocket)
@override
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
@override
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
await self.session.close()
@ -136,7 +136,7 @@ class WebSocket(BaseWebSocket):
raise WebSocketClosed(self.websocket.close_code or 1006)
return msg
@overrides(BaseWebSocket)
@override
async def receive(self) -> str:
msg = await self._receive()
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
@ -145,7 +145,7 @@ class WebSocket(BaseWebSocket):
)
return msg.data
@overrides(BaseWebSocket)
@override
async def receive_text(self) -> str:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
@ -154,7 +154,7 @@ class WebSocket(BaseWebSocket):
)
return msg.data
@overrides(BaseWebSocket)
@override
async def receive_bytes(self) -> bytes:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.BINARY:
@ -163,11 +163,11 @@ class WebSocket(BaseWebSocket):
)
return msg.data
@overrides(BaseWebSocket)
@override
async def send_text(self, data: str) -> None:
await self.websocket.send_str(data)
@overrides(BaseWebSocket)
@override
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)

View File

@ -19,12 +19,12 @@ FrontMatter:
import logging
import contextlib
from functools import wraps
from typing_extensions import override
from typing import Any, Dict, List, Tuple, Union, Optional
from pydantic import BaseSettings
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.exception import WebSocketClosed
from nonebot.internal.driver import FileTypes
from nonebot.config import Config as NoneBotConfig
@ -106,30 +106,30 @@ class Driver(ReverseDriver):
)
@property
@overrides(ReverseDriver)
@override
def type(self) -> str:
"""驱动名称: `fastapi`"""
return "fastapi"
@property
@overrides(ReverseDriver)
@override
def server_app(self) -> FastAPI:
"""`FastAPI APP` 对象"""
return self._server_app
@property
@overrides(ReverseDriver)
@override
def asgi(self) -> FastAPI:
"""`FastAPI APP` 对象"""
return self._server_app
@property
@overrides(ReverseDriver)
@override
def logger(self) -> logging.Logger:
"""fastapi 使用的 logger"""
return logging.getLogger("fastapi")
@overrides(ReverseDriver)
@override
def setup_http_server(self, setup: HTTPServerSetup):
async def _handle(request: Request) -> Response:
return await self._handle_http(request, setup)
@ -142,7 +142,7 @@ class Driver(ReverseDriver):
include_in_schema=self.fastapi_config.fastapi_include_adapter_schema,
)
@overrides(ReverseDriver)
@override
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
async def _handle(websocket: WebSocket) -> None:
await self._handle_ws(websocket, setup)
@ -153,11 +153,11 @@ class Driver(ReverseDriver):
name=setup.name,
)
@overrides(ReverseDriver)
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)
@overrides(ReverseDriver)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)
@ -169,7 +169,7 @@ class Driver(ReverseDriver):
finally:
await self._lifespan.shutdown()
@overrides(ReverseDriver)
@override
def run(
self,
host: Optional[str] = None,
@ -268,30 +268,30 @@ class Driver(ReverseDriver):
class FastAPIWebSocket(BaseWebSocket):
"""FastAPI WebSocket Wrapper"""
@overrides(BaseWebSocket)
@override
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
@override
def closed(self) -> bool:
return (
self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED
)
@overrides(BaseWebSocket)
@override
async def accept(self) -> None:
await self.websocket.accept()
@overrides(BaseWebSocket)
@override
async def close(
self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = ""
) -> None:
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
@override
async def receive(self) -> Union[str, bytes]:
# assert self.websocket.application_state == WebSocketState.CONNECTED
msg = await self.websocket.receive()
@ -299,21 +299,21 @@ class FastAPIWebSocket(BaseWebSocket):
raise WebSocketClosed(msg["code"])
return msg["text"] if "text" in msg else msg["bytes"]
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive_text(self) -> str:
return await self.websocket.receive_text()
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
@override
async def send_text(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})
@overrides(BaseWebSocket)
@override
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send({"type": "websocket.send", "bytes": data})

View File

@ -15,10 +15,10 @@ FrontMatter:
description: nonebot.drivers.httpx 模块
"""
from typing_extensions import override
from typing import Type, AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import (
Request,
@ -43,11 +43,11 @@ class Mixin(ForwardMixin):
"""HTTPX Mixin"""
@property
@overrides(ForwardMixin)
@override
def type(self) -> str:
return "httpx"
@overrides(ForwardMixin)
@override
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
cookies=setup.cookies.jar,
@ -72,7 +72,7 @@ class Mixin(ForwardMixin):
request=setup,
)
@overrides(ForwardMixin)
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
async with super(Mixin, self).websocket(setup) as ws:

View File

@ -12,10 +12,10 @@ FrontMatter:
import signal
import asyncio
import threading
from typing_extensions import override
from nonebot.log import logger
from nonebot.consts import WINDOWS
from nonebot.typing import overrides
from nonebot.config import Env, Config
from nonebot.drivers import Driver as BaseDriver
@ -41,28 +41,28 @@ class Driver(BaseDriver):
self.force_exit: bool = False
@property
@overrides(BaseDriver)
@override
def type(self) -> str:
"""驱动名称: `none`"""
return "none"
@property
@overrides(BaseDriver)
@override
def logger(self):
"""none driver 使用的 logger"""
return logger
@overrides(BaseDriver)
@override
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个启动时执行的函数"""
return self._lifespan.on_startup(func)
@overrides(BaseDriver)
@override
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""注册一个停止时执行的函数"""
return self._lifespan.on_shutdown(func)
@overrides(BaseDriver)
@override
def run(self, *args, **kwargs):
"""启动 none driver"""
super().run(*args, **kwargs)

View File

@ -17,6 +17,7 @@ FrontMatter:
import asyncio
from functools import wraps
from typing_extensions import override
from typing import (
Any,
Dict,
@ -33,7 +34,6 @@ from typing import (
from pydantic import BaseSettings
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.exception import WebSocketClosed
from nonebot.internal.driver import FileTypes
from nonebot.config import Config as NoneBotConfig
@ -102,30 +102,30 @@ class Driver(ReverseDriver):
)
@property
@overrides(ReverseDriver)
@override
def type(self) -> str:
"""驱动名称: `quart`"""
return "quart"
@property
@overrides(ReverseDriver)
@override
def server_app(self) -> Quart:
"""`Quart` 对象"""
return self._server_app
@property
@overrides(ReverseDriver)
@override
def asgi(self):
"""`Quart` 对象"""
return self._server_app
@property
@overrides(ReverseDriver)
@override
def logger(self):
"""Quart 使用的 logger"""
return self._server_app.logger
@overrides(ReverseDriver)
@override
def setup_http_server(self, setup: HTTPServerSetup):
async def _handle() -> Response:
return await self._handle_http(setup)
@ -137,7 +137,7 @@ class Driver(ReverseDriver):
view_func=_handle,
)
@overrides(ReverseDriver)
@override
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
async def _handle() -> None:
return await self._handle_ws(setup)
@ -148,17 +148,17 @@ class Driver(ReverseDriver):
view_func=_handle,
)
@overrides(ReverseDriver)
@override
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.before_serving(func) # type: ignore
@overrides(ReverseDriver)
@override
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
return self.server_app.after_serving(func) # type: ignore
@overrides(ReverseDriver)
@override
def run(
self,
host: Optional[str] = None,
@ -262,25 +262,25 @@ class WebSocket(BaseWebSocket):
return self.websocket_ctx.websocket
@property
@overrides(BaseWebSocket)
@override
def closed(self):
# FIXME
return True
@overrides(BaseWebSocket)
@override
async def accept(self):
await self.websocket.accept()
@overrides(BaseWebSocket)
@override
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive(self) -> Union[str, bytes]:
return await self.websocket.receive()
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.receive()
@ -288,7 +288,7 @@ class WebSocket(BaseWebSocket):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive_bytes(self) -> bytes:
msg = await self.websocket.receive()
@ -296,11 +296,11 @@ class WebSocket(BaseWebSocket):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
@override
async def send_text(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)
@override
async def send_bytes(self, data: bytes):
await self.websocket.send(data)

View File

@ -17,11 +17,10 @@ FrontMatter:
import logging
from functools import wraps
from typing_extensions import ParamSpec
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, override
from typing import Type, Union, TypeVar, Callable, Awaitable, AsyncGenerator
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
@ -63,15 +62,15 @@ class Mixin(ForwardMixin):
"""Websockets Mixin"""
@property
@overrides(ForwardMixin)
@override
def type(self) -> str:
return "websockets"
@overrides(ForwardMixin)
@override
async def request(self, setup: Request) -> Response:
return await super(Mixin, self).request(setup)
@overrides(ForwardMixin)
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect(
@ -86,30 +85,30 @@ class Mixin(ForwardMixin):
class WebSocket(BaseWebSocket):
"""Websockets WebSocket Wrapper"""
@overrides(BaseWebSocket)
@override
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
@override
def closed(self) -> bool:
return self.websocket.closed
@overrides(BaseWebSocket)
@override
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
@override
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive(self) -> Union[str, bytes]:
return await self.websocket.recv()
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.recv()
@ -117,7 +116,7 @@ class WebSocket(BaseWebSocket):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
@override
@catch_closed
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
@ -125,11 +124,11 @@ class WebSocket(BaseWebSocket):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
@override
async def send_text(self, data: str) -> None:
await self.websocket.send(data)
@overrides(BaseWebSocket)
@override
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send(data)

View File

@ -10,7 +10,8 @@ FrontMatter:
description: nonebot.typing 模块
"""
from typing_extensions import ParamSpec, TypeAlias
import warnings
from typing_extensions import ParamSpec, TypeAlias, override
from typing import (
TYPE_CHECKING,
Any,
@ -34,14 +35,16 @@ P = ParamSpec("P")
T_Wrapped: TypeAlias = Callable[P, T]
def overrides(InterfaceClass: object) -> Callable[[T_Wrapped], T_Wrapped]:
def overrides(InterfaceClass: object):
"""标记一个方法为父类 interface 的 implement"""
def overrider(func: T_Wrapped) -> T_Wrapped:
assert func.__name__ in dir(InterfaceClass), f"Error method: {func.__name__}"
return func
return overrider
warnings.warn(
"overrides is deprecated and will be removed in a future version, "
"use @typing_extensions.override instead. "
"See [PEP 698](https://peps.python.org/pep-0698/) for more details.",
DeprecationWarning,
)
return override
# state

View File

@ -15,7 +15,7 @@ from pathlib import Path
from contextvars import copy_context
from functools import wraps, partial
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, get_args, get_origin
from typing_extensions import ParamSpec, get_args, override, get_origin
from typing import (
Any,
Type,
@ -33,7 +33,6 @@ from typing import (
from pydantic.typing import is_union, is_none_type
from nonebot.log import logger
from nonebot.typing import overrides
P = ParamSpec("P")
R = TypeVar("R")
@ -224,7 +223,7 @@ def resolve_dot_notation(
class DataclassEncoder(json.JSONEncoder):
"""可以序列化 {ref}`nonebot.adapters.Message`(List[Dataclass]) 的 `JSONEncoder`"""
@overrides(json.JSONEncoder)
@override
def default(self, o):
if dataclasses.is_dataclass(o):
return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)}