expose a new parameter to retrieve the embedders at search time

This commit is contained in:
Tamo 2024-05-29 17:22:58 +02:00
parent 30d66abf8d
commit 04f6523f3c
10 changed files with 79 additions and 33 deletions

View File

@ -5045,25 +5045,25 @@ mod tests {
// add one doc, specifying vectors // add one doc, specifying vectors
let doc = serde_json::json!( let doc = serde_json::json!(
{ {
"id": 0, "id": 0,
"doggo": "Intel", "doggo": "Intel",
"breed": "beagle", "breed": "beagle",
"_vectors": { "_vectors": {
&fakerest_name: { &fakerest_name: {
// this will never trigger regeneration, which is good because we can't actually generate with // this will never trigger regeneration, which is good because we can't actually generate with
// this embedder // this embedder
"userProvided": true, "userProvided": true,
"embeddings": beagle_embed, "embeddings": beagle_embed,
}, },
&simple_hf_name: { &simple_hf_name: {
// this will be regenerated on updates // this will be regenerated on updates
"userProvided": false, "userProvided": false,
"embeddings": lab_embed, "embeddings": lab_embed,
}, },
"noise": [0.1, 0.2, 0.3] "noise": [0.1, 0.2, 0.3]
} }
} }
); );
let (uuid, mut file) = index_scheduler.create_update_file_with_uuid(0u128).unwrap(); let (uuid, mut file) = index_scheduler.create_update_file_with_uuid(0u128).unwrap();
@ -5163,7 +5163,9 @@ mod tests {
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "Intel to kefir"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "Intel to kefir");
handle.advance_one_successful_batch(); println!("HEEEEERE");
// handle.advance_one_successful_batch();
handle.advance_one_failed_batch();
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "Intel to kefir succeeds"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "Intel to kefir succeeds");
{ {

View File

@ -240,9 +240,11 @@ InvalidSearchAttributesToSearchOn , InvalidRequest , BAD_REQUEST ;
InvalidSearchAttributesToCrop , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToCrop , InvalidRequest , BAD_REQUEST ;
InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ;
InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSimilarRetrieveVectors , InvalidRequest , BAD_REQUEST ;
InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ;
InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ;
InvalidSearchRetrieveVectors , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;

View File

@ -662,6 +662,7 @@ impl SearchAggregator {
page, page,
hits_per_page, hits_per_page,
attributes_to_retrieve: _, attributes_to_retrieve: _,
retrieve_vectors: _,
attributes_to_crop: _, attributes_to_crop: _,
crop_length, crop_length,
attributes_to_highlight: _, attributes_to_highlight: _,
@ -1079,6 +1080,7 @@ impl MultiSearchAggregator {
page: _, page: _,
hits_per_page: _, hits_per_page: _,
attributes_to_retrieve: _, attributes_to_retrieve: _,
retrieve_vectors: _,
attributes_to_crop: _, attributes_to_crop: _,
crop_length: _, crop_length: _,
attributes_to_highlight: _, attributes_to_highlight: _,
@ -1646,6 +1648,7 @@ impl SimilarAggregator {
offset, offset,
limit, limit,
attributes_to_retrieve: _, attributes_to_retrieve: _,
retrieve_vectors: _,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
filter, filter,

View File

@ -115,6 +115,7 @@ impl From<FacetSearchQuery> for SearchQuery {
page: None, page: None,
hits_per_page: None, hits_per_page: None,
attributes_to_retrieve: None, attributes_to_retrieve: None,
retrieve_vectors: false,
attributes_to_crop: None, attributes_to_crop: None,
crop_length: DEFAULT_CROP_LENGTH(), crop_length: DEFAULT_CROP_LENGTH(),
attributes_to_highlight: None, attributes_to_highlight: None,

View File

@ -51,6 +51,8 @@ pub struct SearchQueryGet {
hits_per_page: Option<Param<usize>>, hits_per_page: Option<Param<usize>>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToRetrieve>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToRetrieve>)]
attributes_to_retrieve: Option<CS<String>>, attributes_to_retrieve: Option<CS<String>>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchRetrieveVectors>)]
retrieve_vectors: bool,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToCrop>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToCrop>)]
attributes_to_crop: Option<CS<String>>, attributes_to_crop: Option<CS<String>>,
#[deserr(default = Param(DEFAULT_CROP_LENGTH()), error = DeserrQueryParamError<InvalidSearchCropLength>)] #[deserr(default = Param(DEFAULT_CROP_LENGTH()), error = DeserrQueryParamError<InvalidSearchCropLength>)]
@ -153,6 +155,7 @@ impl From<SearchQueryGet> for SearchQuery {
page: other.page.as_deref().copied(), page: other.page.as_deref().copied(),
hits_per_page: other.hits_per_page.as_deref().copied(), hits_per_page: other.hits_per_page.as_deref().copied(),
attributes_to_retrieve: other.attributes_to_retrieve.map(|o| o.into_iter().collect()), attributes_to_retrieve: other.attributes_to_retrieve.map(|o| o.into_iter().collect()),
retrieve_vectors: other.retrieve_vectors,
attributes_to_crop: other.attributes_to_crop.map(|o| o.into_iter().collect()), attributes_to_crop: other.attributes_to_crop.map(|o| o.into_iter().collect()),
crop_length: other.crop_length.0, crop_length: other.crop_length.0,
attributes_to_highlight: other.attributes_to_highlight.map(|o| o.into_iter().collect()), attributes_to_highlight: other.attributes_to_highlight.map(|o| o.into_iter().collect()),

View File

@ -4,11 +4,7 @@ use deserr::actix_web::{AwebJson, AwebQueryParameter};
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::query_params::Param;
use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::{ use meilisearch_types::error::deserr_codes::*;
InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId,
InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarRankingScoreThreshold,
InvalidSimilarShowRankingScore, InvalidSimilarShowRankingScoreDetails,
};
use meilisearch_types::error::{ErrorCode as _, ResponseError}; use meilisearch_types::error::{ErrorCode as _, ResponseError};
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
@ -122,6 +118,8 @@ pub struct SimilarQueryGet {
limit: Param<usize>, limit: Param<usize>,
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarAttributesToRetrieve>)] #[deserr(default, error = DeserrQueryParamError<InvalidSimilarAttributesToRetrieve>)]
attributes_to_retrieve: Option<CS<String>>, attributes_to_retrieve: Option<CS<String>>,
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarRetrieveVectors>)]
retrieve_vectors: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarFilter>)] #[deserr(default, error = DeserrQueryParamError<InvalidSimilarFilter>)]
filter: Option<String>, filter: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarShowRankingScore>)] #[deserr(default, error = DeserrQueryParamError<InvalidSimilarShowRankingScore>)]
@ -156,6 +154,7 @@ impl TryFrom<SimilarQueryGet> for SimilarQuery {
offset, offset,
limit, limit,
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
filter, filter,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
@ -180,6 +179,7 @@ impl TryFrom<SimilarQueryGet> for SimilarQuery {
filter, filter,
embedder, embedder,
attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()),
retrieve_vectors: retrieve_vectors.0,
show_ranking_score: show_ranking_score.0, show_ranking_score: show_ranking_score.0,
show_ranking_score_details: show_ranking_score_details.0, show_ranking_score_details: show_ranking_score_details.0,
ranking_score_threshold: ranking_score_threshold.map(|x| x.0), ranking_score_threshold: ranking_score_threshold.map(|x| x.0),

View File

@ -59,6 +59,8 @@ pub struct SearchQuery {
pub hits_per_page: Option<usize>, pub hits_per_page: Option<usize>,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToRetrieve>)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToRetrieve>)]
pub attributes_to_retrieve: Option<BTreeSet<String>>, pub attributes_to_retrieve: Option<BTreeSet<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchRetrieveVectors>)]
pub retrieve_vectors: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToCrop>)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToCrop>)]
pub attributes_to_crop: Option<Vec<String>>, pub attributes_to_crop: Option<Vec<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchCropLength>, default = DEFAULT_CROP_LENGTH())] #[deserr(default, error = DeserrJsonError<InvalidSearchCropLength>, default = DEFAULT_CROP_LENGTH())]
@ -141,6 +143,7 @@ impl fmt::Debug for SearchQuery {
page, page,
hits_per_page, hits_per_page,
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
attributes_to_highlight, attributes_to_highlight,
@ -173,6 +176,9 @@ impl fmt::Debug for SearchQuery {
if let Some(q) = q { if let Some(q) = q {
debug.field("q", &q); debug.field("q", &q);
} }
if *retrieve_vectors {
debug.field("retrieve_vectors", &retrieve_vectors);
}
if let Some(v) = vector { if let Some(v) = vector {
if v.len() < 10 { if v.len() < 10 {
debug.field("vector", &v); debug.field("vector", &v);
@ -370,6 +376,8 @@ pub struct SearchQueryWithIndex {
pub hits_per_page: Option<usize>, pub hits_per_page: Option<usize>,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToRetrieve>)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToRetrieve>)]
pub attributes_to_retrieve: Option<BTreeSet<String>>, pub attributes_to_retrieve: Option<BTreeSet<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchRetrieveVectors>)]
pub retrieve_vectors: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToCrop>)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToCrop>)]
pub attributes_to_crop: Option<Vec<String>>, pub attributes_to_crop: Option<Vec<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchCropLength>, default = DEFAULT_CROP_LENGTH())] #[deserr(default, error = DeserrJsonError<InvalidSearchCropLength>, default = DEFAULT_CROP_LENGTH())]
@ -413,6 +421,7 @@ impl SearchQueryWithIndex {
page, page,
hits_per_page, hits_per_page,
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
attributes_to_highlight, attributes_to_highlight,
@ -440,6 +449,7 @@ impl SearchQueryWithIndex {
page, page,
hits_per_page, hits_per_page,
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
attributes_to_highlight, attributes_to_highlight,
@ -478,6 +488,8 @@ pub struct SimilarQuery {
pub embedder: Option<String>, pub embedder: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSimilarAttributesToRetrieve>)] #[deserr(default, error = DeserrJsonError<InvalidSimilarAttributesToRetrieve>)]
pub attributes_to_retrieve: Option<BTreeSet<String>>, pub attributes_to_retrieve: Option<BTreeSet<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSimilarRetrieveVectors>)]
pub retrieve_vectors: bool,
#[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScore>, default)] #[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScore>, default)]
pub show_ranking_score: bool, pub show_ranking_score: bool,
#[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScoreDetails>, default)] #[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScoreDetails>, default)]
@ -847,6 +859,7 @@ pub fn perform_search(
page, page,
hits_per_page, hits_per_page,
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
attributes_to_highlight, attributes_to_highlight,
@ -870,6 +883,7 @@ pub fn perform_search(
let format = AttributesFormat { let format = AttributesFormat {
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
attributes_to_highlight, attributes_to_highlight,
attributes_to_crop, attributes_to_crop,
crop_length, crop_length,
@ -953,6 +967,7 @@ pub fn perform_search(
struct AttributesFormat { struct AttributesFormat {
attributes_to_retrieve: Option<BTreeSet<String>>, attributes_to_retrieve: Option<BTreeSet<String>>,
retrieve_vectors: bool,
attributes_to_highlight: Option<HashSet<String>>, attributes_to_highlight: Option<HashSet<String>>,
attributes_to_crop: Option<Vec<String>>, attributes_to_crop: Option<Vec<String>>,
crop_length: usize, crop_length: usize,
@ -1000,6 +1015,9 @@ fn make_hits(
.intersection(&displayed_ids) .intersection(&displayed_ids)
.cloned() .cloned()
.collect(); .collect();
let is_vectors_displayed =
fields_ids_map.id("_vectors").is_some_and(|fid| displayed_ids.contains(&fid));
let retrieve_vectors = format.retrieve_vectors && is_vectors_displayed;
let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default();
let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); let attr_to_crop = format.attributes_to_crop.unwrap_or_default();
let formatted_options = compute_formatted_options( let formatted_options = compute_formatted_options(
@ -1034,7 +1052,7 @@ fn make_hits(
formatter_builder.highlight_suffix(format.highlight_post_tag); formatter_builder.highlight_suffix(format.highlight_post_tag);
let mut documents = Vec::new(); let mut documents = Vec::new();
let documents_iter = index.documents(rtxn, documents_ids)?; let documents_iter = index.documents(rtxn, documents_ids)?;
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { for ((id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
// First generate a document with all the displayed fields // First generate a document with all the displayed fields
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?;
@ -1045,6 +1063,19 @@ fn make_hits(
let mut document = let mut document =
permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve);
if retrieve_vectors {
let mut vectors = serde_json::Map::new();
for (name, mut vector) in index.embeddings(&rtxn, id)? {
if vector.len() == 1 {
let vector = vector.pop().unwrap();
vectors.insert(name.into(), vector.into());
} else {
vectors.insert(name.into(), vector.into());
}
}
document.insert("_vectors".into(), vectors.into());
}
let (matches_position, formatted) = format_fields( let (matches_position, formatted) = format_fields(
&displayed_document, &displayed_document,
&fields_ids_map, &fields_ids_map,
@ -1125,6 +1156,7 @@ pub fn perform_similar(
filter: _, filter: _,
embedder: _, embedder: _,
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
ranking_score_threshold, ranking_score_threshold,
@ -1171,6 +1203,7 @@ pub fn perform_similar(
let format = AttributesFormat { let format = AttributesFormat {
attributes_to_retrieve, attributes_to_retrieve,
retrieve_vectors,
attributes_to_highlight: None, attributes_to_highlight: None,
attributes_to_crop: None, attributes_to_crop: None,
crop_length: DEFAULT_CROP_LENGTH(), crop_length: DEFAULT_CROP_LENGTH(),

View File

@ -124,7 +124,7 @@ async fn simple_search() {
let (response, code) = index let (response, code) = index
.search_post( .search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}), json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}, "retrieveVectors": true}),
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
@ -133,7 +133,7 @@ async fn simple_search() {
let (response, code) = index let (response, code) = index
.search_post( .search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.5}, "showRankingScore": true}), json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.5}, "showRankingScore": true, "retrieveVectors": true}),
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
@ -142,7 +142,7 @@ async fn simple_search() {
let (response, code) = index let (response, code) = index
.search_post( .search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}, "showRankingScore": true}), json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}, "showRankingScore": true, "retrieveVectors": true}),
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");

View File

@ -557,7 +557,7 @@ async fn limit_and_offset() {
index.wait_task(value.uid()).await; index.wait_task(value.uid()).await;
index index
.similar(json!({"id": 143, "limit": 1}), |response, code| { .similar(json!({"id": 143, "limit": 1, "retrieveVectors": true}), |response, code| {
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(json_string!(response["hits"]), @r###" snapshot!(json_string!(response["hits"]), @r###"
[ [
@ -567,9 +567,9 @@ async fn limit_and_offset() {
"id": "522681", "id": "522681",
"_vectors": { "_vectors": {
"manual": [ "manual": [
0.1, 0.10000000149011612,
0.6, 0.6000000238418579,
0.8 0.800000011920929
] ]
} }
} }

View File

@ -163,6 +163,7 @@ impl Embedder {
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
dbg!(&text_chunks);
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
@ -230,6 +231,7 @@ where
input_value input_value
} }
[input] => { [input] => {
dbg!(&options);
let mut body = options.query.clone(); let mut body = options.query.clone();
body.as_object_mut() body.as_object_mut()