diff --git a/milli/src/search/fst_utils.rs b/milli/src/search/fst_utils.rs new file mode 100644 index 000000000..b488e6c19 --- /dev/null +++ b/milli/src/search/fst_utils.rs @@ -0,0 +1,187 @@ +/// This mod is necessary until https://github.com/BurntSushi/fst/pull/137 gets merged. +/// All credits for this code go to BurntSushi. +use fst::Automaton; + +pub struct StartsWith(pub A); + +/// The `Automaton` state for `StartsWith`. +pub struct StartsWithState(pub StartsWithStateKind); + +impl Clone for StartsWithState +where + A::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +/// The inner state of a `StartsWithState`. +pub enum StartsWithStateKind { + /// Sink state that is reached when the automaton has matched the prefix. + Done, + /// State in which the automaton is while it hasn't matched the prefix. + Running(A::State), +} + +impl Clone for StartsWithStateKind +where + A::State: Clone, +{ + fn clone(&self) -> Self { + match self { + StartsWithStateKind::Done => StartsWithStateKind::Done, + StartsWithStateKind::Running(inner) => StartsWithStateKind::Running(inner.clone()), + } + } +} + +impl Automaton for StartsWith { + type State = StartsWithState; + + fn start(&self) -> StartsWithState { + StartsWithState({ + let inner = self.0.start(); + if self.0.is_match(&inner) { + StartsWithStateKind::Done + } else { + StartsWithStateKind::Running(inner) + } + }) + } + fn is_match(&self, state: &StartsWithState) -> bool { + match state.0 { + StartsWithStateKind::Done => true, + StartsWithStateKind::Running(_) => false, + } + } + fn can_match(&self, state: &StartsWithState) -> bool { + match state.0 { + StartsWithStateKind::Done => true, + StartsWithStateKind::Running(ref inner) => self.0.can_match(inner), + } + } + fn will_always_match(&self, state: &StartsWithState) -> bool { + match state.0 { + StartsWithStateKind::Done => true, + StartsWithStateKind::Running(_) => false, + } + } + fn accept(&self, state: &StartsWithState, byte: u8) -> StartsWithState { + StartsWithState(match state.0 { + StartsWithStateKind::Done => StartsWithStateKind::Done, + StartsWithStateKind::Running(ref inner) => { + let next_inner = self.0.accept(inner, byte); + if self.0.is_match(&next_inner) { + StartsWithStateKind::Done + } else { + StartsWithStateKind::Running(next_inner) + } + } + }) + } +} +/// An automaton that matches when one of its component automata match. +#[derive(Clone, Debug)] +pub struct Union(pub A, pub B); + +/// The `Automaton` state for `Union`. +pub struct UnionState(pub A::State, pub B::State); + +impl Clone for UnionState +where + A::State: Clone, + B::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } +} + +impl Automaton for Union { + type State = UnionState; + fn start(&self) -> UnionState { + UnionState(self.0.start(), self.1.start()) + } + fn is_match(&self, state: &UnionState) -> bool { + self.0.is_match(&state.0) || self.1.is_match(&state.1) + } + fn can_match(&self, state: &UnionState) -> bool { + self.0.can_match(&state.0) || self.1.can_match(&state.1) + } + fn will_always_match(&self, state: &UnionState) -> bool { + self.0.will_always_match(&state.0) || self.1.will_always_match(&state.1) + } + fn accept(&self, state: &UnionState, byte: u8) -> UnionState { + UnionState(self.0.accept(&state.0, byte), self.1.accept(&state.1, byte)) + } +} +/// An automaton that matches when both of its component automata match. +#[derive(Clone, Debug)] +pub struct Intersection(pub A, pub B); + +/// The `Automaton` state for `Intersection`. +pub struct IntersectionState(pub A::State, pub B::State); + +impl Clone for IntersectionState +where + A::State: Clone, + B::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } +} + +impl Automaton for Intersection { + type State = IntersectionState; + fn start(&self) -> IntersectionState { + IntersectionState(self.0.start(), self.1.start()) + } + fn is_match(&self, state: &IntersectionState) -> bool { + self.0.is_match(&state.0) && self.1.is_match(&state.1) + } + fn can_match(&self, state: &IntersectionState) -> bool { + self.0.can_match(&state.0) && self.1.can_match(&state.1) + } + fn will_always_match(&self, state: &IntersectionState) -> bool { + self.0.will_always_match(&state.0) && self.1.will_always_match(&state.1) + } + fn accept(&self, state: &IntersectionState, byte: u8) -> IntersectionState { + IntersectionState(self.0.accept(&state.0, byte), self.1.accept(&state.1, byte)) + } +} +/// An automaton that matches exactly when the automaton it wraps does not. +#[derive(Clone, Debug)] +pub struct Complement(pub A); + +/// The `Automaton` state for `Complement`. +pub struct ComplementState(pub A::State); + +impl Clone for ComplementState +where + A::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Automaton for Complement { + type State = ComplementState; + fn start(&self) -> ComplementState { + ComplementState(self.0.start()) + } + fn is_match(&self, state: &ComplementState) -> bool { + !self.0.is_match(&state.0) + } + fn can_match(&self, state: &ComplementState) -> bool { + !self.0.will_always_match(&state.0) + } + fn will_always_match(&self, state: &ComplementState) -> bool { + !self.0.can_match(&state.0) + } + fn accept(&self, state: &ComplementState, byte: u8) -> ComplementState { + ComplementState(self.0.accept(&state.0, byte)) + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 7c8722187..40e4bca24 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -7,7 +7,8 @@ use std::str::Utf8Error; use std::time::Instant; use distinct::{Distinct, DocIter, FacetDistinct, NoopDistinct}; -use fst::{IntoStreamer, Streamer}; +use fst::automaton::Str; +use fst::{Automaton, IntoStreamer, Streamer}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use log::debug; use meilisearch_tokenizer::{Analyzer, AnalyzerConfig}; @@ -15,6 +16,7 @@ use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; pub use self::facet::{FacetDistribution, FacetNumberIter, Filter}; +use self::fst_utils::{Complement, Intersection, StartsWith, Union}; pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; use crate::error::UserError; @@ -29,6 +31,7 @@ static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); mod criteria; mod distinct; mod facet; +mod fst_utils; mod matching_words; mod query_tree; @@ -284,20 +287,66 @@ pub fn word_derivations<'c>( Entry::Occupied(entry) => Ok(entry.into_mut()), Entry::Vacant(entry) => { let mut derived_words = Vec::new(); - let dfa = build_dfa(word, max_typo, is_prefix); - let mut stream = fst.search_with_state(&dfa).into_stream(); + if max_typo == 0 { + if is_prefix { + let prefix = Str::new(word).starts_with(); + let mut stream = fst.search(prefix).into_stream(); - while let Some((word, state)) = stream.next() { - let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state); - derived_words.push((word.to_string(), distance.to_u8())); + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + derived_words.push((word.to_string(), 0)); + } + } else if fst.contains(word) { + derived_words.push((word.to_string(), 0)); + } + } else { + if max_typo == 1 { + let dfa = build_dfa(word, 1, is_prefix); + let starts = StartsWith(Str::new(get_first(word))); + let mut stream = + fst.search_with_state(Intersection(starts, &dfa)).into_stream(); + + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let d = dfa.distance(state.1); + derived_words.push((word.to_string(), d.to_u8())); + } + } else { + let starts = StartsWith(Str::new(get_first(word))); + let first = Intersection(build_dfa(word, 1, is_prefix), Complement(&starts)); + let second_dfa = build_dfa(word, 2, is_prefix); + let second = Intersection(&second_dfa, &starts); + let automaton = Union(first, &second); + + let mut stream = fst.search_with_state(automaton).into_stream(); + + while let Some((found_word, state)) = stream.next() { + let found_word = std::str::from_utf8(found_word)?; + // in the case the typo is on the first letter, we know the number of typo + // is two + if get_first(found_word) != get_first(word) { + derived_words.push((word.to_string(), 2)); + } else { + // Else, we know that it is the second dfa that matched and compute the + // correct distance + let d = second_dfa.distance((state.1).0); + derived_words.push((word.to_string(), d.to_u8())); + } + } + } } - Ok(entry.insert(derived_words)) } } } +fn get_first(s: &str) -> &str { + match s.chars().next() { + Some(c) => &s[..c.len_utf8()], + None => panic!("unexpected empty query"), + } +} + pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA { let lev = match typos { 0 => &LEVDIST0, diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs index 237bb9be2..f3ee99d9e 100644 --- a/milli/src/search/query_tree.rs +++ b/milli/src/search/query_tree.rs @@ -260,12 +260,12 @@ fn split_best_frequency(ctx: &impl Context, word: &str) -> heed::Result QueryKind { +fn typos(word: String, authorize_typos: bool, max_typos: u8) -> QueryKind { if authorize_typos { match word.chars().count() { 0..=4 => QueryKind::exact(word), - 5..=8 => QueryKind::tolerant(1, word), - _ => QueryKind::tolerant(2, word), + 5..=8 => QueryKind::tolerant(1.min(max_typos), word), + _ => QueryKind::tolerant(2.min(max_typos), word), } } else { QueryKind::exact(word) @@ -316,8 +316,10 @@ fn create_query_tree( if let Some(child) = split_best_frequency(ctx, &word)? { children.push(child); } - children - .push(Operation::Query(Query { prefix, kind: typos(word, authorize_typos) })); + children.push(Operation::Query(Query { + prefix, + kind: typos(word, authorize_typos, 2), + })); Ok(Operation::or(false, children)) } // create a CONSECUTIVE operation wrapping all word in the phrase @@ -363,8 +365,10 @@ fn create_query_tree( .collect(); let mut operations = synonyms(ctx, &words)?.unwrap_or_default(); let concat = words.concat(); - let query = - Query { prefix: is_prefix, kind: typos(concat, authorize_typos) }; + let query = Query { + prefix: is_prefix, + kind: typos(concat, authorize_typos, 1), + }; operations.push(Operation::Query(query)); and_op_children.push(Operation::or(false, operations)); } @@ -655,7 +659,7 @@ mod test { ]), Operation::Query(Query { prefix: true, - kind: QueryKind::tolerant(2, "heyfriends".to_string()), + kind: QueryKind::tolerant(1, "heyfriends".to_string()), }), ], ); @@ -688,7 +692,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "heyfriends".to_string()), + kind: QueryKind::tolerant(1, "heyfriends".to_string()), }), ], ); @@ -753,7 +757,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "helloworld".to_string()), + kind: QueryKind::tolerant(1, "helloworld".to_string()), }), ], ); @@ -851,7 +855,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "newyorkcity".to_string()), + kind: QueryKind::tolerant(1, "newyorkcity".to_string()), }), ], ), @@ -925,7 +929,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "wordsplitfish".to_string()), + kind: QueryKind::tolerant(1, "wordsplitfish".to_string()), }), ], ); @@ -1045,7 +1049,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "heymyfriend".to_string()), + kind: QueryKind::tolerant(1, "heymyfriend".to_string()), }), ], ), diff --git a/milli/tests/assets/test_set.ndjson b/milli/tests/assets/test_set.ndjson index 9a0fe5b0a..6383d274e 100644 --- a/milli/tests/assets/test_set.ndjson +++ b/milli/tests/assets/test_set.ndjson @@ -8,7 +8,7 @@ {"id":"H","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":4,"sort_by_rank":1,"geo_rank":202182,"title":"world hello day","description":"holiday observed on november 21 to express that conflicts should be resolved through communication rather than the use of force","tag":"green","_geo": { "lat": 48.875617484531965, "lng": 2.346747821504194 },"":""} {"id":"I","word_rank":0,"typo_rank":0,"proximity_rank":8,"attribute_rank":338,"exact_rank":3,"asc_desc_rank":3,"sort_by_rank":0,"geo_rank":740667,"title":"hello world song","description":"hello world is a song written by tom douglas tony lane and david lee and recorded by american country music group lady antebellum","tag":"blue","_geo": { "lat": 43.973998070351065, "lng": 3.4661837318345032 },"":""} {"id":"J","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":1,"exact_rank":3,"asc_desc_rank":2,"sort_by_rank":1,"geo_rank":739020,"title":"hello cruel world","description":"hello cruel world is an album by new zealand band tall dwarfs","tag":"green","_geo": { "lat": 43.98920130353838, "lng": 3.480519311627928 },"":""} -{"id":"K","word_rank":0,"typo_rank":2,"proximity_rank":9,"attribute_rank":670,"exact_rank":5,"asc_desc_rank":1,"sort_by_rank":2,"geo_rank":738830,"title":"ello creation system","description":"in few word ello was a construction toy created by the american company mattel to engage girls in construction play","tag":"red","_geo": { "lat": 43.99155030238669, "lng": 3.503453528249425 },"":""} +{"id":"K","word_rank":0,"typo_rank":2,"proximity_rank":9,"attribute_rank":670,"exact_rank":5,"asc_desc_rank":1,"sort_by_rank":2,"geo_rank":738830,"title":"hallo creation system","description":"in few word hallo was a construction toy created by the american company mattel to engage girls in construction play","tag":"red","_geo": { "lat": 43.99155030238669, "lng": 3.503453528249425 },"":""} {"id":"L","word_rank":0,"typo_rank":0,"proximity_rank":2,"attribute_rank":250,"exact_rank":4,"asc_desc_rank":0,"sort_by_rank":0,"geo_rank":737861,"title":"good morning world","description":"good morning world is an american sitcom broadcast on cbs tv during the 1967 1968 season","tag":"blue","_geo": { "lat": 44.000507750283695, "lng": 3.5116812040621572 },"":""} {"id":"M","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":0,"asc_desc_rank":0,"sort_by_rank":2,"geo_rank":739203,"title":"hello world america","description":"a perfect match for a perfect engine using the query hello world america","tag":"red","_geo": { "lat": 43.99150729038736, "lng": 3.606143957295055 },"":""} {"id":"N","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":1,"asc_desc_rank":4,"sort_by_rank":1,"geo_rank":9499586,"title":"hello world america unleashed","description":"a very good match for a very good engine using the query hello world america","tag":"green","_geo": { "lat": 35.511540843367115, "lng": 138.764368875787 },"":""} diff --git a/milli/tests/search/query_criteria.rs b/milli/tests/search/query_criteria.rs index 0dcbf660e..ef080db9f 100644 --- a/milli/tests/search/query_criteria.rs +++ b/milli/tests/search/query_criteria.rs @@ -61,6 +61,7 @@ test_criterion!( vec![Attribute], vec![] ); +test_criterion!(typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Typo], vec![]); test_criterion!( attribute_disallow_typo, DISALLOW_OPTIONAL_WORDS,