🔀 Merge pull request #711

Fix: single_session potential bug
This commit is contained in:
Ju4tCode 2022-01-20 10:59:10 +08:00 committed by GitHub
commit b343fb8e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 26 deletions

View File

@ -1,34 +1,31 @@
from typing import Dict
from typing import Dict, Generator
from nonebot.adapters import Event
from nonebot.message import (
IgnoredException,
run_preprocessor,
run_postprocessor,
)
from nonebot.params import Depends
from nonebot.message import IgnoredException, event_preprocessor
_running_matcher: Dict[str, int] = {}
@run_preprocessor
async def preprocess(event: Event):
async def matcher_mutex(event: Event) -> Generator[bool, None, None]:
result = False
try:
session_id = event.get_session_id()
except Exception:
return
current_event_id = id(event)
event_id = _running_matcher.get(session_id, None)
if event_id and event_id != current_event_id:
yield result
else:
current_event_id = id(event)
event_id = _running_matcher.get(session_id, None)
if event_id:
result = event_id != current_event_id
else:
_running_matcher[session_id] = current_event_id
yield result
if not result:
del _running_matcher[session_id]
@event_preprocessor
async def preprocess(mutex: bool = Depends(matcher_mutex)):
if mutex:
raise IgnoredException("Another matcher running")
_running_matcher[session_id] = current_event_id
@run_postprocessor
async def postprocess(event: Event):
try:
session_id = event.get_session_id()
except Exception:
return
if session_id in _running_matcher:
del _running_matcher[session_id]

View File

@ -0,0 +1,36 @@
from contextlib import asynccontextmanager
import pytest
from utils import make_fake_event
@pytest.mark.asyncio
async def test_matcher_mutex():
from nonebot.plugins.single_session import matcher_mutex, _running_matcher
am = asynccontextmanager(matcher_mutex)
event = make_fake_event()()
event_1 = make_fake_event()()
event_2 = make_fake_event(_session_id="test1")()
event_3 = make_fake_event(_session_id=None)()
async with am(event) as ctx:
assert ctx == False
assert not _running_matcher
async with am(event) as ctx:
async with am(event_1) as ctx_1:
assert ctx == False
assert ctx_1 == True
assert not _running_matcher
async with am(event) as ctx:
async with am(event_2) as ctx_2:
assert ctx == False
assert ctx_2 == False
assert not _running_matcher
async with am(event_3) as ctx_3:
assert ctx_3 == False
assert not _running_matcher

View File

@ -50,7 +50,7 @@ def make_fake_event(
_name: str = "test",
_description: str = "test",
_user_id: str = "test",
_session_id: str = "test",
_session_id: Optional[str] = "test",
_message: Optional["Message"] = None,
_to_me: bool = True,
**fields,
@ -73,7 +73,9 @@ def make_fake_event(
return _user_id
def get_session_id(self) -> str:
return _session_id
if _session_id is not None:
return _session_id
raise NotImplementedError
def get_message(self) -> "Message":
if _message is not None: