diff --git a/milli/src/search/criteria/attribute.rs b/milli/src/search/criteria/attribute.rs index af3e08af1..8d150730f 100644 --- a/milli/src/search/criteria/attribute.rs +++ b/milli/src/search/criteria/attribute.rs @@ -1,5 +1,6 @@ use std::{borrow::Cow, cmp::{self, Ordering}, collections::BinaryHeap}; use std::collections::{BTreeMap, HashMap, btree_map}; +use std::collections::binary_heap::PeekMut; use std::mem::take; use roaring::RoaringBitmap; @@ -332,13 +333,26 @@ struct Branch<'t, 'q> { } impl<'t, 'q> Branch<'t, 'q> { - fn cmp(&self, other: &Self) -> Ordering { - let compute_rank = |left: u32, branch_size: u32| left.saturating_sub((0..branch_size).sum()) / branch_size; - let (s_left, _, _) = self.last_result; - let (o_left, _, _) = other.last_result; + fn next(&mut self) -> heed::Result { + match self.query_level_iterator.next()? { + (tree_level, Some(last_result)) => { + self.last_result = last_result; + self.tree_level = tree_level; + Ok(true) + }, + (_, None) => Ok(false), + } + } + + fn compute_rank(&self) -> u32 { // we compute a rank from the left interval. - let self_rank = compute_rank(s_left, self.branch_size); - let other_rank = compute_rank(o_left, other.branch_size); + let (left, _, _) = self.last_result; + left.saturating_sub((0..self.branch_size).sum()) * 60 / self.branch_size + } + + fn cmp(&self, other: &Self) -> Ordering { + let self_rank = self.compute_rank(); + let other_rank = other.compute_rank(); let left_cmp = self_rank.cmp(&other_rank).reverse(); // on level: higher is better, // we want to reduce highest levels first. @@ -426,44 +440,53 @@ fn set_compute_candidates<'t>( { let mut branches_heap = initialize_query_level_iterators(ctx, branches, wdcache)?; let lowest_level = TreeLevel::min_value(); - let mut final_candidates = None; + let mut final_candidates: Option<(u32, RoaringBitmap)> = None; while let Some(mut branch) = branches_heap.peek_mut() { let is_lowest_level = branch.tree_level == lowest_level; + let branch_rank = branch.compute_rank(); let (_, _, candidates) = &mut branch.last_result; candidates.intersect_with(&allowed_candidates); if candidates.is_empty() { // we don't have candidates, get next interval. - match branch.query_level_iterator.next()? { - (_, Some(last_result)) => { - branch.last_result = last_result; - }, - // TODO clean up this - (_, None) => { std::collections::binary_heap::PeekMut::<'_, Branch<'_, '_>>::pop(branch); }, - } - + if !branch.next()? { PeekMut::pop(branch); } } else if is_lowest_level { // we have candidates, but we can't dig deeper, return candidates. - final_candidates = Some(take(candidates)); - break; + final_candidates = match final_candidates.take() { + Some((best_rank, mut best_candidates)) => { + // if current is worst than best we break to return + // candidates that correspond to the best rank + if branch_rank > best_rank { + final_candidates = Some((best_rank, best_candidates)); + break; + // else we add current candidates to best candidates + // and we fetch the next page + } else { + best_candidates.union_with(candidates); + if !branch.next()? { PeekMut::pop(branch); } + Some((best_rank, best_candidates)) + } + }, + // we take current candidates as best candidates + // and we fetch the next page + None => { + let candidates = take(candidates); + if !branch.next()? { PeekMut::pop(branch); } + Some((branch_rank, candidates)) + }, + }; } else { // we have candidates, lets dig deeper in levels. - let mut query_level_iterator = branch.query_level_iterator.dig(ctx)?; - match query_level_iterator.next()? { - (tree_level, Some(last_result)) => { - branch.query_level_iterator = query_level_iterator; - branch.tree_level = tree_level; - branch.last_result = last_result; - }, - // TODO clean up this - (_, None) => { std::collections::binary_heap::PeekMut::<'_, Branch<'_, '_>>::pop(branch); }, - } + branch.query_level_iterator = branch.query_level_iterator.dig(ctx)?; + if !branch.next()? { PeekMut::pop(branch); } } } - Ok(final_candidates) + Ok(final_candidates.map(|(_rank, candidates)| { + candidates + })) } fn linear_compute_candidates(