Compute score for the ranking rules

This commit is contained in:
Louis Dureuil 2023-06-06 18:20:59 +02:00
parent 63ddea8ae4
commit 4a2a6dc529
No known key found for this signature in database
13 changed files with 140 additions and 22 deletions

View File

@ -2,6 +2,7 @@ use roaring::{MultiOps, RoaringBitmap};
use super::query_graph::QueryGraph;
use super::ranking_rules::{RankingRule, RankingRuleOutput};
use crate::score_details::{self, ScoreDetails};
use crate::search::new::query_graph::QueryNodeData;
use crate::search::new::query_term::ExactTerm;
use crate::{Result, SearchContext, SearchLogger};
@ -244,7 +245,13 @@ impl State {
candidates &= universe;
(
State::AttributeStarts(query_graph.clone(), candidates_per_attribute),
Some(RankingRuleOutput { query: query_graph, candidates }),
Some(RankingRuleOutput {
query: query_graph,
candidates,
score: ScoreDetails::ExactAttribute(
score_details::ExactAttribute::MatchesFull,
),
}),
)
}
State::AttributeStarts(query_graph, candidates_per_attribute) => {
@ -257,12 +264,24 @@ impl State {
candidates &= universe;
(
State::Empty(query_graph.clone()),
Some(RankingRuleOutput { query: query_graph, candidates }),
Some(RankingRuleOutput {
query: query_graph,
candidates,
score: ScoreDetails::ExactAttribute(
score_details::ExactAttribute::MatchesStart,
),
}),
)
}
State::Empty(query_graph) => (
State::Empty(query_graph.clone()),
Some(RankingRuleOutput { query: query_graph, candidates: universe.clone() }),
Some(RankingRuleOutput {
query: query_graph,
candidates: universe.clone(),
score: ScoreDetails::ExactAttribute(
score_details::ExactAttribute::NoExactMatch,
),
}),
),
};
(state, output)

View File

@ -8,6 +8,7 @@ use rstar::RTree;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec};
use crate::score_details::{self, ScoreDetails};
use crate::{
distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext,
SearchLogger,
@ -80,7 +81,7 @@ pub struct GeoSort<Q: RankingRuleQueryTrait> {
field_ids: Option<[u16; 2]>,
rtree: Option<RTree<GeoPoint>>,
cached_sorted_docids: VecDeque<u32>,
cached_sorted_docids: VecDeque<(u32, [f64; 2])>,
geo_candidates: RoaringBitmap,
}
@ -130,7 +131,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
let point = lat_lng_to_xyz(&self.point);
for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_back(point.data.0);
self.cached_sorted_docids.push_back(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
}
@ -142,7 +143,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
let point = lat_lng_to_xyz(&opposite_of(self.point));
for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_front(point.data.0);
self.cached_sorted_docids.push_front(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
}
@ -177,7 +178,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// computing the distance between two points is expensive thus we cache the result
documents
.sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize);
self.cached_sorted_docids.extend(documents.into_iter().map(|(doc_id, _)| doc_id));
self.cached_sorted_docids.extend(documents.into_iter());
};
Ok(())
@ -220,12 +221,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
assert!(universe.len() > 1);
let query = self.query.as_ref().unwrap().clone();
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput { query, candidates: universe.clone() }));
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: None,
}),
}));
}
let ascending = self.ascending;
@ -236,11 +244,16 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
cache.pop_back()
}
};
while let Some(id) = next(&mut self.cached_sorted_docids) {
while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if self.geo_candidates.contains(id) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([id]),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point),
}),
}));
}
}

View File

