From b9b938c902b68c125786f56ddbc7b90087a332c3 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 13 Jun 2024 17:13:36 +0200 Subject: [PATCH] Change `retrieveVectors` behavior: - when the feature is disabled, documents are never modified - when the feature is enabled and `retrieveVectors` is disabled, `_vectors` is removed from documents - when the feature is enabled and `retrieveVectors` is enabled, vectors from the vectors DB are merged with `_vectors` in documents Additionally `_vectors` is never displayed when the `displayedAttributes` list does not contain either `*` or `_vectors` - fixed an issue where `_vectors` was not injected when all vectors in the dataset where always generated --- meilisearch/src/routes/indexes/documents.rs | 83 +++++++++---------- meilisearch/src/routes/indexes/search.rs | 24 +++--- meilisearch/src/routes/indexes/similar.rs | 12 ++- meilisearch/src/routes/multi_search.rs | 13 +-- meilisearch/src/search.rs | 92 ++++++++++++++++++--- 5 files changed, 150 insertions(+), 74 deletions(-) diff --git a/meilisearch/src/routes/indexes/documents.rs b/meilisearch/src/routes/indexes/documents.rs index bfbe20207..1f413ec7d 100644 --- a/meilisearch/src/routes/indexes/documents.rs +++ b/meilisearch/src/routes/indexes/documents.rs @@ -40,7 +40,7 @@ use crate::extractors::sequential_extractor::SeqHandler; use crate::routes::{ get_task_id, is_dry_run, PaginationView, SummarizedTaskView, PAGINATION_DEFAULT_LIMIT, }; -use crate::search::parse_filter; +use crate::search::{parse_filter, RetrieveVectors}; use crate::Opt; static ACCEPTED_CONTENT_TYPE: Lazy> = Lazy::new(|| { @@ -110,21 +110,20 @@ pub async fn get_document( debug!(parameters = ?params, "Get document"); let index_uid = IndexUid::try_from(index_uid)?; - let GetDocument { fields, retrieve_vectors } = params.into_inner(); + let GetDocument { fields, retrieve_vectors: param_retrieve_vectors } = params.into_inner(); let attributes_to_retrieve = fields.merge_star_and_none(); let features = index_scheduler.features(); - if retrieve_vectors.0 { - features.check_vector("Passing `retrieveVectors` as a parameter")?; - } + let retrieve_vectors = RetrieveVectors::new(param_retrieve_vectors.0, features)?; + analytics.get_fetch_documents( - &DocumentFetchKind::PerDocumentId { retrieve_vectors: retrieve_vectors.0 }, + &DocumentFetchKind::PerDocumentId { retrieve_vectors: param_retrieve_vectors.0 }, &req, ); let index = index_scheduler.index(&index_uid)?; let document = - retrieve_document(&index, &document_id, attributes_to_retrieve, retrieve_vectors.0)?; + retrieve_document(&index, &document_id, attributes_to_retrieve, retrieve_vectors)?; debug!(returns = ?document, "Get document"); Ok(HttpResponse::Ok().json(document)) } @@ -195,11 +194,6 @@ pub async fn documents_by_query_post( let body = body.into_inner(); debug!(parameters = ?body, "Get documents POST"); - let features = index_scheduler.features(); - if body.retrieve_vectors { - features.check_vector("Passing `retrieveVectors` as a parameter")?; - } - analytics.post_fetch_documents( &DocumentFetchKind::Normal { with_filter: body.filter.is_some(), @@ -224,11 +218,6 @@ pub async fn get_documents( let BrowseQueryGet { limit, offset, fields, retrieve_vectors, filter } = params.into_inner(); - let features = index_scheduler.features(); - if retrieve_vectors.0 { - features.check_vector("Passing `retrieveVectors` as a parameter")?; - } - let filter = match filter { Some(f) => match serde_json::from_str(&f) { Ok(v) => Some(v), @@ -266,6 +255,9 @@ fn documents_by_query( let index_uid = IndexUid::try_from(index_uid.into_inner())?; let BrowseQuery { offset, limit, fields, retrieve_vectors, filter } = query; + let features = index_scheduler.features(); + let retrieve_vectors = RetrieveVectors::new(retrieve_vectors, features)?; + let index = index_scheduler.index(&index_uid)?; let (total, documents) = retrieve_documents(&index, offset, limit, filter, fields, retrieve_vectors)?; @@ -608,7 +600,7 @@ fn some_documents<'a, 't: 'a>( index: &'a Index, rtxn: &'t RoTxn, doc_ids: impl IntoIterator + 'a, - retrieve_vectors: bool, + retrieve_vectors: RetrieveVectors, ) -> Result> + 'a, ResponseError> { let fields_ids_map = index.fields_ids_map(rtxn)?; let all_fields: Vec<_> = fields_ids_map.iter().map(|(id, _)| id).collect(); @@ -617,24 +609,32 @@ fn some_documents<'a, 't: 'a>( Ok(index.iter_documents(rtxn, doc_ids)?.map(move |ret| { ret.map_err(ResponseError::from).and_then(|(key, document)| -> Result<_, ResponseError> { let mut document = milli::obkv_to_json(&all_fields, &fields_ids_map, document)?; - - if retrieve_vectors { - let mut vectors = serde_json::Map::new(); - for (name, vector) in index.embeddings(rtxn, key)? { - let user_provided = embedding_configs - .iter() - .find(|conf| conf.name == name) - .is_some_and(|conf| conf.user_provided.contains(key)); - let embeddings = ExplicitVectors { - embeddings: Some(vector.into()), - regenerate: !user_provided, - }; - vectors.insert( - name, - serde_json::to_value(embeddings).map_err(MeilisearchHttpError::from)?, - ); + match retrieve_vectors { + RetrieveVectors::Ignore => {} + RetrieveVectors::Hide => { + document.remove("_vectors"); + } + RetrieveVectors::Retrieve => { + let mut vectors = match document.remove("_vectors") { + Some(Value::Object(map)) => map, + _ => Default::default(), + }; + for (name, vector) in index.embeddings(rtxn, key)? { + let user_provided = embedding_configs + .iter() + .find(|conf| conf.name == name) + .is_some_and(|conf| conf.user_provided.contains(key)); + let embeddings = ExplicitVectors { + embeddings: Some(vector.into()), + regenerate: !user_provided, + }; + vectors.insert( + name, + serde_json::to_value(embeddings).map_err(MeilisearchHttpError::from)?, + ); + } + document.insert("_vectors".into(), vectors.into()); } - document.insert("_vectors".into(), vectors.into()); } Ok(document) @@ -648,7 +648,7 @@ fn retrieve_documents>( limit: usize, filter: Option, attributes_to_retrieve: Option>, - retrieve_vectors: bool, + retrieve_vectors: RetrieveVectors, ) -> Result<(u64, Vec), ResponseError> { let rtxn = index.read_txn()?; let filter = &filter; @@ -688,10 +688,9 @@ fn retrieve_documents>( Ok(match &attributes_to_retrieve { Some(attributes_to_retrieve) => permissive_json_pointer::select_values( &document?, - attributes_to_retrieve - .iter() - .map(|s| s.as_ref()) - .chain(retrieve_vectors.then_some("_vectors")), + attributes_to_retrieve.iter().map(|s| s.as_ref()).chain( + (retrieve_vectors == RetrieveVectors::Retrieve).then_some("_vectors"), + ), ), None => document?, }) @@ -705,7 +704,7 @@ fn retrieve_document>( index: &Index, doc_id: &str, attributes_to_retrieve: Option>, - retrieve_vectors: bool, + retrieve_vectors: RetrieveVectors, ) -> Result { let txn = index.read_txn()?; @@ -724,7 +723,7 @@ fn retrieve_document>( attributes_to_retrieve .iter() .map(|s| s.as_ref()) - .chain(retrieve_vectors.then_some("_vectors")), + .chain((retrieve_vectors == RetrieveVectors::Retrieve).then_some("_vectors")), ), None => document, }; diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 6fdff4568..421cf2940 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -20,9 +20,9 @@ use crate::extractors::sequential_extractor::SeqHandler; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::search::{ add_search_rules, perform_search, HybridQuery, MatchingStrategy, RankingScoreThreshold, - SearchKind, SearchQuery, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, - DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, - DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, + RetrieveVectors, SearchKind, SearchQuery, SemanticRatio, DEFAULT_CROP_LENGTH, + DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, + DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; use crate::search_queue::SearchQueue; @@ -225,10 +225,12 @@ pub async fn search_with_url_query( let features = index_scheduler.features(); let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features)?; - + let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors, features)?; let _permit = search_queue.try_get_search_permit().await?; - let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, search_kind)).await?; + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, search_kind, retrieve_vector) + }) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -265,10 +267,13 @@ pub async fn search_with_post( let features = index_scheduler.features(); let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features)?; + let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors, features)?; let _permit = search_queue.try_get_search_permit().await?; - let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, search_kind)).await?; + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, search_kind, retrieve_vectors) + }) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); if search_result.degraded { @@ -295,9 +300,6 @@ pub fn search_kind( if query.hybrid.is_some() { features.check_vector("Passing `hybrid` as a parameter")?; } - if query.retrieve_vectors { - features.check_vector("Passing `retrieveVectors` as a parameter")?; - } // regardless of anything, always do a keyword search when we don't have a vector and the query is whitespace or missing if query.vector.is_none() { diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs index 54ea912ec..1dd83b09b 100644 --- a/meilisearch/src/routes/indexes/similar.rs +++ b/meilisearch/src/routes/indexes/similar.rs @@ -17,8 +17,8 @@ use crate::analytics::{Analytics, SimilarAggregator}; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_similar, RankingScoreThresholdSimilar, SearchKind, SimilarQuery, - SimilarResult, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_similar, RankingScoreThresholdSimilar, RetrieveVectors, SearchKind, + SimilarQuery, SimilarResult, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -93,6 +93,8 @@ async fn similar( features.check_vector("Using the similar API")?; + let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors, features)?; + // Tenant token search_rules. if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { add_search_rules(&mut query.filter, search_rules); @@ -103,8 +105,10 @@ async fn similar( let (embedder_name, embedder) = SearchKind::embedder(&index_scheduler, &index, query.embedder.as_deref(), None)?; - tokio::task::spawn_blocking(move || perform_similar(&index, query, embedder_name, embedder)) - .await? + tokio::task::spawn_blocking(move || { + perform_similar(&index, query, embedder_name, embedder, retrieve_vectors) + }) + .await? } #[derive(Debug, deserr::Deserr)] diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index a83dc4bc0..1d697dac6 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -15,7 +15,7 @@ use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::sequential_extractor::SeqHandler; use crate::routes::indexes::search::search_kind; use crate::search::{ - add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, + add_search_rules, perform_search, RetrieveVectors, SearchQueryWithIndex, SearchResultWithIndex, }; use crate::search_queue::SearchQueue; @@ -83,11 +83,14 @@ pub async fn multi_search_with_post( let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features) .with_index(query_index)?; + let retrieve_vector = + RetrieveVectors::new(query.retrieve_vectors, features).with_index(query_index)?; - let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, search_kind)) - .await - .with_index(query_index)?; + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, search_kind, retrieve_vector) + }) + .await + .with_index(query_index)?; search_results.push(SearchResultWithIndex { index_uid: index_uid.into_inner(), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 60f684ede..9632e3f5d 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -823,6 +823,7 @@ pub fn perform_search( index: &Index, query: SearchQuery, search_kind: SearchKind, + retrieve_vectors: RetrieveVectors, ) -> Result { let before_search = Instant::now(); let rtxn = index.read_txn()?; @@ -860,7 +861,8 @@ pub fn perform_search( page, hits_per_page, attributes_to_retrieve, - retrieve_vectors, + // use the enum passed as parameter + retrieve_vectors: _, attributes_to_crop, crop_length, attributes_to_highlight, @@ -968,7 +970,7 @@ pub fn perform_search( struct AttributesFormat { attributes_to_retrieve: Option>, - retrieve_vectors: bool, + retrieve_vectors: RetrieveVectors, attributes_to_highlight: Option>, attributes_to_crop: Option>, crop_length: usize, @@ -981,6 +983,36 @@ struct AttributesFormat { show_ranking_score_details: bool, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RetrieveVectors { + /// Do not touch the `_vectors` field + /// + /// this is the behavior when the vectorStore feature is disabled + Ignore, + /// Remove the `_vectors` field + /// + /// this is the behavior when the vectorStore feature is enabled, and `retrieveVectors` is `false` + Hide, + /// Retrieve vectors from the DB and merge them into the `_vectors` field + /// + /// this is the behavior when the vectorStore feature is enabled, and `retrieveVectors` is `true` + Retrieve, +} + +impl RetrieveVectors { + pub fn new( + retrieve_vector: bool, + features: index_scheduler::RoFeatures, + ) -> Result { + match (retrieve_vector, features.check_vector("Passing `retrieveVectors` as a parameter")) { + (true, Ok(())) => Ok(Self::Retrieve), + (true, Err(error)) => Err(error), + (false, Ok(())) => Ok(Self::Hide), + (false, Err(_)) => Ok(Self::Ignore), + } + } +} + fn make_hits( index: &Index, rtxn: &RoTxn<'_>, @@ -990,10 +1022,32 @@ fn make_hits( document_scores: Vec>, ) -> Result, MeilisearchHttpError> { let fields_ids_map = index.fields_ids_map(rtxn).unwrap(); - let displayed_ids = index - .displayed_fields_ids(rtxn)? - .map(|fields| fields.into_iter().collect::>()) - .unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); + let displayed_ids = + index.displayed_fields_ids(rtxn)?.map(|fields| fields.into_iter().collect::>()); + + let vectors_fid = fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); + + let vectors_is_hidden = match (&displayed_ids, vectors_fid) { + // displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid + (None, _) => false, + // displayed_ids is a finite list, and `_vectors` cannot be part of it because it is not an existing field + (Some(_), None) => true, + // displayed_ids is a finit list, so hide if `_vectors` is not part of it + (Some(map), Some(vectors_fid)) => map.contains(&vectors_fid), + }; + + let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors { + if vectors_is_hidden { + RetrieveVectors::Hide + } else { + RetrieveVectors::Retrieve + } + } else { + format.retrieve_vectors + }; + + let displayed_ids = + displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); let fids = |attrs: &BTreeSet| { let mut ids = BTreeSet::new(); for attr in attrs { @@ -1016,9 +1070,7 @@ fn make_hits( .intersection(&displayed_ids) .cloned() .collect(); - let is_vectors_displayed = - fields_ids_map.id("_vectors").is_some_and(|fid| displayed_ids.contains(&fid)); - let retrieve_vectors = format.retrieve_vectors && is_vectors_displayed; + let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); let formatted_options = compute_formatted_options( @@ -1058,15 +1110,30 @@ fn make_hits( // First generate a document with all the displayed fields let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; + let add_vectors_fid = + vectors_fid.filter(|_fid| retrieve_vectors == RetrieveVectors::Retrieve); + // select the attributes to retrieve let attributes_to_retrieve = to_retrieve_ids .iter() + // skip the vectors_fid if RetrieveVectors::Hide + .filter(|fid| match vectors_fid { + Some(vectors_fid) => { + !(retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid) + } + None => true, + }) + // need to retrieve the existing `_vectors` field if the `RetrieveVectors::Retrieve` + .chain(add_vectors_fid.iter()) .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); let mut document = permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); - if retrieve_vectors { - let mut vectors = serde_json::Map::new(); + if retrieve_vectors == RetrieveVectors::Retrieve { + let mut vectors = match document.remove("_vectors") { + Some(Value::Object(map)) => map, + _ => Default::default(), + }; for (name, vector) in index.embeddings(rtxn, id)? { let user_provided = embedding_configs .iter() @@ -1148,6 +1215,7 @@ pub fn perform_similar( query: SimilarQuery, embedder_name: String, embedder: Arc, + retrieve_vectors: RetrieveVectors, ) -> Result { let before_search = Instant::now(); let rtxn = index.read_txn()?; @@ -1159,7 +1227,7 @@ pub fn perform_similar( filter: _, embedder: _, attributes_to_retrieve, - retrieve_vectors, + retrieve_vectors: _, show_ranking_score, show_ranking_score_details, ranking_score_threshold,