From 55e6cb9c7b179181e1e131265b0a66da76a76250 Mon Sep 17 00:00:00 2001 From: mpostma Date: Thu, 20 Jan 2022 18:35:11 +0100 Subject: [PATCH] typos on first letter counts as 2 --- Cargo.toml | 3 +++ milli/src/search/mod.rs | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6b3e12f07..9b97dee88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,6 @@ opt-level = 3 opt-level = 3 [profile.test.build-override] opt-level = 3 + +[patch.crates-io] +fst = { path = "/Users/mpostma/Documents/code/rust/fst/" } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 7c8722187..6b2e50c94 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -7,7 +7,8 @@ use std::str::Utf8Error; use std::time::Instant; use distinct::{Distinct, DocIter, FacetDistinct, NoopDistinct}; -use fst::{IntoStreamer, Streamer}; +use fst::automaton::Str; +use fst::{Automaton, IntoStreamer, Streamer}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use log::debug; use meilisearch_tokenizer::{Analyzer, AnalyzerConfig}; @@ -285,19 +286,39 @@ pub fn word_derivations<'c>( Entry::Vacant(entry) => { let mut derived_words = Vec::new(); let dfa = build_dfa(word, max_typo, is_prefix); - let mut stream = fst.search_with_state(&dfa).into_stream(); + if max_typo == 1 { + let starts = Str::new(get_first(word)); + let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); - while let Some((word, state)) = stream.next() { - let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state); - derived_words.push((word.to_string(), distance.to_u8())); + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let distance = dfa.distance(state.1); + derived_words.push((word.to_string(), distance.to_u8())); + } + + Ok(entry.insert(derived_words)) + } else { + let mut stream = fst.search_with_state(&dfa).into_stream(); + + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let distance = dfa.distance(state); + derived_words.push((word.to_string(), distance.to_u8())); + } + + Ok(entry.insert(derived_words)) } - - Ok(entry.insert(derived_words)) } } } +fn get_first(s: &str) -> &str { + match s.chars().next() { + Some(c) => &s[..c.len_utf8()], + None => s, + } +} + pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA { let lev = match typos { 0 => &LEVDIST0,