websocket api

This commit is contained in:
yanyongyu 2020-08-13 15:23:04 +08:00
parent 0e73d4ce20
commit e7f9b2c229
10 changed files with 141 additions and 35 deletions

View File

@ -6,9 +6,10 @@ import importlib
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.log import logger from nonebot.log import logger
from nonebot.typing import Union, Optional, NoReturn
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import BaseDriver from nonebot.drivers import BaseDriver
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.typing import Union, Optional, NoReturn
_driver: Optional[BaseDriver] = None _driver: Optional[BaseDriver] = None
@ -40,6 +41,8 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
Driver = getattr(importlib.import_module(config.driver), "Driver") Driver = getattr(importlib.import_module(config.driver), "Driver")
_driver = Driver(env, config) _driver = Driver(env, config)
_driver.register_adapter("cqhttp", CQBot)
def run(host: Optional[IPv4Address] = None, def run(host: Optional[IPv4Address] = None,
port: Optional[int] = None, port: Optional[int] = None,

View File

@ -2,27 +2,33 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import abc import abc
from functools import reduce from functools import reduce, partial
from dataclasses import dataclass, field from dataclasses import dataclass, field
from nonebot.config import Config from nonebot.config import Config
from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket from nonebot.typing import Driver, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
class BaseBot(abc.ABC): class BaseBot(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __init__(self, def __init__(self,
driver: Driver,
connection_type: str, connection_type: str,
config: Config, config: Config,
self_id: int, self_id: str,
*, *,
websocket: WebSocket = None): websocket: WebSocket = None):
self.driver = driver
self.connection_type = connection_type self.connection_type = connection_type
self.config = config self.config = config
self.self_id = self_id self.self_id = self_id
self.websocket = websocket self.websocket = websocket
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
return partial(self.call_api, name)
@property @property
@abc.abstractmethod @abc.abstractmethod
def type(self) -> str: def type(self) -> str:
@ -37,6 +43,7 @@ class BaseBot(abc.ABC):
raise NotImplementedError raise NotImplementedError
# TODO: improve event
class BaseEvent(abc.ABC): class BaseEvent(abc.ABC):
def __init__(self, raw_event: dict): def __init__(self, raw_event: dict):

View File

@ -2,14 +2,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
import sys
import asyncio
import httpx import httpx
from nonebot.config import Config from nonebot.config import Config
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.exception import ApiNotAvailable from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str: def escape(s: str, *, escape_comma: bool = True) -> str:
@ -38,18 +41,60 @@ def _b2s(b: bool) -> str:
return str(b).lower() return str(b).lower()
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
if isinstance(result, dict):
if result.get("status") == "failed":
raise ActionFailed(retcode=result.get("retcode"))
return result.get("data")
class ResultStore:
_seq = 1
_futures: Dict[int, asyncio.Future] = {}
@classmethod
def get_seq(cls) -> int:
s = cls._seq
cls._seq = (cls._seq + 1) % sys.maxsize
return s
@classmethod
def add_result(cls, result: Dict[str, Any]):
if isinstance(result.get("echo"), dict) and \
isinstance(result["echo"].get("seq"), int):
future = cls._futures.get(result["echo"]["seq"])
if future:
future.set_result(result)
@classmethod
async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]:
future = asyncio.get_event_loop().create_future()
cls._futures[seq] = future
try:
return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
raise NetworkError("WebSocket API call timeout")
finally:
del cls._futures[seq]
class Bot(BaseBot): class Bot(BaseBot):
def __init__(self, def __init__(self,
driver: Driver,
connection_type: str, connection_type: str,
config: Config, config: Config,
self_id: int, self_id: str,
*, *,
websocket: WebSocket = None): websocket: WebSocket = None):
if connection_type not in ["http", "websocket"]: if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type") raise ValueError("Unsupported connection type")
super().__init__(connection_type, config, self_id, websocket=websocket) super().__init__(driver,
connection_type,
config,
self_id,
websocket=websocket)
@property @property
@overrides(BaseBot) @overrides(BaseBot)
@ -61,16 +106,29 @@ class Bot(BaseBot):
if not message: if not message:
return return
# TODO: convert message into event
event = Event(message) event = Event(message)
await handle_event(self, event) await handle_event(self, event)
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, data: dict): async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
# TODO: Call API if "self_id" in data:
self_id = str(data.pop("self_id"))
bot = self.driver.bots[self_id]
return await bot.call_api(api, **data)
if self.type == "websocket": if self.type == "websocket":
pass seq = ResultStore.get_seq()
await self.websocket.send({
"action": api,
"params": data,
"echo": {
"seq": seq
}
})
return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout))
elif self.type == "http": elif self.type == "http":
api_root = self.config.api_root.get(self.self_id) api_root = self.config.api_root.get(self.self_id)
if not api_root: if not api_root:
@ -82,14 +140,19 @@ class Bot(BaseBot):
if self.config.access_token: if self.config.access_token:
headers["Authorization"] = "Bearer " + self.config.access_token headers["Authorization"] = "Bearer " + self.config.access_token
async with httpx.AsyncClient() as client: try:
response = await client.post(api_root + api) async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(api_root + api, json=data)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
# TODO: handle http api response result = response.json()
return ... return _handle_api_result(result)
raise httpx.HTTPError( raise NetworkError(f"HTTP request received unexpected "
"<HttpFailed {0.status_code} for url: {0.url}>", response) f"status code: {response.status_code}")
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
class Event(BaseEvent): class Event(BaseEvent):

View File

