websocket api

This commit is contained in:
yanyongyu 2020-08-13 15:23:04 +08:00
parent 0e73d4ce20
commit e7f9b2c229
10 changed files with 141 additions and 35 deletions

View File

@ -6,9 +6,10 @@ import importlib
from ipaddress import IPv4Address
from nonebot.log import logger
from nonebot.typing import Union, Optional, NoReturn
from nonebot.config import Env, Config
from nonebot.drivers import BaseDriver
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.typing import Union, Optional, NoReturn
_driver: Optional[BaseDriver] = None
@ -40,6 +41,8 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
Driver = getattr(importlib.import_module(config.driver), "Driver")
_driver = Driver(env, config)
_driver.register_adapter("cqhttp", CQBot)
def run(host: Optional[IPv4Address] = None,
port: Optional[int] = None,

View File

@ -2,27 +2,33 @@
# -*- coding: utf-8 -*-
import abc
from functools import reduce
from functools import reduce, partial
from dataclasses import dataclass, field
from nonebot.config import Config
from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket
from nonebot.typing import Driver, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
class BaseBot(abc.ABC):
@abc.abstractmethod
def __init__(self,
driver: Driver,
connection_type: str,
config: Config,
self_id: int,
self_id: str,
*,
websocket: WebSocket = None):
self.driver = driver
self.connection_type = connection_type
self.config = config
self.self_id = self_id
self.websocket = websocket
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
return partial(self.call_api, name)
@property
@abc.abstractmethod
def type(self) -> str:
@ -37,6 +43,7 @@ class BaseBot(abc.ABC):
raise NotImplementedError
# TODO: improve event
class BaseEvent(abc.ABC):
def __init__(self, raw_event: dict):

View File

@ -2,14 +2,17 @@
# -*- coding: utf-8 -*-
import re
import sys
import asyncio
import httpx
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.exception import ApiNotAvailable
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str:
@ -38,18 +41,60 @@ def _b2s(b: bool) -> str:
return str(b).lower()
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
if isinstance(result, dict):
if result.get("status") == "failed":
raise ActionFailed(retcode=result.get("retcode"))
return result.get("data")
class ResultStore:
_seq = 1
_futures: Dict[int, asyncio.Future] = {}
@classmethod
def get_seq(cls) -> int:
s = cls._seq
cls._seq = (cls._seq + 1) % sys.maxsize
return s
@classmethod
def add_result(cls, result: Dict[str, Any]):
if isinstance(result.get("echo"), dict) and \
isinstance(result["echo"].get("seq"), int):
future = cls._futures.get(result["echo"]["seq"])
if future:
future.set_result(result)
@classmethod
async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]:
future = asyncio.get_event_loop().create_future()
cls._futures[seq] = future
try:
return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
raise NetworkError("WebSocket API call timeout")
finally:
del cls._futures[seq]
class Bot(BaseBot):
def __init__(self,
driver: Driver,
connection_type: str,
config: Config,
self_id: int,
self_id: str,
*,
websocket: WebSocket = None):
if connection_type not in ["http", "websocket"]:
raise ValueError("Unsupported connection type")
super().__init__(connection_type, config, self_id, websocket=websocket)
super().__init__(driver,
connection_type,
config,
self_id,
websocket=websocket)
@property
@overrides(BaseBot)
@ -61,16 +106,29 @@ class Bot(BaseBot):
if not message:
return
# TODO: convert message into event
event = Event(message)
await handle_event(self, event)
@overrides(BaseBot)
async def call_api(self, api: str, data: dict):
# TODO: Call API
async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
if "self_id" in data:
self_id = str(data.pop("self_id"))
bot = self.driver.bots[self_id]
return await bot.call_api(api, **data)
if self.type == "websocket":
pass
seq = ResultStore.get_seq()
await self.websocket.send({
"action": api,
"params": data,
"echo": {
"seq": seq
}
})
return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout))
elif self.type == "http":
api_root = self.config.api_root.get(self.self_id)
if not api_root:
@ -82,14 +140,19 @@ class Bot(BaseBot):
if self.config.access_token:
headers["Authorization"] = "Bearer " + self.config.access_token
async with httpx.AsyncClient() as client:
response = await client.post(api_root + api)
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(api_root + api, json=data)
if 200 <= response.status_code < 300:
# TODO: handle http api response
return ...
raise httpx.HTTPError(
"<HttpFailed {0.status_code} for url: {0.url}>", response)
result = response.json()
return _handle_api_result(result)
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
class Event(BaseEvent):

View File

