mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
websocket api
This commit is contained in:
parent
0e73d4ce20
commit
e7f9b2c229
@ -6,9 +6,10 @@ import importlib
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.typing import Union, Optional, NoReturn
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.drivers import BaseDriver
|
||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||
from nonebot.typing import Union, Optional, NoReturn
|
||||
|
||||
_driver: Optional[BaseDriver] = None
|
||||
|
||||
@ -40,6 +41,8 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
|
||||
Driver = getattr(importlib.import_module(config.driver), "Driver")
|
||||
_driver = Driver(env, config)
|
||||
|
||||
_driver.register_adapter("cqhttp", CQBot)
|
||||
|
||||
|
||||
def run(host: Optional[IPv4Address] = None,
|
||||
port: Optional[int] = None,
|
||||
|
@ -2,27 +2,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import abc
|
||||
from functools import reduce
|
||||
from functools import reduce, partial
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket
|
||||
from nonebot.typing import Driver, WebSocket
|
||||
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
|
||||
|
||||
|
||||
class BaseBot(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self,
|
||||
driver: Driver,
|
||||
connection_type: str,
|
||||
config: Config,
|
||||
self_id: int,
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: WebSocket = None):
|
||||
self.driver = driver
|
||||
self.connection_type = connection_type
|
||||
self.config = config
|
||||
self.self_id = self_id
|
||||
self.websocket = websocket
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
|
||||
return partial(self.call_api, name)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
@ -37,6 +43,7 @@ class BaseBot(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: improve event
|
||||
class BaseEvent(abc.ABC):
|
||||
|
||||
def __init__(self, raw_event: dict):
|
||||
|
@ -2,14 +2,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
from nonebot.config import Config
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.exception import ApiNotAvailable
|
||||
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
|
||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
|
||||
|
||||
|
||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||
@ -38,18 +41,60 @@ def _b2s(b: bool) -> str:
|
||||
return str(b).lower()
|
||||
|
||||
|
||||
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
|
||||
if isinstance(result, dict):
|
||||
if result.get("status") == "failed":
|
||||
raise ActionFailed(retcode=result.get("retcode"))
|
||||
return result.get("data")
|
||||
|
||||
|
||||
class ResultStore:
|
||||
_seq = 1
|
||||
_futures: Dict[int, asyncio.Future] = {}
|
||||
|
||||
@classmethod
|
||||
def get_seq(cls) -> int:
|
||||
s = cls._seq
|
||||
cls._seq = (cls._seq + 1) % sys.maxsize
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def add_result(cls, result: Dict[str, Any]):
|
||||
if isinstance(result.get("echo"), dict) and \
|
||||
isinstance(result["echo"].get("seq"), int):
|
||||
future = cls._futures.get(result["echo"]["seq"])
|
||||
if future:
|
||||
future.set_result(result)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
cls._futures[seq] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise NetworkError("WebSocket API call timeout")
|
||||
finally:
|
||||
del cls._futures[seq]
|
||||
|
||||
|
||||
class Bot(BaseBot):
|
||||
|
||||
def __init__(self,
|
||||
driver: Driver,
|
||||
connection_type: str,
|
||||
config: Config,
|
||||
self_id: int,
|
||||
self_id: str,
|
||||
*,
|
||||
websocket: WebSocket = None):
|
||||
if connection_type not in ["http", "websocket"]:
|
||||
raise ValueError("Unsupported connection type")
|
||||
|
||||
super().__init__(connection_type, config, self_id, websocket=websocket)
|
||||
super().__init__(driver,
|
||||
connection_type,
|
||||
config,
|
||||
self_id,
|
||||
websocket=websocket)
|
||||
|
||||
@property
|
||||
@overrides(BaseBot)
|
||||
@ -61,16 +106,29 @@ class Bot(BaseBot):
|
||||
if not message:
|
||||
return
|
||||
|
||||
# TODO: convert message into event
|
||||
event = Event(message)
|
||||
|
||||
await handle_event(self, event)
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, data: dict):
|
||||
# TODO: Call API
|
||||
async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
|
||||
if "self_id" in data:
|
||||
self_id = str(data.pop("self_id"))
|
||||
bot = self.driver.bots[self_id]
|
||||
return await bot.call_api(api, **data)
|
||||
|
||||
if self.type == "websocket":
|
||||
pass
|
||||
seq = ResultStore.get_seq()
|
||||
await self.websocket.send({
|
||||
"action": api,
|
||||
"params": data,
|
||||
"echo": {
|
||||
"seq": seq
|
||||
}
|
||||
})
|
||||
return _handle_api_result(await ResultStore.fetch(
|
||||
seq, self.config.api_timeout))
|
||||
|
||||
elif self.type == "http":
|
||||
api_root = self.config.api_root.get(self.self_id)
|
||||
if not api_root:
|
||||
@ -82,14 +140,19 @@ class Bot(BaseBot):
|
||||
if self.config.access_token:
|
||||
headers["Authorization"] = "Bearer " + self.config.access_token
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(api_root + api)
|
||||
try:
|
||||
async with httpx.AsyncClient(headers=headers) as client:
|
||||
response = await client.post(api_root + api, json=data)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
# TODO: handle http api response
|
||||
return ...
|
||||
raise httpx.HTTPError(
|
||||
"<HttpFailed {0.status_code} for url: {0.url}>", response)
|
||||
result = response.json()
|
||||
return _handle_api_result(result)
|
||||
raise NetworkError(f"HTTP request received unexpected "
|
||||
f"status code: {response.status_code}")
|
||||
except httpx.InvalidURL:
|
||||
raise NetworkError("API root url invalid")
|
||||
except httpx.HTTPError:
|
||||
raise NetworkError("HTTP request failed")
|
||||
|
||||
|
||||
class Event(BaseEvent):
|
||||
|
@ -98,7 +98,8 @@ class Config(BaseConfig):
|
||||
debug: bool = False
|
||||
|
||||
# bot connection configs
|
||||
api_root: Dict[int, str] = {}
|
||||
api_root: Dict[str, str] = {}
|
||||
api_timeout: float = 60.
|
||||
access_token: Optional[str] = None
|
||||
|
||||
# bot runtime configs
|
||||
|
@ -5,16 +5,21 @@ import abc
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.typing import Bot, Dict, Optional, Callable
|
||||
from nonebot.typing import Bot, Dict, Type, Optional, Callable
|
||||
|
||||
|
||||
class BaseDriver(abc.ABC):
|
||||
_adapters: Dict[str, Type[Bot]] = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, env: Env, config: Config):
|
||||
self.env = env.environment
|
||||
self.config = config
|
||||
self._clients: Dict[int, Bot] = {}
|
||||
self._clients: Dict[str, Bot] = {}
|
||||
|
||||
@classmethod
|
||||
def register_adapter(cls, name: str, adapter: Type[Bot]):
|
||||
cls._adapters[name] = adapter
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@ -32,7 +37,7 @@ class BaseDriver(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def bots(self) -> Dict[int, Bot]:
|
||||
def bots(self) -> Dict[str, Bot]:
|
||||
return self._clients
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -106,14 +106,15 @@ class Driver(BaseDriver):
|
||||
adapter: str,
|
||||
response: Response,
|
||||
data: dict = Body(...),
|
||||
x_self_id: int = Header(None),
|
||||
x_self_id: str = Header(None),
|
||||
access_token: str = OAuth2PasswordBearer(
|
||||
"/", auto_error=False)):
|
||||
# TODO: Check authorization
|
||||
|
||||
# Create Bot Object
|
||||
if adapter == "cqhttp":
|
||||
bot = CQBot("http", self.config, x_self_id)
|
||||
if adapter in self._adapters:
|
||||
BotClass = self._adapters[adapter]
|
||||
bot = BotClass(self, "http", self.config, x_self_id)
|
||||
else:
|
||||
response.status_code = status.HTTP_404_NOT_FOUND
|
||||
return {"status": 404, "message": "adapter not found"}
|
||||
@ -125,7 +126,7 @@ class Driver(BaseDriver):
|
||||
async def _handle_ws_reverse(self,
|
||||
adapter: str,
|
||||
websocket: FastAPIWebSocket,
|
||||
x_self_id: int = Header(None),
|
||||
x_self_id: str = Header(None),
|
||||
access_token: str = OAuth2PasswordBearer(
|
||||
"/", auto_error=False)):
|
||||
websocket = WebSocket(websocket)
|
||||
@ -134,7 +135,8 @@ class Driver(BaseDriver):
|
||||
|
||||
# Create Bot Object
|
||||
if adapter == "coolq":
|
||||
bot = CQBot("websocket",
|
||||
bot = CQBot(self,
|
||||
"websocket",
|
||||
self.config,
|
||||
x_self_id,
|
||||
websocket=websocket)
|
||||
|
@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from nonebot.typing import Optional
|
||||
|
||||
|
||||
class IgnoredException(Exception):
|
||||
"""
|
||||
@ -33,3 +35,21 @@ class FinishedException(Exception):
|
||||
class ApiNotAvailable(Exception):
|
||||
"""Api is not available"""
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""There is something error with the network"""
|
||||
pass
|
||||
|
||||
|
||||
class ActionFailed(Exception):
|
||||
"""The action call returned a failed response"""
|
||||
|
||||
def __init__(self, retcode: Optional[int]):
|
||||
self.retcode = retcode
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ActionFailed, retcode={self.retcode}>"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import typing
|
||||
from functools import wraps
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
@ -150,6 +151,10 @@ class Matcher:
|
||||
|
||||
for _ in range(len(self.handlers)):
|
||||
handler = self.handlers.pop(0)
|
||||
annotation = typing.get_type_hints(handler)
|
||||
BotType = annotation.get("bot")
|
||||
if BotType and not isinstance(bot, BotType):
|
||||
continue
|
||||
await handler(bot, event, self.state)
|
||||
except RejectedException:
|
||||
self.handlers.insert(0, handler) # type: ignore
|
||||
|
@ -35,5 +35,5 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]]
|
||||
|
||||
Matcher = TypeVar("Matcher", bound="MatcherClass")
|
||||
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None],
|
||||
Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None],
|
||||
Awaitable[NoReturn]]]
|
||||
|
@ -2,9 +2,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.event import Event
|
||||
from nonebot.typing import Event
|
||||
from nonebot.plugin import on_message
|
||||
from nonebot.adapters.cqhttp import Message
|
||||
from nonebot.adapters.cqhttp import Bot, Message
|
||||
|
||||
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
|
||||
|
||||
@ -12,13 +12,13 @@ test_matcher = on_message(Rule(), state={"default": 1})
|
||||
|
||||
|
||||
@test_matcher.handle()
|
||||
async def test_handler(bot, event: Event, state: dict):
|
||||
async def test_handler(bot: Bot, event: Event, state: dict):
|
||||
print("Test Matcher Received:", event)
|
||||
print("Current State:", state)
|
||||
state["message1"] = event.get("raw_message")
|
||||
state["event"] = event
|
||||
|
||||
|
||||
@test_matcher.receive()
|
||||
async def test_receive(bot, event: Event, state: dict):
|
||||
async def test_receive(bot: Bot, event: Event, state: dict):
|
||||
print("Test Matcher Received next time:", event)
|
||||
print("Current State:", state)
|
||||
|
Loading…
Reference in New Issue
Block a user