diff --git a/nonebot/__init__.py b/nonebot/__init__.py index e28f87a1..6252db7f 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -6,14 +6,14 @@ import importlib from ipaddress import IPv4Address from nonebot.log import logger -from nonebot.typing import Optional +from nonebot.typing import Union, Optional, NoReturn from nonebot.config import Env, Config from nonebot.drivers import BaseDriver _driver: Optional[BaseDriver] = None -def get_driver() -> BaseDriver: +def get_driver() -> Union[NoReturn, BaseDriver]: if _driver is None: raise ValueError("NoneBot has not been initialized.") return _driver @@ -38,7 +38,7 @@ def init(*, _env_file: Optional[str] = None, **kwargs): logger.debug(f"Loaded config: {config.dict()}") Driver = getattr(importlib.import_module(config.driver), "Driver") - _driver = Driver(config) + _driver = Driver(env, config) def run(host: Optional[IPv4Address] = None, diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 5bd2e96f..1ec65fa7 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -3,13 +3,10 @@ import abc from functools import reduce -from dataclasses import dataclass - -# from pydantic.dataclasses import dataclass # dataclass with validation +from dataclasses import dataclass, field from nonebot.config import Config -from nonebot.drivers import BaseWebSocket -from nonebot.typing import Dict, Union, Iterable, Optional +from nonebot.typing import Dict, Union, Iterable, WebSocket class BaseBot(abc.ABC): @@ -20,7 +17,7 @@ class BaseBot(abc.ABC): config: Config, self_id: int, *, - websocket: BaseWebSocket = None): + websocket: WebSocket = None): self.connection_type = connection_type self.config = config self.self_id = self_id @@ -43,7 +40,7 @@ class BaseBot(abc.ABC): @dataclass class BaseMessageSegment(abc.ABC): type: str - data: Dict[str, str] = {} + data: Dict[str, str] = field(default_factory=lambda: {}) @abc.abstractmethod def __str__(self): diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 89dafcb3..1d69ce52 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -8,10 +8,9 @@ import httpx from nonebot.event import Event from nonebot.config import Config from nonebot.message import handle_event -from nonebot.drivers import BaseWebSocket from nonebot.exception import ApiNotAvailable -from nonebot.typing import Tuple, Iterable, Optional, overrides from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment +from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket def escape(s: str, *, escape_comma: bool = True) -> str: @@ -47,7 +46,7 @@ class Bot(BaseBot): config: Config, self_id: int, *, - websocket: BaseWebSocket = None): + websocket: WebSocket = None): if connection_type not in ["http", "websocket"]: raise ValueError("Unsupported connection type") diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 75fe6ba3..440079cb 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -4,17 +4,17 @@ import abc from ipaddress import IPv4Address -from nonebot.config import Config -from nonebot.adapters import BaseBot -from nonebot.typing import Dict, Optional +from nonebot.config import Env, Config +from nonebot.typing import Bot, Dict, Optional class BaseDriver(abc.ABC): @abc.abstractmethod - def __init__(self, config: Config): + def __init__(self, env: Env, config: Config): + self.env = env.environment self.config = config - self._clients: Dict[int, BaseBot] = {} + self._clients: Dict[int, Bot] = {} @property @abc.abstractmethod @@ -32,7 +32,7 @@ class BaseDriver(abc.ABC): raise NotImplementedError @property - def bots(self) -> Dict[int, BaseBot]: + def bots(self) -> Dict[int, Bot]: return self._clients @abc.abstractmethod @@ -59,7 +59,6 @@ class BaseWebSocket(object): self._websocket = websocket @property - @abc.abstractmethod def websocket(self): return self._websocket diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index c0ee663d..85b62780 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -1,26 +1,28 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import json import logging from ipaddress import IPv4Address import uvicorn +from fastapi import FastAPI, status from fastapi.security import OAuth2PasswordBearer from starlette.websockets import WebSocketDisconnect -from fastapi import Body, status, Header, FastAPI, WebSocket as FastAPIWebSocket +from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket from nonebot.log import logger -from nonebot.config import Config -from nonebot.adapters import BaseBot +from nonebot.config import Env, Config +from nonebot.utils import DataclassEncoder +from nonebot.typing import Optional, overrides from nonebot.adapters.cqhttp import Bot as CQBot -from nonebot.typing import Dict, Optional, overrides from nonebot.drivers import BaseDriver, BaseWebSocket class Driver(BaseDriver): - def __init__(self, config: Config): - super().__init__(config) + def __init__(self, env: Env, config: Config): + super().__init__(env, config) self._server_app = FastAPI( debug=config.debug, @@ -94,21 +96,28 @@ class Driver(BaseDriver): @overrides(BaseDriver) async def _handle_http(self, adapter: str, + response: Response, data: dict = Body(...), + x_self_id: int = Header(None), access_token: str = OAuth2PasswordBearer( "/", auto_error=False)): # TODO: Check authorization - logger.debug(f"Received message: {data}") + + # Create Bot Object if adapter == "cqhttp": - bot = CQBot("http", self.config) - await bot.handle_message(data) + bot = CQBot("http", self.config, x_self_id) + else: + response.status_code = status.HTTP_404_NOT_FOUND + return {"status": 404, "message": "adapter not found"} + + await bot.handle_message(data) return {"status": 200, "message": "success"} @overrides(BaseDriver) async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket, - self_id: int = Header(None), + x_self_id: int = Header(None), access_token: str = OAuth2PasswordBearer( "/", auto_error=False)): websocket = WebSocket(websocket) @@ -117,13 +126,16 @@ class Driver(BaseDriver): # Create Bot Object if adapter == "coolq": - bot = CQBot("websocket", self.config, self_id, websocket=websocket) + bot = CQBot("websocket", + self.config, + x_self_id, + websocket=websocket) else: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) return await websocket.accept() - self._clients[self_id] = bot + self._clients[x_self_id] = bot while not websocket.closed: data = await websocket.receive() @@ -133,7 +145,7 @@ class Driver(BaseDriver): await bot.handle_message(data) - del self._clients[self_id] + del self._clients[x_self_id] class WebSocket(BaseWebSocket): @@ -172,4 +184,5 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def send(self, data: dict) -> None: - await self.websocket.send_json(data) + text = json.dumps(data, cls=DataclassEncoder) + await self.websocket.send({"type": "websocket.send", "text": text}) diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 25f75535..bd77cfe6 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -5,10 +5,9 @@ from functools import wraps from datetime import datetime from collections import defaultdict -from nonebot.event import Event -from nonebot.typing import Handler from nonebot.rule import Rule, user -from nonebot.typing import Type, List, Dict, Optional, Callable +from nonebot.typing import Bot, Event, Handler +from nonebot.typing import Type, List, Dict, Optional, NoReturn from nonebot.exception import PausedException, RejectedException, FinishedException matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) @@ -66,7 +65,7 @@ class Matcher: return NewMatcher @classmethod - def check_rule(cls, bot, event: Event) -> bool: + def check_rule(cls, bot: Bot, event: Event) -> bool: """检查 Matcher 的 Rule 是否成立 Args: @@ -98,7 +97,7 @@ class Matcher: def _decorator(func: Handler) -> Handler: - async def _handler(bot, event: Event, state: dict): + async def _handler(bot: Bot, event: Event, state: dict) -> NoReturn: raise PausedException cls.handlers.append(_handler) @@ -144,7 +143,7 @@ class Matcher: # raise RejectedException # 运行handlers - async def run(self, bot, event): + async def run(self, bot: Bot, event: Event): try: # if self.parser: # await self.parser(event, state) # type: ignore diff --git a/nonebot/message.py b/nonebot/message.py index b199c8c1..fe85cbb9 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -5,20 +5,19 @@ import asyncio from datetime import datetime from nonebot.log import logger -from nonebot.event import Event from nonebot.matcher import matchers -from nonebot.typing import Set, Callable from nonebot.exception import IgnoredException +from nonebot.typing import Bot, Set, Event, PreProcessor -_event_preprocessors: Set[Callable] = set() +_event_preprocessors: Set[PreProcessor] = set() -def event_preprocessor(func: Callable) -> Callable: +def event_preprocessor(func: PreProcessor) -> PreProcessor: _event_preprocessors.add(func) return func -async def handle_event(bot, event: Event): +async def handle_event(bot: Bot, event: Event): # TODO: PreProcess coros = [] for preprocessor in _event_preprocessors: diff --git a/nonebot/typing.py b/nonebot/typing.py index cfee1152..72a628f0 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -2,14 +2,16 @@ # -*- coding: utf-8 -*- from types import ModuleType -from typing import TYPE_CHECKING +from typing import NoReturn, TYPE_CHECKING from typing import Any, Set, List, Dict, Type, Tuple, Mapping -from typing import Union, Optional, Iterable, Callable, Awaitable +from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable # import some modules needed when checking types if TYPE_CHECKING: - from nonebot.adapters import BaseBot as Bot - from nonebot.event import Event + from nonebot.event import Event as EventClass + from nonebot.matcher import Matcher as MatcherClass + from nonebot.drivers import BaseDriver, BaseWebSocket + from nonebot.adapters import BaseBot, BaseMessage, BaseMessageSegment def overrides(InterfaceClass: object): @@ -22,4 +24,17 @@ def overrides(InterfaceClass: object): return overrider -Handler = Callable[["Bot", "Event", dict], Awaitable[None]] +Driver = TypeVar("Driver", bound="BaseDriver") +WebSocket = TypeVar("WebSocket", bound="BaseWebSocket") + +Bot = TypeVar("Bot", bound="BaseBot") +Event = TypeVar("Event", bound="EventClass") +Message = TypeVar("Message", bound="BaseMessage") +MessageSegment = TypeVar("MessageSegment", bound="BaseMessageSegment") + +PreProcessor = Callable[[Bot, Event], Union[Awaitable[None], + Awaitable[NoReturn]]] + +Matcher = TypeVar("Matcher", bound="MatcherClass") +Handler = Callable[["Bot", Event, dict], Union[Awaitable[None], + Awaitable[NoReturn]]] diff --git a/nonebot/utils.py b/nonebot/utils.py new file mode 100644 index 00000000..c1c0ddf2 --- /dev/null +++ b/nonebot/utils.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import json +import dataclasses + +from nonebot.typing import overrides + + +class DataclassEncoder(json.JSONEncoder): + + @overrides(json.JSONEncoder) + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + return super().default(o)