mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ use typing.override instead (#2193)
This commit is contained in:
parent
7dd7c927bf
commit
6dc87a9455
@ -15,10 +15,10 @@ FrontMatter:
|
||||
description: nonebot.drivers.aiohttp 模块
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Type, AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
@ -38,11 +38,11 @@ class Mixin(ForwardMixin):
|
||||
"""AIOHTTP Mixin"""
|
||||
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
def type(self) -> str:
|
||||
return "aiohttp"
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if setup.version == HTTPVersion.H10:
|
||||
version = aiohttp.HttpVersion10
|
||||
@ -81,7 +81,7 @@ class Mixin(ForwardMixin):
|
||||
request=setup,
|
||||
)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||
if setup.version == HTTPVersion.H10:
|
||||
@ -117,15 +117,15 @@ class WebSocket(BaseWebSocket):
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
def closed(self):
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def accept(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def close(self, code: int = 1000):
|
||||
await self.websocket.close(code=code)
|
||||
await self.session.close()
|
||||
@ -136,7 +136,7 @@ class WebSocket(BaseWebSocket):
|
||||
raise WebSocketClosed(self.websocket.close_code or 1006)
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def receive(self) -> str:
|
||||
msg = await self._receive()
|
||||
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||
@ -145,7 +145,7 @@ class WebSocket(BaseWebSocket):
|
||||
)
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def receive_text(self) -> str:
|
||||
msg = await self._receive()
|
||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
||||
@ -154,7 +154,7 @@ class WebSocket(BaseWebSocket):
|
||||
)
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self._receive()
|
||||
if msg.type != aiohttp.WSMsgType.BINARY:
|
||||
@ -163,11 +163,11 @@ class WebSocket(BaseWebSocket):
|
||||
)
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.websocket.send_str(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send_bytes(data)
|
||||
|
||||
|
@ -19,12 +19,12 @@ FrontMatter:
|
||||
import logging
|
||||
import contextlib
|
||||
from functools import wraps
|
||||
from typing_extensions import override
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.internal.driver import FileTypes
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
@ -106,30 +106,30 @@ class Driver(ReverseDriver):
|
||||
)
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def type(self) -> str:
|
||||
"""驱动名称: `fastapi`"""
|
||||
return "fastapi"
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def server_app(self) -> FastAPI:
|
||||
"""`FastAPI APP` 对象"""
|
||||
return self._server_app
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def asgi(self) -> FastAPI:
|
||||
"""`FastAPI APP` 对象"""
|
||||
return self._server_app
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def logger(self) -> logging.Logger:
|
||||
"""fastapi 使用的 logger"""
|
||||
return logging.getLogger("fastapi")
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def setup_http_server(self, setup: HTTPServerSetup):
|
||||
async def _handle(request: Request) -> Response:
|
||||
return await self._handle_http(request, setup)
|
||||
@ -142,7 +142,7 @@ class Driver(ReverseDriver):
|
||||
include_in_schema=self.fastapi_config.fastapi_include_adapter_schema,
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
|
||||
async def _handle(websocket: WebSocket) -> None:
|
||||
await self._handle_ws(websocket, setup)
|
||||
@ -153,11 +153,11 @@ class Driver(ReverseDriver):
|
||||
name=setup.name,
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
return self._lifespan.on_startup(func)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
return self._lifespan.on_shutdown(func)
|
||||
|
||||
@ -169,7 +169,7 @@ class Driver(ReverseDriver):
|
||||
finally:
|
||||
await self._lifespan.shutdown()
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
@ -268,30 +268,30 @@ class Driver(ReverseDriver):
|
||||
class FastAPIWebSocket(BaseWebSocket):
|
||||
"""FastAPI WebSocket Wrapper"""
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
|
||||
super().__init__(request=request)
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
def closed(self) -> bool:
|
||||
return (
|
||||
self.websocket.client_state == WebSocketState.DISCONNECTED
|
||||
or self.websocket.application_state == WebSocketState.DISCONNECTED
|
||||
)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def accept(self) -> None:
|
||||
await self.websocket.accept()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def close(
|
||||
self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = ""
|
||||
) -> None:
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
# assert self.websocket.application_state == WebSocketState.CONNECTED
|
||||
msg = await self.websocket.receive()
|
||||
@ -299,21 +299,21 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
raise WebSocketClosed(msg["code"])
|
||||
return msg["text"] if "text" in msg else msg["bytes"]
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive_text(self) -> str:
|
||||
return await self.websocket.receive_text()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive_bytes(self) -> bytes:
|
||||
return await self.websocket.receive_bytes()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.websocket.send({"type": "websocket.send", "text": data})
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send({"type": "websocket.send", "bytes": data})
|
||||
|
||||
|
@ -15,10 +15,10 @@ FrontMatter:
|
||||
description: nonebot.drivers.httpx 模块
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
from typing import Type, AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
from nonebot.drivers import (
|
||||
Request,
|
||||
@ -43,11 +43,11 @@ class Mixin(ForwardMixin):
|
||||
"""HTTPX Mixin"""
|
||||
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
def type(self) -> str:
|
||||
return "httpx"
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
async with httpx.AsyncClient(
|
||||
cookies=setup.cookies.jar,
|
||||
@ -72,7 +72,7 @@ class Mixin(ForwardMixin):
|
||||
request=setup,
|
||||
)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
|
||||
async with super(Mixin, self).websocket(setup) as ws:
|
||||
|
@ -12,10 +12,10 @@ FrontMatter:
|
||||
import signal
|
||||
import asyncio
|
||||
import threading
|
||||
from typing_extensions import override
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.consts import WINDOWS
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import Driver as BaseDriver
|
||||
|
||||
@ -41,28 +41,28 @@ class Driver(BaseDriver):
|
||||
self.force_exit: bool = False
|
||||
|
||||
@property
|
||||
@overrides(BaseDriver)
|
||||
@override
|
||||
def type(self) -> str:
|
||||
"""驱动名称: `none`"""
|
||||
return "none"
|
||||
|
||||
@property
|
||||
@overrides(BaseDriver)
|
||||
@override
|
||||
def logger(self):
|
||||
"""none driver 使用的 logger"""
|
||||
return logger
|
||||
|
||||
@overrides(BaseDriver)
|
||||
@override
|
||||
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""注册一个启动时执行的函数"""
|
||||
return self._lifespan.on_startup(func)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
@override
|
||||
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
|
||||
"""注册一个停止时执行的函数"""
|
||||
return self._lifespan.on_shutdown(func)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
@override
|
||||
def run(self, *args, **kwargs):
|
||||
"""启动 none driver"""
|
||||
super().run(*args, **kwargs)
|
||||
|
@ -17,6 +17,7 @@ FrontMatter:
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing_extensions import override
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
@ -33,7 +34,6 @@ from typing import (
|
||||
from pydantic import BaseSettings
|
||||
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.internal.driver import FileTypes
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
@ -102,30 +102,30 @@ class Driver(ReverseDriver):
|
||||
)
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def type(self) -> str:
|
||||
"""驱动名称: `quart`"""
|
||||
return "quart"
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def server_app(self) -> Quart:
|
||||
"""`Quart` 对象"""
|
||||
return self._server_app
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def asgi(self):
|
||||
"""`Quart` 对象"""
|
||||
return self._server_app
|
||||
|
||||
@property
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def logger(self):
|
||||
"""Quart 使用的 logger"""
|
||||
return self._server_app.logger
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def setup_http_server(self, setup: HTTPServerSetup):
|
||||
async def _handle() -> Response:
|
||||
return await self._handle_http(setup)
|
||||
@ -137,7 +137,7 @@ class Driver(ReverseDriver):
|
||||
view_func=_handle,
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
|
||||
async def _handle() -> None:
|
||||
return await self._handle_ws(setup)
|
||||
@ -148,17 +148,17 @@ class Driver(ReverseDriver):
|
||||
view_func=_handle,
|
||||
)
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
|
||||
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
|
||||
return self.server_app.before_serving(func) # type: ignore
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def on_shutdown(self, func: _AsyncCallable) -> _AsyncCallable:
|
||||
"""参考文档: [`Startup and Shutdown`](https://pgjones.gitlab.io/quart/how_to_guides/startup_shutdown.html)"""
|
||||
return self.server_app.after_serving(func) # type: ignore
|
||||
|
||||
@overrides(ReverseDriver)
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
@ -262,25 +262,25 @@ class WebSocket(BaseWebSocket):
|
||||
return self.websocket_ctx.websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
def closed(self):
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def accept(self):
|
||||
await self.websocket.accept()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def close(self, code: int = 1000, reason: str = ""):
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
return await self.websocket.receive()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive_text(self) -> str:
|
||||
msg = await self.websocket.receive()
|
||||
@ -288,7 +288,7 @@ class WebSocket(BaseWebSocket):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self.websocket.receive()
|
||||
@ -296,11 +296,11 @@ class WebSocket(BaseWebSocket):
|
||||
raise TypeError("WebSocket received unexpected frame type: str")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_text(self, data: str):
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_bytes(self, data: bytes):
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
@ -17,11 +17,10 @@ FrontMatter:
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing_extensions import ParamSpec
|
||||
from contextlib import asynccontextmanager
|
||||
from typing_extensions import ParamSpec, override
|
||||
from typing import Type, Union, TypeVar, Callable, Awaitable, AsyncGenerator
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.log import LoguruHandler
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.exception import WebSocketClosed
|
||||
@ -63,15 +62,15 @@ class Mixin(ForwardMixin):
|
||||
"""Websockets Mixin"""
|
||||
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
def type(self) -> str:
|
||||
return "websockets"
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
return await super(Mixin, self).request(setup)
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@override
|
||||
@asynccontextmanager
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||
connection = Connect(
|
||||
@ -86,30 +85,30 @@ class Mixin(ForwardMixin):
|
||||
class WebSocket(BaseWebSocket):
|
||||
"""Websockets WebSocket Wrapper"""
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
|
||||
super().__init__(request=request)
|
||||
self.websocket = websocket
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
def closed(self) -> bool:
|
||||
return self.websocket.closed
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def accept(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def close(self, code: int = 1000, reason: str = ""):
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
return await self.websocket.recv()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive_text(self) -> str:
|
||||
msg = await self.websocket.recv()
|
||||
@ -117,7 +116,7 @@ class WebSocket(BaseWebSocket):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
@catch_closed
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self.websocket.recv()
|
||||
@ -125,11 +124,11 @@ class WebSocket(BaseWebSocket):
|
||||
raise TypeError("WebSocket received unexpected frame type: str")
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@override
|
||||
async def send_bytes(self, data: bytes) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
|
@ -10,7 +10,8 @@ FrontMatter:
|
||||
description: nonebot.typing 模块
|
||||
"""
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias
|
||||
import warnings
|
||||
from typing_extensions import ParamSpec, TypeAlias, override
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -34,14 +35,16 @@ P = ParamSpec("P")
|
||||
T_Wrapped: TypeAlias = Callable[P, T]
|
||||
|
||||
|
||||
def overrides(InterfaceClass: object) -> Callable[[T_Wrapped], T_Wrapped]:
|
||||
def overrides(InterfaceClass: object):
|
||||
"""标记一个方法为父类 interface 的 implement"""
|
||||
|
||||
def overrider(func: T_Wrapped) -> T_Wrapped:
|
||||
assert func.__name__ in dir(InterfaceClass), f"Error method: {func.__name__}"
|
||||
return func
|
||||
|
||||
return overrider
|
||||
warnings.warn(
|
||||
"overrides is deprecated and will be removed in a future version, "
|
||||
"use @typing_extensions.override instead. "
|
||||
"See [PEP 698](https://peps.python.org/pep-0698/) for more details.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return override
|
||||
|
||||
|
||||
# state
|
||||
|
@ -15,7 +15,7 @@ from pathlib import Path
|
||||
from contextvars import copy_context
|
||||
from functools import wraps, partial
|
||||
from contextlib import asynccontextmanager
|
||||
from typing_extensions import ParamSpec, get_args, get_origin
|
||||
from typing_extensions import ParamSpec, get_args, override, get_origin
|
||||
from typing import (
|
||||
Any,
|
||||
Type,
|
||||
@ -33,7 +33,6 @@ from typing import (
|
||||
from pydantic.typing import is_union, is_none_type
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import overrides
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@ -224,7 +223,7 @@ def resolve_dot_notation(
|
||||
class DataclassEncoder(json.JSONEncoder):
|
||||
"""可以序列化 {ref}`nonebot.adapters.Message`(List[Dataclass]) 的 `JSONEncoder`"""
|
||||
|
||||
@overrides(json.JSONEncoder)
|
||||
@override
|
||||
def default(self, o):
|
||||
if dataclasses.is_dataclass(o):
|
||||
return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)}
|
||||
|
Loading…
Reference in New Issue
Block a user