diff --git a/src/search.rs b/src/search/facet.rs similarity index 55% rename from src/search.rs rename to src/search/facet.rs index 5cd998ffe..22352ab48 100644 --- a/src/search.rs +++ b/src/search/facet.rs @@ -1,33 +1,21 @@ -use std::borrow::Cow; -use std::collections::{HashMap, HashSet}; use std::error::Error as StdError; -use std::fmt::{self, Debug}; +use std::fmt::Debug; use std::ops::Bound::{self, Unbounded, Included, Excluded}; use std::str::FromStr; use anyhow::{bail, ensure, Context}; -use fst::{IntoStreamer, Streamer}; use heed::types::{ByteSlice, DecodeIgnore}; -use levenshtein_automata::DFA; -use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; use log::debug; use num_traits::Bounded; -use once_cell::sync::Lazy; -use roaring::bitmap::RoaringBitmap; +use roaring::RoaringBitmap; use crate::facet::FacetType; use crate::heed_codec::facet::{FacetLevelValueI64Codec, FacetLevelValueF64Codec}; -use crate::heed_codec::CboRoaringBitmapCodec; -use crate::mdfs::Mdfs; -use crate::query_tokens::{QueryTokens, QueryToken}; -use crate::{Index, DocumentId}; +use crate::{Index, CboRoaringBitmapCodec}; -// Building these factories is not free. -static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); -static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); -static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); +use self::FacetCondition::*; +use self::FacetOperator::*; -// TODO support also floats #[derive(Debug, Copy, Clone, PartialEq)] pub enum FacetOperator { GreaterThan(T), @@ -52,8 +40,6 @@ impl FacetCondition { string: &str, ) -> anyhow::Result> { - use FacetCondition::*; - let fields_ids_map = index.fields_ids_map(rtxn)?; let faceted_fields = index.faceted_fields(rtxn)?; @@ -80,8 +66,6 @@ impl FacetCondition { ) -> anyhow::Result> where T::Err: Send + Sync + StdError + 'static, { - use FacetOperator::*; - match iter.next() { Some(">") => { let param = iter.next().context("missing parameter")?; @@ -228,8 +212,6 @@ impl FacetCondition { KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, { - use FacetOperator::*; - // Make sure we always bound the ranges with the field id and the level, // as the facets values are all in the same database and prefixed by the // field id and the level. @@ -259,7 +241,7 @@ impl FacetCondition { } } - fn evaluate( + pub fn evaluate( &self, rtxn: &heed::RoTxn, db: heed::Database, @@ -275,208 +257,3 @@ impl FacetCondition { } } } - -pub struct Search<'a> { - query: Option, - facet_condition: Option, - offset: usize, - limit: usize, - rtxn: &'a heed::RoTxn<'a>, - index: &'a Index, -} - -impl<'a> Search<'a> { - pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { - Search { query: None, facet_condition: None, offset: 0, limit: 20, rtxn, index } - } - - pub fn query(&mut self, query: impl Into) -> &mut Search<'a> { - self.query = Some(query.into()); - self - } - - pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { - self.offset = offset; - self - } - - pub fn limit(&mut self, limit: usize) -> &mut Search<'a> { - self.limit = limit; - self - } - - pub fn facet_condition(&mut self, condition: FacetCondition) -> &mut Search<'a> { - self.facet_condition = Some(condition); - self - } - - /// Extracts the query words from the query string and returns the DFAs accordingly. - /// TODO introduce settings for the number of typos regarding the words lengths. - fn generate_query_dfas(query: &str) -> Vec<(String, bool, DFA)> { - let (lev0, lev1, lev2) = (&LEVDIST0, &LEVDIST1, &LEVDIST2); - - let words: Vec<_> = QueryTokens::new(query).collect(); - let ends_with_whitespace = query.chars().last().map_or(false, char::is_whitespace); - let number_of_words = words.len(); - - words.into_iter().enumerate().map(|(i, word)| { - let (word, quoted) = match word { - QueryToken::Free(word) => (word.to_lowercase(), word.len() <= 3), - QueryToken::Quoted(word) => (word.to_lowercase(), true), - }; - let is_last = i + 1 == number_of_words; - let is_prefix = is_last && !ends_with_whitespace && !quoted; - let lev = match word.len() { - 0..=4 => if quoted { lev0 } else { lev0 }, - 5..=8 => if quoted { lev0 } else { lev1 }, - _ => if quoted { lev0 } else { lev2 }, - }; - - let dfa = if is_prefix { - lev.build_prefix_dfa(&word) - } else { - lev.build_dfa(&word) - }; - - (word, is_prefix, dfa) - }) - .collect() - } - - /// Fetch the words from the given FST related to the given DFAs along with - /// the associated documents ids. - fn fetch_words_docids( - &self, - fst: &fst::Set>, - dfas: Vec<(String, bool, DFA)>, - ) -> anyhow::Result, RoaringBitmap)>> - { - // A Vec storing all the derived words from the original query words, associated - // with the distance from the original word and the docids where the words appears. - let mut derived_words = Vec::<(HashMap::, RoaringBitmap)>::with_capacity(dfas.len()); - - for (_word, _is_prefix, dfa) in dfas { - - let mut acc_derived_words = HashMap::new(); - let mut unions_docids = RoaringBitmap::new(); - let mut stream = fst.search_with_state(&dfa).into_stream(); - while let Some((word, state)) = stream.next() { - - let word = std::str::from_utf8(word)?; - let docids = self.index.word_docids.get(self.rtxn, word)?.unwrap(); - let distance = dfa.distance(state); - unions_docids.union_with(&docids); - acc_derived_words.insert(word.to_string(), (distance.to_u8(), docids)); - } - derived_words.push((acc_derived_words, unions_docids)); - } - - Ok(derived_words) - } - - /// Returns the set of docids that contains all of the query words. - fn compute_candidates( - derived_words: &[(HashMap, RoaringBitmap)], - ) -> RoaringBitmap - { - // We sort the derived words by inverse popularity, this way intersections are faster. - let mut derived_words: Vec<_> = derived_words.iter().collect(); - derived_words.sort_unstable_by_key(|(_, docids)| docids.len()); - - // we do a union between all the docids of each of the derived words, - // we got N unions (the number of original query words), we then intersect them. - let mut candidates = RoaringBitmap::new(); - - for (i, (_, union_docids)) in derived_words.iter().enumerate() { - if i == 0 { - candidates = union_docids.clone(); - } else { - candidates.intersect_with(&union_docids); - } - } - - candidates - } - - pub fn execute(&self) -> anyhow::Result { - let limit = self.limit; - let fst = self.index.words_fst(self.rtxn)?; - - // Construct the DFAs related to the query words. - let derived_words = match self.query.as_deref().map(Self::generate_query_dfas) { - Some(dfas) if !dfas.is_empty() => Some(self.fetch_words_docids(&fst, dfas)?), - _otherwise => None, - }; - - // We create the original candidates with the facet conditions results. - let facet_db = self.index.facet_field_id_value_docids; - let facet_candidates = match self.facet_condition { - Some(condition) => Some(condition.evaluate(self.rtxn, facet_db)?), - None => None, - }; - - debug!("facet candidates: {:?}", facet_candidates); - - let (candidates, derived_words) = match (facet_candidates, derived_words) { - (Some(mut facet_candidates), Some(derived_words)) => { - let words_candidates = Self::compute_candidates(&derived_words); - facet_candidates.intersect_with(&words_candidates); - (facet_candidates, derived_words) - }, - (None, Some(derived_words)) => { - (Self::compute_candidates(&derived_words), derived_words) - }, - (Some(facet_candidates), None) => { - // If the query is not set or results in no DFAs but - // there is some facet conditions we return a placeholder. - let documents_ids = facet_candidates.iter().take(limit).collect(); - return Ok(SearchResult { documents_ids, ..Default::default() }) - }, - (None, None) => { - // If the query is not set or results in no DFAs we return a placeholder. - let documents_ids = self.index.documents_ids(self.rtxn)?.iter().take(limit).collect(); - return Ok(SearchResult { documents_ids, ..Default::default() }) - }, - }; - - debug!("candidates: {:?}", candidates); - - // The mana depth first search is a revised DFS that explore - // solutions in the order of their proximities. - let mut mdfs = Mdfs::new(self.index, self.rtxn, &derived_words, candidates); - let mut documents = Vec::new(); - - // We execute the Mdfs iterator until we find enough documents. - while documents.iter().map(RoaringBitmap::len).sum::() < limit as u64 { - match mdfs.next().transpose()? { - Some((proximity, answer)) => { - debug!("answer with a proximity of {}: {:?}", proximity, answer); - documents.push(answer); - }, - None => break, - } - } - - let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect(); - let documents_ids = documents.into_iter().flatten().take(limit).collect(); - Ok(SearchResult { found_words, documents_ids }) - } -} - -impl fmt::Debug for Search<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Search") - .field("query", &self.query) - .field("facet_condition", &self.facet_condition) - .field("offset", &self.offset) - .field("limit", &self.limit) - .finish() - } -} - -#[derive(Default)] -pub struct SearchResult { - pub found_words: HashSet, - // TODO those documents ids should be associated with their criteria scores. - pub documents_ids: Vec, -} diff --git a/src/search/mod.rs b/src/search/mod.rs new file mode 100644 index 000000000..8ee8461a8 --- /dev/null +++ b/src/search/mod.rs @@ -0,0 +1,228 @@ +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; +use std::fmt; + +use fst::{IntoStreamer, Streamer}; +use levenshtein_automata::DFA; +use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; +use log::debug; +use once_cell::sync::Lazy; +use roaring::bitmap::RoaringBitmap; + +use crate::mdfs::Mdfs; +use crate::query_tokens::{QueryTokens, QueryToken}; +use crate::{Index, DocumentId}; + +pub use self::facet::FacetCondition; + +// Building these factories is not free. +static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); +static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); +static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); + +mod facet; + +pub struct Search<'a> { + query: Option, + facet_condition: Option, + offset: usize, + limit: usize, + rtxn: &'a heed::RoTxn<'a>, + index: &'a Index, +} + +impl<'a> Search<'a> { + pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { + Search { query: None, facet_condition: None, offset: 0, limit: 20, rtxn, index } + } + + pub fn query(&mut self, query: impl Into) -> &mut Search<'a> { + self.query = Some(query.into()); + self + } + + pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { + self.offset = offset; + self + } + + pub fn limit(&mut self, limit: usize) -> &mut Search<'a> { + self.limit = limit; + self + } + + pub fn facet_condition(&mut self, condition: FacetCondition) -> &mut Search<'a> { + self.facet_condition = Some(condition); + self + } + + /// Extracts the query words from the query string and returns the DFAs accordingly. + /// TODO introduce settings for the number of typos regarding the words lengths. + fn generate_query_dfas(query: &str) -> Vec<(String, bool, DFA)> { + let (lev0, lev1, lev2) = (&LEVDIST0, &LEVDIST1, &LEVDIST2); + + let words: Vec<_> = QueryTokens::new(query).collect(); + let ends_with_whitespace = query.chars().last().map_or(false, char::is_whitespace); + let number_of_words = words.len(); + + words.into_iter().enumerate().map(|(i, word)| { + let (word, quoted) = match word { + QueryToken::Free(word) => (word.to_lowercase(), word.len() <= 3), + QueryToken::Quoted(word) => (word.to_lowercase(), true), + }; + let is_last = i + 1 == number_of_words; + let is_prefix = is_last && !ends_with_whitespace && !quoted; + let lev = match word.len() { + 0..=4 => if quoted { lev0 } else { lev0 }, + 5..=8 => if quoted { lev0 } else { lev1 }, + _ => if quoted { lev0 } else { lev2 }, + }; + + let dfa = if is_prefix { + lev.build_prefix_dfa(&word) + } else { + lev.build_dfa(&word) + }; + + (word, is_prefix, dfa) + }) + .collect() + } + + /// Fetch the words from the given FST related to the given DFAs along with + /// the associated documents ids. + fn fetch_words_docids( + &self, + fst: &fst::Set>, + dfas: Vec<(String, bool, DFA)>, + ) -> anyhow::Result, RoaringBitmap)>> + { + // A Vec storing all the derived words from the original query words, associated + // with the distance from the original word and the docids where the words appears. + let mut derived_words = Vec::<(HashMap::, RoaringBitmap)>::with_capacity(dfas.len()); + + for (_word, _is_prefix, dfa) in dfas { + + let mut acc_derived_words = HashMap::new(); + let mut unions_docids = RoaringBitmap::new(); + let mut stream = fst.search_with_state(&dfa).into_stream(); + while let Some((word, state)) = stream.next() { + + let word = std::str::from_utf8(word)?; + let docids = self.index.word_docids.get(self.rtxn, word)?.unwrap(); + let distance = dfa.distance(state); + unions_docids.union_with(&docids); + acc_derived_words.insert(word.to_string(), (distance.to_u8(), docids)); + } + derived_words.push((acc_derived_words, unions_docids)); + } + + Ok(derived_words) + } + + /// Returns the set of docids that contains all of the query words. + fn compute_candidates( + derived_words: &[(HashMap, RoaringBitmap)], + ) -> RoaringBitmap + { + // We sort the derived words by inverse popularity, this way intersections are faster. + let mut derived_words: Vec<_> = derived_words.iter().collect(); + derived_words.sort_unstable_by_key(|(_, docids)| docids.len()); + + // we do a union between all the docids of each of the derived words, + // we got N unions (the number of original query words), we then intersect them. + let mut candidates = RoaringBitmap::new(); + + for (i, (_, union_docids)) in derived_words.iter().enumerate() { + if i == 0 { + candidates = union_docids.clone(); + } else { + candidates.intersect_with(&union_docids); + } + } + + candidates + } + + pub fn execute(&self) -> anyhow::Result { + let limit = self.limit; + let fst = self.index.words_fst(self.rtxn)?; + + // Construct the DFAs related to the query words. + let derived_words = match self.query.as_deref().map(Self::generate_query_dfas) { + Some(dfas) if !dfas.is_empty() => Some(self.fetch_words_docids(&fst, dfas)?), + _otherwise => None, + }; + + // We create the original candidates with the facet conditions results. + let facet_db = self.index.facet_field_id_value_docids; + let facet_candidates = match self.facet_condition { + Some(condition) => Some(condition.evaluate(self.rtxn, facet_db)?), + None => None, + }; + + debug!("facet candidates: {:?}", facet_candidates); + + let (candidates, derived_words) = match (facet_candidates, derived_words) { + (Some(mut facet_candidates), Some(derived_words)) => { + let words_candidates = Self::compute_candidates(&derived_words); + facet_candidates.intersect_with(&words_candidates); + (facet_candidates, derived_words) + }, + (None, Some(derived_words)) => { + (Self::compute_candidates(&derived_words), derived_words) + }, + (Some(facet_candidates), None) => { + // If the query is not set or results in no DFAs but + // there is some facet conditions we return a placeholder. + let documents_ids = facet_candidates.iter().take(limit).collect(); + return Ok(SearchResult { documents_ids, ..Default::default() }) + }, + (None, None) => { + // If the query is not set or results in no DFAs we return a placeholder. + let documents_ids = self.index.documents_ids(self.rtxn)?.iter().take(limit).collect(); + return Ok(SearchResult { documents_ids, ..Default::default() }) + }, + }; + + debug!("candidates: {:?}", candidates); + + // The mana depth first search is a revised DFS that explore + // solutions in the order of their proximities. + let mut mdfs = Mdfs::new(self.index, self.rtxn, &derived_words, candidates); + let mut documents = Vec::new(); + + // We execute the Mdfs iterator until we find enough documents. + while documents.iter().map(RoaringBitmap::len).sum::() < limit as u64 { + match mdfs.next().transpose()? { + Some((proximity, answer)) => { + debug!("answer with a proximity of {}: {:?}", proximity, answer); + documents.push(answer); + }, + None => break, + } + } + + let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect(); + let documents_ids = documents.into_iter().flatten().take(limit).collect(); + Ok(SearchResult { found_words, documents_ids }) + } +} + +impl fmt::Debug for Search<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Search") + .field("query", &self.query) + .field("facet_condition", &self.facet_condition) + .field("offset", &self.offset) + .field("limit", &self.limit) + .finish() + } +} + +#[derive(Default)] +pub struct SearchResult { + pub found_words: HashSet, + // TODO those documents ids should be associated with their criteria scores. + pub documents_ids: Vec, +}