From 59f58c15f7cedf16acdbb3a89bae17247ffcc3ab Mon Sep 17 00:00:00 2001 From: many Date: Wed, 31 Mar 2021 19:23:02 +0200 Subject: [PATCH] Implement attribute criterion * Implement WordLevelIterator * Implement QueryLevelIterator * Implement set algorithm based on iterators Not tested + Some TODO to fix --- milli/src/search/criteria/attribute.rs | 354 +++++++++++++++++++++++-- milli/src/search/criteria/final.rs | 4 +- milli/src/search/criteria/mod.rs | 52 +++- milli/src/search/criteria/proximity.rs | 4 +- milli/src/search/criteria/typo.rs | 4 +- milli/src/search/criteria/words.rs | 4 +- milli/src/tree_level.rs | 4 + 7 files changed, 394 insertions(+), 32 deletions(-) diff --git a/milli/src/search/criteria/attribute.rs b/milli/src/search/criteria/attribute.rs index 31c11e7bb..af336c21f 100644 --- a/milli/src/search/criteria/attribute.rs +++ b/milli/src/search/criteria/attribute.rs @@ -1,17 +1,17 @@ -use std::cmp; +use std::{cmp::{self, Ordering}, collections::BinaryHeap}; use std::collections::{BTreeMap, HashMap, btree_map}; use std::mem::take; use roaring::RoaringBitmap; -use crate::{search::build_dfa}; +use crate::{TreeLevel, search::build_dfa}; use crate::search::criteria::Query; use crate::search::query_tree::{Operation, QueryKind}; use crate::search::WordDerivationsCache; use super::{Criterion, CriterionResult, Context, resolve_query_tree}; pub struct Attribute<'t> { - ctx: &'t dyn Context, + ctx: &'t dyn Context<'t>, query_tree: Option, candidates: Option, bucket_candidates: RoaringBitmap, @@ -21,7 +21,7 @@ pub struct Attribute<'t> { } impl<'t> Attribute<'t> { - pub fn new(ctx: &'t dyn Context, parent: Box) -> Self { + pub fn new(ctx: &'t dyn Context<'t>, parent: Box) -> Self { Attribute { ctx, query_tree: None, @@ -51,23 +51,27 @@ impl<'t> Criterion for Attribute<'t> { flatten_query_tree(&qt) }); - let current_buckets = match self.current_buckets.as_mut() { - Some(current_buckets) => current_buckets, - None => { - let new_buckets = linear_compute_candidates(self.ctx, flattened_query_tree, candidates)?; - self.current_buckets.get_or_insert(new_buckets.into_iter()) - }, - }; + let found_candidates = if candidates.len() < 1000 { + let current_buckets = match self.current_buckets.as_mut() { + Some(current_buckets) => current_buckets, + None => { + let new_buckets = linear_compute_candidates(self.ctx, flattened_query_tree, candidates)?; + self.current_buckets.get_or_insert(new_buckets.into_iter()) + }, + }; - let found_candidates = match current_buckets.next() { - Some((_score, candidates)) => candidates, - None => { - return Ok(Some(CriterionResult { - query_tree: self.query_tree.take(), - candidates: self.candidates.take(), - bucket_candidates: take(&mut self.bucket_candidates), - })); - }, + match current_buckets.next() { + Some((_score, candidates)) => candidates, + None => { + return Ok(Some(CriterionResult { + query_tree: self.query_tree.take(), + candidates: self.candidates.take(), + bucket_candidates: take(&mut self.bucket_candidates), + })); + }, + } + } else { + set_compute_candidates(self.ctx, flattened_query_tree, candidates)? }; candidates.difference_with(&found_candidates); @@ -114,6 +118,316 @@ impl<'t> Criterion for Attribute<'t> { } } +struct WordLevelIterator<'t, 'q> { + inner: Box> + 't>, + level: TreeLevel, + interval_size: u32, + word: &'q str, + in_prefix_cache: bool, + inner_next: Option<(u32, u32, RoaringBitmap)>, + current_interval: Option<(u32, u32)>, +} + +impl<'t, 'q> WordLevelIterator<'t, 'q> { + fn new(ctx: &'t dyn Context<'t>, query: &'q Query) -> heed::Result> { + // TODO make it typo/prefix tolerant + let word = query.kind.word(); + let in_prefix_cache = query.prefix && ctx.in_prefix_cache(word); + match ctx.word_position_last_level(word, in_prefix_cache)? { + Some(level) => { + let interval_size = 4u32.pow(Into::::into(level.clone()) as u32); + let inner = ctx.word_position_iterator(word, level, in_prefix_cache, None, None)?; + Ok(Some(Self { inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None })) + }, + None => Ok(None), + } + } + + fn dig(&self, ctx: &'t dyn Context<'t>, level: &TreeLevel) -> heed::Result { + let level = level.min(&self.level).clone(); + let interval_size = 4u32.pow(Into::::into(level.clone()) as u32); + let word = self.word; + let in_prefix_cache = self.in_prefix_cache; + // TODO try to dig starting from the current interval + // let left = self.current_interval.map(|(left, _)| left); + let inner = ctx.word_position_iterator(word, level, in_prefix_cache, None, None)?; + + Ok(Self {inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None}) + } + + fn next(&mut self) -> heed::Result> { + fn is_next_interval(last_right: u32, next_left: u32) -> bool { last_right + 1 == next_left } + + let inner_next = match self.inner_next.take() { + Some(inner_next) => Some(inner_next), + None => self.inner.next().transpose()?.map(|((_, _, left, right), docids)| (left, right, docids)), + }; + + match inner_next { + Some((left, right, docids)) => { + match self.current_interval { + Some((last_left, last_right)) if !is_next_interval(last_right, left) => { + let blank_left = last_left + self.interval_size; + let blank_right = last_right + self.interval_size; + self.current_interval = Some((blank_left, blank_right)); + self.inner_next = Some((left, right, docids)); + Ok(Some((blank_left, blank_right, RoaringBitmap::new()))) + }, + _ => { + self.current_interval = Some((left, right)); + Ok(Some((left, right, docids))) + } + } + }, + None => Ok(None), + } + } +} + +struct QueryLevelIterator<'t, 'q> { + previous: Option>>, + inner: Vec>, + level: TreeLevel, + accumulator: Vec>, + previous_accumulator: Vec>, +} + +impl<'t, 'q> QueryLevelIterator<'t, 'q> { + fn new(ctx: &'t dyn Context<'t>, queries: &'q Vec) -> heed::Result> { + let mut inner = Vec::with_capacity(queries.len()); + for query in queries { + if let Some(word_level_iterator) = WordLevelIterator::new(ctx, query)? { + inner.push(word_level_iterator); + } + } + + let highest = inner.iter().max_by_key(|wli| wli.level).map(|wli| wli.level.clone()); + match highest { + Some(level) => Ok(Some(Self { + previous: None, + inner, + level, + accumulator: vec![], + previous_accumulator: vec![], + })), + None => Ok(None), + } + } + + fn previous(&mut self, previous: QueryLevelIterator<'t, 'q>) -> &Self { + self.previous = Some(Box::new(previous)); + self + } + + fn dig(&self, ctx: &'t dyn Context<'t>) -> heed::Result { + let (level, previous) = match &self.previous { + Some(previous) => { + let previous = previous.dig(ctx)?; + (previous.level.min(self.level), Some(Box::new(previous))) + }, + None => (self.level.saturating_sub(1), None), + }; + + let mut inner = Vec::with_capacity(self.inner.len()); + for word_level_iterator in self.inner.iter() { + inner.push(word_level_iterator.dig(ctx, &level)?); + } + + Ok(Self {previous, inner, level, accumulator: vec![], previous_accumulator: vec![]}) + } + + + + fn inner_next(&mut self, level: TreeLevel) -> heed::Result> { + let mut accumulated: Option<(u32, u32, RoaringBitmap)> = None; + let u8_level = Into::::into(level); + let interval_size = 4u32.pow(u8_level as u32); + for wli in self.inner.iter_mut() { + let wli_u8_level = Into::::into(wli.level.clone()); + let accumulated_count = 4u32.pow((u8_level - wli_u8_level) as u32); + for _ in 0..accumulated_count { + if let Some((next_left, _, next_docids)) = wli.next()? { + accumulated = accumulated.take().map( + |(acc_left, acc_right, mut acc_docids)| { + acc_docids.union_with(&next_docids); + (acc_left, acc_right, acc_docids) + } + ).or_else(|| Some((next_left, next_left + interval_size, next_docids))); + } + } + } + + Ok(accumulated) + } + + fn next(&mut self) -> heed::Result<(TreeLevel, Option<(u32, u32, RoaringBitmap)>)> { + let previous_result = match self.previous.as_mut() { + Some(previous) => { + Some(previous.next()?) + }, + None => None, + }; + + match previous_result { + Some((previous_level, previous_next)) => { + let inner_next = self.inner_next(previous_level)?; + self.accumulator.push(inner_next); + self.previous_accumulator.push(previous_next); + // TODO @many clean firsts intervals of both accumulators when both RoaringBitmap are empty, + // WARNING the cleaned intervals count needs to be kept to skip at the end + let mut merged_interval = None; + for current in self.accumulator.iter().rev().zip(self.previous_accumulator.iter()) { + if let (Some((left_a, right_a, a)), Some((left_b, right_b, b))) = current { + let (_, _, merged_docids) = merged_interval.get_or_insert_with(|| (left_a + left_b, right_a + right_b, RoaringBitmap::new())); + merged_docids.union_with(&(a & b)); + } + } + Ok((previous_level, merged_interval)) + }, + None => { + let level = self.level.clone(); + let next_interval = self.inner_next(level.clone())?; + self.accumulator = vec![next_interval.clone()]; + Ok((level, next_interval)) + } + } + } +} + +struct Branch<'t, 'q> { + query_level_iterator: QueryLevelIterator<'t, 'q>, + last_result: Option<(u32, u32, RoaringBitmap)>, + tree_level: TreeLevel, + branch_size: u32, +} + +impl<'t, 'q> Branch<'t, 'q> { + fn cmp(&self, other: &Self) -> Ordering { + fn compute_rank(left: u32, branch_size: u32) -> u32 { left.saturating_sub((1..branch_size).sum()) / branch_size } + match (&self.last_result, &other.last_result) { + (Some((s_left, _, _)), Some((o_left, _, _))) => { + // we compute a rank form the left interval. + let self_rank = compute_rank(*s_left, self.branch_size); + let other_rank = compute_rank(*o_left, other.branch_size); + let left_cmp = self_rank.cmp(&other_rank).reverse(); + // on level: higher is better, + // we want to reduce highest levels first. + let level_cmp = self.tree_level.cmp(&other.tree_level); + + left_cmp.then(level_cmp) + }, + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } +} + +impl<'t, 'q> Ord for Branch<'t, 'q> { + fn cmp(&self, other: &Self) -> Ordering { + self.cmp(other) + } +} + +impl<'t, 'q> PartialOrd for Branch<'t, 'q> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl<'t, 'q> PartialEq for Branch<'t, 'q> { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl<'t, 'q> Eq for Branch<'t, 'q> {} + +fn initialize_query_level_iterators<'t, 'q>( + ctx: &'t dyn Context<'t>, + branches: &'q Vec>>, +) -> heed::Result>> { + + let mut positions = BinaryHeap::with_capacity(branches.len()); + for branch in branches { + let mut branch_positions = Vec::with_capacity(branch.len()); + for query in branch { + match QueryLevelIterator::new(ctx, query)? { + Some(qli) => branch_positions.push(qli), + None => { + // the branch seems to be invalid, so we skip it. + branch_positions.clear(); + break; + }, + } + } + // QueryLevelIterator need to be sorted by level and folded in descending order. + branch_positions.sort_unstable_by_key(|qli| qli.level); + let folded_query_level_iterators = branch_positions + .into_iter() + .rev() + .fold(None, |fold: Option, qli| match fold { + Some(mut fold) => { + fold.previous(qli); + Some(fold) + }, + None => Some(qli), + }); + + if let Some(mut folded_query_level_iterators) = folded_query_level_iterators { + let (tree_level, last_result) = folded_query_level_iterators.next()?; + let branch = Branch { + last_result, + tree_level, + query_level_iterator: folded_query_level_iterators, + branch_size: branch.len() as u32, + }; + positions.push(branch); + } + } + + Ok(positions) +} + +fn set_compute_candidates<'t>( + ctx: &'t dyn Context<'t>, + branches: &Vec>>, + allowed_candidates: &RoaringBitmap, +) -> anyhow::Result +{ + let mut branches_heap = initialize_query_level_iterators(ctx, branches)?; + let lowest_level = TreeLevel::min_value(); + + while let Some(mut branch) = branches_heap.peek_mut() { + let is_lowest_level = branch.tree_level == lowest_level; + match branch.last_result.as_mut() { + Some((_, _, candidates)) => { + candidates.intersect_with(&allowed_candidates); + if candidates.len() > 0 && is_lowest_level { + // we have candidates, but we can't dig deeper, return candidates. + return Ok(std::mem::take(candidates)); + } else if candidates.len() > 0 { + // we have candidates, lets dig deeper in levels. + let mut query_level_iterator = branch.query_level_iterator.dig(ctx)?; + let (tree_level, last_result) = query_level_iterator.next()?; + branch.query_level_iterator = query_level_iterator; + branch.tree_level = tree_level; + branch.last_result = last_result; + } else { + // we don't have candidates, get next interval. + let (_, last_result) = branch.query_level_iterator.next()?; + branch.last_result = last_result; + } + }, + // None = no candidates to find. + None => return Ok(RoaringBitmap::new()), + } + } + + // we made all iterations without finding anything. + Ok(RoaringBitmap::new()) +} + fn linear_compute_candidates( ctx: &dyn Context, branches: &Vec>>, diff --git a/milli/src/search/criteria/final.rs b/milli/src/search/criteria/final.rs index fe224ef94..d3c394467 100644 --- a/milli/src/search/criteria/final.rs +++ b/milli/src/search/criteria/final.rs @@ -19,13 +19,13 @@ pub struct FinalResult { } pub struct Final<'t> { - ctx: &'t dyn Context, + ctx: &'t dyn Context<'t>, parent: Box, wdcache: WordDerivationsCache, } impl<'t> Final<'t> { - pub fn new(ctx: &'t dyn Context, parent: Box) -> Final<'t> { + pub fn new(ctx: &'t dyn Context<'t>, parent: Box) -> Final<'t> { Final { ctx, parent, wdcache: WordDerivationsCache::new() } } diff --git a/milli/src/search/criteria/mod.rs b/milli/src/search/criteria/mod.rs index 5e75be6ce..b972a0b2c 100644 --- a/milli/src/search/criteria/mod.rs +++ b/milli/src/search/criteria/mod.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use anyhow::bail; use roaring::RoaringBitmap; -use crate::search::{word_derivations, WordDerivationsCache}; +use crate::{TreeLevel, search::{word_derivations, WordDerivationsCache}}; use crate::{Index, DocumentId}; use super::query_tree::{Operation, Query, QueryKind}; @@ -64,7 +64,7 @@ impl Default for Candidates { } } -pub trait Context { +pub trait Context<'c> { fn documents_ids(&self) -> heed::Result; fn word_docids(&self, word: &str) -> heed::Result>; fn word_prefix_docids(&self, word: &str) -> heed::Result>; @@ -73,6 +73,8 @@ pub trait Context { fn words_fst<'t>(&self) -> &'t fst::Set>; fn in_prefix_cache(&self, word: &str) -> bool; fn docid_words_positions(&self, docid: DocumentId) -> heed::Result>; + fn word_position_iterator(&self, word: &str, level: TreeLevel, in_prefix_cache: bool, left: Option, right: Option) -> heed::Result> + 'c>>; + fn word_position_last_level(&self, word: &str, in_prefix_cache: bool) -> heed::Result>; } pub struct CriteriaBuilder<'t> { rtxn: &'t heed::RoTxn<'t>, @@ -81,7 +83,7 @@ pub struct CriteriaBuilder<'t> { words_prefixes_fst: fst::Set>, } -impl<'a> Context for CriteriaBuilder<'a> { +impl<'c> Context<'c> for CriteriaBuilder<'c> { fn documents_ids(&self) -> heed::Result { self.index.documents_ids(self.rtxn) } @@ -120,6 +122,40 @@ impl<'a> Context for CriteriaBuilder<'a> { } Ok(words_positions) } + + fn word_position_iterator(&self, word: &str, level: TreeLevel, in_prefix_cache: bool, left: Option, right: Option) -> heed::Result> + 'c>> { + let range = { + let left = left.unwrap_or(u32::min_value()); + let right = right.unwrap_or(u32::max_value()); + let left = (word, level, left, left); + let right = (word, level, right, right); + left..=right + }; + let db = match in_prefix_cache { + true => self.index.word_prefix_level_position_docids, + false => self.index.word_level_position_docids, + }; + + Ok(Box::new(db.range(self.rtxn, &range)?)) + } + + fn word_position_last_level(&self, word: &str, in_prefix_cache: bool) -> heed::Result> { + let range = { + let left = (word, TreeLevel::min_value(), u32::min_value(), u32::min_value()); + let right = (word, TreeLevel::max_value(), u32::max_value(), u32::max_value()); + left..=right + }; + let db = match in_prefix_cache { + true => self.index.word_prefix_level_position_docids, + false => self.index.word_level_position_docids, + }; + let last_level = db + .remap_data_type::() + .range(self.rtxn, &range)?.last().transpose()? + .map(|((_, level, _, _), _)| level); + + Ok(last_level) + } } impl<'t> CriteriaBuilder<'t> { @@ -354,7 +390,7 @@ pub mod test { docid_words: HashMap>, } - impl<'a> Context for TestContext<'a> { + impl<'c> Context<'c> for TestContext<'c> { fn documents_ids(&self) -> heed::Result { Ok(self.word_docids.iter().fold(RoaringBitmap::new(), |acc, (_, docids)| acc | docids)) } @@ -397,6 +433,14 @@ pub mod test { Ok(HashMap::new()) } } + + fn word_position_iterator(&self, _word: &str, _level: TreeLevel, _in_prefix_cache: bool, _left: Option, _right: Option) -> heed::Result> + 'c>> { + todo!() + } + + fn word_position_last_level(&self, _word: &str, _in_prefix_cache: bool) -> heed::Result> { + todo!() + } } impl<'a> Default for TestContext<'a> { diff --git a/milli/src/search/criteria/proximity.rs b/milli/src/search/criteria/proximity.rs index dc1daafb2..ca412bf28 100644 --- a/milli/src/search/criteria/proximity.rs +++ b/milli/src/search/criteria/proximity.rs @@ -13,7 +13,7 @@ use super::{Criterion, CriterionResult, Context, query_docids, query_pair_proxim type Cache = HashMap<(Operation, u8), Vec<(Query, Query, RoaringBitmap)>>; pub struct Proximity<'t> { - ctx: &'t dyn Context, + ctx: &'t dyn Context<'t>, /// ((max_proximity, query_tree), allowed_candidates) state: Option<(Option<(usize, Operation)>, RoaringBitmap)>, proximity: u8, @@ -24,7 +24,7 @@ pub struct Proximity<'t> { } impl<'t> Proximity<'t> { - pub fn new(ctx: &'t dyn Context, parent: Box) -> Self { + pub fn new(ctx: &'t dyn Context<'t>, parent: Box) -> Self { Proximity { ctx, state: None, diff --git a/milli/src/search/criteria/typo.rs b/milli/src/search/criteria/typo.rs index 40b06afc4..bf58fa258 100644 --- a/milli/src/search/criteria/typo.rs +++ b/milli/src/search/criteria/typo.rs @@ -9,7 +9,7 @@ use crate::search::{word_derivations, WordDerivationsCache}; use super::{Candidates, Criterion, CriterionResult, Context, query_docids, query_pair_proximity_docids}; pub struct Typo<'t> { - ctx: &'t dyn Context, + ctx: &'t dyn Context<'t>, query_tree: Option<(usize, Operation)>, number_typos: u8, candidates: Candidates, @@ -19,7 +19,7 @@ pub struct Typo<'t> { } impl<'t> Typo<'t> { - pub fn new(ctx: &'t dyn Context, parent: Box) -> Self { + pub fn new(ctx: &'t dyn Context<'t>, parent: Box) -> Self { Typo { ctx, query_tree: None, diff --git a/milli/src/search/criteria/words.rs b/milli/src/search/criteria/words.rs index 5bb9d8d90..047b3c5f0 100644 --- a/milli/src/search/criteria/words.rs +++ b/milli/src/search/criteria/words.rs @@ -8,7 +8,7 @@ use crate::search::query_tree::Operation; use super::{resolve_query_tree, Criterion, CriterionResult, Context, WordDerivationsCache}; pub struct Words<'t> { - ctx: &'t dyn Context, + ctx: &'t dyn Context<'t>, query_trees: Vec, candidates: Option, bucket_candidates: RoaringBitmap, @@ -17,7 +17,7 @@ pub struct Words<'t> { } impl<'t> Words<'t> { - pub fn new(ctx: &'t dyn Context, parent: Box) -> Self { + pub fn new(ctx: &'t dyn Context<'t>, parent: Box) -> Self { Words { ctx, query_trees: Vec::default(), diff --git a/milli/src/tree_level.rs b/milli/src/tree_level.rs index 7ce2904e2..b69316cf6 100644 --- a/milli/src/tree_level.rs +++ b/milli/src/tree_level.rs @@ -21,6 +21,10 @@ impl TreeLevel { pub const fn min_value() -> TreeLevel { TreeLevel(0) } + + pub fn saturating_sub(&self, lhs: u8) -> TreeLevel { + TreeLevel(self.0.saturating_sub(lhs)) + } } impl Into for TreeLevel {