@ -50,6 +50,7 @@ use super::ranking_rule_graph::{
};
use super::small_bitmap::SmallBitmap;
use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext};
use crate::score_details::Rank;
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::ranking_rule_graph::PathVisitor;
use crate::{Result, TermsMatchingStrategy};
@ -118,6 +119,8 @@ pub struct GraphBasedRankingRuleState<G: RankingRuleGraphTrait> {
all_costs: MappedInterner<QueryNode, Vec<u64>>,
/// An index in the first element of `all_distances`, giving the cost of the next bucket
cur_cost: u64,
/// One above the highest possible cost for this rule
next_max_cost: u64,
}
impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBasedRankingRule<G> {
@ -161,12 +164,16 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
// Then pre-compute the cost of all paths from each node to the end node
let all_costs = graph.find_all_costs_to_end();
let next_max_cost =
all_costs.get(graph.query_graph.root_node).iter().copied().max().unwrap_or(0) + 1;
let state = GraphBasedRankingRuleState {
graph,
conditions_cache: condition_docids_cache,
dead_ends_cache,
all_costs,
cur_cost: 0,
next_max_cost,
};
self.state = Some(state);
@ -180,17 +187,13 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
logger: &mut dyn SearchLogger<QueryGraph>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<QueryGraph>>> {
// If universe.len() <= 1, the bucket sort algorithm
// should not have called this function.
assert!(universe.len() > 1);
// Will crash if `next_bucket` is called before `start_iteration` or after `end_iteration`,
// should never happen
let mut state = self.state.take().unwrap();
let all_costs = state.all_costs.get(state.graph.query_graph.root_node);
// Retrieve the cost of the paths to compute
let Some(&cost) = state
.all_costs
.get(state.graph.query_graph.root_node)
let Some(&cost) = all_costs
.iter()
.find(|c| **c >= state.cur_cost) else {
self.state = None;
@ -206,8 +209,12 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
dead_ends_cache,
all_costs,
cur_cost: _,
next_max_cost,
} = &mut state;
let rank = *next_max_cost - cost;
let score = G::rank_to_score(Rank { rank: rank as u32, max_rank: *next_max_cost as u32 });
let mut universe = universe.clone();
let mut used_conditions = SmallBitmap::for_interned_values_in(&graph.conditions_interner);
@ -324,7 +331,7 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase
self.state = Some(state);
Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket }))
Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket, score }))
}
fn end_iteration(

View File

@ -44,6 +44,7 @@ use self::geo_sort::GeoSort;
pub use self::geo_sort::Strategy as GeoSortStrategy;
use self::graph_based_ranking_rule::Words;
use self::interner::Interned;
use crate::score_details::ScoreDetails;
use crate::search::new::distinct::apply_distinct_rule;
use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError};

View File

@ -1,6 +1,7 @@
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::score_details::{Rank, ScoreDetails};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::{ExactTerm, LocatedQueryTermSubset};
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids;
@ -84,4 +85,8 @@ impl RankingRuleGraphTrait for ExactnessGraph {
Ok(vec![(0, exact_condition), (dest_node.term_ids.len() as u32, skip_condition)])
}
fn rank_to_score(rank: Rank) -> ScoreDetails {
ScoreDetails::Exactness(rank)
}
}

View File

@ -2,6 +2,7 @@ use fxhash::FxHashSet;
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::score_details::{Rank, ScoreDetails};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_field_id;
@ -107,4 +108,8 @@ impl RankingRuleGraphTrait for FidGraph {
Ok(edges)
}
fn rank_to_score(rank: Rank) -> ScoreDetails {
ScoreDetails::Fid(rank)
}
}

View File

@ -41,6 +41,7 @@ use super::interner::{DedupInterner, FixedSizeInterner, Interned, MappedInterner
use super::query_term::LocatedQueryTermSubset;
use super::small_bitmap::SmallBitmap;
use super::{QueryGraph, QueryNode, SearchContext};
use crate::score_details::{Rank, ScoreDetails};
use crate::Result;
pub struct ComputedCondition {
@ -110,6 +111,9 @@ pub trait RankingRuleGraphTrait: Sized + 'static {
source_node: Option<&LocatedQueryTermSubset>,
dest_node: &LocatedQueryTermSubset,
) -> Result<Vec<(u32, Interned<Self::Condition>)>>;
/// Convert the rank of a path to its corresponding score for the ranking rule
fn rank_to_score(rank: Rank) -> ScoreDetails;
}
/// The graph used by graph-based ranking rules.

View File

@ -2,6 +2,7 @@ use fxhash::{FxHashMap, FxHashSet};
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::score_details::{Rank, ScoreDetails};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_position;
@ -115,6 +116,10 @@ impl RankingRuleGraphTrait for PositionGraph {
Ok(edges)
}
fn rank_to_score(rank: Rank) -> ScoreDetails {
ScoreDetails::Position(rank)
}
}
fn cost_from_position(sum_positions: u32) -> u32 {

View File

@ -4,6 +4,7 @@ pub mod compute_docids;
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::score_details::{Rank, ScoreDetails};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::SearchContext;
@ -36,4 +37,8 @@ impl RankingRuleGraphTrait for ProximityGraph {
) -> Result<Vec<(u32, Interned<Self::Condition>)>> {
build::build_edges(ctx, conditions_interner, source_term, dest_term)
}
fn rank_to_score(rank: Rank) -> ScoreDetails {
ScoreDetails::Proximity(rank)
}
}

