add startup shutdown deco

This commit is contained in:
yanyongyu 2020-08-11 10:44:05 +08:00
parent 2d90c35df6
commit b32d4a24d1
5 changed files with 36 additions and 10 deletions

View File

@ -95,11 +95,12 @@ class BaseMessageSegment(abc.ABC):
class BaseMessage(list, abc.ABC):
def __init__(self,
message: Union[str, BaseMessageSegment, "BaseMessage"] = None,
message: Union[str, dict, list, BaseMessageSegment,
"BaseMessage"] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
if isinstance(message, str):
if isinstance(message, (str, dict, list)):
self.extend(self._construct(message))
elif isinstance(message, BaseMessage):
self.extend(message)
@ -111,7 +112,7 @@ class BaseMessage(list, abc.ABC):
@staticmethod
@abc.abstractmethod
def _construct(msg: str) -> Iterable[BaseMessageSegment]:
def _construct(msg: Union[str, dict, list]) -> Iterable[BaseMessageSegment]:
raise NotImplementedError
def __add__(

View File

@ -5,12 +5,11 @@ import re
import httpx
# from nonebot.event import Event
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.exception import ApiNotAvailable
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
def escape(s: str, *, escape_comma: bool = True) -> str:
@ -98,6 +97,10 @@ class Bot(BaseBot):
class Event(BaseEvent):
def __init__(self, raw_event: dict):
super().__init__(raw_event)
@property
@overrides(BaseEvent)
def type(self):
@ -286,7 +289,14 @@ class Message(BaseMessage):
@staticmethod
@overrides(BaseMessage)
def _construct(msg: str) -> Iterable[MessageSegment]:
def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
if isinstance(msg, dict):
yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, list):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
return
def _iter_message() -> Iterable[Tuple[str, str]]:
text_begin = 0

View File

@ -5,7 +5,7 @@ import abc
from ipaddress import IPv4Address
from nonebot.config import Env, Config
from nonebot.typing import Bot, Dict, Optional
from nonebot.typing import Bot, Dict, Optional, Callable
class BaseDriver(abc.ABC):
@ -35,6 +35,14 @@ class BaseDriver(abc.ABC):
def bots(self) -> Dict[int, Bot]:
return self._clients
@abc.abstractmethod
def on_startup(self, func: Callable) -> Callable:
raise NotImplementedError
@abc.abstractmethod
def on_shutdown(self, func: Callable) -> Callable:
raise NotImplementedError
@abc.abstractmethod
def run(self,
host: Optional[IPv4Address] = None,

View File

@ -14,9 +14,9 @@ from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
from nonebot.log import logger
from nonebot.config import Env, Config
from nonebot.utils import DataclassEncoder
from nonebot.typing import Optional, overrides
from nonebot.adapters.cqhttp import Bot as CQBot
from nonebot.drivers import BaseDriver, BaseWebSocket
from nonebot.typing import Optional, Callable, overrides
class Driver(BaseDriver):
@ -38,7 +38,7 @@ class Driver(BaseDriver):
@property
@overrides(BaseDriver)
def server_app(self):
def server_app(self) -> FastAPI:
return self._server_app
@property
@ -51,6 +51,14 @@ class Driver(BaseDriver):
def logger(self) -> logging.Logger:
return logging.getLogger("fastapi")
@overrides(BaseDriver)
def on_startup(self, func: Callable) -> Callable:
return self.server_app.on_event("startup")(func)
@overrides(BaseDriver)
def on_shutdown(self, func: Callable) -> Callable:
return self.server_app.on_event("shutdown")(func)
@overrides(BaseDriver)
def run(self,
host: Optional[IPv4Address] = None,

View File

@ -18,7 +18,6 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
async def handle_event(bot: Bot, event: Event):
# TODO: PreProcess
coros = []
for preprocessor in _event_preprocessors:
coros.append(preprocessor(bot, event))