From abdc4afcca564b5227c6ba3735fac3e5ea82ec48 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Wed, 29 May 2024 11:06:39 +0200 Subject: [PATCH] Implement Frequency matching strategy --- milli/src/search/mod.rs | 2 + .../search/new/graph_based_ranking_rule.rs | 15 +++++ milli/src/search/new/mod.rs | 5 ++ milli/src/search/new/query_graph.rs | 56 ++++++++++++++++++- milli/tests/search/mod.rs | 1 + 5 files changed, 78 insertions(+), 1 deletion(-) diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index ca0eda49e..c85b80d2f 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -276,6 +276,8 @@ pub enum TermsMatchingStrategy { Last, // all words are mandatory All, + // remove more frequent word first + Frequency, } impl Default for TermsMatchingStrategy { diff --git a/milli/src/search/new/graph_based_ranking_rule.rs b/milli/src/search/new/graph_based_ranking_rule.rs index 3136eb190..b066f82bd 100644 --- a/milli/src/search/new/graph_based_ranking_rule.rs +++ b/milli/src/search/new/graph_based_ranking_rule.rs @@ -164,6 +164,21 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase } costs } + TermsMatchingStrategy::Frequency => { + let removal_order = + query_graph.removal_order_for_terms_matching_strategy_frequency(ctx)?; + let mut forbidden_nodes = + SmallBitmap::for_interned_values_in(&query_graph.nodes); + let mut costs = query_graph.nodes.map(|_| None); + // FIXME: this works because only words uses termsmatchingstrategy at the moment. + for ns in removal_order { + for n in ns.iter() { + *costs.get_mut(n) = Some((1, forbidden_nodes.clone())); + } + forbidden_nodes.union(&ns); + } + costs + } TermsMatchingStrategy::All => query_graph.nodes.map(|_| None), } } else { diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 5e4c2f829..f178b03cf 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -197,6 +197,11 @@ fn resolve_maximally_reduced_query_graph( .iter() .flat_map(|x| x.iter()) .collect(), + TermsMatchingStrategy::Frequency => query_graph + .removal_order_for_terms_matching_strategy_frequency(ctx)? + .iter() + .flat_map(|x| x.iter()) + .collect(), TermsMatchingStrategy::All => vec![], }; graph.remove_nodes_keep_edges(&nodes_to_remove); diff --git a/milli/src/search/new/query_graph.rs b/milli/src/search/new/query_graph.rs index d34d0afb5..9cbe55aff 100644 --- a/milli/src/search/new/query_graph.rs +++ b/milli/src/search/new/query_graph.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; use fxhash::{FxHashMap, FxHasher}; +use roaring::RoaringBitmap; use super::interner::{FixedSizeInterner, Interned}; use super::query_term::{ @@ -11,6 +12,7 @@ use super::query_term::{ use super::small_bitmap::SmallBitmap; use super::SearchContext; use crate::search::new::interner::Interner; +use crate::search::new::resolve_query_graph::compute_query_term_subset_docids; use crate::Result; /// A node of the [`QueryGraph`]. @@ -290,6 +292,49 @@ impl QueryGraph { } } + pub fn removal_order_for_terms_matching_strategy_frequency( + &self, + ctx: &mut SearchContext, + ) -> Result>> { + // lookup frequency for each term + let mut term_with_frequency: Vec<(u8, u64)> = { + let mut term_docids: BTreeMap = Default::default(); + for (_, node) in self.nodes.iter() { + match &node.data { + QueryNodeData::Term(t) => { + let docids = compute_query_term_subset_docids(ctx, &t.term_subset)?; + for id in t.term_ids.clone() { + term_docids + .entry(id) + .and_modify(|curr| *curr |= &docids) + .or_insert_with(|| docids.clone()); + } + } + QueryNodeData::Deleted | QueryNodeData::Start | QueryNodeData::End => continue, + } + } + term_docids + .into_iter() + .map(|(idx, docids)| match docids.len() { + 0 => (idx, u64::max_value()), + frequency => (idx, frequency), + }) + .collect() + }; + term_with_frequency.sort_by_key(|(_, frequency)| *frequency); + let mut term_weight = BTreeMap::new(); + let mut weight: u16 = 1; + let mut peekable = term_with_frequency.into_iter().peekable(); + while let Some((idx, frequency)) = peekable.next() { + term_weight.insert(idx, weight); + if peekable.peek().map_or(false, |(_, f)| frequency < *f) { + weight += 1; + } + } + let cost_of_term_idx = move |term_idx: u8| *term_weight.get(&term_idx).unwrap(); + Ok(self.removal_order_for_terms_matching_strategy(ctx, cost_of_term_idx)) + } + pub fn removal_order_for_terms_matching_strategy_last( &self, ctx: &SearchContext, @@ -315,10 +360,19 @@ impl QueryGraph { if first_term_idx >= last_term_idx { return vec![]; } + let cost_of_term_idx = |term_idx: u8| { let rank = 1 + last_term_idx - term_idx; rank as u16 }; + self.removal_order_for_terms_matching_strategy(ctx, cost_of_term_idx) + } + + pub fn removal_order_for_terms_matching_strategy( + &self, + ctx: &SearchContext, + order: impl Fn(u8) -> u16, + ) -> Vec> { let mut nodes_to_remove = BTreeMap::>::new(); let mut at_least_one_mandatory_term = false; for (node_id, node) in self.nodes.iter() { @@ -329,7 +383,7 @@ impl QueryGraph { } let mut cost = 0; for id in t.term_ids.clone() { - cost = std::cmp::max(cost, cost_of_term_idx(id)); + cost = std::cmp::max(cost, order(id)); } nodes_to_remove .entry(cost) diff --git a/milli/tests/search/mod.rs b/milli/tests/search/mod.rs index 9193ab762..310780e03 100644 --- a/milli/tests/search/mod.rs +++ b/milli/tests/search/mod.rs @@ -159,6 +159,7 @@ pub fn expected_order( match optional_words { TermsMatchingStrategy::Last => groups.into_iter().flatten().collect(), + TermsMatchingStrategy::Frequency => groups.into_iter().flatten().collect(), TermsMatchingStrategy::All => { groups.into_iter().flatten().filter(|d| d.word_rank == 0).collect() }