View File

@ -1,6 +1,7 @@
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::score_details::{self, Rank, ScoreDetails};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids;
@ -75,4 +76,8 @@ impl RankingRuleGraphTrait for TypoGraph {
}
Ok(edges)
}
fn rank_to_score(rank: Rank) -> ScoreDetails {
ScoreDetails::Typo(score_details::Typo::from_rank(rank))
}
}

View File

@ -1,6 +1,7 @@
use roaring::RoaringBitmap;
use super::{ComputedCondition, RankingRuleGraphTrait};
use crate::score_details::{self, Rank, ScoreDetails};
use crate::search::new::interner::{DedupInterner, Interned};
use crate::search::new::query_term::LocatedQueryTermSubset;
use crate::search::new::resolve_query_graph::compute_query_term_subset_docids;
@ -43,4 +44,8 @@ impl RankingRuleGraphTrait for WordsGraph {
) -> Result<Vec<(u32, Interned<Self::Condition>)>> {
Ok(vec![(0, conditions_interner.insert(WordsCondition { term: to_term.clone() }))])
}
fn rank_to_score(rank: Rank) -> ScoreDetails {
ScoreDetails::Words(score_details::Words::from_rank(rank))
}
}

View File

@ -2,6 +2,7 @@ use roaring::RoaringBitmap;
use super::logger::SearchLogger;
use super::{QueryGraph, SearchContext};
use crate::score_details::ScoreDetails;
use crate::Result;
/// An internal trait implemented by only [`PlaceholderQuery`] and [`QueryGraph`]
@ -66,4 +67,6 @@ pub struct RankingRuleOutput<Q> {
pub query: Q,
/// The allowed candidates for the child ranking rule
pub candidates: RoaringBitmap,
/// The score for the candidates of the current bucket
pub score: ScoreDetails,
}

View File

@ -1,9 +1,11 @@
use heed::BytesDecode;
use roaring::RoaringBitmap;
use super::logger::SearchLogger;
use super::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait, SearchContext};
use crate::heed_codec::facet::FacetGroupKeyCodec;
use crate::heed_codec::ByteSliceRefCodec;
use crate::heed_codec::facet::{FacetGroupKeyCodec, OrderedF64Codec};
use crate::heed_codec::{ByteSliceRefCodec, StrRefCodec};
use crate::score_details::{self, ScoreDetails};
use crate::search::facet::{ascending_facet_sort, descending_facet_sort};
use crate::{FieldId, Index, Result};
@ -118,12 +120,43 @@ impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx,
(itertools::Either::Right(number_iter), itertools::Either::Right(string_iter))
};
let number_iter = number_iter.map(|r| -> Result<_> {
let (docids, bytes) = r?;
Ok((
docids,
serde_json::Value::Number(
serde_json::Number::from_f64(
OrderedF64Codec::bytes_decode(bytes).expect("some number"),
)
.expect("too big float"),
),
))
});
let string_iter = string_iter.map(|r| -> Result<_> {
let (docids, bytes) = r?;
Ok((
docids,
serde_json::Value::String(
StrRefCodec::bytes_decode(bytes).expect("some string").to_owned(),
),
))
});
let query_graph = parent_query.clone();
let ascending = self.is_ascending;
let field_name = self.field_name.clone();
RankingRuleOutputIterWrapper::new(Box::new(number_iter.chain(string_iter).map(
move |r| {
let (docids, _) = r?;
Ok(RankingRuleOutput { query: query_graph.clone(), candidates: docids })
let (docids, value) = r?;
Ok(RankingRuleOutput {
query: query_graph.clone(),
candidates: docids,
score: ScoreDetails::Sort(score_details::Sort {
field_name: field_name.clone(),
ascending,
value,
}),
})
},
)))
}
@ -150,7 +183,15 @@ impl<'ctx, Query: RankingRuleQueryTrait> RankingRule<'ctx, Query> for Sort<'ctx,
Ok(Some(bucket))
} else {
let query = self.original_query.as_ref().unwrap().clone();
Ok(Some(RankingRuleOutput { query, candidates: universe.clone() }))
Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Sort(score_details::Sort {
field_name: self.field_name.clone(),
ascending: self.is_ascending,
value: serde_json::Value::Null,
}),
}))
}
}