mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-30 17:15:08 +08:00
add types
This commit is contained in:
parent
00913f1a8f
commit
9e33a605a6
@ -6,14 +6,14 @@ import importlib
|
|||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import Optional
|
from nonebot.typing import Union, Optional, NoReturn
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.drivers import BaseDriver
|
from nonebot.drivers import BaseDriver
|
||||||
|
|
||||||
_driver: Optional[BaseDriver] = None
|
_driver: Optional[BaseDriver] = None
|
||||||
|
|
||||||
|
|
||||||
def get_driver() -> BaseDriver:
|
def get_driver() -> Union[NoReturn, BaseDriver]:
|
||||||
if _driver is None:
|
if _driver is None:
|
||||||
raise ValueError("NoneBot has not been initialized.")
|
raise ValueError("NoneBot has not been initialized.")
|
||||||
return _driver
|
return _driver
|
||||||
@ -38,7 +38,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
|
|||||||
logger.debug(f"Loaded config: {config.dict()}")
|
logger.debug(f"Loaded config: {config.dict()}")
|
||||||
|
|
||||||
Driver = getattr(importlib.import_module(config.driver), "Driver")
|
Driver = getattr(importlib.import_module(config.driver), "Driver")
|
||||||
_driver = Driver(config)
|
_driver = Driver(env, config)
|
||||||
|
|
||||||
|
|
||||||
def run(host: Optional[IPv4Address] = None,
|
def run(host: Optional[IPv4Address] = None,
|
||||||
|
@ -3,13 +3,10 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
# from pydantic.dataclasses import dataclass # dataclass with validation
|
|
||||||
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.drivers import BaseWebSocket
|
from nonebot.typing import Dict, Union, Iterable, WebSocket
|
||||||
from nonebot.typing import Dict, Union, Iterable, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBot(abc.ABC):
|
class BaseBot(abc.ABC):
|
||||||
@ -20,7 +17,7 @@ class BaseBot(abc.ABC):
|
|||||||
config: Config,
|
config: Config,
|
||||||
self_id: int,
|
self_id: int,
|
||||||
*,
|
*,
|
||||||
websocket: BaseWebSocket = None):
|
websocket: WebSocket = None):
|
||||||
self.connection_type = connection_type
|
self.connection_type = connection_type
|
||||||
self.config = config
|
self.config = config
|
||||||
self.self_id = self_id
|
self.self_id = self_id
|
||||||
@ -43,7 +40,7 @@ class BaseBot(abc.ABC):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BaseMessageSegment(abc.ABC):
|
class BaseMessageSegment(abc.ABC):
|
||||||
type: str
|
type: str
|
||||||
data: Dict[str, str] = {}
|
data: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -8,10 +8,9 @@ import httpx
|
|||||||
from nonebot.event import Event
|
from nonebot.event import Event
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.drivers import BaseWebSocket
|
|
||||||
from nonebot.exception import ApiNotAvailable
|
from nonebot.exception import ApiNotAvailable
|
||||||
from nonebot.typing import Tuple, Iterable, Optional, overrides
|
|
||||||
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
|
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
|
||||||
|
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
|
||||||
|
|
||||||
|
|
||||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||||
@ -47,7 +46,7 @@ class Bot(BaseBot):
|
|||||||
config: Config,
|
config: Config,
|
||||||
self_id: int,
|
self_id: int,
|
||||||
*,
|
*,
|
||||||
websocket: BaseWebSocket = None):
|
websocket: WebSocket = None):
|
||||||
if connection_type not in ["http", "websocket"]:
|
if connection_type not in ["http", "websocket"]:
|
||||||
raise ValueError("Unsupported connection type")
|
raise ValueError("Unsupported connection type")
|
||||||
|
|
||||||
|
@ -4,17 +4,17 @@
|
|||||||
import abc
|
import abc
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.adapters import BaseBot
|
from nonebot.typing import Bot, Dict, Optional
|
||||||
from nonebot.typing import Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDriver(abc.ABC):
|
class BaseDriver(abc.ABC):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __init__(self, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
|
self.env = env.environment
|
||||||
self.config = config
|
self.config = config
|
||||||
self._clients: Dict[int, BaseBot] = {}
|
self._clients: Dict[int, Bot] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -32,7 +32,7 @@ class BaseDriver(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bots(self) -> Dict[int, BaseBot]:
|
def bots(self) -> Dict[int, Bot]:
|
||||||
return self._clients
|
return self._clients
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -59,7 +59,6 @@ class BaseWebSocket(object):
|
|||||||
self._websocket = websocket
|
self._websocket = websocket
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
|
||||||
def websocket(self):
|
def websocket(self):
|
||||||
return self._websocket
|
return self._websocket
|
||||||
|
|
||||||
|
@ -1,26 +1,28 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from starlette.websockets import WebSocketDisconnect
|
from starlette.websockets import WebSocketDisconnect
|
||||||
from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket
|
from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.adapters import BaseBot
|
from nonebot.utils import DataclassEncoder
|
||||||
|
from nonebot.typing import Optional, overrides
|
||||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||||
from nonebot.typing import Dict, Optional, overrides
|
|
||||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||||
|
|
||||||
|
|
||||||
class Driver(BaseDriver):
|
class Driver(BaseDriver):
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
super().__init__(config)
|
super().__init__(env, config)
|
||||||
|
|
||||||
self._server_app = FastAPI(
|
self._server_app = FastAPI(
|
||||||
debug=config.debug,
|
debug=config.debug,
|
||||||
@ -94,21 +96,28 @@ class Driver(BaseDriver):
|
|||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
async def _handle_http(self,
|
async def _handle_http(self,
|
||||||
adapter: str,
|
adapter: str,
|
||||||
|
response: Response,
|
||||||
data: dict = Body(...),
|
data: dict = Body(...),
|
||||||
|
x_self_id: int = Header(None),
|
||||||
access_token: str = OAuth2PasswordBearer(
|
access_token: str = OAuth2PasswordBearer(
|
||||||
"/", auto_error=False)):
|
"/", auto_error=False)):
|
||||||
# TODO: Check authorization
|
# TODO: Check authorization
|
||||||
logger.debug(f"Received message: {data}")
|
|
||||||
|
# Create Bot Object
|
||||||
if adapter == "cqhttp":
|
if adapter == "cqhttp":
|
||||||
bot = CQBot("http", self.config)
|
bot = CQBot("http", self.config, x_self_id)
|
||||||
await bot.handle_message(data)
|
else:
|
||||||
|
response.status_code = status.HTTP_404_NOT_FOUND
|
||||||
|
return {"status": 404, "message": "adapter not found"}
|
||||||
|
|
||||||
|
await bot.handle_message(data)
|
||||||
return {"status": 200, "message": "success"}
|
return {"status": 200, "message": "success"}
|
||||||
|
|
||||||
@overrides(BaseDriver)
|
@overrides(BaseDriver)
|
||||||
async def _handle_ws_reverse(self,
|
async def _handle_ws_reverse(self,
|
||||||
adapter: str,
|
adapter: str,
|
||||||
websocket: FastAPIWebSocket,
|
websocket: FastAPIWebSocket,
|
||||||
self_id: int = Header(None),
|
x_self_id: int = Header(None),
|
||||||
access_token: str = OAuth2PasswordBearer(
|
access_token: str = OAuth2PasswordBearer(
|
||||||
"/", auto_error=False)):
|
"/", auto_error=False)):
|
||||||
websocket = WebSocket(websocket)
|
websocket = WebSocket(websocket)
|
||||||
@ -117,13 +126,16 @@ class Driver(BaseDriver):
|
|||||||
|
|
||||||
# Create Bot Object
|
# Create Bot Object
|
||||||
if adapter == "coolq":
|
if adapter == "coolq":
|
||||||
bot = CQBot("websocket", self.config, self_id, websocket=websocket)
|
bot = CQBot("websocket",
|
||||||
|
self.config,
|
||||||
|
x_self_id,
|
||||||
|
websocket=websocket)
|
||||||
else:
|
else:
|
||||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||||
return
|
return
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
self._clients[self_id] = bot
|
self._clients[x_self_id] = bot
|
||||||
|
|
||||||
while not websocket.closed:
|
while not websocket.closed:
|
||||||
data = await websocket.receive()
|
data = await websocket.receive()
|
||||||
@ -133,7 +145,7 @@ class Driver(BaseDriver):
|
|||||||
|
|
||||||
await bot.handle_message(data)
|
await bot.handle_message(data)
|
||||||
|
|
||||||
del self._clients[self_id]
|
del self._clients[x_self_id]
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
@ -172,4 +184,5 @@ class WebSocket(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: dict) -> None:
|
async def send(self, data: dict) -> None:
|
||||||
await self.websocket.send_json(data)
|
text = json.dumps(data, cls=DataclassEncoder)
|
||||||
|
await self.websocket.send({"type": "websocket.send", "text": text})
|
||||||
|
@ -5,10 +5,9 @@ from functools import wraps
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from nonebot.event import Event
|
|
||||||
from nonebot.typing import Handler
|
|
||||||
from nonebot.rule import Rule, user
|
from nonebot.rule import Rule, user
|
||||||
from nonebot.typing import Type, List, Dict, Optional, Callable
|
from nonebot.typing import Bot, Event, Handler
|
||||||
|
from nonebot.typing import Type, List, Dict, Optional, NoReturn
|
||||||
from nonebot.exception import PausedException, RejectedException, FinishedException
|
from nonebot.exception import PausedException, RejectedException, FinishedException
|
||||||
|
|
||||||
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
|
matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list)
|
||||||
@ -66,7 +65,7 @@ class Matcher:
|
|||||||
return NewMatcher
|
return NewMatcher
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_rule(cls, bot, event: Event) -> bool:
|
def check_rule(cls, bot: Bot, event: Event) -> bool:
|
||||||
"""检查 Matcher 的 Rule 是否成立
|
"""检查 Matcher 的 Rule 是否成立
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -98,7 +97,7 @@ class Matcher:
|
|||||||
|
|
||||||
def _decorator(func: Handler) -> Handler:
|
def _decorator(func: Handler) -> Handler:
|
||||||
|
|
||||||
async def _handler(bot, event: Event, state: dict):
|
async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn:
|
||||||
raise PausedException
|
raise PausedException
|
||||||
|
|
||||||
cls.handlers.append(_handler)
|
cls.handlers.append(_handler)
|
||||||
@ -144,7 +143,7 @@ class Matcher:
|
|||||||
# raise RejectedException
|
# raise RejectedException
|
||||||
|
|
||||||
# 运行handlers
|
# 运行handlers
|
||||||
async def run(self, bot, event):
|
async def run(self, bot: Bot, event: Event):
|
||||||
try:
|
try:
|
||||||
# if self.parser:
|
# if self.parser:
|
||||||
# await self.parser(event, state) # type: ignore
|
# await self.parser(event, state) # type: ignore
|
||||||
|
@ -5,20 +5,19 @@ import asyncio
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.event import Event
|
|
||||||
from nonebot.matcher import matchers
|
from nonebot.matcher import matchers
|
||||||
from nonebot.typing import Set, Callable
|
|
||||||
from nonebot.exception import IgnoredException
|
from nonebot.exception import IgnoredException
|
||||||
|
from nonebot.typing import Bot, Set, Event, PreProcessor
|
||||||
|
|
||||||
_event_preprocessors: Set[Callable] = set()
|
_event_preprocessors: Set[PreProcessor] = set()
|
||||||
|
|
||||||
|
|
||||||
def event_preprocessor(func: Callable) -> Callable:
|
def event_preprocessor(func: PreProcessor) -> PreProcessor:
|
||||||
_event_preprocessors.add(func)
|
_event_preprocessors.add(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
async def handle_event(bot, event: Event):
|
async def handle_event(bot: Bot, event: Event):
|
||||||
# TODO: PreProcess
|
# TODO: PreProcess
|
||||||
coros = []
|
coros = []
|
||||||
for preprocessor in _event_preprocessors:
|
for preprocessor in _event_preprocessors:
|
||||||
|
@ -2,14 +2,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import TYPE_CHECKING
|
from typing import NoReturn, TYPE_CHECKING
|
||||||
from typing import Any, Set, List, Dict, Type, Tuple, Mapping
|
from typing import Any, Set, List, Dict, Type, Tuple, Mapping
|
||||||
from typing import Union, Optional, Iterable, Callable, Awaitable
|
from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable
|
||||||
|
|
||||||
# import some modules needed when checking types
|
# import some modules needed when checking types
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nonebot.adapters import BaseBot as Bot
|
from nonebot.event import Event as EventClass
|
||||||
from nonebot.event import Event
|
from nonebot.matcher import Matcher as MatcherClass
|
||||||
|
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||||
|
from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment
|
||||||
|
|
||||||
|
|
||||||
def overrides(InterfaceClass: object):
|
def overrides(InterfaceClass: object):
|
||||||
@ -22,4 +24,17 @@ def overrides(InterfaceClass: object):
|
|||||||
return overrider
|
return overrider
|
||||||
|
|
||||||
|
|
||||||
Handler = Callable[["Bot", "Event", dict], Awaitable[None]]
|
Driver = TypeVar("Driver", bound="BaseDriver")
|
||||||
|
WebSocket = TypeVar("WebSocket", bound="BaseWebSocket")
|
||||||
|
|
||||||
|
Bot = TypeVar("Bot", bound="BaseBot")
|
||||||
|
Event = TypeVar("Event", bound="EventClass")
|
||||||
|
Message = TypeVar("Message", bound="BaseMessage")
|
||||||
|
MessageSegment = TypeVar("MessageSegment", bound="BaseMessageSegment")
|
||||||
|
|
||||||
|
PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
|
||||||
|
Awaitable[NoReturn]]]
|
||||||
|
|
||||||
|
Matcher = TypeVar("Matcher", bound="MatcherClass")
|
||||||
|
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None],
|
||||||
|
Awaitable[NoReturn]]]
|
||||||
|
16
nonebot/utils.py
Normal file
16
nonebot/utils.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import json
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
|
|
||||||
|
class DataclassEncoder(json.JSONEncoder):
|
||||||
|
|
||||||
|
@overrides(json.JSONEncoder)
|
||||||
|
def default(self, o):
|
||||||
|
if dataclasses.is_dataclass(o):
|
||||||
|
return dataclasses.asdict(o)
|
||||||
|
return super().default(o)
|
Loading…
Reference in New Issue
Block a user