mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 04:56:57 +08:00
websocket api
This commit is contained in:
parent
0e73d4ce20
commit
e7f9b2c229
@ -6,9 +6,10 @@ import importlib
|
|||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.typing import Union, Optional, NoReturn
|
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.drivers import BaseDriver
|
from nonebot.drivers import BaseDriver
|
||||||
|
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||||
|
from nonebot.typing import Union, Optional, NoReturn
|
||||||
|
|
||||||
_driver: Optional[BaseDriver] = None
|
_driver: Optional[BaseDriver] = None
|
||||||
|
|
||||||
@ -40,6 +41,8 @@ def init(*, _env_file: Optional[str] = None, **kwargs):
|
|||||||
Driver = getattr(importlib.import_module(config.driver), "Driver")
|
Driver = getattr(importlib.import_module(config.driver), "Driver")
|
||||||
_driver = Driver(env, config)
|
_driver = Driver(env, config)
|
||||||
|
|
||||||
|
_driver.register_adapter("cqhttp", CQBot)
|
||||||
|
|
||||||
|
|
||||||
def run(host: Optional[IPv4Address] = None,
|
def run(host: Optional[IPv4Address] = None,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
|
@ -2,27 +2,33 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from functools import reduce
|
from functools import reduce, partial
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.typing import Dict, Union, Optional, Iterable, WebSocket
|
from nonebot.typing import Driver, WebSocket
|
||||||
|
from nonebot.typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable
|
||||||
|
|
||||||
|
|
||||||
class BaseBot(abc.ABC):
|
class BaseBot(abc.ABC):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
driver: Driver,
|
||||||
connection_type: str,
|
connection_type: str,
|
||||||
config: Config,
|
config: Config,
|
||||||
self_id: int,
|
self_id: str,
|
||||||
*,
|
*,
|
||||||
websocket: WebSocket = None):
|
websocket: WebSocket = None):
|
||||||
|
self.driver = driver
|
||||||
self.connection_type = connection_type
|
self.connection_type = connection_type
|
||||||
self.config = config
|
self.config = config
|
||||||
self.self_id = self_id
|
self.self_id = self_id
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
|
||||||
|
return partial(self.call_api, name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -37,6 +43,7 @@ class BaseBot(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: improve event
|
||||||
class BaseEvent(abc.ABC):
|
class BaseEvent(abc.ABC):
|
||||||
|
|
||||||
def __init__(self, raw_event: dict):
|
def __init__(self, raw_event: dict):
|
||||||
|
@ -2,14 +2,17 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from nonebot.config import Config
|
from nonebot.config import Config
|
||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.exception import ApiNotAvailable
|
from nonebot.typing import overrides, Driver, WebSocket, NoReturn
|
||||||
|
from nonebot.typing import Any, Dict, Union, Tuple, Iterable, Optional
|
||||||
|
from nonebot.exception import NetworkError, ActionFailed, ApiNotAvailable
|
||||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||||
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
|
|
||||||
|
|
||||||
|
|
||||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||||
@ -38,18 +41,60 @@ def _b2s(b: bool) -> str:
|
|||||||
return str(b).lower()
|
return str(b).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
if result.get("status") == "failed":
|
||||||
|
raise ActionFailed(retcode=result.get("retcode"))
|
||||||
|
return result.get("data")
|
||||||
|
|
||||||
|
|
||||||
|
class ResultStore:
|
||||||
|
_seq = 1
|
||||||
|
_futures: Dict[int, asyncio.Future] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_seq(cls) -> int:
|
||||||
|
s = cls._seq
|
||||||
|
cls._seq = (cls._seq + 1) % sys.maxsize
|
||||||
|
return s
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_result(cls, result: Dict[str, Any]):
|
||||||
|
if isinstance(result.get("echo"), dict) and \
|
||||||
|
isinstance(result["echo"].get("seq"), int):
|
||||||
|
future = cls._futures.get(result["echo"]["seq"])
|
||||||
|
if future:
|
||||||
|
future.set_result(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls, seq: int, timeout: float) -> Dict[str, Any]:
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
cls._futures[seq] = future
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(future, timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise NetworkError("WebSocket API call timeout")
|
||||||
|
finally:
|
||||||
|
del cls._futures[seq]
|
||||||
|
|
||||||
|
|
||||||
class Bot(BaseBot):
|
class Bot(BaseBot):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
driver: Driver,
|
||||||
connection_type: str,
|
connection_type: str,
|
||||||
config: Config,
|
config: Config,
|
||||||
self_id: int,
|
self_id: str,
|
||||||
*,
|
*,
|
||||||
websocket: WebSocket = None):
|
websocket: WebSocket = None):
|
||||||
if connection_type not in ["http", "websocket"]:
|
if connection_type not in ["http", "websocket"]:
|
||||||
raise ValueError("Unsupported connection type")
|
raise ValueError("Unsupported connection type")
|
||||||
|
|
||||||
super().__init__(connection_type, config, self_id, websocket=websocket)
|
super().__init__(driver,
|
||||||
|
connection_type,
|
||||||
|
config,
|
||||||
|
self_id,
|
||||||
|
websocket=websocket)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
@ -61,16 +106,29 @@ class Bot(BaseBot):
|
|||||||
if not message:
|
if not message:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: convert message into event
|
|
||||||
event = Event(message)
|
event = Event(message)
|
||||||
|
|
||||||
await handle_event(self, event)
|
await handle_event(self, event)
|
||||||
|
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def call_api(self, api: str, data: dict):
|
async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
|
||||||
# TODO: Call API
|
if "self_id" in data:
|
||||||
|
self_id = str(data.pop("self_id"))
|
||||||
|
bot = self.driver.bots[self_id]
|
||||||
|
return await bot.call_api(api, **data)
|
||||||
|
|
||||||
if self.type == "websocket":
|
if self.type == "websocket":
|
||||||
pass
|
seq = ResultStore.get_seq()
|
||||||
|
await self.websocket.send({
|
||||||
|
"action": api,
|
||||||
|
"params": data,
|
||||||
|
"echo": {
|
||||||
|
"seq": seq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return _handle_api_result(await ResultStore.fetch(
|
||||||
|
seq, self.config.api_timeout))
|
||||||
|
|
||||||
elif self.type == "http":
|
elif self.type == "http":
|
||||||
api_root = self.config.api_root.get(self.self_id)
|
api_root = self.config.api_root.get(self.self_id)
|
||||||
if not api_root:
|
if not api_root:
|
||||||
@ -82,14 +140,19 @@ class Bot(BaseBot):
|
|||||||
if self.config.access_token:
|
if self.config.access_token:
|
||||||
headers["Authorization"] = "Bearer " + self.config.access_token
|
headers["Authorization"] = "Bearer " + self.config.access_token
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
try:
|
||||||
response = await client.post(api_root + api)
|
async with httpx.AsyncClient(headers=headers) as client:
|
||||||
|
response = await client.post(api_root + api, json=data)
|
||||||
|
|
||||||
if 200 <= response.status_code < 300:
|
if 200 <= response.status_code < 300:
|
||||||
# TODO: handle http api response
|
result = response.json()
|
||||||
return ...
|
return _handle_api_result(result)
|
||||||
raise httpx.HTTPError(
|
raise NetworkError(f"HTTP request received unexpected "
|
||||||
"<HttpFailed {0.status_code} for url: {0.url}>", response)
|
f"status code: {response.status_code}")
|
||||||
|
except httpx.InvalidURL:
|
||||||
|
raise NetworkError("API root url invalid")
|
||||||
|
except httpx.HTTPError:
|
||||||
|
raise NetworkError("HTTP request failed")
|
||||||
|
|
||||||
|
|
||||||
class Event(BaseEvent):
|
class Event(BaseEvent):
|
||||||
|
@ -98,7 +98,8 @@ class Config(BaseConfig):
|
|||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
|
||||||
# bot connection configs
|
# bot connection configs
|
||||||
api_root: Dict[int, str] = {}
|
api_root: Dict[str, str] = {}
|
||||||
|
api_timeout: float = 60.
|
||||||
access_token: Optional[str] = None
|
access_token: Optional[str] = None
|
||||||
|
|
||||||
# bot runtime configs
|
# bot runtime configs
|
||||||
|
@ -5,16 +5,21 @@ import abc
|
|||||||
from ipaddress import IPv4Address
|
from ipaddress import IPv4Address
|
||||||
|
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
from nonebot.typing import Bot, Dict, Optional, Callable
|
from nonebot.typing import Bot, Dict, Type, Optional, Callable
|
||||||
|
|
||||||
|
|
||||||
class BaseDriver(abc.ABC):
|
class BaseDriver(abc.ABC):
|
||||||
|
_adapters: Dict[str, Type[Bot]] = {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __init__(self, env: Env, config: Config):
|
def __init__(self, env: Env, config: Config):
|
||||||
self.env = env.environment
|
self.env = env.environment
|
||||||
self.config = config
|
self.config = config
|
||||||
self._clients: Dict[int, Bot] = {}
|
self._clients: Dict[str, Bot] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_adapter(cls, name: str, adapter: Type[Bot]):
|
||||||
|
cls._adapters[name] = adapter
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -32,7 +37,7 @@ class BaseDriver(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bots(self) -> Dict[int, Bot]:
|
def bots(self) -> Dict[str, Bot]:
|
||||||
return self._clients
|
return self._clients
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -106,14 +106,15 @@ class Driver(BaseDriver):
|
|||||||
adapter: str,
|
adapter: str,
|
||||||
response: Response,
|
response: Response,
|
||||||
data: dict = Body(...),
|
data: dict = Body(...),
|
||||||
x_self_id: int = Header(None),
|
x_self_id: str = Header(None),
|
||||||
access_token: str = OAuth2PasswordBearer(
|
access_token: str = OAuth2PasswordBearer(
|
||||||
"/", auto_error=False)):
|
"/", auto_error=False)):
|
||||||
# TODO: Check authorization
|
# TODO: Check authorization
|
||||||
|
|
||||||
# Create Bot Object
|
# Create Bot Object
|
||||||
if adapter == "cqhttp":
|
if adapter in self._adapters:
|
||||||
bot = CQBot("http", self.config, x_self_id)
|
BotClass = self._adapters[adapter]
|
||||||
|
bot = BotClass(self, "http", self.config, x_self_id)
|
||||||
else:
|
else:
|
||||||
response.status_code = status.HTTP_404_NOT_FOUND
|
response.status_code = status.HTTP_404_NOT_FOUND
|
||||||
return {"status": 404, "message": "adapter not found"}
|
return {"status": 404, "message": "adapter not found"}
|
||||||
@ -125,7 +126,7 @@ class Driver(BaseDriver):
|
|||||||
async def _handle_ws_reverse(self,
|
async def _handle_ws_reverse(self,
|
||||||
adapter: str,
|
adapter: str,
|
||||||
websocket: FastAPIWebSocket,
|
websocket: FastAPIWebSocket,
|
||||||
x_self_id: int = Header(None),
|
x_self_id: str = Header(None),
|
||||||
access_token: str = OAuth2PasswordBearer(
|
access_token: str = OAuth2PasswordBearer(
|
||||||
"/", auto_error=False)):
|
"/", auto_error=False)):
|
||||||
websocket = WebSocket(websocket)
|
websocket = WebSocket(websocket)
|
||||||
@ -134,7 +135,8 @@ class Driver(BaseDriver):
|
|||||||
|
|
||||||
# Create Bot Object
|
# Create Bot Object
|
||||||
if adapter == "coolq":
|
if adapter == "coolq":
|
||||||
bot = CQBot("websocket",
|
bot = CQBot(self,
|
||||||
|
"websocket",
|
||||||
self.config,
|
self.config,
|
||||||
x_self_id,
|
x_self_id,
|
||||||
websocket=websocket)
|
websocket=websocket)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from nonebot.typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class IgnoredException(Exception):
|
class IgnoredException(Exception):
|
||||||
"""
|
"""
|
||||||
@ -33,3 +35,21 @@ class FinishedException(Exception):
|
|||||||
class ApiNotAvailable(Exception):
|
class ApiNotAvailable(Exception):
|
||||||
"""Api is not available"""
|
"""Api is not available"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkError(Exception):
|
||||||
|
"""There is something error with the network"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ActionFailed(Exception):
|
||||||
|
"""The action call returned a failed response"""
|
||||||
|
|
||||||
|
def __init__(self, retcode: Optional[int]):
|
||||||
|
self.retcode = retcode
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ActionFailed, retcode={self.retcode}>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import typing
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -150,6 +151,10 @@ class Matcher:
|
|||||||
|
|
||||||
for _ in range(len(self.handlers)):
|
for _ in range(len(self.handlers)):
|
||||||
handler = self.handlers.pop(0)
|
handler = self.handlers.pop(0)
|
||||||
|
annotation = typing.get_type_hints(handler)
|
||||||
|
BotType = annotation.get("bot")
|
||||||
|
if BotType and not isinstance(bot, BotType):
|
||||||
|
continue
|
||||||
await handler(bot, event, self.state)
|
await handler(bot, event, self.state)
|
||||||
except RejectedException:
|
except RejectedException:
|
||||||
self.handlers.insert(0, handler) # type: ignore
|
self.handlers.insert(0, handler) # type: ignore
|
||||||
|
@ -35,5 +35,5 @@ PreProcessor = Callable[[Bot, Event], Union[Awaitable[None],
|
|||||||
Awaitable[NoReturn]]]
|
Awaitable[NoReturn]]]
|
||||||
|
|
||||||
Matcher = TypeVar("Matcher", bound="MatcherClass")
|
Matcher = TypeVar("Matcher", bound="MatcherClass")
|
||||||
Handler = Callable[["Bot", Event, dict], Union[Awaitable[None],
|
Handler = Callable[[Bot, Event, Dict[Any, Any]], Union[Awaitable[None],
|
||||||
Awaitable[NoReturn]]]
|
Awaitable[NoReturn]]]
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from nonebot.rule import Rule
|
from nonebot.rule import Rule
|
||||||
from nonebot.event import Event
|
from nonebot.typing import Event
|
||||||
from nonebot.plugin import on_message
|
from nonebot.plugin import on_message
|
||||||
from nonebot.adapters.cqhttp import Message
|
from nonebot.adapters.cqhttp import Bot, Message
|
||||||
|
|
||||||
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
|
print(repr(Message("asdfasdf[CQ:at,qq=123][CQ:at,qq=all]")))
|
||||||
|
|
||||||
@ -12,13 +12,13 @@ test_matcher = on_message(Rule(), state={"default": 1})
|
|||||||
|
|
||||||
|
|
||||||
@test_matcher.handle()
|
@test_matcher.handle()
|
||||||
async def test_handler(bot, event: Event, state: dict):
|
async def test_handler(bot: Bot, event: Event, state: dict):
|
||||||
print("Test Matcher Received:", event)
|
print("Test Matcher Received:", event)
|
||||||
print("Current State:", state)
|
print("Current State:", state)
|
||||||
state["message1"] = event.get("raw_message")
|
state["event"] = event
|
||||||
|
|
||||||
|
|
||||||
@test_matcher.receive()
|
@test_matcher.receive()
|
||||||
async def test_receive(bot, event: Event, state: dict):
|
async def test_receive(bot: Bot, event: Event, state: dict):
|
||||||
print("Test Matcher Received next time:", event)
|
print("Test Matcher Received next time:", event)
|
||||||
print("Current State:", state)
|
print("Current State:", state)
|
||||||
|
Loading…
Reference in New Issue
Block a user