Deserialize semantic ratio

This commit is contained in:
ManyTheFish 2023-12-14 10:21:10 +01:00 committed by Louis Dureuil
parent ac68f33194
commit 93dcbf598d
No known key found for this signature in database
4 changed files with 36 additions and 12 deletions

View File

@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError);
merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskKindError);
merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(ParseTaskStatusError);
merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(IndexUidFormatError);
merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio);

View File

@ -235,7 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
@ -459,6 +459,15 @@ impl fmt::Display for DeserrParseIntError {
} }
} }
impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`."
)
}
}
#[macro_export] #[macro_export]
macro_rules! internal_error { macro_rules! internal_error {
($target:ty : $($other:path), *) => { ($target:ty : $($other:path), *) => {

View File

@ -17,7 +17,7 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
}; };
@ -75,10 +75,10 @@ pub struct SearchQueryGet {
matching_strategy: MatchingStrategy, matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)]
pub attributes_to_search_on: Option<CS<String>>, pub attributes_to_search_on: Option<CS<String>>,
#[deserr(default, error = DeserrQueryParamError<InvalidHybridQuery>)] #[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)]
pub hybrid_embedder: Option<String>, pub hybrid_embedder: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidHybridQuery>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchSemanticRatio>)]
pub hybrid_semantic_ratio: Option<f32>, pub hybrid_semantic_ratio: Option<SemanticRatio>,
} }
impl From<SearchQueryGet> for SearchQuery { impl From<SearchQueryGet> for SearchQuery {

View File

@ -36,7 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10;
pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string(); pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string();
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5; pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5);
#[derive(Debug, Clone, Default, PartialEq, Deserr)] #[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
@ -91,12 +91,27 @@ pub struct SearchQuery {
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery { pub struct HybridQuery {
/// TODO validate that sementic ratio is between 0.0 and 1,0 /// TODO validate that sementic ratio is between 0.0 and 1,0
#[deserr(default, error = DeserrJsonError<InvalidSemanticRatio>, default = DEFAULT_SEMANTIC_RATIO())] #[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)]
pub semantic_ratio: f32, pub semantic_ratio: SemanticRatio,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)] #[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>, pub embedder: Option<String>,
} }
#[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32);
impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> {
if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio)
} else {
Ok(SemanticRatio(f))
}
}
}
impl SearchQuery { impl SearchQuery {
pub fn is_finite_pagination(&self) -> bool { pub fn is_finite_pagination(&self) -> bool {
self.page.or(self.hits_per_page).is_some() self.page.or(self.hits_per_page).is_some()
@ -457,10 +472,9 @@ pub fn perform_search(
/// + < 1.0 or remove q /// + < 1.0 or remove q
/// + > 0.0 or remove vector /// + > 0.0 or remove vector
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
if query.q.is_some() && query.vector.is_some() { match query.hybrid {
search.execute_hybrid()? Some(_) => search.execute_hybrid()?,
} else { None => search.execute()?,
search.execute()?
}; };
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();