diff --git a/milli/src/search/matching_words.rs b/milli/src/search/matching_words.rs new file mode 100644 index 000000000..37a4f49c0 --- /dev/null +++ b/milli/src/search/matching_words.rs @@ -0,0 +1,204 @@ +use std::collections::HashSet; +use std::cmp::{min, Reverse}; +use std::collections::BTreeMap; +use std::ops::{Index, IndexMut}; + +use levenshtein_automata::{DFA, Distance}; + +use crate::search::query_tree::{Operation, Query}; + +use super::build_dfa; + +type IsPrefix = bool; + +/// The query tree builder is the interface to build a query tree. +#[derive(Default)] +pub struct MatchingWords { + dfas: Vec<(DFA, String, u8, IsPrefix)>, +} + +impl MatchingWords { + /// Lists all words which can be considered as a match for the query tree. + pub fn from_query_tree(tree: &Operation) -> Self { + let mut dfas: Vec<_> = fetch_queries(tree) + .into_iter() + .map(|(w, t, p)| (build_dfa(w, t, p), w.to_string(), t, p)) + .collect(); + dfas.sort_unstable_by_key(|(_dfa, query_word, _typo, _is_prefix)| Reverse(query_word.len())); + Self { dfas } + } + + /// Returns the number of matching bytes if the word matches. + pub fn matching_bytes(&self, word: &str) -> Option { + self.dfas.iter().find_map(|(dfa, query_word, typo, is_prefix)| match dfa.eval(word) { + Distance::Exact(t) if t <= *typo => { + if *is_prefix { + let (_dist, len) = prefix_damerau_levenshtein(query_word.as_bytes(), word.as_bytes()); + Some(len) + } else { + Some(word.len()) + } + }, + _otherwise => None, + }) + } +} + +/// Lists all words which can be considered as a match for the query tree. +fn fetch_queries(tree: &Operation) -> HashSet<(&str, u8, IsPrefix)> { + fn resolve_ops<'a>(tree: &'a Operation, out: &mut HashSet<(&'a str, u8, IsPrefix)>) { + match tree { + Operation::Or(_, ops) | Operation::And(ops) | Operation::Consecutive(ops) => { + ops.as_slice().iter().for_each(|op| resolve_ops(op, out)); + }, + Operation::Query(Query { prefix, kind }) => { + let typo = if kind.is_exact() { 0 } else { kind.typo() }; + out.insert((kind.word(), typo, *prefix)); + }, + } + } + + let mut queries = HashSet::new(); + resolve_ops(tree, &mut queries); + queries +} + +// A simple wrapper around vec so we can get contiguous but index it like it's 2D array. +struct N2Array { + y_size: usize, + buf: Vec, +} + +impl N2Array { + fn new(x: usize, y: usize, value: T) -> N2Array { + N2Array { + y_size: y, + buf: vec![value; x * y], + } + } +} + +impl Index<(usize, usize)> for N2Array { + type Output = T; + + #[inline] + fn index(&self, (x, y): (usize, usize)) -> &T { + &self.buf[(x * self.y_size) + y] + } +} + +impl IndexMut<(usize, usize)> for N2Array { + #[inline] + fn index_mut(&mut self, (x, y): (usize, usize)) -> &mut T { + &mut self.buf[(x * self.y_size) + y] + } +} + +fn prefix_damerau_levenshtein(source: &[u8], target: &[u8]) -> (u32, usize) { + let (n, m) = (source.len(), target.len()); + + if n == 0 { + return (m as u32, 0); + } + if m == 0 { + return (n as u32, 0); + } + + if n == m && source == target { + return (0, m); + } + + let inf = n + m; + let mut matrix = N2Array::new(n + 2, m + 2, 0); + + matrix[(0, 0)] = inf; + for i in 0..n + 1 { + matrix[(i + 1, 0)] = inf; + matrix[(i + 1, 1)] = i; + } + for j in 0..m + 1 { + matrix[(0, j + 1)] = inf; + matrix[(1, j + 1)] = j; + } + + let mut last_row = BTreeMap::new(); + + for (row, char_s) in source.iter().enumerate() { + let mut last_match_col = 0; + let row = row + 1; + + for (col, char_t) in target.iter().enumerate() { + let col = col + 1; + let last_match_row = *last_row.get(&char_t).unwrap_or(&0); + let cost = if char_s == char_t { 0 } else { 1 }; + + let dist_add = matrix[(row, col + 1)] + 1; + let dist_del = matrix[(row + 1, col)] + 1; + let dist_sub = matrix[(row, col)] + cost; + let dist_trans = matrix[(last_match_row, last_match_col)] + + (row - last_match_row - 1) + + 1 + + (col - last_match_col - 1); + + let dist = min(min(dist_add, dist_del), min(dist_sub, dist_trans)); + + matrix[(row + 1, col + 1)] = dist; + + if cost == 0 { + last_match_col = col; + } + } + + last_row.insert(char_s, row); + } + + let mut minimum = (u32::max_value(), 0); + + for x in 0..=m { + let dist = matrix[(n + 1, x + 1)] as u32; + if dist < minimum.0 { + minimum = (dist, x) + } + } + + minimum +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::MatchingWords; + use crate::search::query_tree::{Operation, Query, QueryKind}; + + #[test] + fn matched_length() { + let query = "Levenste"; + let text = "Levenshtein"; + + let (dist, length) = prefix_damerau_levenshtein(query.as_bytes(), text.as_bytes()); + assert_eq!(dist, 1); + assert_eq!(&text[..length], "Levenshte"); + } + + #[test] + fn matching_words() { + let query_tree = Operation::Or(false, vec![ + Operation::And(vec![ + Operation::Query(Query { prefix: true, kind: QueryKind::exact("split".to_string()) }), + Operation::Query(Query { prefix: false, kind: QueryKind::exact("this".to_string()) }), + Operation::Query(Query { prefix: true, kind: QueryKind::tolerant(1, "world".to_string()) }), + ]), + ]); + + let matching_words = MatchingWords::from_query_tree(&query_tree); + + assert_eq!(matching_words.matching_bytes("word"), Some(4)); + assert_eq!(matching_words.matching_bytes("nyc"), None); + assert_eq!(matching_words.matching_bytes("world"), Some(5)); + assert_eq!(matching_words.matching_bytes("splitted"), Some(5)); + assert_eq!(matching_words.matching_bytes("thisnew"), None); + assert_eq!(matching_words.matching_bytes("borld"), Some(5)); + assert_eq!(matching_words.matching_bytes("wordsplit"), Some(4)); + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 623581706..fc64d020f 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -17,7 +17,7 @@ use crate::search::criteria::r#final::{Final, FinalResult}; use crate::{Index, DocumentId}; pub use self::facet::{FacetCondition, FacetDistribution, FacetIter, Operator}; -pub use self::query_tree::MatchingWords; +pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; // Building these factories is not free. @@ -29,6 +29,7 @@ mod criteria; mod distinct; mod facet; mod query_tree; +mod matching_words; pub struct Search<'a> { query: Option, diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs index 4876e37c8..3125664ab 100644 --- a/milli/src/search/query_tree.rs +++ b/milli/src/search/query_tree.rs @@ -294,48 +294,6 @@ fn synonyms(ctx: &impl Context, word: &[&str]) -> heed::Result, -} - -impl MatchingWords { - /// List all words which can be considered as a match for the query tree. - pub fn from_query_tree(tree: &Operation) -> Self { - Self { - dfas: fetch_queries(tree).into_iter().map(|(w, t, p)| (build_dfa(w, t, p), t)).collect() - } - } - - /// Return true if the word match. - pub fn matches(&self, word: &str) -> bool { - self.dfas.iter().any(|(dfa, typo)| match dfa.eval(word) { - Distance::Exact(t) => t <= *typo, - Distance::AtLeast(_) => false, - }) - } -} - -/// Lists all words which can be considered as a match for the query tree. -fn fetch_queries(tree: &Operation) -> HashSet<(&str, u8, IsPrefix)> { - fn resolve_ops<'a>(tree: &'a Operation, out: &mut HashSet<(&'a str, u8, IsPrefix)>) { - match tree { - Operation::Or(_, ops) | Operation::And(ops) | Operation::Consecutive(ops) => { - ops.as_slice().iter().for_each(|op| resolve_ops(op, out)); - }, - Operation::Query(Query { prefix, kind }) => { - let typo = if kind.is_exact() { 0 } else { kind.typo() }; - out.insert((kind.word(), typo, *prefix)); - }, - } - } - - let mut queries = HashSet::new(); - resolve_ops(tree, &mut queries); - queries -} - /// Main function that creates the final query tree from the primitive query. fn create_query_tree( ctx: &impl Context, @@ -951,39 +909,6 @@ mod test { assert_eq!(expected, query_tree); } - #[test] - fn fetching_words() { - let query = "wordsplit nyc world"; - let analyzer = Analyzer::new(AnalyzerConfig::>::default()); - let result = analyzer.analyze(query); - let tokens = result.tokens(); - - let context = TestContext::default(); - let (query_tree, _) = context.build(false, true, None, tokens).unwrap().unwrap(); - - let expected = hashset!{ - ("word", 0, false), - ("nyc", 0, false), - ("wordsplit", 2, false), - ("wordsplitnycworld", 2, true), - ("nature", 0, false), - ("new", 0, false), - ("city", 0, false), - ("world", 1, true), - ("york", 0, false), - ("split", 0, false), - ("nycworld", 1, true), - ("earth", 0, false), - ("wordsplitnyc", 2, false), - }; - - let mut keys = context.postings.keys().collect::>(); - keys.sort_unstable(); - - let words = fetch_queries(&query_tree); - assert_eq!(expected, words); - } - #[test] fn words_limit() { let query = "\"hey my\" good friend";