mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 09:38:21 +08:00
add startup shutdown deco
This commit is contained in:
parent
2d90c35df6
commit
b32d4a24d1
@ -95,11 +95,12 @@ class BaseMessageSegment(abc.ABC):
|
||||
class BaseMessage(list, abc.ABC):
|
||||
|
||||
def __init__(self,
|
||||
message: Union[str, BaseMessageSegment, "BaseMessage"] = None,
|
||||
message: Union[str, dict, list, BaseMessageSegment,
|
||||
"BaseMessage"] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(message, str):
|
||||
if isinstance(message, (str, dict, list)):
|
||||
self.extend(self._construct(message))
|
||||
elif isinstance(message, BaseMessage):
|
||||
self.extend(message)
|
||||
@ -111,7 +112,7 @@ class BaseMessage(list, abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def _construct(msg: str) -> Iterable[BaseMessageSegment]:
|
||||
def _construct(msg: Union[str, dict, list]) -> Iterable[BaseMessageSegment]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __add__(
|
||||
|
@ -5,12 +5,11 @@ import re
|
||||
|
||||
import httpx
|
||||
|
||||
# from nonebot.event import Event
|
||||
from nonebot.config import Config
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.exception import ApiNotAvailable
|
||||
from nonebot.typing import Tuple, Iterable, Optional, overrides, WebSocket
|
||||
from nonebot.adapters import BaseBot, BaseEvent, BaseMessage, BaseMessageSegment
|
||||
from nonebot.typing import Union, Tuple, Iterable, Optional, overrides, WebSocket
|
||||
|
||||
|
||||
def escape(s: str, *, escape_comma: bool = True) -> str:
|
||||
@ -98,6 +97,10 @@ class Bot(BaseBot):
|
||||
|
||||
class Event(BaseEvent):
|
||||
|
||||
def __init__(self, raw_event: dict):
|
||||
|
||||
super().__init__(raw_event)
|
||||
|
||||
@property
|
||||
@overrides(BaseEvent)
|
||||
def type(self):
|
||||
@ -286,7 +289,14 @@ class Message(BaseMessage):
|
||||
|
||||
@staticmethod
|
||||
@overrides(BaseMessage)
|
||||
def _construct(msg: str) -> Iterable[MessageSegment]:
|
||||
def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
|
||||
if isinstance(msg, dict):
|
||||
yield MessageSegment(msg["type"], msg.get("data") or {})
|
||||
return
|
||||
elif isinstance(msg, list):
|
||||
for seg in msg:
|
||||
yield MessageSegment(seg["type"], seg.get("data") or {})
|
||||
return
|
||||
|
||||
def _iter_message() -> Iterable[Tuple[str, str]]:
|
||||
text_begin = 0
|
||||
|
@ -5,7 +5,7 @@ import abc
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.typing import Bot, Dict, Optional
|
||||
from nonebot.typing import Bot, Dict, Optional, Callable
|
||||
|
||||
|
||||
class BaseDriver(abc.ABC):
|
||||
@ -35,6 +35,14 @@ class BaseDriver(abc.ABC):
|
||||
def bots(self) -> Dict[int, Bot]:
|
||||
return self._clients
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_shutdown(self, func: Callable) -> Callable:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self,
|
||||
host: Optional[IPv4Address] = None,
|
||||
|
@ -14,9 +14,9 @@ from fastapi import Body, Header, Response, WebSocket as FastAPIWebSocket
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
from nonebot.utils import DataclassEncoder
|
||||
from nonebot.typing import Optional, overrides
|
||||
from nonebot.adapters.cqhttp import Bot as CQBot
|
||||
from nonebot.drivers import BaseDriver, BaseWebSocket
|
||||
from nonebot.typing import Optional, Callable, overrides
|
||||
|
||||
|
||||
class Driver(BaseDriver):
|
||||
@ -38,7 +38,7 @@ class Driver(BaseDriver):
|
||||
|
||||
@property
|
||||
@overrides(BaseDriver)
|
||||
def server_app(self):
|
||||
def server_app(self) -> FastAPI:
|
||||
return self._server_app
|
||||
|
||||
@property
|
||||
@ -51,6 +51,14 @@ class Driver(BaseDriver):
|
||||
def logger(self) -> logging.Logger:
|
||||
return logging.getLogger("fastapi")
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def on_startup(self, func: Callable) -> Callable:
|
||||
return self.server_app.on_event("startup")(func)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def on_shutdown(self, func: Callable) -> Callable:
|
||||
return self.server_app.on_event("shutdown")(func)
|
||||
|
||||
@overrides(BaseDriver)
|
||||
def run(self,
|
||||
host: Optional[IPv4Address] = None,
|
||||
|
@ -18,7 +18,6 @@ def event_preprocessor(func: PreProcessor) -> PreProcessor:
|
||||
|
||||
|
||||
async def handle_event(bot: Bot, event: Event):
|
||||
# TODO: PreProcess
|
||||
coros = []
|
||||
for preprocessor in _event_preprocessors:
|
||||
coros.append(preprocessor(bot, event))
|
||||
|
Loading…
Reference in New Issue
Block a user