From 04f6523f3c90e16068c8b540853c24a2e19ea597 Mon Sep 17 00:00:00 2001 From: Tamo Date: Wed, 29 May 2024 17:22:58 +0200 Subject: [PATCH] expose a new parameter to retrieve the embedders at search time --- index-scheduler/src/lib.rs | 42 ++++++++++--------- meilisearch-types/src/error.rs | 2 + .../src/analytics/segment_analytics.rs | 3 ++ .../src/routes/indexes/facet_search.rs | 1 + meilisearch/src/routes/indexes/search.rs | 3 ++ meilisearch/src/routes/indexes/similar.rs | 10 ++--- meilisearch/src/search.rs | 35 +++++++++++++++- meilisearch/tests/search/hybrid.rs | 6 +-- meilisearch/tests/similar/mod.rs | 8 ++-- milli/src/vector/rest.rs | 2 + 10 files changed, 79 insertions(+), 33 deletions(-) diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index 29b7c861f..c76a207f5 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -5045,25 +5045,25 @@ mod tests { // add one doc, specifying vectors let doc = serde_json::json!( - { - "id": 0, - "doggo": "Intel", - "breed": "beagle", - "_vectors": { - &fakerest_name: { - // this will never trigger regeneration, which is good because we can't actually generate with - // this embedder - "userProvided": true, - "embeddings": beagle_embed, - }, - &simple_hf_name: { - // this will be regenerated on updates - "userProvided": false, - "embeddings": lab_embed, - }, - "noise": [0.1, 0.2, 0.3] - } - } + { + "id": 0, + "doggo": "Intel", + "breed": "beagle", + "_vectors": { + &fakerest_name: { + // this will never trigger regeneration, which is good because we can't actually generate with + // this embedder + "userProvided": true, + "embeddings": beagle_embed, + }, + &simple_hf_name: { + // this will be regenerated on updates + "userProvided": false, + "embeddings": lab_embed, + }, + "noise": [0.1, 0.2, 0.3] + } + } ); let (uuid, mut file) = index_scheduler.create_update_file_with_uuid(0u128).unwrap(); @@ -5163,7 +5163,9 @@ mod tests { snapshot!(snapshot_index_scheduler(&index_scheduler), name: "Intel to kefir"); - handle.advance_one_successful_batch(); + println!("HEEEEERE"); + // handle.advance_one_successful_batch(); + handle.advance_one_failed_batch(); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "Intel to kefir succeeds"); { diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 150c56b9d..63543fb1b 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -240,9 +240,11 @@ InvalidSearchAttributesToSearchOn , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToCrop , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; +InvalidSimilarRetrieveVectors , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; +InvalidSearchRetrieveVectors , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index aed29e612..3eb74c7d1 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -662,6 +662,7 @@ impl SearchAggregator { page, hits_per_page, attributes_to_retrieve: _, + retrieve_vectors: _, attributes_to_crop: _, crop_length, attributes_to_highlight: _, @@ -1079,6 +1080,7 @@ impl MultiSearchAggregator { page: _, hits_per_page: _, attributes_to_retrieve: _, + retrieve_vectors: _, attributes_to_crop: _, crop_length: _, attributes_to_highlight: _, @@ -1646,6 +1648,7 @@ impl SimilarAggregator { offset, limit, attributes_to_retrieve: _, + retrieve_vectors: _, show_ranking_score, show_ranking_score_details, filter, diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 10b371f2d..2e9cf6e1b 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -115,6 +115,7 @@ impl From for SearchQuery { page: None, hits_per_page: None, attributes_to_retrieve: None, + retrieve_vectors: false, attributes_to_crop: None, crop_length: DEFAULT_CROP_LENGTH(), attributes_to_highlight: None, diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 348d8295c..91c8c8178 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -51,6 +51,8 @@ pub struct SearchQueryGet { hits_per_page: Option>, #[deserr(default, error = DeserrQueryParamError)] attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrQueryParamError)] + retrieve_vectors: bool, #[deserr(default, error = DeserrQueryParamError)] attributes_to_crop: Option>, #[deserr(default = Param(DEFAULT_CROP_LENGTH()), error = DeserrQueryParamError)] @@ -153,6 +155,7 @@ impl From for SearchQuery { page: other.page.as_deref().copied(), hits_per_page: other.hits_per_page.as_deref().copied(), attributes_to_retrieve: other.attributes_to_retrieve.map(|o| o.into_iter().collect()), + retrieve_vectors: other.retrieve_vectors, attributes_to_crop: other.attributes_to_crop.map(|o| o.into_iter().collect()), crop_length: other.crop_length.0, attributes_to_highlight: other.attributes_to_highlight.map(|o| o.into_iter().collect()), diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs index 518fedab7..54ea912ec 100644 --- a/meilisearch/src/routes/indexes/similar.rs +++ b/meilisearch/src/routes/indexes/similar.rs @@ -4,11 +4,7 @@ use deserr::actix_web::{AwebJson, AwebQueryParameter}; use index_scheduler::IndexScheduler; use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; -use meilisearch_types::error::deserr_codes::{ - InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId, - InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarRankingScoreThreshold, - InvalidSimilarShowRankingScore, InvalidSimilarShowRankingScoreDetails, -}; +use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::{ErrorCode as _, ResponseError}; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::keys::actions; @@ -122,6 +118,8 @@ pub struct SimilarQueryGet { limit: Param, #[deserr(default, error = DeserrQueryParamError)] attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrQueryParamError)] + retrieve_vectors: Param, #[deserr(default, error = DeserrQueryParamError)] filter: Option, #[deserr(default, error = DeserrQueryParamError)] @@ -156,6 +154,7 @@ impl TryFrom for SimilarQuery { offset, limit, attributes_to_retrieve, + retrieve_vectors, filter, show_ranking_score, show_ranking_score_details, @@ -180,6 +179,7 @@ impl TryFrom for SimilarQuery { filter, embedder, attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), + retrieve_vectors: retrieve_vectors.0, show_ranking_score: show_ranking_score.0, show_ranking_score_details: show_ranking_score_details.0, ranking_score_threshold: ranking_score_threshold.map(|x| x.0), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 05b3c1aff..1ab42a79f 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -59,6 +59,8 @@ pub struct SearchQuery { pub hits_per_page: Option, #[deserr(default, error = DeserrJsonError)] pub attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrJsonError)] + pub retrieve_vectors: bool, #[deserr(default, error = DeserrJsonError)] pub attributes_to_crop: Option>, #[deserr(default, error = DeserrJsonError, default = DEFAULT_CROP_LENGTH())] @@ -141,6 +143,7 @@ impl fmt::Debug for SearchQuery { page, hits_per_page, attributes_to_retrieve, + retrieve_vectors, attributes_to_crop, crop_length, attributes_to_highlight, @@ -173,6 +176,9 @@ impl fmt::Debug for SearchQuery { if let Some(q) = q { debug.field("q", &q); } + if *retrieve_vectors { + debug.field("retrieve_vectors", &retrieve_vectors); + } if let Some(v) = vector { if v.len() < 10 { debug.field("vector", &v); @@ -370,6 +376,8 @@ pub struct SearchQueryWithIndex { pub hits_per_page: Option, #[deserr(default, error = DeserrJsonError)] pub attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrJsonError)] + pub retrieve_vectors: bool, #[deserr(default, error = DeserrJsonError)] pub attributes_to_crop: Option>, #[deserr(default, error = DeserrJsonError, default = DEFAULT_CROP_LENGTH())] @@ -413,6 +421,7 @@ impl SearchQueryWithIndex { page, hits_per_page, attributes_to_retrieve, + retrieve_vectors, attributes_to_crop, crop_length, attributes_to_highlight, @@ -440,6 +449,7 @@ impl SearchQueryWithIndex { page, hits_per_page, attributes_to_retrieve, + retrieve_vectors, attributes_to_crop, crop_length, attributes_to_highlight, @@ -478,6 +488,8 @@ pub struct SimilarQuery { pub embedder: Option, #[deserr(default, error = DeserrJsonError)] pub attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrJsonError)] + pub retrieve_vectors: bool, #[deserr(default, error = DeserrJsonError, default)] pub show_ranking_score: bool, #[deserr(default, error = DeserrJsonError, default)] @@ -847,6 +859,7 @@ pub fn perform_search( page, hits_per_page, attributes_to_retrieve, + retrieve_vectors, attributes_to_crop, crop_length, attributes_to_highlight, @@ -870,6 +883,7 @@ pub fn perform_search( let format = AttributesFormat { attributes_to_retrieve, + retrieve_vectors, attributes_to_highlight, attributes_to_crop, crop_length, @@ -953,6 +967,7 @@ pub fn perform_search( struct AttributesFormat { attributes_to_retrieve: Option>, + retrieve_vectors: bool, attributes_to_highlight: Option>, attributes_to_crop: Option>, crop_length: usize, @@ -1000,6 +1015,9 @@ fn make_hits( .intersection(&displayed_ids) .cloned() .collect(); + let is_vectors_displayed = + fields_ids_map.id("_vectors").is_some_and(|fid| displayed_ids.contains(&fid)); + let retrieve_vectors = format.retrieve_vectors && is_vectors_displayed; let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); let formatted_options = compute_formatted_options( @@ -1034,7 +1052,7 @@ fn make_hits( formatter_builder.highlight_suffix(format.highlight_post_tag); let mut documents = Vec::new(); let documents_iter = index.documents(rtxn, documents_ids)?; - for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { + for ((id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { // First generate a document with all the displayed fields let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; @@ -1045,6 +1063,19 @@ fn make_hits( let mut document = permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); + if retrieve_vectors { + let mut vectors = serde_json::Map::new(); + for (name, mut vector) in index.embeddings(&rtxn, id)? { + if vector.len() == 1 { + let vector = vector.pop().unwrap(); + vectors.insert(name.into(), vector.into()); + } else { + vectors.insert(name.into(), vector.into()); + } + } + document.insert("_vectors".into(), vectors.into()); + } + let (matches_position, formatted) = format_fields( &displayed_document, &fields_ids_map, @@ -1125,6 +1156,7 @@ pub fn perform_similar( filter: _, embedder: _, attributes_to_retrieve, + retrieve_vectors, show_ranking_score, show_ranking_score_details, ranking_score_threshold, @@ -1171,6 +1203,7 @@ pub fn perform_similar( let format = AttributesFormat { attributes_to_retrieve, + retrieve_vectors, attributes_to_highlight: None, attributes_to_crop: None, crop_length: DEFAULT_CROP_LENGTH(), diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index 9c50df6e1..0c8b4534c 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -124,7 +124,7 @@ async fn simple_search() { let (response, code) = index .search_post( - json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}), + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}, "retrieveVectors": true}), ) .await; snapshot!(code, @"200 OK"); @@ -133,7 +133,7 @@ async fn simple_search() { let (response, code) = index .search_post( - json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.5}, "showRankingScore": true}), + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.5}, "showRankingScore": true, "retrieveVectors": true}), ) .await; snapshot!(code, @"200 OK"); @@ -142,7 +142,7 @@ async fn simple_search() { let (response, code) = index .search_post( - json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}, "showRankingScore": true}), + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}, "showRankingScore": true, "retrieveVectors": true}), ) .await; snapshot!(code, @"200 OK"); diff --git a/meilisearch/tests/similar/mod.rs b/meilisearch/tests/similar/mod.rs index bde23b67f..a2378eb58 100644 --- a/meilisearch/tests/similar/mod.rs +++ b/meilisearch/tests/similar/mod.rs @@ -557,7 +557,7 @@ async fn limit_and_offset() { index.wait_task(value.uid()).await; index - .similar(json!({"id": 143, "limit": 1}), |response, code| { + .similar(json!({"id": 143, "limit": 1, "retrieveVectors": true}), |response, code| { snapshot!(code, @"200 OK"); snapshot!(json_string!(response["hits"]), @r###" [ @@ -567,9 +567,9 @@ async fn limit_and_offset() { "id": "522681", "_vectors": { "manual": [ - 0.1, - 0.6, - 0.8 + 0.10000000149011612, + 0.6000000238418579, + 0.800000011920929 ] } } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 60f54782e..e7fc509b3 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -163,6 +163,7 @@ impl Embedder { text_chunks: Vec>, threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { + dbg!(&text_chunks); threads .install(move || { text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() @@ -230,6 +231,7 @@ where input_value } [input] => { + dbg!(&options); let mut body = options.query.clone(); body.as_object_mut()