diff --git a/milli/src/search/criteria/attribute.rs b/milli/src/search/criteria/attribute.rs index 2672169de..745d8cdb0 100644 --- a/milli/src/search/criteria/attribute.rs +++ b/milli/src/search/criteria/attribute.rs @@ -101,7 +101,7 @@ impl<'t> Criterion for Attribute<'t> { }, (Some(qt), None) => { let query_tree_candidates = resolve_query_tree(self.ctx, &qt, &mut HashMap::new(), wdcache)?; - self.bucket_candidates.union_with(&query_tree_candidates); + self.bucket_candidates |= &query_tree_candidates; self.candidates = Some(query_tree_candidates); }, (None, Some(_)) => { @@ -123,7 +123,7 @@ impl<'t> Criterion for Attribute<'t> { Some(CriterionResult { query_tree, candidates, bucket_candidates }) => { self.query_tree = query_tree; self.candidates = candidates; - self.bucket_candidates.union_with(&bucket_candidates); + self.bucket_candidates |= bucket_candidates; self.flattened_query_tree = None; self.current_buckets = None; }, @@ -160,14 +160,12 @@ impl<'t, 'q> WordLevelIterator<'t, 'q> { } } - fn dig(&self, ctx: &'t dyn Context<'t>, level: &TreeLevel) -> heed::Result { + fn dig(&self, ctx: &'t dyn Context<'t>, level: &TreeLevel, left_interval: Option) -> heed::Result { let level = level.min(&self.level).clone(); let interval_size = 4u32.pow(Into::::into(level.clone()) as u32); let word = self.word.clone(); let in_prefix_cache = self.in_prefix_cache; - // TODO try to dig starting from the current interval - // let left = self.current_interval.map(|(left, _)| left); - let inner = ctx.word_position_iterator(&word, level, in_prefix_cache, None, None)?; + let inner = ctx.word_position_iterator(&word, level, in_prefix_cache, left_interval, None)?; Ok(Self {inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None}) } @@ -209,6 +207,7 @@ struct QueryLevelIterator<'t, 'q> { level: TreeLevel, accumulator: Vec>, parent_accumulator: Vec>, + interval_to_skip: usize, } impl<'t, 'q> QueryLevelIterator<'t, 'q> { @@ -250,6 +249,7 @@ impl<'t, 'q> QueryLevelIterator<'t, 'q> { level, accumulator: vec![], parent_accumulator: vec![], + interval_to_skip: 0, })), None => Ok(None), } @@ -270,16 +270,15 @@ impl<'t, 'q> QueryLevelIterator<'t, 'q> { None => (self.level.saturating_sub(1), None), }; + let left_interval = self.accumulator.get(self.interval_to_skip).map(|opt| opt.as_ref().map(|(left, _, _)| *left)).flatten(); let mut inner = Vec::with_capacity(self.inner.len()); for word_level_iterator in self.inner.iter() { - inner.push(word_level_iterator.dig(ctx, &level)?); + inner.push(word_level_iterator.dig(ctx, &level, left_interval)?); } - Ok(Self {parent, inner, level, accumulator: vec![], parent_accumulator: vec![]}) + Ok(Self {parent, inner, level, accumulator: vec![], parent_accumulator: vec![], interval_to_skip: 0}) } - - fn inner_next(&mut self, level: TreeLevel) -> heed::Result> { let mut accumulated: Option<(u32, u32, RoaringBitmap)> = None; let u8_level = Into::::into(level); @@ -289,12 +288,13 @@ impl<'t, 'q> QueryLevelIterator<'t, 'q> { let accumulated_count = 4u32.pow((u8_level - wli_u8_level) as u32); for _ in 0..accumulated_count { if let Some((next_left, _, next_docids)) = wli.next()? { - accumulated = accumulated.take().map( - |(acc_left, acc_right, mut acc_docids)| { - acc_docids.union_with(&next_docids); - (acc_left, acc_right, acc_docids) - } - ).or_else(|| Some((next_left, next_left + interval_size, next_docids))); + accumulated = match accumulated.take(){ + Some((acc_left, acc_right, mut acc_docids)) => { + acc_docids |= next_docids; + Some((acc_left, acc_right, acc_docids)) + }, + None => Some((next_left, next_left + interval_size, next_docids)), + }; } } } @@ -304,35 +304,59 @@ impl<'t, 'q> QueryLevelIterator<'t, 'q> { /// return the next meta-interval created from inner WordLevelIterators, /// and from eventual chainned QueryLevelIterator. - fn next(&mut self) -> heed::Result<(TreeLevel, Option<(u32, u32, RoaringBitmap)>)> { + fn next(&mut self, allowed_candidates: &RoaringBitmap, tree_level: TreeLevel) -> heed::Result> { let parent_result = match self.parent.as_mut() { Some(parent) => { - Some(parent.next()?) + Some(parent.next(allowed_candidates, tree_level)?) }, None => None, }; match parent_result { - Some((parent_level, parent_next)) => { - let inner_next = self.inner_next(parent_level)?; + Some(parent_next) => { + let inner_next = self.inner_next(tree_level)?; + self.interval_to_skip += self.accumulator.iter().zip(self.parent_accumulator.iter()).skip(self.interval_to_skip).take_while(|current| { + match current { + (Some((_, _, inner)), Some((_, _, parent))) => { + inner.is_disjoint(allowed_candidates) && parent.is_empty() + }, + (Some((_, _, inner)), None) => { + inner.is_disjoint(allowed_candidates) + }, + (None, Some((_, _, parent))) => { + parent.is_empty() + }, + (None, None) => true, + } + }).count(); self.accumulator.push(inner_next); self.parent_accumulator.push(parent_next); - // TODO @many clean firsts intervals of both accumulators when both RoaringBitmap are empty, - // WARNING the cleaned intervals count needs to be kept to skip at the end - let mut merged_interval = None; - for current in self.accumulator.iter().rev().zip(self.parent_accumulator.iter()) { + let mut merged_interval: Option<(u32, u32, RoaringBitmap)> = None; + + for current in self.accumulator.iter().rev().zip(self.parent_accumulator.iter()).skip(self.interval_to_skip) { if let (Some((left_a, right_a, a)), Some((left_b, right_b, b))) = current { - let (_, _, merged_docids) = merged_interval.get_or_insert_with(|| (left_a + left_b, right_a + right_b, RoaringBitmap::new())); - merged_docids.union_with(&(a & b)); + match merged_interval.as_mut() { + Some((_, _, merged_docids)) => *merged_docids |= a & b, + None => merged_interval = Some((left_a + left_b, right_a + right_b, a & b)), + } } } - Ok((parent_level, merged_interval)) + Ok(merged_interval) }, None => { - let level = self.level.clone(); - let next_interval = self.inner_next(level.clone())?; - self.accumulator = vec![next_interval.clone()]; - Ok((level, next_interval)) + let level = self.level; + match self.inner_next(level)? { + Some((left, right, mut candidates)) => { + self.accumulator = vec![Some((left, right, RoaringBitmap::new()))]; + candidates &= allowed_candidates; + Ok(Some((left, right, candidates))) + + }, + None => { + self.accumulator = vec![None]; + Ok(None) + }, + } } } } @@ -346,17 +370,31 @@ struct Branch<'t, 'q> { } impl<'t, 'q> Branch<'t, 'q> { - fn next(&mut self) -> heed::Result { - match self.query_level_iterator.next()? { - (tree_level, Some(last_result)) => { + fn next(&mut self, allowed_candidates: &RoaringBitmap) -> heed::Result { + let tree_level = self.query_level_iterator.level; + match self.query_level_iterator.next(allowed_candidates, tree_level)? { + Some(last_result) => { self.last_result = last_result; self.tree_level = tree_level; Ok(true) }, - (_, None) => Ok(false), + None => Ok(false), } } + fn dig(&mut self, ctx: &'t dyn Context<'t>) -> heed::Result<()> { + self.query_level_iterator = self.query_level_iterator.dig(ctx)?; + Ok(()) + } + + fn lazy_next(&mut self) { + let u8_level = Into::::into(self.tree_level.clone()); + let interval_size = 4u32.pow(u8_level as u32); + let (left, right, _) = self.last_result; + + self.last_result = (left + interval_size, right + interval_size, RoaringBitmap::new()); + } + fn compute_rank(&self) -> u32 { // we compute a rank from the left interval. let (left, _, _) = self.last_result; @@ -367,11 +405,11 @@ impl<'t, 'q> Branch<'t, 'q> { let self_rank = self.compute_rank(); let other_rank = other.compute_rank(); let left_cmp = self_rank.cmp(&other_rank).reverse(); - // on level: higher is better, - // we want to reduce highest levels first. - let level_cmp = self.tree_level.cmp(&other.tree_level); + // on level: lower is better, + // we want to dig faster into levels on interesting branches. + let level_cmp = self.tree_level.cmp(&other.tree_level).reverse(); - left_cmp.then(level_cmp) + left_cmp.then(level_cmp).then(self.last_result.2.len().cmp(&other.last_result.2.len())) } } @@ -398,6 +436,7 @@ impl<'t, 'q> Eq for Branch<'t, 'q> {} fn initialize_query_level_iterators<'t, 'q>( ctx: &'t dyn Context<'t>, branches: &'q Vec>>, + allowed_candidates: &RoaringBitmap, wdcache: &mut WordDerivationsCache, ) -> anyhow::Result>> { @@ -418,7 +457,6 @@ fn initialize_query_level_iterators<'t, 'q>( branch_positions.sort_unstable_by_key(|qli| qli.level); let folded_query_level_iterators = branch_positions .into_iter() - .rev() .fold(None, |fold: Option, mut qli| match fold { Some(fold) => { qli.parent(fold); @@ -428,7 +466,8 @@ fn initialize_query_level_iterators<'t, 'q>( }); if let Some(mut folded_query_level_iterators) = folded_query_level_iterators { - let (tree_level, last_result) = folded_query_level_iterators.next()?; + let tree_level = folded_query_level_iterators.level; + let last_result = folded_query_level_iterators.next(allowed_candidates, tree_level)?; if let Some(last_result) = last_result { let branch = Branch { last_result, @@ -451,48 +490,43 @@ fn set_compute_candidates<'t>( wdcache: &mut WordDerivationsCache, ) -> anyhow::Result> { - let mut branches_heap = initialize_query_level_iterators(ctx, branches, wdcache)?; + let mut branches_heap = initialize_query_level_iterators(ctx, branches, allowed_candidates, wdcache)?; let lowest_level = TreeLevel::min_value(); let mut final_candidates: Option<(u32, RoaringBitmap)> = None; + let mut allowed_candidates = allowed_candidates.clone(); while let Some(mut branch) = branches_heap.peek_mut() { let is_lowest_level = branch.tree_level == lowest_level; let branch_rank = branch.compute_rank(); - let (_, _, candidates) = &mut branch.last_result; - candidates.intersect_with(&allowed_candidates); + // if current is worst than best we break to return + // candidates that correspond to the best rank + if let Some((best_rank, _)) = final_candidates { if branch_rank > best_rank { break; } } + let _left = branch.last_result.0; + let candidates = take(&mut branch.last_result.2); if candidates.is_empty() { // we don't have candidates, get next interval. - if !branch.next()? { PeekMut::pop(branch); } + if !branch.next(&allowed_candidates)? { PeekMut::pop(branch); } } else if is_lowest_level { - // we have candidates, but we can't dig deeper, return candidates. + // we have candidates, but we can't dig deeper. + allowed_candidates -= &candidates; final_candidates = match final_candidates.take() { + // we add current candidates to best candidates Some((best_rank, mut best_candidates)) => { - // if current is worst than best we break to return - // candidates that correspond to the best rank - if branch_rank > best_rank { - final_candidates = Some((best_rank, best_candidates)); - break; - // else we add current candidates to best candidates - // and we fetch the next page - } else { - best_candidates.union_with(candidates); - if !branch.next()? { PeekMut::pop(branch); } - Some((best_rank, best_candidates)) - } + best_candidates |= candidates; + branch.lazy_next(); + Some((best_rank, best_candidates)) }, // we take current candidates as best candidates - // and we fetch the next page None => { - let candidates = take(candidates); - if !branch.next()? { PeekMut::pop(branch); } + branch.lazy_next(); Some((branch_rank, candidates)) }, }; } else { // we have candidates, lets dig deeper in levels. - branch.query_level_iterator = branch.query_level_iterator.dig(ctx)?; - if !branch.next()? { PeekMut::pop(branch); } + branch.dig(ctx)?; + if !branch.next(&allowed_candidates)? { PeekMut::pop(branch); } } }