Fix the return of equal candidates in different pages

This commit is contained in:
many 2021-04-13 15:06:12 +02:00
parent 0efa011e09
commit 2b036449be
No known key found for this signature in database
GPG Key ID: 2CEF23B75189EACA

View File

@ -1,5 +1,6 @@
use std::{borrow::Cow, cmp::{self, Ordering}, collections::BinaryHeap}; use std::{borrow::Cow, cmp::{self, Ordering}, collections::BinaryHeap};
use std::collections::{BTreeMap, HashMap, btree_map}; use std::collections::{BTreeMap, HashMap, btree_map};
use std::collections::binary_heap::PeekMut;
use std::mem::take; use std::mem::take;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
@ -332,13 +333,26 @@ struct Branch<'t, 'q> {
} }
impl<'t, 'q> Branch<'t, 'q> { impl<'t, 'q> Branch<'t, 'q> {
fn cmp(&self, other: &Self) -> Ordering { fn next(&mut self) -> heed::Result<bool> {
let compute_rank = |left: u32, branch_size: u32| left.saturating_sub((0..branch_size).sum()) / branch_size; match self.query_level_iterator.next()? {
let (s_left, _, _) = self.last_result; (tree_level, Some(last_result)) => {
let (o_left, _, _) = other.last_result; self.last_result = last_result;
self.tree_level = tree_level;
Ok(true)
},
(_, None) => Ok(false),
}
}
fn compute_rank(&self) -> u32 {
// we compute a rank from the left interval. // we compute a rank from the left interval.
let self_rank = compute_rank(s_left, self.branch_size); let (left, _, _) = self.last_result;
let other_rank = compute_rank(o_left, other.branch_size); left.saturating_sub((0..self.branch_size).sum()) * 60 / self.branch_size
}
fn cmp(&self, other: &Self) -> Ordering {
let self_rank = self.compute_rank();
let other_rank = other.compute_rank();
let left_cmp = self_rank.cmp(&other_rank).reverse(); let left_cmp = self_rank.cmp(&other_rank).reverse();
// on level: higher is better, // on level: higher is better,
// we want to reduce highest levels first. // we want to reduce highest levels first.
@ -426,44 +440,53 @@ fn set_compute_candidates<'t>(
{ {
let mut branches_heap = initialize_query_level_iterators(ctx, branches, wdcache)?; let mut branches_heap = initialize_query_level_iterators(ctx, branches, wdcache)?;
let lowest_level = TreeLevel::min_value(); let lowest_level = TreeLevel::min_value();
let mut final_candidates = None; let mut final_candidates: Option<(u32, RoaringBitmap)> = None;
while let Some(mut branch) = branches_heap.peek_mut() { while let Some(mut branch) = branches_heap.peek_mut() {
let is_lowest_level = branch.tree_level == lowest_level; let is_lowest_level = branch.tree_level == lowest_level;
let branch_rank = branch.compute_rank();
let (_, _, candidates) = &mut branch.last_result; let (_, _, candidates) = &mut branch.last_result;
candidates.intersect_with(&allowed_candidates); candidates.intersect_with(&allowed_candidates);
if candidates.is_empty() { if candidates.is_empty() {
// we don't have candidates, get next interval. // we don't have candidates, get next interval.
match branch.query_level_iterator.next()? { if !branch.next()? { PeekMut::pop(branch); }
(_, Some(last_result)) => {
branch.last_result = last_result;
},
// TODO clean up this
(_, None) => { std::collections::binary_heap::PeekMut::<'_, Branch<'_, '_>>::pop(branch); },
}
} }
else if is_lowest_level { else if is_lowest_level {
// we have candidates, but we can't dig deeper, return candidates. // we have candidates, but we can't dig deeper, return candidates.
final_candidates = Some(take(candidates)); final_candidates = match final_candidates.take() {
Some((best_rank, mut best_candidates)) => {
// if current is worst than best we break to return
// candidates that correspond to the best rank
if branch_rank > best_rank {
final_candidates = Some((best_rank, best_candidates));
break; break;
// else we add current candidates to best candidates
// and we fetch the next page
} else {
best_candidates.union_with(candidates);
if !branch.next()? { PeekMut::pop(branch); }
Some((best_rank, best_candidates))
}
},
// we take current candidates as best candidates
// and we fetch the next page
None => {
let candidates = take(candidates);
if !branch.next()? { PeekMut::pop(branch); }
Some((branch_rank, candidates))
},
};
} else { } else {
// we have candidates, lets dig deeper in levels. // we have candidates, lets dig deeper in levels.
let mut query_level_iterator = branch.query_level_iterator.dig(ctx)?; branch.query_level_iterator = branch.query_level_iterator.dig(ctx)?;
match query_level_iterator.next()? { if !branch.next()? { PeekMut::pop(branch); }
(tree_level, Some(last_result)) => {
branch.query_level_iterator = query_level_iterator;
branch.tree_level = tree_level;
branch.last_result = last_result;
},
// TODO clean up this
(_, None) => { std::collections::binary_heap::PeekMut::<'_, Branch<'_, '_>>::pop(branch); },
}
} }
} }
Ok(final_candidates) Ok(final_candidates.map(|(_rank, candidates)| {
candidates
}))
} }
fn linear_compute_candidates( fn linear_compute_candidates(