From 3647c9d226c30d73dd55f470ecbe1aca97735284 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 29 Jan 2025 11:07:33 +0100 Subject: [PATCH] Add WeightedScoreValues to be able to compare remote scores --- crates/milli/src/score_details.rs | 105 +++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 24 deletions(-) diff --git a/crates/milli/src/score_details.rs b/crates/milli/src/score_details.rs index 1efa3b8e6..940e5f395 100644 --- a/crates/milli/src/score_details.rs +++ b/crates/milli/src/score_details.rs @@ -1,7 +1,7 @@ use std::cmp::Ordering; use itertools::Itertools; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::distance_between_two_points; @@ -36,6 +36,15 @@ enum RankOrValue<'a> { Score(f64), } +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum WeightedScoreValue { + WeightedScore(f64), + Sort { asc: bool, value: serde_json::Value }, + GeoSort { asc: bool, distance: Option }, + VectorSort(f64), +} + impl ScoreDetails { pub fn local_score(&self) -> Option { self.rank().map(Rank::local_score) @@ -87,6 +96,30 @@ impl ScoreDetails { }) } + pub fn weighted_score_values<'a>( + details: impl Iterator + 'a, + weight: f64, + ) -> impl Iterator + 'a { + details + .map(ScoreDetails::rank_or_value) + .coalesce(|left, right| match (left, right) { + (RankOrValue::Rank(left), RankOrValue::Rank(right)) => { + Ok(RankOrValue::Rank(Rank::merge(left, right))) + } + (left, right) => Err((left, right)), + }) + .map(move |rank_or_value| match rank_or_value { + RankOrValue::Rank(r) => WeightedScoreValue::WeightedScore(r.local_score() * weight), + RankOrValue::Sort(s) => { + WeightedScoreValue::Sort { asc: s.ascending, value: s.value.clone() } + } + RankOrValue::GeoSort(g) => { + WeightedScoreValue::GeoSort { asc: g.ascending, distance: g.distance() } + } + RankOrValue::Score(s) => WeightedScoreValue::VectorSort(s * weight), + }) + } + fn rank_or_value(&self) -> RankOrValue<'_> { match self { ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()), @@ -423,34 +456,58 @@ pub struct Sort { pub value: serde_json::Value, } +pub fn compare_sort_values( + ascending: bool, + left: &serde_json::Value, + right: &serde_json::Value, +) -> Ordering { + use serde_json::Value::*; + match (left, right) { + (Null, Null) => Ordering::Equal, + (Null, _) => Ordering::Less, + (_, Null) => Ordering::Greater, + // numbers are always before strings + (Number(_), String(_)) => Ordering::Greater, + (String(_), Number(_)) => Ordering::Less, + (Number(left), Number(right)) => { + // FIXME: unwrap permitted here? + let order = left + .as_f64() + .unwrap() + .partial_cmp(&right.as_f64().unwrap()) + .unwrap_or(Ordering::Equal); + // 12 < 42, and when ascending, we want to see 12 first, so the smallest. + // Hence, when ascending, smaller is better + if ascending { + order.reverse() + } else { + order + } + } + (String(left), String(right)) => { + let order = left.cmp(right); + // Taking e.g. "a" and "z" + // "a" < "z", and when ascending, we want to see "a" first, so the smallest. + // Hence, when ascending, smaller is better + if ascending { + order.reverse() + } else { + order + } + } + (left, right) => { + tracing::warn!(%left, %right, "sort values that are neither numbers, strings or null, handling as equal"); + Ordering::Equal + } + } +} + impl PartialOrd for Sort { fn partial_cmp(&self, other: &Self) -> Option { if self.ascending != other.ascending { return None; } - match (&self.value, &other.value) { - (serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal), - (serde_json::Value::Null, _) => Some(Ordering::Less), - (_, serde_json::Value::Null) => Some(Ordering::Greater), - // numbers are always before strings - (serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater), - (serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less), - (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { - // FIXME: unwrap permitted here? - let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?; - // 12 < 42, and when ascending, we want to see 12 first, so the smallest. - // Hence, when ascending, smaller is better - Some(if self.ascending { order.reverse() } else { order }) - } - (serde_json::Value::String(left), serde_json::Value::String(right)) => { - let order = left.cmp(right); - // Taking e.g. "a" and "z" - // "a" < "z", and when ascending, we want to see "a" first, so the smallest. - // Hence, when ascending, smaller is better - Some(if self.ascending { order.reverse() } else { order }) - } - _ => None, - } + Some(compare_sort_values(self.ascending, &self.value, &other.value)) } }