mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-09-21 05:12:34 +00:00
commit
b343fb8e6f
@ -1,34 +1,31 @@
|
|||||||
from typing import Dict
|
from typing import Dict, Generator
|
||||||
|
|
||||||
from nonebot.adapters import Event
|
from nonebot.adapters import Event
|
||||||
from nonebot.message import (
|
from nonebot.params import Depends
|
||||||
IgnoredException,
|
from nonebot.message import IgnoredException, event_preprocessor
|
||||||
run_preprocessor,
|
|
||||||
run_postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
_running_matcher: Dict[str, int] = {}
|
_running_matcher: Dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
@run_preprocessor
|
async def matcher_mutex(event: Event) -> Generator[bool, None, None]:
|
||||||
async def preprocess(event: Event):
|
result = False
|
||||||
try:
|
try:
|
||||||
session_id = event.get_session_id()
|
session_id = event.get_session_id()
|
||||||
except Exception:
|
except Exception:
|
||||||
return
|
yield result
|
||||||
|
else:
|
||||||
current_event_id = id(event)
|
current_event_id = id(event)
|
||||||
event_id = _running_matcher.get(session_id, None)
|
event_id = _running_matcher.get(session_id, None)
|
||||||
if event_id and event_id != current_event_id:
|
if event_id:
|
||||||
raise IgnoredException("Another matcher running")
|
result = event_id != current_event_id
|
||||||
|
else:
|
||||||
_running_matcher[session_id] = current_event_id
|
_running_matcher[session_id] = current_event_id
|
||||||
|
yield result
|
||||||
|
if not result:
|
||||||
@run_postprocessor
|
|
||||||
async def postprocess(event: Event):
|
|
||||||
try:
|
|
||||||
session_id = event.get_session_id()
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
if session_id in _running_matcher:
|
|
||||||
del _running_matcher[session_id]
|
del _running_matcher[session_id]
|
||||||
|
|
||||||
|
|
||||||
|
@event_preprocessor
|
||||||
|
async def preprocess(mutex: bool = Depends(matcher_mutex)):
|
||||||
|
if mutex:
|
||||||
|
raise IgnoredException("Another matcher running")
|
||||||
|
36
tests/test_single_session.py
Normal file
36
tests/test_single_session.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import make_fake_event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_matcher_mutex():
|
||||||
|
from nonebot.plugins.single_session import matcher_mutex, _running_matcher
|
||||||
|
|
||||||
|
am = asynccontextmanager(matcher_mutex)
|
||||||
|
event = make_fake_event()()
|
||||||
|
event_1 = make_fake_event()()
|
||||||
|
event_2 = make_fake_event(_session_id="test1")()
|
||||||
|
event_3 = make_fake_event(_session_id=None)()
|
||||||
|
|
||||||
|
async with am(event) as ctx:
|
||||||
|
assert ctx == False
|
||||||
|
assert not _running_matcher
|
||||||
|
|
||||||
|
async with am(event) as ctx:
|
||||||
|
async with am(event_1) as ctx_1:
|
||||||
|
assert ctx == False
|
||||||
|
assert ctx_1 == True
|
||||||
|
assert not _running_matcher
|
||||||
|
|
||||||
|
async with am(event) as ctx:
|
||||||
|
async with am(event_2) as ctx_2:
|
||||||
|
assert ctx == False
|
||||||
|
assert ctx_2 == False
|
||||||
|
assert not _running_matcher
|
||||||
|
|
||||||
|
async with am(event_3) as ctx_3:
|
||||||
|
assert ctx_3 == False
|
||||||
|
assert not _running_matcher
|
@ -50,7 +50,7 @@ def make_fake_event(
|
|||||||
_name: str = "test",
|
_name: str = "test",
|
||||||
_description: str = "test",
|
_description: str = "test",
|
||||||
_user_id: str = "test",
|
_user_id: str = "test",
|
||||||
_session_id: str = "test",
|
_session_id: Optional[str] = "test",
|
||||||
_message: Optional["Message"] = None,
|
_message: Optional["Message"] = None,
|
||||||
_to_me: bool = True,
|
_to_me: bool = True,
|
||||||
**fields,
|
**fields,
|
||||||
@ -73,7 +73,9 @@ def make_fake_event(
|
|||||||
return _user_id
|
return _user_id
|
||||||
|
|
||||||
def get_session_id(self) -> str:
|
def get_session_id(self) -> str:
|
||||||
|
if _session_id is not None:
|
||||||
return _session_id
|
return _session_id
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_message(self) -> "Message":
|
def get_message(self) -> "Message":
|
||||||
if _message is not None:
|
if _message is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user