diff --git a/milli/src/search/new/ranking_rule_graph/proximity/build.rs b/milli/src/search/new/ranking_rule_graph/proximity/build.rs new file mode 100644 index 000000000..07ec3bb5e --- /dev/null +++ b/milli/src/search/new/ranking_rule_graph/proximity/build.rs @@ -0,0 +1,165 @@ +use std::collections::BTreeMap; + +use super::ProximityEdge; +use crate::new::db_cache::DatabaseCache; +use crate::new::query_term::{LocatedQueryTerm, QueryTerm, WordDerivations}; +use crate::new::ranking_rule_graph::proximity::WordPair; +use crate::new::ranking_rule_graph::{Edge, EdgeDetails}; +use crate::new::QueryNode; +use crate::{Index, Result}; +use heed::RoTxn; +use itertools::Itertools; + +pub fn visit_from_node(from_node: &QueryNode) -> Result> { + Ok(Some(match from_node { + QueryNode::Term(LocatedQueryTerm { value: value1, positions: pos1 }) => { + match value1 { + QueryTerm::Word { derivations } => (derivations.clone(), *pos1.end()), + QueryTerm::Phrase(phrase1) => { + // TODO: remove second unwrap + let original = phrase1.last().unwrap().as_ref().unwrap().clone(); + ( + WordDerivations { + original: original.clone(), + zero_typo: vec![original], + one_typo: vec![], + two_typos: vec![], + use_prefix_db: false, + }, + *pos1.end(), + ) + } + } + } + QueryNode::Start => ( + WordDerivations { + original: String::new(), + zero_typo: vec![], + one_typo: vec![], + two_typos: vec![], + use_prefix_db: false, + }, + -100, + ), + _ => return Ok(None), + })) +} + +pub fn visit_to_node<'transaction, 'from_data>( + index: &Index, + txn: &'transaction RoTxn, + db_cache: &mut DatabaseCache<'transaction>, + to_node: &QueryNode, + from_node_data: &'from_data (WordDerivations, i8), +) -> Result)>>> { + let (derivations1, pos1) = from_node_data; + let term2 = match &to_node { + QueryNode::End => return Ok(Some(vec![(0, EdgeDetails::Unconditional)])), + QueryNode::Deleted | QueryNode::Start => return Ok(None), + QueryNode::Term(term) => term, + }; + let LocatedQueryTerm { value: value2, positions: pos2 } = term2; + + let (derivations2, pos2, ngram_len2) = match value2 { + QueryTerm::Word { derivations } => (derivations.clone(), *pos2.start(), pos2.len()), + QueryTerm::Phrase(phrase2) => { + // TODO: remove second unwrap + let original = phrase2.last().unwrap().as_ref().unwrap().clone(); + ( + WordDerivations { + original: original.clone(), + zero_typo: vec![original], + one_typo: vec![], + two_typos: vec![], + use_prefix_db: false, + }, + *pos2.start(), + 1, + ) + } + }; + + // TODO: here we would actually do it for each combination of word1 and word2 + // and take the union of them + if pos1 + 1 != pos2 { + // TODO: how should this actually be handled? + // We want to effectively ignore this pair of terms + // Unconditionally walk through the edge without computing the docids + // But also what should the cost be? + return Ok(Some(vec![(0, EdgeDetails::Unconditional)])); + } + + let updb1 = derivations1.use_prefix_db; + let updb2 = derivations2.use_prefix_db; + + // left term cannot be a prefix + assert!(!updb1); + + let derivations1 = derivations1.all_derivations_except_prefix_db(); + let original_word_2 = derivations2.original.clone(); + let mut cost_proximity_word_pairs = BTreeMap::>>::new(); + + if updb2 { + for word1 in derivations1.clone() { + for proximity in 0..(7 - ngram_len2) { + let cost = (proximity + ngram_len2 - 1) as u8; + if db_cache + .get_word_prefix_pair_proximity_docids( + index, + txn, + word1, + original_word_2.as_str(), + proximity as u8, + )? + .is_some() + { + cost_proximity_word_pairs + .entry(cost) + .or_default() + .entry(proximity as u8) + .or_default() + .push(WordPair::WordPrefix { + left: word1.to_owned(), + right_prefix: original_word_2.to_owned(), + }); + } + } + } + } + + let derivations2 = derivations2.all_derivations_except_prefix_db(); + // TODO: safeguard in case the cartesian product is too large? + let product_derivations = derivations1.cartesian_product(derivations2); + + for (word1, word2) in product_derivations { + for proximity in 0..(7 - ngram_len2) { + let cost = (proximity + ngram_len2 - 1) as u8; + // TODO: do the opposite way with a proximity penalty as well! + // search for (word2, word1, proximity-1), I guess? + if db_cache + .get_word_pair_proximity_docids(index, txn, word1, word2, proximity as u8)? + .is_some() + { + cost_proximity_word_pairs + .entry(cost) + .or_default() + .entry(proximity as u8) + .or_default() + .push(WordPair::Words { left: word1.to_owned(), right: word2.to_owned() }); + } + } + } + let mut new_edges = cost_proximity_word_pairs + .into_iter() + .flat_map(|(cost, proximity_word_pairs)| { + let mut edges = vec![]; + for (proximity, word_pairs) in proximity_word_pairs { + edges + .push((cost, EdgeDetails::Data(ProximityEdge { pairs: word_pairs, proximity }))) + } + edges + }) + .collect::>(); + new_edges.push((8 + (ngram_len2 - 1) as u8, EdgeDetails::Unconditional)); + Ok(Some(new_edges)) +} diff --git a/milli/src/search/new/ranking_rule_graph/proximity/compute_docids.rs b/milli/src/search/new/ranking_rule_graph/proximity/compute_docids.rs new file mode 100644 index 000000000..325042761 --- /dev/null +++ b/milli/src/search/new/ranking_rule_graph/proximity/compute_docids.rs @@ -0,0 +1,31 @@ +use roaring::MultiOps; + +use super::{ProximityEdge, WordPair}; +use crate::new::db_cache::DatabaseCache; +use crate::CboRoaringBitmapCodec; + +pub fn compute_docids<'transaction>( + index: &crate::Index, + txn: &'transaction heed::RoTxn, + db_cache: &mut DatabaseCache<'transaction>, + edge: &ProximityEdge, +) -> crate::Result { + let ProximityEdge { pairs, proximity } = edge; + // TODO: we should know already which pair of words to look for + let mut pair_docids = vec![]; + for pair in pairs.iter() { + let bytes = match pair { + WordPair::Words { left, right } => { + db_cache.get_word_pair_proximity_docids(index, txn, left, right, *proximity) + } + WordPair::WordPrefix { left, right_prefix } => db_cache + .get_word_prefix_pair_proximity_docids(index, txn, left, right_prefix, *proximity), + }?; + let bitmap = + bytes.map(CboRoaringBitmapCodec::deserialize_from).transpose()?.unwrap_or_default(); + pair_docids.push(bitmap); + } + pair_docids.sort_by_key(|rb| rb.len()); + let docids = MultiOps::union(pair_docids); + Ok(docids) +} diff --git a/milli/src/search/new/ranking_rule_graph/proximity/mod.rs b/milli/src/search/new/ranking_rule_graph/proximity/mod.rs new file mode 100644 index 000000000..199a5eb4a --- /dev/null +++ b/milli/src/search/new/ranking_rule_graph/proximity/mod.rs @@ -0,0 +1,61 @@ +pub mod build; +pub mod compute_docids; + +use super::{Edge, EdgeDetails, RankingRuleGraphTrait}; +use crate::new::db_cache::DatabaseCache; +use crate::new::query_term::WordDerivations; +use crate::new::QueryNode; +use crate::{Index, Result}; +use heed::RoTxn; + +#[derive(Debug, Clone)] +pub enum WordPair { + // TODO: add WordsSwapped and WordPrefixSwapped case + Words { left: String, right: String }, + WordPrefix { left: String, right_prefix: String }, +} + +pub struct ProximityEdge { + pairs: Vec, + proximity: u8, +} + +pub enum ProximityGraph {} + +impl RankingRuleGraphTrait for ProximityGraph { + type EdgeDetails = ProximityEdge; + type BuildVisitedFromNode = (WordDerivations, i8); + + fn edge_details_dot_label(edge: &Self::EdgeDetails) -> String { + let ProximityEdge { pairs, proximity } = edge; + format!(", prox {proximity}, {} pairs", pairs.len()) + } + + fn compute_docids<'db_cache, 'transaction>( + index: &Index, + txn: &'transaction RoTxn, + db_cache: &mut DatabaseCache<'transaction>, + edge: &Self::EdgeDetails, + ) -> Result { + compute_docids::compute_docids(index, txn, db_cache, edge) + } + + fn build_visit_from_node<'transaction>( + _index: &Index, + _txn: &'transaction RoTxn, + _db_cache: &mut DatabaseCache<'transaction>, + from_node: &QueryNode, + ) -> Result> { + build::visit_from_node(from_node) + } + + fn build_visit_to_node<'from_data, 'transaction: 'from_data>( + index: &Index, + txn: &'transaction RoTxn, + db_cache: &mut DatabaseCache<'transaction>, + to_node: &QueryNode, + from_node_data: &'from_data Self::BuildVisitedFromNode, + ) -> Result)>>> { + build::visit_to_node(index, txn, db_cache, to_node, from_node_data) + } +}