diff --git a/milli/src/search/criteria/attribute.rs b/milli/src/search/criteria/attribute.rs index af336c21f..87f9d4dde 100644 --- a/milli/src/search/criteria/attribute.rs +++ b/milli/src/search/criteria/attribute.rs @@ -1,4 +1,4 @@ -use std::{cmp::{self, Ordering}, collections::BinaryHeap}; +use std::{borrow::Cow, cmp::{self, Ordering}, collections::BinaryHeap}; use std::collections::{BTreeMap, HashMap, btree_map}; use std::mem::take; @@ -7,7 +7,7 @@ use roaring::RoaringBitmap; use crate::{TreeLevel, search::build_dfa}; use crate::search::criteria::Query; use crate::search::query_tree::{Operation, QueryKind}; -use crate::search::WordDerivationsCache; +use crate::search::{word_derivations, WordDerivationsCache}; use super::{Criterion, CriterionResult, Context, resolve_query_tree}; pub struct Attribute<'t> { @@ -71,7 +71,7 @@ impl<'t> Criterion for Attribute<'t> { }, } } else { - set_compute_candidates(self.ctx, flattened_query_tree, candidates)? + set_compute_candidates(self.ctx, flattened_query_tree, candidates, wdcache)? }; candidates.difference_with(&found_candidates); @@ -122,21 +122,18 @@ struct WordLevelIterator<'t, 'q> { inner: Box> + 't>, level: TreeLevel, interval_size: u32, - word: &'q str, + word: Cow<'q, str>, in_prefix_cache: bool, inner_next: Option<(u32, u32, RoaringBitmap)>, current_interval: Option<(u32, u32)>, } impl<'t, 'q> WordLevelIterator<'t, 'q> { - fn new(ctx: &'t dyn Context<'t>, query: &'q Query) -> heed::Result> { - // TODO make it typo/prefix tolerant - let word = query.kind.word(); - let in_prefix_cache = query.prefix && ctx.in_prefix_cache(word); - match ctx.word_position_last_level(word, in_prefix_cache)? { + fn new(ctx: &'t dyn Context<'t>, word: Cow<'q, str>, in_prefix_cache: bool) -> heed::Result> { + match ctx.word_position_last_level(&word, in_prefix_cache)? { Some(level) => { let interval_size = 4u32.pow(Into::::into(level.clone()) as u32); - let inner = ctx.word_position_iterator(word, level, in_prefix_cache, None, None)?; + let inner = ctx.word_position_iterator(&word, level, in_prefix_cache, None, None)?; Ok(Some(Self { inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None })) }, None => Ok(None), @@ -146,11 +143,11 @@ impl<'t, 'q> WordLevelIterator<'t, 'q> { fn dig(&self, ctx: &'t dyn Context<'t>, level: &TreeLevel) -> heed::Result { let level = level.min(&self.level).clone(); let interval_size = 4u32.pow(Into::::into(level.clone()) as u32); - let word = self.word; + let word = self.word.clone(); let in_prefix_cache = self.in_prefix_cache; // TODO try to dig starting from the current interval // let left = self.current_interval.map(|(left, _)| left); - let inner = ctx.word_position_iterator(word, level, in_prefix_cache, None, None)?; + let inner = ctx.word_position_iterator(&word, level, in_prefix_cache, None, None)?; Ok(Self {inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None}) } @@ -193,11 +190,33 @@ struct QueryLevelIterator<'t, 'q> { } impl<'t, 'q> QueryLevelIterator<'t, 'q> { - fn new(ctx: &'t dyn Context<'t>, queries: &'q Vec) -> heed::Result> { + fn new(ctx: &'t dyn Context<'t>, queries: &'q Vec, wdcache: &mut WordDerivationsCache) -> anyhow::Result> { let mut inner = Vec::with_capacity(queries.len()); for query in queries { - if let Some(word_level_iterator) = WordLevelIterator::new(ctx, query)? { - inner.push(word_level_iterator); + match &query.kind { + QueryKind::Exact { word, .. } => { + if !query.prefix || ctx.in_prefix_cache(&word) { + let word = Cow::Borrowed(query.kind.word()); + if let Some(word_level_iterator) = WordLevelIterator::new(ctx, word, query.prefix)? { + inner.push(word_level_iterator); + } + } else { + for (word, _) in word_derivations(&word, true, 0, ctx.words_fst(), wdcache)? { + let word = Cow::Owned(word.to_owned()); + if let Some(word_level_iterator) = WordLevelIterator::new(ctx, word, false)? { + inner.push(word_level_iterator); + } + } + } + }, + QueryKind::Tolerant { typo, word } => { + for (word, _) in word_derivations(&word, query.prefix, *typo, ctx.words_fst(), wdcache)? { + let word = Cow::Owned(word.to_owned()); + if let Some(word_level_iterator) = WordLevelIterator::new(ctx, word, false)? { + inner.push(word_level_iterator); + } + } + } } } @@ -346,13 +365,14 @@ impl<'t, 'q> Eq for Branch<'t, 'q> {} fn initialize_query_level_iterators<'t, 'q>( ctx: &'t dyn Context<'t>, branches: &'q Vec>>, -) -> heed::Result>> { + wdcache: &mut WordDerivationsCache, +) -> anyhow::Result>> { let mut positions = BinaryHeap::with_capacity(branches.len()); for branch in branches { let mut branch_positions = Vec::with_capacity(branch.len()); for query in branch { - match QueryLevelIterator::new(ctx, query)? { + match QueryLevelIterator::new(ctx, query, wdcache)? { Some(qli) => branch_positions.push(qli), None => { // the branch seems to be invalid, so we skip it. @@ -393,9 +413,10 @@ fn set_compute_candidates<'t>( ctx: &'t dyn Context<'t>, branches: &Vec>>, allowed_candidates: &RoaringBitmap, + wdcache: &mut WordDerivationsCache, ) -> anyhow::Result { - let mut branches_heap = initialize_query_level_iterators(ctx, branches)?; + let mut branches_heap = initialize_query_level_iterators(ctx, branches, wdcache)?; let lowest_level = TreeLevel::min_value(); while let Some(mut branch) = branches_heap.peek_mut() {