implement check_reply check_atme check_nickname

This commit is contained in:
StarHeartHunt 2021-07-09 16:01:32 +08:00
parent 70e424b58f
commit b1c5013088
2 changed files with 128 additions and 20 deletions

View File

@ -1,5 +1,6 @@
import json
import httpx import httpx
import json
import re
from aiocache import cached, Cache from aiocache import cached, Cache
from aiocache.serializers import PickleSerializer from aiocache.serializers import PickleSerializer
@ -12,7 +13,7 @@ from nonebot.adapters import Bot as BaseBot
from nonebot.drivers import Driver, HTTPRequest, HTTPResponse from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
from .config import Config as FeishuConfig from .config import Config as FeishuConfig
from .event import Event, GroupMessageEvent, PrivateMessageEvent, get_event_model from .event import Event, GroupMessageEvent, MessageEvent, PrivateMessageEvent, Reply, get_event_model
from .exception import ActionFailed, ApiNotAvailable, NetworkError from .exception import ActionFailed, ApiNotAvailable, NetworkError
from .message import Message, MessageSegment, MessageSerializer from .message import Message, MessageSegment, MessageSerializer
from .utils import log, AESCipher from .utils import log, AESCipher
@ -25,45 +26,99 @@ async def _check_reply(bot: "Bot", event: "Event"):
""" """
:说明: :说明:
检查消息中存在的回复去除并赋值 ``event.reply``, ``event.to_me`` 检查是否回复bot消息赋值 ``event.reply``, ``event.to_me``
:参数: :参数:
* ``bot: Bot``: Bot 对象 * ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
""" """
#TODO:实现该函数 if not isinstance(event, MessageEvent):
... return
if event.event.message.parent_id:
ret = await bot.call_api(
f"im/v1/messages/{event.event.message.parent_id}", method="GET")
event.reply = Reply.parse_obj(ret["items"][0])
if event.reply.sender.sender_type == "app":
event.to_me = True
return
def _check_at_me(bot: "Bot", event: "Event"): def _check_at_me(bot: "Bot", event: "Event"):
""" """
:说明: :说明:
检查消息开头或结尾是否存在 @机器人去除并赋值 ``event.to_me`` 检查消息开头或结尾是否存在 @机器人去除并赋值 ``event.reply``, ``event.to_me``
:参数: :参数:
* ``bot: Bot``: Bot 对象 * ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
""" """
#TODO:实现该函数 if not isinstance(event, MessageEvent):
... return
message = event.get_message()
# ensure message not empty
if not message:
message.append(MessageSegment.text(""))
if event.event.message.chat_type == "p2p":
event.to_me = True
for index, segment in enumerate(message):
if segment.type == "at" and segment.data.get(
"user_name") in bot.config.nickname:
event.to_me = True
del event.event.message.content[index]
return
elif segment.type == "text" and segment.data.get("mentions"):
for mention in segment.data["mentions"].values():
if mention["name"] in bot.config.nickname:
event.to_me = True
segment.data["text"] = segment.data["text"].replace(
f"@{mention['name']}", "")
segment.data["text"] = segment.data["text"].lstrip()
break
else:
continue
break
if not message:
message.append(MessageSegment.text(""))
def _check_nickname(bot: "Bot", event: "Event"): def _check_nickname(bot: "Bot", event: "Event"):
""" """
:说明: :说明:
检查消息开头是否存在去除并赋值 ``event.to_me`` 检查消息开头是否存在昵称去除并赋值 ``event.to_me``
:参数: :参数:
* ``bot: Bot``: Bot 对象 * ``bot: Bot``: Bot 对象
* ``event: Event``: Event 对象 * ``event: Event``: Event 对象
""" """
#TODO:实现该函数 if not isinstance(event, MessageEvent):
... return
first_msg_seg = event.get_message()[0]
if first_msg_seg.type != "text":
return
first_text = first_msg_seg.data["text"]
nicknames = set(filter(lambda n: n, bot.config.nickname))
if nicknames:
# check if the user is calling me with my nickname
nickname_regex = "|".join(nicknames)
m = re.search(rf"^({nickname_regex})([\s,]*|$)", first_text,
re.IGNORECASE)
if m:
nickname = m.group(1)
log("DEBUG", f"User is calling me {nickname}")
event.to_me = True
first_msg_seg.data["text"] = first_text[m.end():]
def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any: def _handle_api_result(result: Optional[Dict[str, Any]]) -> Any:
@ -188,6 +243,11 @@ class Bot(BaseBot):
log("DEBUG", "Event Parser Error", e) log("DEBUG", "Event Parser Error", e)
else: else:
event = Event.parse_obj(data) event = Event.parse_obj(data)
_check_at_me(self, event)
await _check_reply(self, event)
_check_nickname(self, event)
await handle_event(self, event) await handle_event(self, event)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
@ -232,19 +292,20 @@ class Bot(BaseBot):
raise ApiNotAvailable raise ApiNotAvailable
headers = {} headers = {}
if self.feishu_config.tenant_access_token is None: self.feishu_config.tenant_access_token = await self._fetch_tenant_access_token(
self.feishu_config.tenant_access_token = await self._fetch_tenant_access_token( )
)
headers[ headers[
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token "Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
try: try:
async with httpx.AsyncClient(headers=headers) as client: async with httpx.AsyncClient(
response = await client.post( timeout=self.config.api_timeout) as client:
self.api_root + api, response = await client.send(
json=data.get("body", {}), httpx.Request(data["method"],
params=data.get("query", {}), self.api_root + api,
timeout=self.config.api_timeout) json=data.get("body", {}),
params=data.get("query", {}),
headers=headers))
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
result = response.json() result = response.json()
return _handle_api_result(result) return _handle_api_result(result)
@ -303,6 +364,7 @@ class Bot(BaseBot):
msg_type, content = MessageSerializer(msg).serialize() msg_type, content = MessageSerializer(msg).serialize()
params = { params = {
"method": "POST",
"query": { "query": {
"receive_id_type": receive_id_type "receive_id_type": receive_id_type
}, },

View File

@ -77,6 +77,13 @@ class Sender(BaseModel):
tenant_key: str tenant_key: str
class ReplySender(BaseModel):
id: str
id_type: str
sender_type: str
tenant_key: str
class Mention(BaseModel): class Mention(BaseModel):
key: str key: str
id: UserId id: UserId
@ -84,6 +91,37 @@ class Mention(BaseModel):
tenant_key: str tenant_key: str
class ReplyMention(BaseModel):
id: str
id_type: str
key: str
name: str
tenant_key: str
class MessageBody(BaseModel):
content: str
class Reply(BaseModel):
message_id: str
root_id: Optional[str]
parent_id: Optional[str]
msg_type: str
create_time: str
update_time: str
deleted: bool
updated: bool
chat_id: str
sender: ReplySender
body: MessageBody
mentions: List[ReplyMention]
upper_message_id: Optional[str]
class Config:
extra = "allow"
class EventMessage(BaseModel): class EventMessage(BaseModel):
message_id: str message_id: str
root_id: Optional[str] root_id: Optional[str]
@ -128,6 +166,14 @@ class MessageEvent(Event):
__event__ = "im.message.receive_v1" __event__ = "im.message.receive_v1"
event: MessageEventDetail event: MessageEventDetail
to_me: bool = False
reply: Optional[Reply]
"""
:说明: 消息是否与机器人有关
:类型: ``bool``
"""
@overrides(Event) @overrides(Event)
def get_type(self) -> Literal["message", "notice", "meta_event"]: def get_type(self) -> Literal["message", "notice", "meta_event"]:
return "message" return "message"