diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 0cde53ca..01745c28 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -6,10 +6,10 @@ """ import abc +from typing_extensions import Literal from functools import reduce, partial from dataclasses import dataclass, field -from typing import Any, Dict, Union, TypeVar, Optional, Callable, Iterable, Awaitable, Generic, TYPE_CHECKING -from typing_extensions import Literal +from typing import Any, Dict, Union, Optional, Callable, Iterable, Awaitable, TYPE_CHECKING from pydantic import BaseModel @@ -162,6 +162,14 @@ class Event(abc.ABC, BaseModel): def get_session_id(self) -> str: raise NotImplementedError + @abc.abstractmethod + def get_message(self) -> "Message": + raise NotImplementedError + + @abc.abstractmethod + def get_plaintext(self) -> str: + raise NotImplementedError + # T = TypeVar("T", bound=BaseModel) diff --git a/nonebot/adapters/cqhttp/__init__.py b/nonebot/adapters/cqhttp/__init__.py index 14635eda..d6aa23c7 100644 --- a/nonebot/adapters/cqhttp/__init__.py +++ b/nonebot/adapters/cqhttp/__init__.py @@ -10,7 +10,7 @@ CQHTTP (OneBot) v11 协议适配 https://github.com/howmanybots/onebot/blob/master/README.md """ -from .event import Event +from .event import CQHTTPEvent from .message import Message, MessageSegment from .utils import log, escape, unescape, _b2s from .bot import Bot, _check_at_me, _check_nickname, _check_reply, _handle_api_result diff --git a/nonebot/adapters/cqhttp/bot.py b/nonebot/adapters/cqhttp/bot.py index e20dd094..a96f1308 100644 --- a/nonebot/adapters/cqhttp/bot.py +++ b/nonebot/adapters/cqhttp/bot.py @@ -15,9 +15,9 @@ from nonebot.adapters import Bot as BaseBot from nonebot.exception import RequestDenied from .utils import log -from .event import Reply, CQHTTPEvent, MessageEvent from .message import Message, MessageSegment from .exception import NetworkError, ApiNotAvailable, ActionFailed +from .event import Reply, CQHTTPEvent, MessageEvent, get_event_model if TYPE_CHECKING: from nonebot.drivers import Driver, WebSocket @@ -297,7 +297,20 @@ class Bot(BaseBot): return try: - event = CQHTTPEvent.parse_obj(message) + post_type = message['post_type'] + detail_type = message.get(f"{post_type}_type") + detail_type = f".{detail_type}" if detail_type else "" + sub_type = message.get("sub_type") + sub_type = f".{sub_type}" if sub_type else "" + models = get_event_model(f".{post_type}{detail_type}{sub_type}") + for model in models: + try: + event = model.parse_obj(message) + break + except Exception as e: + log("DEBUG", "Event Parser Error", e) + else: + event = CQHTTPEvent.parse_obj(message) # Check whether user is calling me await _check_reply(self, event) diff --git a/nonebot/adapters/cqhttp/event.py b/nonebot/adapters/cqhttp/event.py index 1b7976d6..eaede5d5 100644 --- a/nonebot/adapters/cqhttp/event.py +++ b/nonebot/adapters/cqhttp/event.py @@ -1,7 +1,9 @@ -from typing import Optional +import inspect from typing_extensions import Literal +from typing import Type, List, Optional from pydantic import BaseModel +from pygtrie import StringTrie from nonebot.adapters import Event from nonebot.utils import escape_tag from nonebot.typing import overrides @@ -210,6 +212,7 @@ from .message import Message class CQHTTPEvent(Event): + __event__ = "" time: int self_id: int post_type: Literal["message", "notice", "request", "meta_event"] @@ -226,6 +229,14 @@ class CQHTTPEvent(Event): def get_event_description(self) -> str: return str(self.dict()) + @overrides(Event) + def get_message(self) -> Message: + raise ValueError("Event has no message!") + + @overrides(Event) + def get_plaintext(self) -> str: + raise ValueError("Event has no message!") + # Models class Sender(BaseModel): @@ -284,6 +295,7 @@ class Status(BaseModel): # Message Events class MessageEvent(CQHTTPEvent): + __event__ = "message" post_type: Literal["message"] sub_type: str user_id: int @@ -302,8 +314,17 @@ class MessageEvent(CQHTTPEvent): return f"{self.post_type}.{self.message_type}" + (f".{sub_type}" if sub_type else "") + @overrides(CQHTTPEvent) + def get_message(self) -> Message: + return self.message + + @overrides(CQHTTPEvent) + def get_plaintext(self) -> str: + return self.message.extract_plain_text() + class PrivateMessageEvent(MessageEvent): + __event__ = "message.private" message_type: Literal["private"] @overrides(CQHTTPEvent) @@ -316,6 +337,7 @@ class PrivateMessageEvent(MessageEvent): class GroupMessageEvent(MessageEvent): + __event__ = "message.group" message_type: Literal["group"] group_id: int anonymous: Anonymous @@ -333,6 +355,7 @@ class GroupMessageEvent(MessageEvent): # Notice Events class NoticeEvent(CQHTTPEvent): + __event__ = "notice" post_type: Literal["notice"] notice_type: str @@ -344,6 +367,7 @@ class NoticeEvent(CQHTTPEvent): class GroupUploadNoticeEvent(NoticeEvent): + __event__ = "notice.group_upload" notice_type: Literal["group_upload"] user_id: int group_id: int @@ -351,6 +375,7 @@ class GroupUploadNoticeEvent(NoticeEvent): class GroupAdminNoticeEvent(NoticeEvent): + __event__ = "notice.group_admin" notice_type: Literal["group_admin"] sub_type: str user_id: int @@ -358,6 +383,7 @@ class GroupAdminNoticeEvent(NoticeEvent): class GroupDecreaseNoticeEvent(NoticeEvent): + __event__ = "notice.group_decrease" notice_type: Literal["group_decrease"] sub_type: str user_id: int @@ -366,6 +392,7 @@ class GroupDecreaseNoticeEvent(NoticeEvent): class GroupIncreaseNoticeEvent(NoticeEvent): + __event__ = "notice.group_increase" notice_type: Literal["group_increase"] sub_type: str user_id: int @@ -374,6 +401,7 @@ class GroupIncreaseNoticeEvent(NoticeEvent): class GroupBanNoticeEvent(NoticeEvent): + __event__ = "notice.group_ban" notice_type: Literal["group_ban"] sub_type: str user_id: int @@ -383,11 +411,13 @@ class GroupBanNoticeEvent(NoticeEvent): class FriendAddNoticeEvent(NoticeEvent): + __event__ = "notice.friend_add" notice_type: Literal["friend_add"] user_id: int class GroupRecallNoticeEvent(NoticeEvent): + __event__ = "notice.group_recall" notice_type: Literal["group_recall"] user_id: int group_id: int @@ -396,12 +426,14 @@ class GroupRecallNoticeEvent(NoticeEvent): class FriendRecallNoticeEvent(NoticeEvent): + __event__ = "notice.friend_recall" notice_type: Literal["friend_recall"] user_id: int message_id: int class NotifyEvent(NoticeEvent): + __event__ = "notice.notify" notice_type: Literal["notify"] sub_type: str user_id: int @@ -409,19 +441,26 @@ class NotifyEvent(NoticeEvent): class PokeNotifyEvent(NotifyEvent): + __event__ = "notice.notify.poke" + sub_type: Literal["poke"] target_id: int class LuckyKingNotifyEvent(NotifyEvent): + __event__ = "notice.notify.lucky_king" + sub_type: Literal["lucky_king"] target_id: int class HonorNotifyEvent(NotifyEvent): + __event__ = "notice.notify.honor" + sub_type: Literal["honor"] honor_type: str # Request Events class RequestEvent(CQHTTPEvent): + __event__ = "request" post_type: Literal["request"] request_type: str @@ -433,6 +472,7 @@ class RequestEvent(CQHTTPEvent): class FriendRequestEvent(RequestEvent): + __event__ = "request.friend" request_type: Literal["friend"] user_id: int comment: str @@ -440,6 +480,7 @@ class FriendRequestEvent(RequestEvent): class GroupRequestEvent(RequestEvent): + __event__ = "request.group" request_type: Literal["group"] sub_type: str group_id: int @@ -450,6 +491,7 @@ class GroupRequestEvent(RequestEvent): # Meta Events class MetaEvent(CQHTTPEvent): + __event__ = "meta_event" post_type: Literal["meta_event"] meta_event_type: str @@ -465,11 +507,26 @@ class MetaEvent(CQHTTPEvent): class LifecycleMetaEvent(MetaEvent): + __event__ = "meta_event.lifecycle" meta_event_type: Literal["lifecycle"] sub_type: str class HeartbeatMetaEvent(MetaEvent): + __event__ = "meta_event.heartbeat" meta_event_type: Literal["heartbeat"] status: Status interval: int + + +_t = StringTrie(separator=".") + +model = None +for model in globals().values(): + if not inspect.isclass(model) or not issubclass(model, CQHTTPEvent): + continue + _t["." + model.__event__] = model + + +def get_event_model(event_name) -> List[Type[CQHTTPEvent]]: + return [model.value for model in _t.prefixes("." + event_name)][::-1] diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 293d9767..af4448e2 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -203,7 +203,8 @@ class Matcher(metaclass=MatcherMeta): - ``bool``: 是否满足匹配规则 """ - return (event.type == (cls.type or event.type) and + event_type = event.get_type() + return (event_type == (cls.type or event_type) and await cls.rule(bot, event, state)) @classmethod @@ -303,7 +304,7 @@ class Matcher(metaclass=MatcherMeta): if parser: await parser(bot, event, state) else: - state[state["_current_key"]] = str(event.message) + state[state["_current_key"]] = str(event.get_message()) cls.handlers.append(_key_getter) cls.handlers.append(_key_parser) @@ -427,7 +428,8 @@ class Matcher(metaclass=MatcherMeta): Matcher.new( self.type, Rule(), - USER(event.user_id, perm=self.permission), # type:ignore + USER(event.get_session_id(), + perm=self.permission), # type:ignore self.handlers, temp=True, priority=0, @@ -439,7 +441,8 @@ class Matcher(metaclass=MatcherMeta): Matcher.new( self.type, Rule(), - USER(event.user_id, perm=self.permission), # type:ignore + USER(event.get_session_id(), + perm=self.permission), # type:ignore self.handlers, temp=True, priority=0, diff --git a/nonebot/permission.py b/nonebot/permission.py index d5f12db8..6f2ccd20 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -81,19 +81,19 @@ class Permission: async def _message(bot: "Bot", event: "Event") -> bool: - return event.type == "message" + return event.get_type() == "message" async def _notice(bot: "Bot", event: "Event") -> bool: - return event.type == "notice" + return event.get_type() == "notice" async def _request(bot: "Bot", event: "Event") -> bool: - return event.type == "request" + return event.get_type() == "request" async def _metaevent(bot: "Bot", event: "Event") -> bool: - return event.type == "meta_event" + return event.get_type() == "meta_event" MESSAGE = Permission(_message) @@ -114,7 +114,7 @@ METAEVENT = Permission(_metaevent) """ -def USER(*user: int, perm: Permission = Permission()): +def USER(*user: str, perm: Permission = Permission()): """ :说明: @@ -122,104 +122,94 @@ def USER(*user: int, perm: Permission = Permission()): :参数: - * ``*user: int``: 白名单 + * ``*user: str``: 白名单 * ``perm: Permission``: 需要同时满足的权限 """ async def _user(bot: "Bot", event: "Event") -> bool: - return event.type == "message" and event.user_id in user and await perm( - bot, event) + return event.get_type() == "message" and event.get_session_id( + ) in user and await perm(bot, event) return Permission(_user) -async def _private(bot: "Bot", event: "Event") -> bool: - return event.type == "message" and event.detail_type == "private" +# async def _private(bot: "Bot", event: "Event") -> bool: +# return event.get_type() == "message" and event.detail_type == "private" +# async def _private_friend(bot: "Bot", event: "Event") -> bool: +# return (event.get_type() == "message" and event.detail_type == "private" and +# event.sub_type == "friend") -async def _private_friend(bot: "Bot", event: "Event") -> bool: - return (event.type == "message" and event.detail_type == "private" and - event.sub_type == "friend") +# async def _private_group(bot: "Bot", event: "Event") -> bool: +# return (event.get_type() == "message" and event.detail_type == "private" and +# event.sub_type == "group") +# async def _private_other(bot: "Bot", event: "Event") -> bool: +# return (event.get_type() == "message" and event.detail_type == "private" and +# event.sub_type == "other") -async def _private_group(bot: "Bot", event: "Event") -> bool: - return (event.type == "message" and event.detail_type == "private" and - event.sub_type == "group") +# PRIVATE = Permission(_private) +# """ +# - **说明**: 匹配任意私聊消息类型事件 +# """ +# PRIVATE_FRIEND = Permission(_private_friend) +# """ +# - **说明**: 匹配任意好友私聊消息类型事件 +# """ +# PRIVATE_GROUP = Permission(_private_group) +# """ +# - **说明**: 匹配任意群临时私聊消息类型事件 +# """ +# PRIVATE_OTHER = Permission(_private_other) +# """ +# - **说明**: 匹配任意其他私聊消息类型事件 +# """ +# async def _group(bot: "Bot", event: "Event") -> bool: +# return event.get_type() == "message" and event.detail_type == "group" -async def _private_other(bot: "Bot", event: "Event") -> bool: - return (event.type == "message" and event.detail_type == "private" and - event.sub_type == "other") +# async def _group_member(bot: "Bot", event: "Event") -> bool: +# return (event.get_type() == "message" and event.detail_type == "group" and +# event.sender.get("role") == "member") +# async def _group_admin(bot: "Bot", event: "Event") -> bool: +# return (event.get_type() == "message" and event.detail_type == "group" and +# event.sender.get("role") == "admin") -PRIVATE = Permission(_private) -""" -- **说明**: 匹配任意私聊消息类型事件 -""" -PRIVATE_FRIEND = Permission(_private_friend) -""" -- **说明**: 匹配任意好友私聊消息类型事件 -""" -PRIVATE_GROUP = Permission(_private_group) -""" -- **说明**: 匹配任意群临时私聊消息类型事件 -""" -PRIVATE_OTHER = Permission(_private_other) -""" -- **说明**: 匹配任意其他私聊消息类型事件 -""" +# async def _group_owner(bot: "Bot", event: "Event") -> bool: +# return (event.get_type() == "message" and event.detail_type == "group" and +# event.sender.get("role") == "owner") +# GROUP = Permission(_group) +# """ +# - **说明**: 匹配任意群聊消息类型事件 +# """ +# GROUP_MEMBER = Permission(_group_member) +# """ +# - **说明**: 匹配任意群员群聊消息类型事件 -async def _group(bot: "Bot", event: "Event") -> bool: - return event.type == "message" and event.detail_type == "group" +# \:\:\:warning 警告 +# 该权限通过 event.sender 进行判断且不包含管理员以及群主! +# \:\:\: +# """ +# GROUP_ADMIN = Permission(_group_admin) +# """ +# - **说明**: 匹配任意群管理员群聊消息类型事件 +# """ +# GROUP_OWNER = Permission(_group_owner) +# """ +# - **说明**: 匹配任意群主群聊消息类型事件 +# """ +# async def _superuser(bot: "Bot", event: "Event") -> bool: +# return event.get_type( +# ) == "message" and event.user_id in bot.config.superusers -async def _group_member(bot: "Bot", event: "Event") -> bool: - return (event.type == "message" and event.detail_type == "group" and - event.sender.get("role") == "member") - - -async def _group_admin(bot: "Bot", event: "Event") -> bool: - return (event.type == "message" and event.detail_type == "group" and - event.sender.get("role") == "admin") - - -async def _group_owner(bot: "Bot", event: "Event") -> bool: - return (event.type == "message" and event.detail_type == "group" and - event.sender.get("role") == "owner") - - -GROUP = Permission(_group) -""" -- **说明**: 匹配任意群聊消息类型事件 -""" -GROUP_MEMBER = Permission(_group_member) -""" -- **说明**: 匹配任意群员群聊消息类型事件 - -\:\:\:warning 警告 -该权限通过 event.sender 进行判断且不包含管理员以及群主! -\:\:\: -""" -GROUP_ADMIN = Permission(_group_admin) -""" -- **说明**: 匹配任意群管理员群聊消息类型事件 -""" -GROUP_OWNER = Permission(_group_owner) -""" -- **说明**: 匹配任意群主群聊消息类型事件 -""" - - -async def _superuser(bot: "Bot", event: "Event") -> bool: - return event.type == "message" and event.user_id in bot.config.superusers - - -SUPERUSER = Permission(_superuser) -""" -- **说明**: 匹配任意超级用户消息类型事件 -""" -EVERYBODY = MESSAGE -""" -- **说明**: 匹配任意消息类型事件 -""" +# SUPERUSER = Permission(_superuser) +# """ +# - **说明**: 匹配任意超级用户消息类型事件 +# """ +# EVERYBODY = MESSAGE +# """ +# - **说明**: 匹配任意消息类型事件 +# """ diff --git a/nonebot/rule.py b/nonebot/rule.py index 99f96112..31902603 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -119,7 +119,7 @@ class TrieRule: @classmethod def get_value(cls, bot: "Bot", event: "Event", state: State) -> Tuple[Dict[str, Any], Dict[str, Any]]: - if event.type != "message": + if event.get_type() != "message": state["_prefix"] = {"raw_command": None, "command": None} state["_suffix"] = {"raw_command": None, "command": None} return { @@ -132,12 +132,14 @@ class TrieRule: prefix = None suffix = None - message = event.message[0] - if message.type == "text": - prefix = cls.prefix.longest_prefix(str(message).lstrip()) - message_r = event.message[-1] - if message_r.type == "text": - suffix = cls.suffix.longest_prefix(str(message_r).rstrip()[::-1]) + message = event.get_message() + message_seg = message[0] + if message_seg.type == "text": + prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) + message_seg_r = message[-1] + if message_seg_r.type == "text": + suffix = cls.suffix.longest_prefix( + str(message_seg_r).rstrip()[::-1]) state["_prefix"] = { "raw_command": prefix.key, @@ -181,7 +183,10 @@ def startswith(msg: str) -> Rule: """ async def _startswith(bot: "Bot", event: "Event", state: State) -> bool: - return event.plain_text.startswith(msg) + if event.get_type() != "message": + return False + text = event.get_plaintext() + return text.startswith(msg) return Rule(_startswith) @@ -198,7 +203,9 @@ def endswith(msg: str) -> Rule: """ async def _endswith(bot: "Bot", event: "Event", state: State) -> bool: - return event.plain_text.endswith(msg) + if event.get_type() != "message": + return False + return event.get_plaintext().endswith(msg) return Rule(_endswith) @@ -215,8 +222,10 @@ def keyword(*keywords: str) -> Rule: """ async def _keyword(bot: "Bot", event: "Event", state: State) -> bool: - return bool(event.plain_text and - any(keyword in event.plain_text for keyword in keywords)) + if event.get_type() != "message": + return False + text = event.get_plaintext() + return bool(text and any(keyword in text for keyword in keywords)) return Rule(_keyword) @@ -287,7 +296,9 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: pattern = re.compile(regex, flags) async def _regex(bot: "Bot", event: "Event", state: State) -> bool: - matched = pattern.search(str(event.message)) + if event.get_type() != "message": + return False + matched = pattern.search(str(event.get_message())) if matched: state["_matched"] = matched.group() return True @@ -298,18 +309,20 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: return Rule(_regex) -def to_me() -> Rule: - """ - :说明: +# def to_me() -> Rule: +# """ +# :说明: - 通过 ``event.to_me`` 判断消息是否是发送给机器人 +# 通过 ``event.to_me`` 判断消息是否是发送给机器人 - :参数: +# :参数: - * 无 - """ +# * 无 +# """ - async def _to_me(bot: "Bot", event: "Event", state: State) -> bool: - return bool(event.to_me) +# async def _to_me(bot: "Bot", event: "Event", state: State) -> bool: +# if event.get_type() != "message": +# return False +# return bool(event.to_me) - return Rule(_to_me) +# return Rule(_to_me)