diff --git a/milli/src/search/fst_utils.rs b/milli/src/search/fst_utils.rs
new file mode 100644
index 000000000..b488e6c19
--- /dev/null
+++ b/milli/src/search/fst_utils.rs
@@ -0,0 +1,187 @@
+/// This mod is necessary until https://github.com/BurntSushi/fst/pull/137 gets merged.
+/// All credits for this code go to BurntSushi.
+use fst::Automaton;
+
+pub struct StartsWith(pub A);
+
+/// The `Automaton` state for `StartsWith`.
+pub struct StartsWithState(pub StartsWithStateKind);
+
+impl Clone for StartsWithState
+where
+ A::State: Clone,
+{
+ fn clone(&self) -> Self {
+ Self(self.0.clone())
+ }
+}
+
+/// The inner state of a `StartsWithState`.
+pub enum StartsWithStateKind {
+ /// Sink state that is reached when the automaton has matched the prefix.
+ Done,
+ /// State in which the automaton is while it hasn't matched the prefix.
+ Running(A::State),
+}
+
+impl Clone for StartsWithStateKind
+where
+ A::State: Clone,
+{
+ fn clone(&self) -> Self {
+ match self {
+ StartsWithStateKind::Done => StartsWithStateKind::Done,
+ StartsWithStateKind::Running(inner) => StartsWithStateKind::Running(inner.clone()),
+ }
+ }
+}
+
+impl Automaton for StartsWith {
+ type State = StartsWithState;
+
+ fn start(&self) -> StartsWithState {
+ StartsWithState({
+ let inner = self.0.start();
+ if self.0.is_match(&inner) {
+ StartsWithStateKind::Done
+ } else {
+ StartsWithStateKind::Running(inner)
+ }
+ })
+ }
+ fn is_match(&self, state: &StartsWithState) -> bool {
+ match state.0 {
+ StartsWithStateKind::Done => true,
+ StartsWithStateKind::Running(_) => false,
+ }
+ }
+ fn can_match(&self, state: &StartsWithState) -> bool {
+ match state.0 {
+ StartsWithStateKind::Done => true,
+ StartsWithStateKind::Running(ref inner) => self.0.can_match(inner),
+ }
+ }
+ fn will_always_match(&self, state: &StartsWithState) -> bool {
+ match state.0 {
+ StartsWithStateKind::Done => true,
+ StartsWithStateKind::Running(_) => false,
+ }
+ }
+ fn accept(&self, state: &StartsWithState, byte: u8) -> StartsWithState {
+ StartsWithState(match state.0 {
+ StartsWithStateKind::Done => StartsWithStateKind::Done,
+ StartsWithStateKind::Running(ref inner) => {
+ let next_inner = self.0.accept(inner, byte);
+ if self.0.is_match(&next_inner) {
+ StartsWithStateKind::Done
+ } else {
+ StartsWithStateKind::Running(next_inner)
+ }
+ }
+ })
+ }
+}
+/// An automaton that matches when one of its component automata match.
+#[derive(Clone, Debug)]
+pub struct Union(pub A, pub B);
+
+/// The `Automaton` state for `Union`.
+pub struct UnionState(pub A::State, pub B::State);
+
+impl Clone for UnionState
+where
+ A::State: Clone,
+ B::State: Clone,
+{
+ fn clone(&self) -> Self {
+ Self(self.0.clone(), self.1.clone())
+ }
+}
+
+impl Automaton for Union {
+ type State = UnionState;
+ fn start(&self) -> UnionState {
+ UnionState(self.0.start(), self.1.start())
+ }
+ fn is_match(&self, state: &UnionState) -> bool {
+ self.0.is_match(&state.0) || self.1.is_match(&state.1)
+ }
+ fn can_match(&self, state: &UnionState) -> bool {
+ self.0.can_match(&state.0) || self.1.can_match(&state.1)
+ }
+ fn will_always_match(&self, state: &UnionState) -> bool {
+ self.0.will_always_match(&state.0) || self.1.will_always_match(&state.1)
+ }
+ fn accept(&self, state: &UnionState, byte: u8) -> UnionState {
+ UnionState(self.0.accept(&state.0, byte), self.1.accept(&state.1, byte))
+ }
+}
+/// An automaton that matches when both of its component automata match.
+#[derive(Clone, Debug)]
+pub struct Intersection(pub A, pub B);
+
+/// The `Automaton` state for `Intersection`.
+pub struct IntersectionState(pub A::State, pub B::State);
+
+impl Clone for IntersectionState
+where
+ A::State: Clone,
+ B::State: Clone,
+{
+ fn clone(&self) -> Self {
+ Self(self.0.clone(), self.1.clone())
+ }
+}
+
+impl Automaton for Intersection {
+ type State = IntersectionState;
+ fn start(&self) -> IntersectionState {
+ IntersectionState(self.0.start(), self.1.start())
+ }
+ fn is_match(&self, state: &IntersectionState) -> bool {
+ self.0.is_match(&state.0) && self.1.is_match(&state.1)
+ }
+ fn can_match(&self, state: &IntersectionState) -> bool {
+ self.0.can_match(&state.0) && self.1.can_match(&state.1)
+ }
+ fn will_always_match(&self, state: &IntersectionState) -> bool {
+ self.0.will_always_match(&state.0) && self.1.will_always_match(&state.1)
+ }
+ fn accept(&self, state: &IntersectionState, byte: u8) -> IntersectionState {
+ IntersectionState(self.0.accept(&state.0, byte), self.1.accept(&state.1, byte))
+ }
+}
+/// An automaton that matches exactly when the automaton it wraps does not.
+#[derive(Clone, Debug)]
+pub struct Complement(pub A);
+
+/// The `Automaton` state for `Complement`.
+pub struct ComplementState(pub A::State);
+
+impl Clone for ComplementState
+where
+ A::State: Clone,
+{
+ fn clone(&self) -> Self {
+ Self(self.0.clone())
+ }
+}
+
+impl Automaton for Complement {
+ type State = ComplementState;
+ fn start(&self) -> ComplementState {
+ ComplementState(self.0.start())
+ }
+ fn is_match(&self, state: &ComplementState) -> bool {
+ !self.0.is_match(&state.0)
+ }
+ fn can_match(&self, state: &ComplementState) -> bool {
+ !self.0.will_always_match(&state.0)
+ }
+ fn will_always_match(&self, state: &ComplementState) -> bool {
+ !self.0.can_match(&state.0)
+ }
+ fn accept(&self, state: &ComplementState, byte: u8) -> ComplementState {
+ ComplementState(self.0.accept(&state.0, byte))
+ }
+}
diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs
index 7c8722187..40e4bca24 100644
--- a/milli/src/search/mod.rs
+++ b/milli/src/search/mod.rs
@@ -7,7 +7,8 @@ use std::str::Utf8Error;
use std::time::Instant;
use distinct::{Distinct, DocIter, FacetDistinct, NoopDistinct};
-use fst::{IntoStreamer, Streamer};
+use fst::automaton::Str;
+use fst::{Automaton, IntoStreamer, Streamer};
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
use log::debug;
use meilisearch_tokenizer::{Analyzer, AnalyzerConfig};
@@ -15,6 +16,7 @@ use once_cell::sync::Lazy;
use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, FacetNumberIter, Filter};
+use self::fst_utils::{Complement, Intersection, StartsWith, Union};
pub use self::matching_words::MatchingWords;
use self::query_tree::QueryTreeBuilder;
use crate::error::UserError;
@@ -29,6 +31,7 @@ static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true));
mod criteria;
mod distinct;
mod facet;
+mod fst_utils;
mod matching_words;
mod query_tree;
@@ -284,20 +287,66 @@ pub fn word_derivations<'c>(
Entry::Occupied(entry) => Ok(entry.into_mut()),
Entry::Vacant(entry) => {
let mut derived_words = Vec::new();
- let dfa = build_dfa(word, max_typo, is_prefix);
- let mut stream = fst.search_with_state(&dfa).into_stream();
+ if max_typo == 0 {
+ if is_prefix {
+ let prefix = Str::new(word).starts_with();
+ let mut stream = fst.search(prefix).into_stream();
- while let Some((word, state)) = stream.next() {
- let word = std::str::from_utf8(word)?;
- let distance = dfa.distance(state);
- derived_words.push((word.to_string(), distance.to_u8()));
+ while let Some(word) = stream.next() {
+ let word = std::str::from_utf8(word)?;
+ derived_words.push((word.to_string(), 0));
+ }
+ } else if fst.contains(word) {
+ derived_words.push((word.to_string(), 0));
+ }
+ } else {
+ if max_typo == 1 {
+ let dfa = build_dfa(word, 1, is_prefix);
+ let starts = StartsWith(Str::new(get_first(word)));
+ let mut stream =
+ fst.search_with_state(Intersection(starts, &dfa)).into_stream();
+
+ while let Some((word, state)) = stream.next() {
+ let word = std::str::from_utf8(word)?;
+ let d = dfa.distance(state.1);
+ derived_words.push((word.to_string(), d.to_u8()));
+ }
+ } else {
+ let starts = StartsWith(Str::new(get_first(word)));
+ let first = Intersection(build_dfa(word, 1, is_prefix), Complement(&starts));
+ let second_dfa = build_dfa(word, 2, is_prefix);
+ let second = Intersection(&second_dfa, &starts);
+ let automaton = Union(first, &second);
+
+ let mut stream = fst.search_with_state(automaton).into_stream();
+
+ while let Some((found_word, state)) = stream.next() {
+ let found_word = std::str::from_utf8(found_word)?;
+ // in the case the typo is on the first letter, we know the number of typo
+ // is two
+ if get_first(found_word) != get_first(word) {
+ derived_words.push((word.to_string(), 2));
+ } else {
+ // Else, we know that it is the second dfa that matched and compute the
+ // correct distance
+ let d = second_dfa.distance((state.1).0);
+ derived_words.push((word.to_string(), d.to_u8()));
+ }
+ }
+ }
}
-
Ok(entry.insert(derived_words))
}
}
}
+fn get_first(s: &str) -> &str {
+ match s.chars().next() {
+ Some(c) => &s[..c.len_utf8()],
+ None => panic!("unexpected empty query"),
+ }
+}
+
pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA {
let lev = match typos {
0 => &LEVDIST0,
diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs
index 237bb9be2..f3ee99d9e 100644
--- a/milli/src/search/query_tree.rs
+++ b/milli/src/search/query_tree.rs
@@ -260,12 +260,12 @@ fn split_best_frequency(ctx: &impl Context, word: &str) -> heed::Result