diff --git a/nonebot/plugins/single_session.py b/nonebot/plugins/single_session.py index e76e1a1b..3bed8ea1 100644 --- a/nonebot/plugins/single_session.py +++ b/nonebot/plugins/single_session.py @@ -1,34 +1,33 @@ -from typing import Dict - +from typing import Generator, Dict from nonebot.adapters import Event from nonebot.message import ( IgnoredException, - run_preprocessor, - run_postprocessor, + event_preprocessor ) +from nonebot.params import Depends _running_matcher: Dict[str, int] = {} -@run_preprocessor -async def preprocess(event: Event): +async def matcher_mutex(event: Event) -> Generator[bool, None, None]: + result = False try: session_id = event.get_session_id() except Exception: - return - current_event_id = id(event) - event_id = _running_matcher.get(session_id, None) - if event_id and event_id != current_event_id: + yield result + else: + current_event_id = id(event) + event_id = _running_matcher.get(session_id, None) + if event_id : + result = event_id != current_event_id + else: + _running_matcher[session_id] = current_event_id + yield result + if result: + del _running_matcher[session_id] + +@event_preprocessor +async def preprocess(mutex: bool = Depends(matcher_mutex)): + if mutex: raise IgnoredException("Another matcher running") - _running_matcher[session_id] = current_event_id - - -@run_postprocessor -async def postprocess(event: Event): - try: - session_id = event.get_session_id() - except Exception: - return - if session_id in _running_matcher: - del _running_matcher[session_id] diff --git a/tests/test_single_session.py b/tests/test_single_session.py new file mode 100644 index 00000000..f0a696d5 --- /dev/null +++ b/tests/test_single_session.py @@ -0,0 +1,28 @@ +from contextlib import asynccontextmanager + +import pytest + +from utils import make_fake_event + + +@pytest.mark.asyncio +async def test_matcher_mutex(): + from nonebot.plugins.single_session import matcher_mutex + + am = asynccontextmanager(matcher_mutex) + event = make_fake_event()() + event_1 = make_fake_event()() + event_2 = make_fake_event(_session_id="test1")() + + async with am(event) as ctx: + assert ctx == False + + async with am(event) as ctx: + async with am(event_1) as ctx_1: + assert ctx == False + assert ctx_1 == True + + async with am(event) as ctx: + async with am(event_2) as ctx_2: + assert ctx == False + assert ctx_2 == False