@ -98,7 +98,8 @@ class Config(BaseConfig):
debug: bool = False debug: bool = False
# bot connection configs # bot connection configs
api_root: Dict[int, str] = {} api_root: Dict[str, str] = {}
api_timeout: float = 60.
access_token: Optional[str] = None access_token: Optional[str] = None
# bot runtime configs # bot runtime configs

View File

@ -5,16 +5,21 @@ import abc
from ipaddress import IPv4Address from ipaddress import IPv4Address
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.typing import Bot, Dict, Optional, Callable from nonebot.typing import Bot, Dict, Type, Optional, Callable
class BaseDriver(abc.ABC): class BaseDriver(abc.ABC):
_adapters: Dict[str, Type[Bot]] = {}
@abc.abstractmethod @abc.abstractmethod
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
self.env = env.environment self.env = env.environment
self.config = config self.config = config
self._clients: Dict[int, Bot] = {} self._clients: Dict[str, Bot] = {}
@classmethod
def register_adapter(cls, name: str, adapter: Type[Bot]):
cls._adapters[name] = adapter
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -32,7 +37,7 @@ class BaseDriver(abc.ABC):
raise NotImplementedError raise NotImplementedError
@property @property
def bots(self) -> Dict[int, Bot]: def bots(self) -> Dict[str, Bot]:
return self._clients return self._clients
@abc.abstractmethod @abc.abstractmethod

View File

@ -106,14 +106,15 @@ class Driver(BaseDriver):
adapter: str, adapter: str,
response: Response, response: Response,
data: dict = Body(...), data: dict = Body(...),
x_self_id: int = Header(None), x_self_id: str = Header(None),
access_token: str = OAuth2PasswordBearer( access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)): "/", auto_error=False)):
# TODO: Check authorization # TODO: Check authorization
# Create Bot Object # Create Bot Object
if adapter == "cqhttp": if adapter in self._adapters:
bot = CQBot("http", self.config, x_self_id) BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id)
else: else:
response.status_code = status.HTTP_404_NOT_FOUND response.status_code = status.HTTP_404_NOT_FOUND
return {"status": 404, "message": "adapter not found"} return {"status": 404, "message": "adapter not found"}
@ -125,7 +126,7 @@ class Driver(BaseDriver):
async def _handle_ws_reverse(self, async def _handle_ws_reverse(self,
adapter: str, adapter: str,
websocket: FastAPIWebSocket, websocket: FastAPIWebSocket,
x_self_id: int = Header(None), x_self_id: str = Header(None),
access_token: str = OAuth2PasswordBearer( access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)): "/", auto_error=False)):
websocket = WebSocket(websocket) websocket = WebSocket(websocket)
@ -134,7 +135,8 @@ class Driver(BaseDriver):
# Create Bot Object # Create Bot Object
if adapter == "coolq": if adapter == "coolq":
bot = CQBot("websocket", bot = CQBot(self,
"websocket",
self.config, self.config,
x_self_id, x_self_id,
websocket=websocket) websocket=websocket)

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from nonebot.typing import Optional
class IgnoredException(Exception): class IgnoredException(Exception):
""" """
@ -33,3 +35,21 @@ class FinishedException(Exception):
class ApiNotAvailable(Exception): class ApiNotAvailable(Exception):
"""Api is not available""" """Api is not available"""
pass pass
class NetworkError(Exception):
"""There is something error with the network"""
pass
class ActionFailed(Exception):
"""The action call returned a failed response"""
def __init__(self, retcode: Optional[int]):
self.retcode = retcode
def __repr__(self):
return f"<ActionFailed, retcode={self.retcode}>"
def __str__(self):
return self.__repr__()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import typing
from functools import wraps from functools import wraps
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
@ -150,6 +151,10 @@ class Matcher:
for _ in range(len(self.handlers)): for _ in range(len(self.handlers)):
handler = self.handlers.pop(0) handler = self.handlers.pop(0)
annotation = typing.get_type_hints(handler)
BotType = annotation.get("bot")
if BotType and not isinstance(bot, BotType):
continue
await handler(bot, event, self.state) await handler(bot, event, self.state)
except RejectedException: except RejectedException:
self.handlers.insert(0, handler) # type: ignore self.handlers.insert(0, handler) # type: ignore

View File

@ -35,5 +35,5 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
Awaitable[NoReturn]]] Awaitable[NoReturn]]]
Matcher = TypeVar("Matcher", bound="MatcherClass") Matcher = TypeVar("Matcher", bound="MatcherClass")
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None], Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None],
Awaitable[NoReturn]]] Awaitable[NoReturn]]]

View File

@ -2,9 +2,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.event import Event from nonebot.typing import Event
from nonebot.plugin import on_message from nonebot.plugin import on_message
from nonebot.adapters.cqhttp import Message from nonebot.adapters.cqhttp import Bot, Message
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]"))) print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
@ -12,13 +12,13 @@ test_matcher = on_message(Rule(), state={"default": 1})
@test_matcher.handle() @test_matcher.handle()
async def test_handler(bot, event: Event, state: dict): async def test_handler(bot: Bot, event: Event, state: dict):
print("Test Matcher Received:", event) print("Test Matcher Received:", event)
print("Current State:", state) print("Current State:", state)
state["message1"] = event.get("raw_message") state["event"] = event
@test_matcher.receive() @test_matcher.receive()
async def test_receive(bot, event: Event, state: dict): async def test_receive(bot: Bot, event: Event, state: dict):
print("Test Matcher Received next time:", event) print("Test Matcher Received next time:", event)
print("Current State:", state) print("Current State:", state)