diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 6252db7f..c4b93b2c 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -6,9 +6,10 @@ import importlib from ipaddress import IPv4Address from nonebot.log import logger -from nonebot.typing import Union, Optional, NoReturn from nonebot.config import Env, Config from nonebot.drivers import BaseDriver +from nonebot.adapters.cqhttp import Bot as CQBot +from nonebot.typing import Union, Optional, NoReturn _driver: Optional[BaseDriver] = None @@ -40,6 +41,8 @@ def init(*, _env_file: Optional[str] = None, **kwargs): Driver = getattr(importlib.import_module(config.driver), "Driver") _driver = Driver(env, config) + _driver.register_adapter("cqhttp", CQBot) + def run(host: Optional[IPv4Address] = None, port: Optional[int] = None, diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 1ce77bc0..04183a24 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -2,27 +2,33 @@ # -*- coding: utf-8 -*- import abc -from functools import reduce +from functools import reduce, partial from dataclasses import dataclass, field from nonebot.config import Config -from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket +from nonebot.typing import Driver, WebSocket +from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable class BaseBot(abc.ABC): @abc.abstractmethod def __init__(self, + driver: Driver, connection_type: str, config: Config, - self_id: int, + self_id: str, *, websocket: WebSocket = None): + self.driver = driver self.connection_type = connection_type self.config = config self.self_id = self_id self.websocket = websocket + def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]: + return partial(self.call_api, name) + @property @abc.abstractmethod def type(self) -> str: @@ -37,6 +43,7 @@ class BaseBot(abc.ABC): raise NotImplementedError +# TODO: improve event class BaseEvent(abc.ABC): def __init__(self, raw_event: dict): diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index ce68977a..c1beaf7f 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -2,14 +2,17 @@ # -*- coding: utf-8 -*- import re +import sys +import asyncio import httpx from nonebot.config import Config from nonebot.message import handle_event -from nonebot.exception import ApiNotAvailable +from nonebot.typing import overrides, Driver, WebSocket, NoReturn +from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional +from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment -from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket def escape(s: str, *, escape_comma: bool = True) -> str: @@ -38,18 +41,60 @@ def _b2s(b: bool) -> str: return str(b).lower() +def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: + if isinstance(result, dict): + if result.get("status") == "failed": + raise ActionFailed(retcode=result.get("retcode")) + return result.get("data") + + +class ResultStore: + _seq = 1 + _futures: Dict[int, asyncio.Future] = {} + + @classmethod + def get_seq(cls) -> int: + s = cls._seq + cls._seq = (cls._seq + 1) % sys.maxsize + return s + + @classmethod + def add_result(cls, result: Dict[str, Any]): + if isinstance(result.get("echo"), dict) and \ + isinstance(result["echo"].get("seq"), int): + future = cls._futures.get(result["echo"]["seq"]) + if future: + future.set_result(result) + + @classmethod + async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]: + future = asyncio.get_event_loop().create_future() + cls._futures[seq] = future + try: + return await asyncio.wait_for(future, timeout) + except asyncio.TimeoutError: + raise NetworkError("WebSocket API call timeout") + finally: + del cls._futures[seq] + + class Bot(BaseBot): def __init__(self, + driver: Driver, connection_type: str, config: Config, - self_id: int, + self_id: str, *, websocket: WebSocket = None): if connection_type not in ["http", "websocket"]: raise ValueError("Unsupported connection type") - super().__init__(connection_type, config, self_id, websocket=websocket) + super().__init__(driver, + connection_type, + config, + self_id, + websocket=websocket) @property @overrides(BaseBot) @@ -61,16 +106,29 @@ class Bot(BaseBot): if not message: return - # TODO: convert message into event event = Event(message) await handle_event(self, event) @overrides(BaseBot) - async def call_api(self, api: str, data: dict): - # TODO: Call API + async def call_api(self, api: str, **data) -> Union[Any, NoReturn]: + if "self_id" in data: + self_id = str(data.pop("self_id")) + bot = self.driver.bots[self_id] + return await bot.call_api(api, **data) + if self.type == "websocket": - pass + seq = ResultStore.get_seq() + await self.websocket.send({ + "action": api, + "params": data, + "echo": { + "seq": seq + } + }) + return _handle_api_result(await ResultStore.fetch( + seq, self.config.api_timeout)) + elif self.type == "http": api_root = self.config.api_root.get(self.self_id) if not api_root: @@ -82,14 +140,19 @@ class Bot(BaseBot): if self.config.access_token: headers["Authorization"] = "Bearer " + self.config.access_token - async with httpx.AsyncClient() as client: - response = await client.post(api_root + api) + try: + async with httpx.AsyncClient(headers=headers) as client: + response = await client.post(api_root + api, json=data) - if 200 <= response.status_code < 300: - # TODO: handle http api response - return ... - raise httpx.HTTPError( - "", response) + if 200 <= response.status_code < 300: + result = response.json() + return _handle_api_result(result) + raise NetworkError(f"HTTP request received unexpected " + f"status code: {response.status_code}") + except httpx.InvalidURL: + raise NetworkError("API root url invalid") + except httpx.HTTPError: + raise NetworkError("HTTP request failed") class Event(BaseEvent): diff --git a/nonebot/config.py b/nonebot/config.py index e236d42b..34da187d 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -98,7 +98,8 @@ class Config(BaseConfig): debug: bool = False # bot connection configs - api_root: Dict[int, str] = {} + api_root: Dict[str, str] = {} + api_timeout: float = 60. access_token: Optional[str] = None # bot runtime configs diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 4c642935..30e657c7 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -5,16 +5,21 @@ import abc from ipaddress import IPv4Address from nonebot.config import Env, Config -from nonebot.typing import Bot, Dict, Optional, Callable +from nonebot.typing import Bot, Dict, Type, Optional, Callable class BaseDriver(abc.ABC): + _adapters: Dict[str, Type[Bot]] = {} @abc.abstractmethod def __init__(self, env: Env, config: Config): self.env = env.environment self.config = config - self._clients: Dict[int, Bot] = {} + self._clients: Dict[str, Bot] = {} + + @classmethod + def register_adapter(cls, name: str, adapter: Type[Bot]): + cls._adapters[name] = adapter @property @abc.abstractmethod @@ -32,7 +37,7 @@ class BaseDriver(abc.ABC): raise NotImplementedError @property - def bots(self) -> Dict[int, Bot]: + def bots(self) -> Dict[str, Bot]: return self._clients @abc.abstractmethod diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 2c4d2e42..60f23261 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -106,14 +106,15 @@ class Driver(BaseDriver): adapter: str, response: Response, data: dict = Body(...), - x_self_id: int = Header(None), + x_self_id: str = Header(None), access_token: str = OAuth2PasswordBearer( "/", auto_error=False)): # TODO: Check authorization # Create Bot Object - if adapter == "cqhttp": - bot = CQBot("http", self.config, x_self_id) + if adapter in self._adapters: + BotClass = self._adapters[adapter] + bot = BotClass(self, "http", self.config, x_self_id) else: response.status_code = status.HTTP_404_NOT_FOUND return {"status": 404, "message": "adapter not found"} @@ -125,7 +126,7 @@ class Driver(BaseDriver): async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket, - x_self_id: int = Header(None), + x_self_id: str = Header(None), access_token: str = OAuth2PasswordBearer( "/", auto_error=False)): websocket = WebSocket(websocket) @@ -134,7 +135,8 @@ class Driver(BaseDriver): # Create Bot Object if adapter == "coolq": - bot = CQBot("websocket", + bot = CQBot(self, + "websocket", self.config, x_self_id, websocket=websocket) diff --git a/nonebot/exception.py b/nonebot/exception.py index de001dad..cc485b7b 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from nonebot.typing import Optional + class IgnoredException(Exception): """ @@ -33,3 +35,21 @@ class FinishedException(Exception): class ApiNotAvailable(Exception): """Api is not available""" pass + + +class NetworkError(Exception): + """There is something error with the network""" + pass + + +class ActionFailed(Exception): + """The action call returned a failed response""" + + def __init__(self, retcode: Optional[int]): + self.retcode = retcode + + def __repr__(self): + return f"" + + def __str__(self): + return self.__repr__() diff --git a/nonebot/matcher.py b/nonebot/matcher.py index bd77cfe6..fda4926b 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import typing from functools import wraps from datetime import datetime from collections import defaultdict @@ -150,6 +151,10 @@ class Matcher: for _ in range(len(self.handlers)): handler = self.handlers.pop(0) + annotation = typing.get_type_hints(handler) + BotType = annotation.get("bot") + if BotType and not isinstance(bot, BotType): + continue await handler(bot, event, self.state) except RejectedException: self.handlers.insert(0, handler) # type: ignore diff --git a/nonebot/typing.py b/nonebot/typing.py index f3f6ec0e..816b62e2 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -35,5 +35,5 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None], Awaitable[NoReturn]]] Matcher = TypeVar("Matcher", bound="MatcherClass") -Handler = Callable[["Bot", Event, dict], Union[Awaitable[None], - Awaitable[NoReturn]]] +Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None], + Awaitable[NoReturn]]] diff --git a/tests/test_plugins/test_matcher.py b/tests/test_plugins/test_matcher.py index 6f5c4e37..fb36f821 100644 --- a/tests/test_plugins/test_matcher.py +++ b/tests/test_plugins/test_matcher.py @@ -2,9 +2,9 @@ # -*- coding: utf-8 -*- from nonebot.rule import Rule -from nonebot.event import Event +from nonebot.typing import Event from nonebot.plugin import on_message -from nonebot.adapters.cqhttp import Message +from nonebot.adapters.cqhttp import Bot, Message print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]"))) @@ -12,13 +12,13 @@ test_matcher = on_message(Rule(), state={"default": 1}) @test_matcher.handle() -async def test_handler(bot, event: Event, state: dict): +async def test_handler(bot: Bot, event: Event, state: dict): print("Test Matcher Received:", event) print("Current State:", state) - state["message1"] = event.get("raw_message") + state["event"] = event @test_matcher.receive() -async def test_receive(bot, event: Event, state: dict): +async def test_receive(bot: Bot, event: Event, state: dict): print("Test Matcher Received next time:", event) print("Current State:", state)