diff --git a/meilidb-core/src/automaton/mod.rs b/meilidb-core/src/automaton/mod.rs index 1bfb41e67..4b5fa0604 100644 --- a/meilidb-core/src/automaton/mod.rs +++ b/meilidb-core/src/automaton/mod.rs @@ -2,7 +2,7 @@ mod dfa; mod query_enhancer; use std::cmp::Reverse; -use std::vec; +use std::{cmp, vec}; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::DFA; @@ -18,7 +18,7 @@ use self::query_enhancer::QueryEnhancerBuilder; const NGRAMS: usize = 3; pub struct AutomatonProducer { - automatons: Vec>, + automatons: Vec, } impl AutomatonProducer { @@ -26,19 +26,26 @@ impl AutomatonProducer { reader: &heed::RoTxn, query: &str, main_store: store::Main, + postings_list_store: store::PostingsLists, synonyms_store: store::Synonyms, ) -> MResult<(AutomatonProducer, QueryEnhancer)> { let (automatons, query_enhancer) = - generate_automatons(reader, query, main_store, synonyms_store)?; + generate_automatons(reader, query, main_store, postings_list_store, synonyms_store)?; Ok((AutomatonProducer { automatons }, query_enhancer)) } - pub fn into_iter(self) -> vec::IntoIter> { + pub fn into_iter(self) -> vec::IntoIter { self.automatons.into_iter() } } +#[derive(Debug)] +pub enum AutomatonGroup { + Normal(Vec), + PhraseQuery(Vec), +} + #[derive(Debug)] pub struct Automaton { pub index: usize, @@ -102,12 +109,42 @@ pub fn normalize_str(string: &str) -> String { string } +fn split_best_frequency<'a>( + reader: &heed::RoTxn, + word: &'a str, + postings_lists_store: store::PostingsLists, +) -> MResult> { + let chars = word.char_indices().skip(1); + let mut best = None; + + for (i, _) in chars { + let (left, right) = word.split_at(i); + + let left_freq = postings_lists_store + .postings_list(reader, left.as_ref())? + .map_or(0, |i| i.len()); + + let right_freq = postings_lists_store + .postings_list(reader, right.as_ref())? + .map_or(0, |i| i.len()); + + let min_freq = cmp::min(left_freq, right_freq); + if min_freq != 0 && best.map_or(true, |(old, _, _)| min_freq > old) { + best = Some((min_freq, left, right)); + } + } + + Ok(best.map(|(_, l, r)| (l, r))) +} + fn generate_automatons( reader: &heed::RoTxn, query: &str, main_store: store::Main, + postings_lists_store: store::PostingsLists, synonym_store: store::Synonyms, -) -> MResult<(Vec>, QueryEnhancer)> { +) -> MResult<(Vec, QueryEnhancer)> +{ let has_end_whitespace = query.chars().last().map_or(false, char::is_whitespace); let query_words: Vec<_> = split_query_string(query).map(str::to_lowercase).collect(); let synonyms = match main_store.synonyms_fst(reader)? { @@ -136,7 +173,7 @@ fn generate_automatons( original_automatons.push(automaton); } - automatons.push(original_automatons); + automatons.push(AutomatonGroup::Normal(original_automatons)); for n in 1..=NGRAMS { let mut ngrams = query_words.windows(n).enumerate().peekable(); @@ -188,13 +225,25 @@ fn generate_automatons( Automaton::non_exact(automaton_index, n, synonym) }; automaton_index += 1; - automatons.push(vec![automaton]); + automatons.push(AutomatonGroup::Normal(vec![automaton])); } } } } - if n != 1 { + if n == 1 { + if let Some((left, right)) = split_best_frequency(reader, &normalized, postings_lists_store)? { + let a = Automaton::exact(automaton_index, 1, left); + enhancer_builder.declare(query_range.clone(), automaton_index, &[left]); + automaton_index += 1; + + let b = Automaton::exact(automaton_index, 1, right); + enhancer_builder.declare(query_range.clone(), automaton_index, &[left]); + automaton_index += 1; + + automatons.push(AutomatonGroup::PhraseQuery(vec![a, b])); + } + } else { // automaton of concatenation of query words let concat = ngram_slice.concat(); let normalized = normalize_str(&concat); @@ -204,15 +253,18 @@ fn generate_automatons( let automaton = Automaton::exact(automaton_index, n, &normalized); automaton_index += 1; - automatons.push(vec![automaton]); + automatons.push(AutomatonGroup::Normal(vec![automaton])); } } } // order automatons, the most important first, // we keep the original automatons at the front. - automatons[1..].sort_by_key(|a| { - let a = a.first().unwrap(); + automatons[1..].sort_by_key(|group| { + let a = match group { + AutomatonGroup::Normal(group) => group.first().unwrap(), + AutomatonGroup::PhraseQuery(group) => group.first().unwrap(), + }; (Reverse(a.is_exact), a.ngram) }); diff --git a/meilidb-core/src/query_builder.rs b/meilidb-core/src/query_builder.rs index 02c299a2a..21b28e663 100644 --- a/meilidb-core/src/query_builder.rs +++ b/meilidb-core/src/query_builder.rs @@ -8,7 +8,7 @@ use fst::{IntoStreamer, Streamer}; use sdset::SetBuf; use slice_group_by::{GroupBy, GroupByMut}; -use crate::automaton::{Automaton, AutomatonProducer, QueryEnhancer}; +use crate::automaton::{Automaton, AutomatonGroup, AutomatonProducer, QueryEnhancer}; use crate::distinct_map::{BufferedDistinctMap, DistinctMap}; use crate::raw_document::{raw_documents_from, RawDocument}; use crate::{criterion::Criteria, Document, DocumentId, Highlight, TmpMatch}; @@ -138,7 +138,7 @@ fn multiword_rewrite_matches( fn fetch_raw_documents( reader: &heed::RoTxn, - automatons: &[Automaton], + automatons_groups: &[AutomatonGroup], query_enhancer: &QueryEnhancer, searchables: Option<&ReorderedAttrs>, main_store: store::Main, @@ -148,52 +148,127 @@ fn fetch_raw_documents( let mut matches = Vec::new(); let mut highlights = Vec::new(); - for automaton in automatons { - let Automaton { - index, - is_exact, - query_len, - .. - } = automaton; - let dfa = automaton.dfa(); + for group in automatons_groups { + match group { + AutomatonGroup::Normal(automatons) => { + for automaton in automatons { + let Automaton { index, is_exact, query_len, .. } = automaton; + let dfa = automaton.dfa(); - let words = match main_store.words_fst(reader)? { - Some(words) => words, - None => return Ok(Vec::new()), - }; - - let mut stream = words.search(&dfa).into_stream(); - while let Some(input) = stream.next() { - let distance = dfa.eval(input).to_u8(); - let is_exact = *is_exact && distance == 0 && input.len() == *query_len; - - let doc_indexes = match postings_lists_store.postings_list(reader, input)? { - Some(doc_indexes) => doc_indexes, - None => continue, - }; - - matches.reserve(doc_indexes.len()); - highlights.reserve(doc_indexes.len()); - - for di in doc_indexes.as_ref() { - let attribute = searchables.map_or(Some(di.attribute), |r| r.get(di.attribute)); - if let Some(attribute) = attribute { - let match_ = TmpMatch { - query_index: *index as u32, - distance, - attribute, - word_index: di.word_index, - is_exact, + let words = match main_store.words_fst(reader)? { + Some(words) => words, + None => return Ok(Vec::new()), }; - let highlight = Highlight { - attribute: di.attribute, - char_index: di.char_index, - char_length: di.char_length, + let mut stream = words.search(&dfa).into_stream(); + while let Some(input) = stream.next() { + let distance = dfa.eval(input).to_u8(); + let is_exact = *is_exact && distance == 0 && input.len() == *query_len; + + let doc_indexes = match postings_lists_store.postings_list(reader, input)? { + Some(doc_indexes) => doc_indexes, + None => continue, + }; + + matches.reserve(doc_indexes.len()); + highlights.reserve(doc_indexes.len()); + + for di in doc_indexes.as_ref() { + let attribute = searchables.map_or(Some(di.attribute), |r| r.get(di.attribute)); + if let Some(attribute) = attribute { + let match_ = TmpMatch { + query_index: *index as u32, + distance, + attribute, + word_index: di.word_index, + is_exact, + }; + + let highlight = Highlight { + attribute: di.attribute, + char_index: di.char_index, + char_length: di.char_length, + }; + + matches.push((di.document_id, match_)); + highlights.push((di.document_id, highlight)); + } + } + } + } + }, + AutomatonGroup::PhraseQuery(automatons) => { + let mut tmp_matches = Vec::new(); + let phrase_query_len = automatons.len(); + + for (id, automaton) in automatons.into_iter().enumerate() { + let Automaton { index, is_exact, query_len, .. } = automaton; + let dfa = automaton.dfa(); + + let words = match main_store.words_fst(reader)? { + Some(words) => words, + None => return Ok(Vec::new()), }; - matches.push((di.document_id, match_)); - highlights.push((di.document_id, highlight)); + let mut stream = words.search(&dfa).into_stream(); + while let Some(input) = stream.next() { + let distance = dfa.eval(input).to_u8(); + let is_exact = *is_exact && distance == 0 && input.len() == *query_len; + + let doc_indexes = match postings_lists_store.postings_list(reader, input)? { + Some(doc_indexes) => doc_indexes, + None => continue, + }; + + tmp_matches.reserve(doc_indexes.len()); + + for di in doc_indexes.as_ref() { + let attribute = searchables.map_or(Some(di.attribute), |r| r.get(di.attribute)); + if let Some(attribute) = attribute { + let match_ = TmpMatch { + query_index: *index as u32, + distance, + attribute, + word_index: di.word_index, + is_exact, + }; + + let highlight = Highlight { + attribute: di.attribute, + char_index: di.char_index, + char_length: di.char_length, + }; + + tmp_matches.push((di.document_id, id, match_, highlight)); + } + } + } + } + + tmp_matches.sort_unstable_by_key(|(id, _, m, _)| (*id, m.attribute, m.word_index)); + for group in tmp_matches.linear_group_by_key(|(id, _, m, _)| (*id, m.attribute)) { + for window in group.windows(2) { + let (ida, ia, ma, ha) = window[0]; + let (idb, ib, mb, hb) = window[1]; + + debug_assert_eq!(ida, idb); + + // if matches must follow and actually follows themselves + if ia + 1 == ib && ma.word_index + 1 == mb.word_index { + + // TODO we must make it work for phrase query longer than 2 + // if the second match is the last phrase query word + if ib + 1 == phrase_query_len { + // insert first match + matches.push((ida, ma)); + highlights.push((ida, ha)); + + // insert second match + matches.push((idb, mb)); + highlights.push((idb, hb)); + } + } + } } } } @@ -368,14 +443,14 @@ where let mut raw_documents_processed = Vec::with_capacity(range.len()); let (automaton_producer, query_enhancer) = - AutomatonProducer::new(reader, query, main_store, synonyms_store)?; + AutomatonProducer::new(reader, query, main_store, postings_lists_store, synonyms_store)?; let automaton_producer = automaton_producer.into_iter(); let mut automatons = Vec::new(); // aggregate automatons groups by groups after time for auts in automaton_producer { - automatons.extend(auts); + automatons.push(auts); // we must retrieve the documents associated // with the current automatons @@ -481,14 +556,14 @@ where let mut raw_documents_processed = Vec::new(); let (automaton_producer, query_enhancer) = - AutomatonProducer::new(reader, query, main_store, synonyms_store)?; + AutomatonProducer::new(reader, query, main_store, postings_lists_store, synonyms_store)?; let automaton_producer = automaton_producer.into_iter(); let mut automatons = Vec::new(); // aggregate automatons groups by groups after time for auts in automaton_producer { - automatons.extend(auts); + automatons.push(auts); // we must retrieve the documents associated // with the current automatons @@ -1697,4 +1772,71 @@ mod tests { }); assert_matches!(iter.next(), None); } + + #[test] + fn simple_phrase_query_splitting() { + let store = TempDatabase::from_iter(vec![ + ("search", &[doc_index(0, 0)][..]), + ("engine", &[doc_index(0, 1)][..]), + + ("search", &[doc_index(1, 0)][..]), + ("slow", &[doc_index(1, 1)][..]), + ("engine", &[doc_index(1, 2)][..]), + ]); + + let env = &store.database.env; + let reader = env.read_txn().unwrap(); + + let builder = store.query_builder(); + let results = builder.query(&reader, "searchengine", 0..20).unwrap(); + let mut iter = results.into_iter(); + + assert_matches!(iter.next(), Some(Document { id: DocumentId(0), matches, .. }) => { + let mut iter = matches.into_iter(); + assert_matches!(iter.next(), Some(TmpMatch { query_index: 0, word_index: 0, distance: 0, .. })); // search + assert_matches!(iter.next(), Some(TmpMatch { query_index: 0, word_index: 1, distance: 0, .. })); // engine + assert_matches!(iter.next(), None); + }); + assert_matches!(iter.next(), None); + } + + #[test] + fn harder_phrase_query_splitting() { + let store = TempDatabase::from_iter(vec![ + ("search", &[doc_index(0, 0)][..]), + ("search", &[doc_index(0, 1)][..]), + ("engine", &[doc_index(0, 2)][..]), + + ("search", &[doc_index(1, 0)][..]), + ("slow", &[doc_index(1, 1)][..]), + ("search", &[doc_index(1, 2)][..]), + ("engine", &[doc_index(1, 3)][..]), + + ("search", &[doc_index(1, 0)][..]), + ("search", &[doc_index(1, 1)][..]), + ("slow", &[doc_index(1, 2)][..]), + ("engine", &[doc_index(1, 3)][..]), + ]); + + let env = &store.database.env; + let reader = env.read_txn().unwrap(); + + let builder = store.query_builder(); + let results = builder.query(&reader, "searchengine", 0..20).unwrap(); + let mut iter = results.into_iter(); + + assert_matches!(iter.next(), Some(Document { id: DocumentId(0), matches, .. }) => { + let mut iter = matches.into_iter(); + assert_matches!(iter.next(), Some(TmpMatch { query_index: 0, word_index: 1, distance: 0, .. })); // search + assert_matches!(iter.next(), Some(TmpMatch { query_index: 0, word_index: 2, distance: 0, .. })); // engine + assert_matches!(iter.next(), None); + }); + assert_matches!(iter.next(), Some(Document { id: DocumentId(1), matches, .. }) => { + let mut iter = matches.into_iter(); + assert_matches!(iter.next(), Some(TmpMatch { query_index: 0, word_index: 2, distance: 0, .. })); // search + assert_matches!(iter.next(), Some(TmpMatch { query_index: 0, word_index: 3, distance: 0, .. })); // engine + assert_matches!(iter.next(), None); + }); + assert_matches!(iter.next(), None); + } }