diff --git a/milli/src/search/new/query_graph.rs b/milli/src/search/new/query_graph.rs new file mode 100644 index 000000000..726a1460c --- /dev/null +++ b/milli/src/search/new/query_graph.rs @@ -0,0 +1,401 @@ +use std::collections::HashSet; +use std::fmt::Debug; + +use heed::RoTxn; + +use super::{ + db_cache::DatabaseCache, + query_term::{LocatedQueryTerm, QueryTerm, WordDerivations}, +}; +use crate::{Index, Result}; + +#[derive(Clone)] +pub enum QueryNode { + Term(LocatedQueryTerm), + Deleted, + Start, + End, +} + +#[derive(Debug, Clone)] +pub struct Edges { + pub incoming: HashSet, + pub outgoing: HashSet, +} + +#[derive(Debug, Clone)] +pub struct QueryGraph { + pub root_node: usize, + pub end_node: usize, + pub nodes: Vec, + pub edges: Vec, +} + +fn _assert_sizes() { + let _: [u8; 112] = [0; std::mem::size_of::()]; + let _: [u8; 96] = [0; std::mem::size_of::()]; +} + +impl Default for QueryGraph { + /// Create a new QueryGraph with two disconnected nodes: the root and end nodes. + fn default() -> Self { + let nodes = vec![QueryNode::Start, QueryNode::End]; + let edges = vec![ + Edges { incoming: HashSet::new(), outgoing: HashSet::new() }, + Edges { incoming: HashSet::new(), outgoing: HashSet::new() }, + ]; + + Self { root_node: 0, end_node: 1, nodes, edges } + } +} + +impl QueryGraph { + fn connect_to_node(&mut self, from_nodes: &[usize], end_node: usize) { + for &from_node in from_nodes { + self.edges[from_node].outgoing.insert(end_node); + self.edges[end_node].incoming.insert(from_node); + } + } + fn add_node(&mut self, from_nodes: &[usize], node: QueryNode) -> usize { + let new_node_idx = self.nodes.len(); + self.nodes.push(node); + self.edges.push(Edges { + incoming: from_nodes.iter().copied().collect(), + outgoing: HashSet::new(), + }); + for from_node in from_nodes { + self.edges[*from_node].outgoing.insert(new_node_idx); + } + new_node_idx + } +} + +impl QueryGraph { + // TODO: return the list of all matching words here as well + + pub fn from_query<'transaction>( + index: &Index, + txn: &RoTxn, + _db_cache: &mut DatabaseCache<'transaction>, + query: Vec, + ) -> Result { + // TODO: maybe empty nodes should not be removed here, to compute + // the score of the `words` ranking rule correctly + // it is very easy to traverse the graph and remove afterwards anyway + // Still, I'm keeping this here as a demo + let mut empty_nodes = vec![]; + + let word_set = index.words_fst(txn)?; + let mut graph = QueryGraph::default(); + + let (mut prev2, mut prev1, mut prev0): (Vec, Vec, Vec) = + (vec![], vec![], vec![graph.root_node]); + + // TODO: add all the word derivations found in the fst + // and add split words / support phrases + + for length in 1..=query.len() { + let query = &query[..length]; + + let term0 = query.last().unwrap(); + + let mut new_nodes = vec![]; + let new_node_idx = graph.add_node(&prev0, QueryNode::Term(term0.clone())); + new_nodes.push(new_node_idx); + if term0.is_empty() { + empty_nodes.push(new_node_idx); + } + + if !prev1.is_empty() { + if let Some((ngram2_str, ngram2_pos)) = + LocatedQueryTerm::ngram2(&query[length - 2], &query[length - 1]) + { + if word_set.contains(ngram2_str.as_bytes()) { + let ngram2 = LocatedQueryTerm { + value: QueryTerm::Word { + derivations: WordDerivations { + original: ngram2_str.clone(), + // TODO: could add a typo if it's an ngram? + zero_typo: vec![ngram2_str], + one_typo: vec![], + two_typos: vec![], + use_prefix_db: false, + }, + }, + positions: ngram2_pos, + }; + let ngram2_idx = graph.add_node(&prev1, QueryNode::Term(ngram2)); + new_nodes.push(ngram2_idx); + } + } + } + if !prev2.is_empty() { + if let Some((ngram3_str, ngram3_pos)) = LocatedQueryTerm::ngram3( + &query[length - 3], + &query[length - 2], + &query[length - 1], + ) { + if word_set.contains(ngram3_str.as_bytes()) { + let ngram3 = LocatedQueryTerm { + value: QueryTerm::Word { + derivations: WordDerivations { + original: ngram3_str.clone(), + // TODO: could add a typo if it's an ngram? + zero_typo: vec![ngram3_str], + one_typo: vec![], + two_typos: vec![], + use_prefix_db: false, + }, + }, + positions: ngram3_pos, + }; + let ngram3_idx = graph.add_node(&prev2, QueryNode::Term(ngram3)); + new_nodes.push(ngram3_idx); + } + } + } + (prev0, prev1, prev2) = (new_nodes, prev0, prev1); + } + graph.connect_to_node(&prev0, graph.end_node); + + graph.remove_nodes_keep_edges(&empty_nodes); + + Ok(graph) + } + pub fn remove_nodes(&mut self, nodes: &[usize]) { + for &node in nodes { + self.nodes[node] = QueryNode::Deleted; + let edges = self.edges[node].clone(); + for &pred in edges.incoming.iter() { + self.edges[pred].outgoing.remove(&node); + } + for succ in edges.outgoing { + self.edges[succ].incoming.remove(&node); + } + self.edges[node] = Edges { incoming: HashSet::new(), outgoing: HashSet::new() }; + } + } + pub fn remove_nodes_keep_edges(&mut self, nodes: &[usize]) { + for &node in nodes { + self.nodes[node] = QueryNode::Deleted; + let edges = self.edges[node].clone(); + for &pred in edges.incoming.iter() { + self.edges[pred].outgoing.remove(&node); + self.edges[pred].outgoing.extend(edges.outgoing.iter()); + } + for succ in edges.outgoing { + self.edges[succ].incoming.remove(&node); + self.edges[succ].incoming.extend(edges.incoming.iter()); + } + self.edges[node] = Edges { incoming: HashSet::new(), outgoing: HashSet::new() }; + } + } + pub fn remove_words_at_position(&mut self, position: i8) { + let mut nodes_to_remove_keeping_edges = vec![]; + let mut nodes_to_remove = vec![]; + for (node_idx, node) in self.nodes.iter().enumerate() { + let QueryNode::Term(LocatedQueryTerm { value: _, positions }) = node else { continue }; + if positions.contains(&position) { + nodes_to_remove_keeping_edges.push(node_idx) + } else if positions.contains(&position) { + nodes_to_remove.push(node_idx) + } + } + + self.remove_nodes(&nodes_to_remove); + self.remove_nodes_keep_edges(&nodes_to_remove_keeping_edges); + + self.simplify(); + } + + fn simplify(&mut self) { + loop { + let mut nodes_to_remove = vec![]; + for (node_idx, node) in self.nodes.iter().enumerate() { + if (!matches!(node, QueryNode::End | QueryNode::Deleted) + && self.edges[node_idx].outgoing.is_empty()) + || (!matches!(node, QueryNode::Start | QueryNode::Deleted) + && self.edges[node_idx].incoming.is_empty()) + { + nodes_to_remove.push(node_idx); + } + } + if nodes_to_remove.is_empty() { + break; + } else { + self.remove_nodes(&nodes_to_remove); + } + } + } +} +impl Debug for QueryNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + QueryNode::Term(term @ LocatedQueryTerm { value, positions: _ }) => match value { + QueryTerm::Word { + derivations: + WordDerivations { original, zero_typo, one_typo, two_typos, use_prefix_db }, + } => { + if term.is_empty() { + write!(f, "\"{original} (∅)\"") + } else { + let derivations = std::iter::once(original.clone()) + .chain(zero_typo.iter().map(|s| format!("T0 .. {s}"))) + .chain(one_typo.iter().map(|s| format!("T1 .. {s}"))) + .chain(two_typos.iter().map(|s| format!("T2 .. {s}"))) + .collect::>() + .join(" | "); + + write!(f, "\"{derivations}")?; + if *use_prefix_db { + write!(f, " | +prefix_db")?; + } + write!(f, " | pos:{}..={}", term.positions.start(), term.positions.end())?; + write!(f, "\"")?; + /* + "beautiful" [label = " beautiful | beauiful | beautifol"] + */ + Ok(()) + } + } + QueryTerm::Phrase(ws) => { + let joined = + ws.iter().filter_map(|x| x.clone()).collect::>().join(" "); + let in_quotes = format!("\"{joined}\""); + let escaped = in_quotes.escape_default().collect::(); + write!(f, "\"{escaped}\"") + } + }, + QueryNode::Start => write!(f, "\"START\""), + QueryNode::End => write!(f, "\"END\""), + QueryNode::Deleted => write!(f, "\"_deleted_\""), + } + } +} + +/* +TODO: + +1. Find the minimum number of words to check to resolve the 10 query trees at once. + (e.g. just 0 | 01 | 012 ) +2. Simplify the query tree after removal of a node ✅ +3. Create the proximity graph ✅ +4. Assign different proximities for the ngrams ✅ +5. Walk the proximity graph, finding all the potential paths of weight N from START to END ✅ +(without checking the bitmaps) + +*/ +impl QueryGraph { + pub fn graphviz(&self) -> String { + let mut desc = String::new(); + desc.push_str( + r#" +digraph G { +rankdir = LR; +node [shape = "record"] +"#, + ); + + for node in 0..self.nodes.len() { + if matches!(self.nodes[node], QueryNode::Deleted) { + continue; + } + desc.push_str(&format!("{node} [label = {:?}]", &self.nodes[node],)); + if node == self.root_node { + desc.push_str("[color = blue]"); + } else if node == self.end_node { + desc.push_str("[color = red]"); + } + desc.push_str(";\n"); + + for edge in self.edges[node].outgoing.iter() { + desc.push_str(&format!("{node} -> {edge};\n")); + } + // for edge in self.edges[node].incoming.iter() { + // desc.push_str(&format!("{node} -> {edge} [color = grey];\n")); + // } + } + + desc.push('}'); + desc + } +} + +#[cfg(test)] +mod tests { + use charabia::Tokenize; + + use super::{LocatedQueryTerm, QueryGraph, QueryNode}; + use crate::index::tests::TempIndex; + use crate::new::db_cache::DatabaseCache; + use crate::search::new::query_term::word_derivations; + + #[test] + fn build_graph() { + let mut index = TempIndex::new(); + index.index_documents_config.autogenerate_docids = true; + index + .update_settings(|s| { + s.set_searchable_fields(vec!["text".to_owned()]); + }) + .unwrap(); + index + .add_documents(documents!({ + "text": "0 1 2 3 4 5 6 7 01 23 234 56 79 709 7356", + })) + .unwrap(); + + // let fst = fst::Set::from_iter(["01", "23", "234", "56"]).unwrap(); + let txn = index.read_txn().unwrap(); + let mut db_cache = DatabaseCache::default(); + + let fst = index.words_fst(&txn).unwrap(); + let query = LocatedQueryTerm::from_query( + "0 no 1 2 3 4 5 6 7".tokenize(), + None, + |word, is_prefix| { + word_derivations( + &index, + &txn, + word, + if word.len() < 3 { + 0 + } else if word.len() < 6 { + 1 + } else { + 2 + }, + is_prefix, + &fst, + ) + }, + ) + .unwrap(); + + let graph = QueryGraph::from_query(&index, &txn, &mut db_cache, query).unwrap(); + println!("{}", graph.graphviz()); + + // let positions_to_remove = vec![3, 6, 0, 4]; + // for p in positions_to_remove { + // graph.remove_words_at_position(p); + // println!("{}", graph.graphviz()); + // } + + // let proximities = |w1: &str, w2: &str| -> Vec { + // if matches!((w1, w2), ("56", "7")) { + // vec![] + // } else { + // vec![1, 2] + // } + // }; + + // let prox_graph = ProximityGraph::from_query_graph(graph, proximities); + + // println!("{}", prox_graph.graphviz()); + } +} + +// fn remove_element_from_vector(v: &mut Vec, el: usize) { +// let position = v.iter().position(|&x| x == el).unwrap(); +// v.swap_remove(position); +// } diff --git a/milli/src/search/new/query_term.rs b/milli/src/search/new/query_term.rs new file mode 100644 index 000000000..4d2b22264 --- /dev/null +++ b/milli/src/search/new/query_term.rs @@ -0,0 +1,305 @@ +// TODO: put primitive query part in here + +use std::borrow::Cow; +use std::mem; +use std::ops::RangeInclusive; + +use charabia::normalizer::NormalizedTokenIter; +use charabia::{SeparatorKind, TokenKind}; +use fst::automaton::Str; +use fst::{Automaton, IntoStreamer, Streamer}; +use heed::types::DecodeIgnore; +use heed::RoTxn; + +use crate::search::fst_utils::{Complement, Intersection, StartsWith, Union}; +use crate::search::{build_dfa, get_first}; +use crate::{Index, Result}; + +#[derive(Debug, Clone)] +pub struct WordDerivations { + // TODO: should have a list for the words corresponding to the prefix as well! + // This is to implement the `exactness` ranking rule. + // However, we could also consider every term in `zero_typo` (except first one) to + // be words of that the original word is a prefix of + pub original: String, + pub zero_typo: Vec, + pub one_typo: Vec, + pub two_typos: Vec, + pub use_prefix_db: bool, +} +impl WordDerivations { + pub fn all_derivations_except_prefix_db(&self) -> impl Iterator + Clone { + self.zero_typo.iter().chain(self.one_typo.iter()).chain(self.two_typos.iter()) + } + fn is_empty(&self) -> bool { + self.zero_typo.is_empty() + && self.one_typo.is_empty() + && self.two_typos.is_empty() + && !self.use_prefix_db + } +} + +pub fn word_derivations( + index: &Index, + txn: &RoTxn, + word: &str, + max_typo: u8, + is_prefix: bool, + fst: &fst::Set>, +) -> Result { + let use_prefix_db = is_prefix + && index.word_prefix_docids.remap_data_type::().get(txn, word)?.is_some(); + + let mut zero_typo = vec![]; + let mut one_typo = vec![]; + let mut two_typos = vec![]; + + if max_typo == 0 { + if is_prefix { + let prefix = Str::new(word).starts_with(); + let mut stream = fst.search(prefix).into_stream(); + + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + zero_typo.push(word.to_string()); + } + } else if fst.contains(word) { + zero_typo.push(word.to_string()); + } + } else if max_typo == 1 { + let dfa = build_dfa(word, 1, is_prefix); + let starts = StartsWith(Str::new(get_first(word))); + let mut stream = fst.search_with_state(Intersection(starts, &dfa)).into_stream(); + + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let d = dfa.distance(state.1); + match d.to_u8() { + 0 => { + zero_typo.push(word.to_string()); + } + 1 => { + one_typo.push(word.to_string()); + } + _ => panic!(), + } + } + } else { + let starts = StartsWith(Str::new(get_first(word))); + let first = Intersection(build_dfa(word, 1, is_prefix), Complement(&starts)); + let second_dfa = build_dfa(word, 2, is_prefix); + let second = Intersection(&second_dfa, &starts); + let automaton = Union(first, &second); + + let mut stream = fst.search_with_state(automaton).into_stream(); + + while let Some((found_word, state)) = stream.next() { + let found_word = std::str::from_utf8(found_word)?; + // in the case the typo is on the first letter, we know the number of typo + // is two + if get_first(found_word) != get_first(word) { + two_typos.push(found_word.to_string()); + } else { + // Else, we know that it is the second dfa that matched and compute the + // correct distance + let d = second_dfa.distance((state.1).0); + match d.to_u8() { + 0 => { + zero_typo.push(found_word.to_string()); + } + 1 => { + one_typo.push(found_word.to_string()); + } + 2 => { + two_typos.push(found_word.to_string()); + } + _ => panic!(), + } + } + } + } + + Ok(WordDerivations { original: word.to_owned(), zero_typo, one_typo, two_typos, use_prefix_db }) +} + +#[derive(Debug, Clone)] +pub enum QueryTerm { + Phrase(Vec>), + Word { derivations: WordDerivations }, +} +impl QueryTerm { + pub fn original_single_word(&self) -> Option<&str> { + match self { + QueryTerm::Phrase(_) => None, + QueryTerm::Word { derivations } => { + if derivations.is_empty() { + None + } else { + Some(derivations.original.as_str()) + } + } + } + } +} + +#[derive(Debug, Clone)] +pub struct LocatedQueryTerm { + pub value: QueryTerm, // value should be able to contain the word derivations as well + pub positions: RangeInclusive, +} + +impl LocatedQueryTerm { + pub fn is_empty(&self) -> bool { + match &self.value { + QueryTerm::Phrase(_) => false, + QueryTerm::Word { derivations, .. } => derivations.is_empty(), + } + } + /// Create primitive query from tokenized query string, + /// the primitive query is an intermediate state to build the query tree. + pub fn from_query( + query: NormalizedTokenIter>, + words_limit: Option, + derivations: impl Fn(&str, bool) -> Result, + ) -> Result> { + let mut primitive_query = Vec::new(); + let mut phrase = Vec::new(); + + let mut quoted = false; + + let parts_limit = words_limit.unwrap_or(usize::MAX); + + let mut position = -1i8; + let mut phrase_start = -1i8; + let mut phrase_end = -1i8; + + let mut peekable = query.peekable(); + while let Some(token) = peekable.next() { + // early return if word limit is exceeded + if primitive_query.len() >= parts_limit { + return Ok(primitive_query); + } + + match token.kind { + TokenKind::Word | TokenKind::StopWord => { + position += 1; + // 1. if the word is quoted we push it in a phrase-buffer waiting for the ending quote, + // 2. if the word is not the last token of the query and is not a stop_word we push it as a non-prefix word, + // 3. if the word is the last token of the query we push it as a prefix word. + if quoted { + phrase_end = position; + if phrase.is_empty() { + phrase_start = position; + } + if let TokenKind::StopWord = token.kind { + phrase.push(None); + } else { + // TODO: in a phrase, check that every word exists + // otherwise return WordDerivations::Empty + phrase.push(Some(token.lemma().to_string())); + } + } else if peekable.peek().is_some() { + if let TokenKind::StopWord = token.kind { + } else { + let derivations = derivations(token.lemma(), false)?; + let located_term = LocatedQueryTerm { + value: QueryTerm::Word { derivations }, + positions: position..=position, + }; + primitive_query.push(located_term); + } + } else { + let derivations = derivations(token.lemma(), true)?; + let located_term = LocatedQueryTerm { + value: QueryTerm::Word { derivations }, + positions: position..=position, + }; + primitive_query.push(located_term); + } + } + TokenKind::Separator(separator_kind) => { + match separator_kind { + SeparatorKind::Hard => { + position += 1; + } + SeparatorKind::Soft => { + position += 0; + } + } + let quote_count = token.lemma().chars().filter(|&s| s == '"').count(); + // swap quoted state if we encounter a double quote + if quote_count % 2 != 0 { + quoted = !quoted; + } + // if there is a quote or a hard separator we close the phrase. + if !phrase.is_empty() + && (quote_count > 0 || separator_kind == SeparatorKind::Hard) + { + let located_query_term = LocatedQueryTerm { + value: QueryTerm::Phrase(mem::take(&mut phrase)), + positions: phrase_start..=phrase_end, + }; + primitive_query.push(located_query_term); + } + } + _ => (), + } + } + + // If a quote is never closed, we consider all of the end of the query as a phrase. + if !phrase.is_empty() { + let located_query_term = LocatedQueryTerm { + value: QueryTerm::Phrase(mem::take(&mut phrase)), + positions: phrase_start..=phrase_end, + }; + primitive_query.push(located_query_term); + } + + Ok(primitive_query) + } +} + +impl LocatedQueryTerm { + pub fn ngram2( + x: &LocatedQueryTerm, + y: &LocatedQueryTerm, + ) -> Option<(String, RangeInclusive)> { + if *x.positions.end() != y.positions.start() - 1 { + println!( + "x positions end: {}, y positions start: {}", + *x.positions.end(), + y.positions.start() + ); + return None; + } + match (&x.value.original_single_word(), &y.value.original_single_word()) { + (Some(w1), Some(w2)) => { + let term = (format!("{w1}{w2}"), *x.positions.start()..=*y.positions.end()); + Some(term) + } + _ => None, + } + } + pub fn ngram3( + x: &LocatedQueryTerm, + y: &LocatedQueryTerm, + z: &LocatedQueryTerm, + ) -> Option<(String, RangeInclusive)> { + if *x.positions.end() != y.positions.start() - 1 + || *y.positions.end() != z.positions.start() - 1 + { + return None; + } + match ( + &x.value.original_single_word(), + &y.value.original_single_word(), + &z.value.original_single_word(), + ) { + (Some(w1), Some(w2), Some(w3)) => { + let term = (format!("{w1}{w2}{w3}"), *x.positions.start()..=*z.positions.end()); + Some(term) + } + _ => None, + } + } +}