From 7ce2691374daede46dfb2a8051ea062811f391f0 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 30 May 2024 11:21:31 +0200 Subject: [PATCH] Add ranking score threshold to similar API --- meilisearch-types/src/deserr/mod.rs | 1 + meilisearch-types/src/error.rs | 7 +++++ meilisearch/src/routes/indexes/similar.rs | 35 +++++++++++++++++------ meilisearch/src/search.rs | 25 ++++++++++++++++ 4 files changed, 59 insertions(+), 9 deletions(-) diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index 198a4e7b7..1c1b0e987 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -190,4 +190,5 @@ merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); merge_with_error_impl_take_error_message!(InvalidSearchRankingScoreThreshold); +merge_with_error_impl_take_error_message!(InvalidSimilarRankingScoreThreshold); merge_with_error_impl_take_error_message!(InvalidSimilarId); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index bf8eaba1c..150c56b9d 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -242,6 +242,7 @@ InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; +InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; @@ -515,6 +516,12 @@ impl fmt::Display for deserr_codes::InvalidSearchRankingScoreThreshold { } } +impl fmt::Display for deserr_codes::InvalidSimilarRankingScoreThreshold { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + deserr_codes::InvalidSearchRankingScoreThreshold.fmt(f) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs index da73dd63b..518fedab7 100644 --- a/meilisearch/src/routes/indexes/similar.rs +++ b/meilisearch/src/routes/indexes/similar.rs @@ -6,8 +6,8 @@ use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::{ InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId, - InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarShowRankingScore, - InvalidSimilarShowRankingScoreDetails, + InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarRankingScoreThreshold, + InvalidSimilarShowRankingScore, InvalidSimilarShowRankingScoreDetails, }; use meilisearch_types::error::{ErrorCode as _, ResponseError}; use meilisearch_types::index_uid::IndexUid; @@ -21,8 +21,8 @@ use crate::analytics::{Analytics, SimilarAggregator}; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_similar, SearchKind, SimilarQuery, SimilarResult, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_similar, RankingScoreThresholdSimilar, SearchKind, SimilarQuery, + SimilarResult, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -42,9 +42,7 @@ pub async fn similar_get( ) -> Result { let index_uid = IndexUid::try_from(index_uid.into_inner())?; - let query = params.0.try_into().map_err(|code: InvalidSimilarId| { - ResponseError::from_msg(code.to_string(), code.error_code()) - })?; + let query = params.0.try_into()?; let mut aggregate = SimilarAggregator::from_query(&query, &req); @@ -130,12 +128,27 @@ pub struct SimilarQueryGet { show_ranking_score: Param, #[deserr(default, error = DeserrQueryParamError)] show_ranking_score_details: Param, + #[deserr(default, error = DeserrQueryParamError, default)] + pub ranking_score_threshold: Option, #[deserr(default, error = DeserrQueryParamError)] pub embedder: Option, } +#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] +#[deserr(try_from(String) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)] +pub struct RankingScoreThresholdGet(RankingScoreThresholdSimilar); + +impl std::convert::TryFrom for RankingScoreThresholdGet { + type Error = InvalidSimilarRankingScoreThreshold; + + fn try_from(s: String) -> Result { + let f: f64 = s.parse().map_err(|_| InvalidSimilarRankingScoreThreshold)?; + Ok(RankingScoreThresholdGet(RankingScoreThresholdSimilar::try_from(f)?)) + } +} + impl TryFrom for SimilarQuery { - type Error = InvalidSimilarId; + type Error = ResponseError; fn try_from( SimilarQueryGet { @@ -147,6 +160,7 @@ impl TryFrom for SimilarQuery { show_ranking_score, show_ranking_score_details, embedder, + ranking_score_threshold, }: SimilarQueryGet, ) -> Result { let filter = match filter { @@ -158,7 +172,9 @@ impl TryFrom for SimilarQuery { }; Ok(SimilarQuery { - id: id.0.try_into()?, + id: id.0.try_into().map_err(|code: InvalidSimilarId| { + ResponseError::from_msg(code.to_string(), code.error_code()) + })?, offset: offset.0, limit: limit.0, filter, @@ -166,6 +182,7 @@ impl TryFrom for SimilarQuery { attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), show_ranking_score: show_ranking_score.0, show_ranking_score_details: show_ranking_score_details.0, + ranking_score_threshold: ranking_score_threshold.map(|x| x.0), }) } } diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index f4648a9d5..23f9d3f79 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -109,6 +109,24 @@ impl std::convert::TryFrom for RankingScoreThreshold { } } +#[derive(Debug, Clone, Copy, PartialEq, Deserr)] +#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)] +pub struct RankingScoreThresholdSimilar(f64); + +impl std::convert::TryFrom for RankingScoreThresholdSimilar { + type Error = InvalidSimilarRankingScoreThreshold; + + fn try_from(f: f64) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] + if f > 1.0 || f < 0.0 { + Err(InvalidSimilarRankingScoreThreshold) + } else { + Ok(Self(f)) + } + } +} + // Since this structure is logged A LOT we're going to reduce the number of things it logs to the bare minimum. // - Only what IS used, we know everything else is set to None so there is no need to print it // - Re-order the most important field to debug first @@ -464,6 +482,8 @@ pub struct SimilarQuery { pub show_ranking_score: bool, #[deserr(default, error = DeserrJsonError, default)] pub show_ranking_score_details: bool, + #[deserr(default, error = DeserrJsonError, default)] + pub ranking_score_threshold: Option, } #[derive(Debug, Clone, PartialEq, Deserr)] @@ -1102,6 +1122,7 @@ pub fn perform_similar( attributes_to_retrieve, show_ranking_score, show_ranking_score_details, + ranking_score_threshold, } = query; // using let-else rather than `?` so that the borrow checker identifies we're always returning here, @@ -1125,6 +1146,10 @@ pub fn perform_similar( } } + if let Some(ranking_score_threshold) = ranking_score_threshold { + similar.ranking_score_threshold(ranking_score_threshold.0); + } + let milli::SearchResult { documents_ids, matching_words: _,