Store the scores for each bucket

This commit is contained in:
Louis Dureuil 2023-06-06 18:25:25 +02:00
parent 4a2a6dc529
commit 16898c661e
No known key found for this signature in database
2 changed files with 45 additions and 10 deletions

View File

@ -3,11 +3,13 @@ use roaring::RoaringBitmap;
use super::logger::SearchLogger; use super::logger::SearchLogger;
use super::ranking_rules::{BoxRankingRule, RankingRuleQueryTrait}; use super::ranking_rules::{BoxRankingRule, RankingRuleQueryTrait};
use super::SearchContext; use super::SearchContext;
use crate::score_details::ScoreDetails;
use crate::search::new::distinct::{apply_distinct_rule, distinct_single_docid, DistinctOutput}; use crate::search::new::distinct::{apply_distinct_rule, distinct_single_docid, DistinctOutput};
use crate::Result; use crate::Result;
pub struct BucketSortOutput { pub struct BucketSortOutput {
pub docids: Vec<u32>, pub docids: Vec<u32>,
pub scores: Vec<Vec<ScoreDetails>>,
pub all_candidates: RoaringBitmap, pub all_candidates: RoaringBitmap,
} }
@ -31,7 +33,11 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
}; };
if universe.len() < from as u64 { if universe.len() < from as u64 {
return Ok(BucketSortOutput { docids: vec![], all_candidates: universe.clone() }); return Ok(BucketSortOutput {
docids: vec![],
scores: vec![],
all_candidates: universe.clone(),
});
} }
if ranking_rules.is_empty() { if ranking_rules.is_empty() {
if let Some(distinct_fid) = distinct_fid { if let Some(distinct_fid) = distinct_fid {
@ -49,22 +55,32 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
} }
let mut all_candidates = universe - excluded; let mut all_candidates = universe - excluded;
all_candidates.extend(results.iter().copied()); all_candidates.extend(results.iter().copied());
return Ok(BucketSortOutput { docids: results, all_candidates }); return Ok(BucketSortOutput {
scores: vec![Default::default(); results.len()],
docids: results,
all_candidates,
});
} else { } else {
let docids = universe.iter().skip(from).take(length).collect(); let docids: Vec<u32> = universe.iter().skip(from).take(length).collect();
return Ok(BucketSortOutput { docids, all_candidates: universe.clone() }); return Ok(BucketSortOutput {
scores: vec![Default::default(); docids.len()],
docids,
all_candidates: universe.clone(),
});
}; };
} }
let ranking_rules_len = ranking_rules.len(); let ranking_rules_len = ranking_rules.len();
logger.start_iteration_ranking_rule(0, ranking_rules[0].as_ref(), query, universe); logger.start_iteration_ranking_rule(0, ranking_rules[0].as_ref(), query, universe);
ranking_rules[0].start_iteration(ctx, logger, universe, query)?; ranking_rules[0].start_iteration(ctx, logger, universe, query)?;
let mut ranking_rule_scores: Vec<ScoreDetails> = vec![];
let mut ranking_rule_universes: Vec<RoaringBitmap> = let mut ranking_rule_universes: Vec<RoaringBitmap> =
vec![RoaringBitmap::default(); ranking_rules_len]; vec![RoaringBitmap::default(); ranking_rules_len];
ranking_rule_universes[0] = universe.clone(); ranking_rule_universes[0] = universe.clone();
let mut cur_ranking_rule_index = 0; let mut cur_ranking_rule_index = 0;
/// Finish iterating over the current ranking rule, yielding /// Finish iterating over the current ranking rule, yielding
@ -89,11 +105,16 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
} else { } else {
cur_ranking_rule_index -= 1; cur_ranking_rule_index -= 1;
} }
// FIXME: check off by one
if ranking_rule_scores.len() > cur_ranking_rule_index {
ranking_rule_scores.pop();
}
}; };
} }
let mut all_candidates = universe.clone(); let mut all_candidates = universe.clone();
let mut valid_docids = vec![]; let mut valid_docids = vec![];
let mut valid_scores = vec![];
let mut cur_offset = 0usize; let mut cur_offset = 0usize;
macro_rules! maybe_add_to_results { macro_rules! maybe_add_to_results {
@ -130,6 +151,8 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
continue; continue;
}; };
ranking_rule_scores.push(next_bucket.score);
logger.next_bucket_ranking_rule( logger.next_bucket_ranking_rule(
cur_ranking_rule_index, cur_ranking_rule_index,
ranking_rules[cur_ranking_rule_index].as_ref(), ranking_rules[cur_ranking_rule_index].as_ref(),
@ -146,6 +169,8 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|| cur_offset + (next_bucket.candidates.len() as usize) < from || cur_offset + (next_bucket.candidates.len() as usize) < from
{ {
maybe_add_to_results!(next_bucket.candidates); maybe_add_to_results!(next_bucket.candidates);
// FIXME: use index based logic like all the other rules so that you don't have to maintain the pop/push?
ranking_rule_scores.pop();
continue; continue;
} }
@ -165,7 +190,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
)?; )?;
} }
Ok(BucketSortOutput { docids: valid_docids, all_candidates }) Ok(BucketSortOutput { docids: valid_docids, scores: valid_scores, all_candidates })
} }
/// Add the candidates to the results. Take `distinct`, `from`, `length`, and `cur_offset` /// Add the candidates to the results. Take `distinct`, `from`, `length`, and `cur_offset`
@ -178,14 +203,18 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
logger: &mut dyn SearchLogger<Q>, logger: &mut dyn SearchLogger<Q>,
valid_docids: &mut Vec<u32>, valid_docids: &mut Vec<u32>,
valid_scores: &mut Vec<Vec<ScoreDetails>>,
all_candidates: &mut RoaringBitmap, all_candidates: &mut RoaringBitmap,
ranking_rule_universes: &mut [RoaringBitmap], ranking_rule_universes: &mut [RoaringBitmap],
ranking_rules: &mut [BoxRankingRule<'ctx, Q>], ranking_rules: &mut [BoxRankingRule<'ctx, Q>],
cur_ranking_rule_index: usize, cur_ranking_rule_index: usize,
cur_offset: &mut usize, cur_offset: &mut usize,
distinct_fid: Option<u16>, distinct_fid: Option<u16>,
ranking_rule_scores: &[ScoreDetails],
candidates: RoaringBitmap, candidates: RoaringBitmap,
) -> Result<()> { ) -> Result<()> {
// First apply the distinct rule on the candidates, reducing the universes if necessary // First apply the distinct rule on the candidates, reducing the universes if necessary
@ -230,13 +259,17 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
let candidates = let candidates =
candidates.iter().take(length - valid_docids.len()).copied().collect::<Vec<_>>(); candidates.iter().take(length - valid_docids.len()).copied().collect::<Vec<_>>();
logger.add_to_results(&candidates); logger.add_to_results(&candidates);
valid_docids.extend(&candidates); valid_docids.extend_from_slice(&candidates);
valid_scores
.extend(std::iter::repeat(ranking_rule_scores.to_owned()).take(candidates.len()));
} }
} else { } else {
// if we have passed the offset already, add some of the documents (up to the limit) // if we have passed the offset already, add some of the documents (up to the limit)
let candidates = candidates.iter().take(length - valid_docids.len()).collect::<Vec<u32>>(); let candidates = candidates.iter().take(length - valid_docids.len()).collect::<Vec<u32>>();
logger.add_to_results(&candidates); logger.add_to_results(&candidates);
valid_docids.extend(&candidates); valid_docids.extend_from_slice(&candidates);
valid_scores
.extend(std::iter::repeat(ranking_rule_scores.to_owned()).take(candidates.len()));
} }
*cur_offset += candidates.len() as usize; *cur_offset += candidates.len() as usize;

View File

@ -427,13 +427,15 @@ pub fn execute_search(
)? )?
}; };
let BucketSortOutput { docids, mut all_candidates } = bucket_sort_output; let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output;
let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?;
// The candidates is the universe unless the exhaustive number of hits // The candidates is the universe unless the exhaustive number of hits
// is requested and a distinct attribute is set. // is requested and a distinct attribute is set.
if exhaustive_number_hits { if exhaustive_number_hits {
if let Some(f) = ctx.index.distinct_field(ctx.txn)? { if let Some(f) = ctx.index.distinct_field(ctx.txn)? {
if let Some(distinct_fid) = ctx.index.fields_ids_map(ctx.txn)?.id(f) { if let Some(distinct_fid) = fields_ids_map.id(f) {
all_candidates = apply_distinct_rule(ctx, distinct_fid, &all_candidates)?.remaining; all_candidates = apply_distinct_rule(ctx, distinct_fid, &all_candidates)?.remaining;
} }
} }