@ -98,7 +98,8 @@ class Config(BaseConfig):
debug: bool = False
# bot connection configs
api_root: Dict[int, str] = {}
api_root: Dict[str, str] = {}
api_timeout: float = 60.
access_token: Optional[str] = None
# bot runtime configs

View File

@ -5,16 +5,21 @@ import abc
from ipaddress import IPv4Address
from nonebot.config import Env, Config
from nonebot.typing import Bot, Dict, Optional, Callable
from nonebot.typing import Bot, Dict, Type, Optional, Callable
class BaseDriver(abc.ABC):
_adapters: Dict[str, Type[Bot]] = {}
@abc.abstractmethod
def __init__(self, env: Env, config: Config):
self.env = env.environment
self.config = config
self._clients: Dict[int, Bot] = {}
self._clients: Dict[str, Bot] = {}
@classmethod
def register_adapter(cls, name: str, adapter: Type[Bot]):
cls._adapters[name] = adapter
@property
@abc.abstractmethod
@ -32,7 +37,7 @@ class BaseDriver(abc.ABC):
raise NotImplementedError
@property
def bots(self) -> Dict[int, Bot]:
def bots(self) -> Dict[str, Bot]:
return self._clients
@abc.abstractmethod

View File

@ -106,14 +106,15 @@ class Driver(BaseDriver):
adapter: str,
response: Response,
data: dict = Body(...),
x_self_id: int = Header(None),
x_self_id: str = Header(None),
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
# TODO: Check authorization
# Create Bot Object
if adapter == "cqhttp":
bot = CQBot("http", self.config, x_self_id)
if adapter in self._adapters:
BotClass = self._adapters[adapter]
bot = BotClass(self, "http", self.config, x_self_id)
else:
response.status_code = status.HTTP_404_NOT_FOUND
return {"status": 404, "message": "adapter not found"}
@ -125,7 +126,7 @@ class Driver(BaseDriver):
async def _handle_ws_reverse(self,
adapter: str,
websocket: FastAPIWebSocket,
x_self_id: int = Header(None),
x_self_id: str = Header(None),
access_token: str = OAuth2PasswordBearer(
"/", auto_error=False)):
websocket = WebSocket(websocket)
@ -134,7 +135,8 @@ class Driver(BaseDriver):
# Create Bot Object
if adapter == "coolq":
bot = CQBot("websocket",
bot = CQBot(self,
"websocket",
self.config,
x_self_id,
websocket=websocket)

View File

@ -1,6 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from nonebot.typing import Optional
class IgnoredException(Exception):
"""
@ -33,3 +35,21 @@ class FinishedException(Exception):
class ApiNotAvailable(Exception):
"""Api is not available"""
pass
class NetworkError(Exception):
"""There is something error with the network"""
pass
class ActionFailed(Exception):
"""The action call returned a failed response"""
def __init__(self, retcode: Optional[int]):
self.retcode = retcode
def __repr__(self):
return f"<ActionFailed, retcode={self.retcode}>"
def __str__(self):
return self.__repr__()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import typing
from functools import wraps
from datetime import datetime
from collections import defaultdict
@ -150,6 +151,10 @@ class Matcher:
for _ in range(len(self.handlers)):
handler = self.handlers.pop(0)
annotation = typing.get_type_hints(handler)
BotType = annotation.get("bot")
if BotType and not isinstance(bot, BotType):
continue
await handler(bot, event, self.state)
except RejectedException:
self.handlers.insert(0, handler) # type: ignore

View File

@ -35,5 +35,5 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
Awaitable[NoReturn]]]
Matcher = TypeVar("Matcher", bound="MatcherClass")
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None],
Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None],
Awaitable[NoReturn]]]

View File

@ -2,9 +2,9 @@
# -*- coding: utf-8 -*-
from nonebot.rule import Rule
from nonebot.event import Event
from nonebot.typing import Event
from nonebot.plugin import on_message
from nonebot.adapters.cqhttp import Message
from nonebot.adapters.cqhttp import Bot, Message
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
@ -12,13 +12,13 @@ test_matcher = on_message(Rule(), state={"default": 1})
@test_matcher.handle()
async def test_handler(bot, event: Event, state: dict):
async def test_handler(bot: Bot, event: Event, state: dict):
print("Test Matcher Received:", event)
print("Current State:", state)
state["message1"] = event.get("raw_message")
state["event"] = event
@test_matcher.receive()
async def test_receive(bot, event: Event, state: dict):
async def test_receive(bot: Bot, event: Event, state: dict):
print("Test Matcher Received next time:", event)
print("Current State:", state)