From fa41d2489ee2ac9976542e5bb563cf4212693b75 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 15 Jun 2023 17:36:40 +0200 Subject: [PATCH] Score for sort --- milli/src/search/new/sort.rs | 72 +++++++++++++++++++++++++++++++++--- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/milli/src/search/new/sort.rs b/milli/src/search/new/sort.rs index 3e66ef2bc..f17aed6ed 100644 --- a/milli/src/search/new/sort.rs +++ b/milli/src/search/new/sort.rs @@ -1,9 +1,11 @@ +use heed::BytesDecode; use roaring::RoaringBitmap; use super::logger::SearchLogger; use super::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait, SearchContext}; -use crate::heed_codec::facet::FacetGroupKeyCodec; -use crate::heed_codec::ByteSliceRefCodec; +use crate::heed_codec::facet::{FacetGroupKeyCodec, OrderedF64Codec}; +use crate::heed_codec::{ByteSliceRefCodec, StrRefCodec}; +use crate::score_details::{self, ScoreDetails}; use crate::search::facet::{ascending_facet_sort, descending_facet_sort}; use crate::{FieldId, Index, Result}; @@ -49,6 +51,7 @@ pub struct Sort<'ctx, Query> { is_ascending: bool, original_query: Option, iter: Option>, + must_redact: bool, } impl<'ctx, Query> Sort<'ctx, Query> { pub fn new( @@ -59,8 +62,23 @@ impl<'ctx, Query> Sort<'ctx, Query> { ) -> Result { let fields_ids_map = index.fields_ids_map(rtxn)?; let field_id = fields_ids_map.id(&field_name); + let must_redact = Self::must_redact(index, rtxn, &field_name)?; - Ok(Self { field_name, field_id, is_ascending, original_query: None, iter: None }) + Ok(Self { + field_name, + field_id, + is_ascending, + original_query: None, + iter: None, + must_redact, + }) + } + + fn must_redact(index: &Index, rtxn: &'ctx heed::RoTxn, field_name: &str) -> Result { + let Some(displayed_fields) = index.displayed_fields(rtxn)? + else { return Ok(false); }; + + Ok(!displayed_fields.iter().any(|&field| field == field_name)) } } @@ -118,12 +136,45 @@ impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx, (itertools::Either::Right(number_iter), itertools::Either::Right(string_iter)) }; + let number_iter = number_iter.map(|r| -> Result<_> { + let (docids, bytes) = r?; + Ok(( + docids, + serde_json::Value::Number( + serde_json::Number::from_f64( + OrderedF64Codec::bytes_decode(bytes).expect("some number"), + ) + .expect("too big float"), + ), + )) + }); + let string_iter = string_iter.map(|r| -> Result<_> { + let (docids, bytes) = r?; + Ok(( + docids, + serde_json::Value::String( + StrRefCodec::bytes_decode(bytes).expect("some string").to_owned(), + ), + )) + }); let query_graph = parent_query.clone(); + let ascending = self.is_ascending; + let field_name = self.field_name.clone(); + let must_redact = self.must_redact; RankingRuleOutputIterWrapper::new(Box::new(number_iter.chain(string_iter).map( move |r| { - let (docids, _) = r?; - Ok(RankingRuleOutput { query: query_graph.clone(), candidates: docids }) + let (docids, value) = r?; + Ok(RankingRuleOutput { + query: query_graph.clone(), + candidates: docids, + score: ScoreDetails::Sort(score_details::Sort { + field_name: field_name.clone(), + ascending, + redacted: must_redact, + value, + }), + }) }, ))) } @@ -146,7 +197,16 @@ impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx, Ok(Some(bucket)) } else { let query = self.original_query.as_ref().unwrap().clone(); - Ok(Some(RankingRuleOutput { query, candidates: universe.clone() })) + Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Sort(score_details::Sort { + field_name: self.field_name.clone(), + ascending: self.is_ascending, + redacted: self.must_redact, + value: serde_json::Value::Null, + }), + })) } }