nonebot2/tests/utils.py

102 lines
2.7 KiB
Python
Raw Normal View History

from typing import Type, Union, Mapping, Iterable, Optional
from pydantic import Extra, create_model
from nonebot.adapters import Event, Message, MessageSegment
2021-12-16 23:22:25 +08:00
def escape_text(s: str, *, escape_comma: bool = True) -> str:
s = s.replace("&", "&").replace("[", "[").replace("]", "]")
if escape_comma:
s = s.replace(",", ",")
return s
class FakeMessageSegment(MessageSegment["FakeMessage"]):
@classmethod
def get_message_class(cls):
return FakeMessage
2021-12-20 00:28:02 +08:00
def __str__(self) -> str:
return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
2021-12-20 00:28:02 +08:00
@classmethod
def text(cls, text: str):
return cls("text", {"text": text})
2021-12-20 00:28:02 +08:00
@staticmethod
def image(url: str):
return FakeMessageSegment("image", {"url": url})
@staticmethod
def nested(content: "FakeMessage"):
return FakeMessageSegment("node", {"content": content})
def is_text(self) -> bool:
return self.type == "text"
2021-12-20 00:28:02 +08:00
class FakeMessage(Message[FakeMessageSegment]):
@classmethod
def get_segment_class(cls):
return FakeMessageSegment
2021-12-20 00:28:02 +08:00
@staticmethod
def _construct(msg: Union[str, Iterable[Mapping]]):
if isinstance(msg, str):
yield FakeMessageSegment.text(msg)
else:
for seg in msg:
yield FakeMessageSegment(**seg)
return
def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other)
2021-12-20 00:28:02 +08:00
2021-12-16 23:22:25 +08:00
def make_fake_event(
_base: Optional[Type[Event]] = None,
2021-12-16 23:22:25 +08:00
_type: str = "message",
_name: str = "test",
_description: str = "test",
_user_id: Optional[str] = "test",
2022-01-20 03:16:04 +08:00
_session_id: Optional[str] = "test",
_message: Optional[Message] = None,
2021-12-16 23:22:25 +08:00
_to_me: bool = True,
**fields,
) -> Type[Event]:
Base = _base or Event
2021-12-16 23:22:25 +08:00
class FakeEvent(Base, extra=Extra.forbid):
2021-12-16 23:22:25 +08:00
def get_type(self) -> str:
return _type
def get_event_name(self) -> str:
return _name
def get_event_description(self) -> str:
return _description
def get_user_id(self) -> str:
if _user_id is not None:
return _user_id
raise NotImplementedError
2021-12-16 23:22:25 +08:00
def get_session_id(self) -> str:
2022-01-20 03:16:04 +08:00
if _session_id is not None:
return _session_id
raise NotImplementedError
2021-12-16 23:22:25 +08:00
def get_message(self) -> "Message":
if _message is not None:
return _message
raise NotImplementedError
def is_tome(self) -> bool:
return _to_me
return create_model("FakeEvent", __base__=FakeEvent, **fields)