From b2b413db12d2db1bb57c704c16dc9d7d9ae5f325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:31:23 +0200 Subject: [PATCH] Return all the _semanticScore values in the documents --- meilisearch/src/search.rs | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index a8c6765bc..346c9b1ec 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -228,6 +228,8 @@ pub struct SearchHit { pub ranking_score: Option, #[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")] pub ranking_score_details: Option>, + #[serde(rename = "_semanticScore", skip_serializing_if = "Option::is_none")] + pub semantic_score: Option, } #[derive(Serialize, Debug, Clone, PartialEq)] @@ -462,11 +464,13 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - if let Some(vector) = query.vector.as_ref() { - if let Some(vectors) = extract_field("_vectors", &fields_ids_map, obkv)? { - insert_semantic_score(vector, vectors, &mut document); - } - } + let semantic_score = match query.vector.as_ref() { + Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { + Some(vectors) => compute_semantic_score(vector, vectors)?, + None => None, + }, + None => None, + }; let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -479,6 +483,7 @@ pub fn perform_search( matches_position, ranking_score_details, ranking_score, + semantic_score, }; documents.push(hit); } @@ -553,18 +558,15 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn insert_semantic_score(query: &[f32], vectors: Value, document: &mut Document) { - let vectors = - match serde_json::from_value(vectors).map(VectorOrArrayOfVectors::into_array_of_vectors) { - Ok(vectors) => vectors, - Err(_) => return, - }; - let similarity = vectors +fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result> { + let vectors = serde_json::from_value(vectors) + .map(VectorOrArrayOfVectors::into_array_of_vectors) + .map_err(InternalError::SerdeJson)?; + Ok(vectors .into_iter() .map(|v| OrderedFloat(dot_product_similarity(query, &v))) .max() - .map(OrderedFloat::into_inner); - document.insert("_semanticScore".to_string(), json!(similarity)); + .map(OrderedFloat::into_inner)) } fn compute_formatted_options(