From b32d4a24d1c1af90ab62993543f989c63381c277 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Tue, 11 Aug 2020 10:44:05 +0800 Subject: [PATCH] add startup shutdown deco --- nonebot/adapters/__init__.py | 7 ++++--- nonebot/adapters/cqhttp.py | 16 +++++++++++++--- nonebot/drivers/__init__.py | 10 +++++++++- nonebot/drivers/fastapi.py | 12 ++++++++++-- nonebot/message.py | 1 - 5 files changed, 36 insertions(+), 10 deletions(-) diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index f02ceae6..f7e6a3bd 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -95,11 +95,12 @@ class BaseMessageSegment(abc.ABC): class BaseMessage(list, abc.ABC): def __init__(self, - message: Union[str, BaseMessageSegment, "BaseMessage"] = None, + message: Union[str, dict, list, BaseMessageSegment, + "BaseMessage"] = None, *args, **kwargs): super().__init__(*args, **kwargs) - if isinstance(message, str): + if isinstance(message, (str, dict, list)): self.extend(self._construct(message)) elif isinstance(message, BaseMessage): self.extend(message) @@ -111,7 +112,7 @@ class BaseMessage(list, abc.ABC): @staticmethod @abc.abstractmethod - def _construct(msg: str) -> Iterable[BaseMessageSegment]: + def _construct(msg: Union[str, dict, list]) -> Iterable[BaseMessageSegment]: raise NotImplementedError def __add__( diff --git a/nonebot/adapters/cqhttp.py b/nonebot/adapters/cqhttp.py index 0b8683fd..ee109468 100644 --- a/nonebot/adapters/cqhttp.py +++ b/nonebot/adapters/cqhttp.py @@ -5,12 +5,11 @@ import re import httpx -# from nonebot.event import Event from nonebot.config import Config from nonebot.message import handle_event from nonebot.exception import ApiNotAvailable -from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment +from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket def escape(s: str, *, escape_comma: bool = True) -> str: @@ -98,6 +97,10 @@ class Bot(BaseBot): class Event(BaseEvent): + def __init__(self, raw_event: dict): + + super().__init__(raw_event) + @property @overrides(BaseEvent) def type(self): @@ -286,7 +289,14 @@ class Message(BaseMessage): @staticmethod @overrides(BaseMessage) - def _construct(msg: str) -> Iterable[MessageSegment]: + def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]: + if isinstance(msg, dict): + yield MessageSegment(msg["type"], msg.get("data") or {}) + return + elif isinstance(msg, list): + for seg in msg: + yield MessageSegment(seg["type"], seg.get("data") or {}) + return def _iter_message() -> Iterable[Tuple[str, str]]: text_begin = 0 diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 440079cb..4c642935 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -5,7 +5,7 @@ import abc from ipaddress import IPv4Address from nonebot.config import Env, Config -from nonebot.typing import Bot, Dict, Optional +from nonebot.typing import Bot, Dict, Optional, Callable class BaseDriver(abc.ABC): @@ -35,6 +35,14 @@ class BaseDriver(abc.ABC): def bots(self) -> Dict[int, Bot]: return self._clients + @abc.abstractmethod + def on_startup(self, func: Callable) -> Callable: + raise NotImplementedError + + @abc.abstractmethod + def on_shutdown(self, func: Callable) -> Callable: + raise NotImplementedError + @abc.abstractmethod def run(self, host: Optional[IPv4Address] = None, diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 85b62780..2c4d2e42 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -14,9 +14,9 @@ from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket from nonebot.log import logger from nonebot.config import Env, Config from nonebot.utils import DataclassEncoder -from nonebot.typing import Optional, overrides from nonebot.adapters.cqhttp import Bot as CQBot from nonebot.drivers import BaseDriver, BaseWebSocket +from nonebot.typing import Optional, Callable, overrides class Driver(BaseDriver): @@ -38,7 +38,7 @@ class Driver(BaseDriver): @property @overrides(BaseDriver) - def server_app(self): + def server_app(self) -> FastAPI: return self._server_app @property @@ -51,6 +51,14 @@ class Driver(BaseDriver): def logger(self) -> logging.Logger: return logging.getLogger("fastapi") + @overrides(BaseDriver) + def on_startup(self, func: Callable) -> Callable: + return self.server_app.on_event("startup")(func) + + @overrides(BaseDriver) + def on_shutdown(self, func: Callable) -> Callable: + return self.server_app.on_event("shutdown")(func) + @overrides(BaseDriver) def run(self, host: Optional[IPv4Address] = None, diff --git a/nonebot/message.py b/nonebot/message.py index fe85cbb9..751d40e5 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -18,7 +18,6 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor: async def handle_event(bot: Bot, event: Event): - # TODO: PreProcess coros = [] for preprocessor in _event_preprocessors: coros.append(preprocessor(bot, event))