diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index fc13a5e1e..87f922c4c 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -169,6 +169,7 @@ impl<'a> Search<'a> { index: self.index, semantic: self.semantic.clone(), time_budget: self.time_budget.clone(), + ranking_score_threshold: self.ranking_score_threshold, }; let semantic = search.semantic.take(); diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 76068b1f2..f7bcf6e7b 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -50,6 +50,7 @@ pub struct Search<'a> { index: &'a Index, semantic: Option, time_budget: TimeBudget, + ranking_score_threshold: Option, } impl<'a> Search<'a> { @@ -70,6 +71,7 @@ impl<'a> Search<'a> { index, semantic: None, time_budget: TimeBudget::max(), + ranking_score_threshold: None, } } @@ -146,6 +148,14 @@ impl<'a> Search<'a> { self } + pub fn ranking_score_threshold( + &mut self, + ranking_score_threshold: Option, + ) -> &mut Search<'a> { + self.ranking_score_threshold = ranking_score_threshold; + self + } + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { if has_vector_search { let ctx = SearchContext::new(self.index, self.rtxn)?; @@ -184,6 +194,7 @@ impl<'a> Search<'a> { embedder_name, embedder, self.time_budget.clone(), + self.ranking_score_threshold, )? } _ => execute_search( @@ -201,6 +212,7 @@ impl<'a> Search<'a> { &mut DefaultSearchLogger, &mut DefaultSearchLogger, self.time_budget.clone(), + self.ranking_score_threshold, )?, }; @@ -239,6 +251,7 @@ impl fmt::Debug for Search<'_> { index: _, semantic, time_budget, + ranking_score_threshold, } = self; f.debug_struct("Search") .field("query", query) @@ -257,6 +270,7 @@ impl fmt::Debug for Search<'_> { &semantic.as_ref().map(|semantic| &semantic.embedder_name), ) .field("time_budget", time_budget) + .field("ranking_score_threshold", ranking_score_threshold) .finish() } } diff --git a/milli/src/search/new/bucket_sort.rs b/milli/src/search/new/bucket_sort.rs index e9bc5449d..b15e735d0 100644 --- a/milli/src/search/new/bucket_sort.rs +++ b/milli/src/search/new/bucket_sort.rs @@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( scoring_strategy: ScoringStrategy, logger: &mut dyn SearchLogger, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { logger.initial_query(query); logger.ranking_rules(&ranking_rules); @@ -144,6 +145,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( ctx, from, length, + ranking_score_threshold, logger, &mut valid_docids, &mut valid_scores, @@ -164,7 +166,9 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( loop { let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]); ranking_rule_scores.push(ScoreDetails::Skipped); + maybe_add_to_results!(bucket); + ranking_rule_scores.pop(); if cur_ranking_rule_index == 0 { @@ -220,6 +224,17 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( debug_assert!( ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates) ); + + if let Some(ranking_score_threshold) = ranking_score_threshold { + let current_score = ScoreDetails::global_score(ranking_rule_scores.iter()); + if current_score < ranking_score_threshold { + all_candidates -= + next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index]; + back!(); + continue; + } + } + ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates; if cur_ranking_rule_index == ranking_rules_len - 1 @@ -262,6 +277,7 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>( ctx: &mut SearchContext<'ctx>, from: usize, length: usize, + ranking_score_threshold: Option, logger: &mut dyn SearchLogger, valid_docids: &mut Vec, @@ -279,6 +295,15 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>( ranking_rule_scores: &[ScoreDetails], candidates: RoaringBitmap, ) -> Result<()> { + // remove candidates from the universe without adding them to result if their score is below the threshold + if let Some(ranking_score_threshold) = ranking_score_threshold { + let score = ScoreDetails::global_score(ranking_rule_scores.iter()); + if score < ranking_score_threshold { + *all_candidates -= candidates | &ranking_rule_universes[cur_ranking_rule_index]; + return Ok(()); + } + } + // First apply the distinct rule on the candidates, reducing the universes if necessary let candidates = if let Some(distinct_fid) = distinct_fid { let DistinctOutput { remaining, excluded } = diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index f121971b8..87ddb2915 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -523,6 +523,7 @@ mod tests { &mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger, TimeBudget::max(), + None, ) .unwrap(); diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index e152dd233..bbeab31fd 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -568,6 +568,7 @@ pub fn execute_vector_search( embedder_name: &str, embedder: &Embedder, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -597,6 +598,7 @@ pub fn execute_vector_search( scoring_strategy, placeholder_search_logger, time_budget, + ranking_score_threshold, )?; Ok(PartialSearchResult { @@ -626,6 +628,7 @@ pub fn execute_search( placeholder_search_logger: &mut dyn SearchLogger, query_graph_logger: &mut dyn SearchLogger, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -714,6 +717,7 @@ pub fn execute_search( scoring_strategy, query_graph_logger, time_budget, + ranking_score_threshold, )? } else { let ranking_rules = @@ -728,6 +732,7 @@ pub fn execute_search( scoring_strategy, placeholder_search_logger, time_budget, + ranking_score_threshold, )? };