check for reply

This commit is contained in:
yanyongyu 2020-08-28 11:54:21 +08:00
parent 43bd9d0a96
commit def5caedbc
5 changed files with 40 additions and 9 deletions

View File

@ -150,6 +150,16 @@ class BaseEvent(abc.ABC):
def message(self, value) -> None: def message(self, value) -> None:
raise NotImplementedError raise NotImplementedError
@property
@abc.abstractmethod
def reply(self) -> Optional[dict]:
raise NotImplementedError
@reply.setter
@abc.abstractmethod
def reply(self, value) -> None:
raise NotImplementedError
@property @property
@abc.abstractmethod @abc.abstractmethod
def raw_message(self) -> Optional[str]: def raw_message(self) -> Optional[str]:

View File

@ -57,6 +57,19 @@ def _b2s(b: Optional[bool]) -> Optional[str]:
return b if b is None else str(b).lower() return b if b is None else str(b).lower()
async def _check_reply(bot: "Bot", event: "Event"):
if event.type != "message":
return
first_msg_seg = event.message[0]
if first_msg_seg.type == "reply":
msg_id = first_msg_seg.data["id"]
event.reply = await bot.get_msg(message_id=msg_id)
if event.reply["sender"]["user_id"] == event.self_id:
event.to_me = True
del event.message[0]
def _check_at_me(bot: "Bot", event: "Event"): def _check_at_me(bot: "Bot", event: "Event"):
if event.type != "message": if event.type != "message":
return return
@ -64,7 +77,6 @@ def _check_at_me(bot: "Bot", event: "Event"):
if event.detail_type == "private": if event.detail_type == "private":
event.to_me = True event.to_me = True
else: else:
event.to_me = False
at_me_seg = MessageSegment.at(event.self_id) at_me_seg = MessageSegment.at(event.self_id)
# check the first segment # check the first segment
@ -150,7 +162,7 @@ class ResultStore:
try: try:
return await asyncio.wait_for(future, timeout) return await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise NetworkError("WebSocket API call timeout") raise NetworkError("WebSocket API call timeout") from None
finally: finally:
del cls._futures[seq] del cls._futures[seq]
@ -190,7 +202,7 @@ class Bot(BaseBot):
event = Event(message) event = Event(message)
# Check whether user is calling me # Check whether user is calling me
# TODO: Check reply await _check_reply(self, event)
_check_at_me(self, event) _check_at_me(self, event)
_check_nickname(self, event) _check_nickname(self, event)
@ -205,7 +217,7 @@ class Bot(BaseBot):
return await bot.call_api(api, **data) return await bot.call_api(api, **data)
log("DEBUG", f"Calling API <y>{api}</y>") log("DEBUG", f"Calling API <y>{api}</y>")
if self.type == "websocket": if self.connection_type == "websocket":
seq = ResultStore.get_seq() seq = ResultStore.get_seq()
await self.websocket.send({ await self.websocket.send({
"action": api, "action": api,
@ -217,7 +229,7 @@ class Bot(BaseBot):
return _handle_api_result(await ResultStore.fetch( return _handle_api_result(await ResultStore.fetch(
seq, self.config.api_timeout)) seq, self.config.api_timeout))
elif self.type == "http": elif self.connection_type == "http":
api_root = self.config.api_root.get(self.self_id) api_root = self.config.api_root.get(self.self_id)
if not api_root: if not api_root:
raise ApiNotAvailable raise ApiNotAvailable
@ -377,6 +389,16 @@ class Event(BaseEvent):
def message(self, value) -> None: def message(self, value) -> None:
self._raw_event["message"] = value self._raw_event["message"] = value
@property
@overrides(BaseEvent)
def reply(self) -> Optional[dict]:
return self._raw_event.get("reply")
@reply.setter
@overrides(BaseEvent)
def reply(self, value) -> None:
self._raw_event["reply"] = value
@property @property
@overrides(BaseEvent) @overrides(BaseEvent)
def raw_message(self) -> Optional[str]: def raw_message(self) -> Optional[str]:

View File

@ -174,10 +174,10 @@ class Config(BaseConfig):
API_ROOT={"123456": "http://127.0.0.1:5700"} API_ROOT={"123456": "http://127.0.0.1:5700"}
""" """
api_timeout: Optional[float] = 60. api_timeout: Optional[float] = 30.
""" """
- 类型: ``Optional[float]`` - 类型: ``Optional[float]``
- 默认值: ``60.`` - 默认值: ``30.``
- 说明: - 说明:
API 请求超时时间单位: API 请求超时时间单位:
""" """

View File

@ -28,7 +28,7 @@ class MatcherMeta(type):
f"temp={self.temp}>") # type: ignore f"temp={self.temp}>") # type: ignore
def __str__(self) -> str: def __str__(self) -> str:
return self.__repr__() return repr(self)
class Matcher(metaclass=MatcherMeta): class Matcher(metaclass=MatcherMeta):

View File

@ -34,7 +34,6 @@ async def _run_matcher(Matcher: Type[Matcher], bot: Bot, event: Event,
f"<r><bg #f8bbd0>Rule check failed for {Matcher}.</bg #f8bbd0></r>") f"<r><bg #f8bbd0>Rule check failed for {Matcher}.</bg #f8bbd0></r>")
return return
# TODO: log matcher
logger.info(f"Event will be handled by {Matcher}") logger.info(f"Event will be handled by {Matcher}")
matcher = Matcher() matcher = Matcher()