From 12940d79a96905f38a9016cb647f2013693e849a Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 23:39:01 +0100 Subject: [PATCH] WIP - manual embedder - multi embedders OK - clippy + tests OK --- meilisearch-types/src/error.rs | 3 + meilisearch/src/routes/indexes/search.rs | 60 +++--- meilisearch/src/search.rs | 21 +- meilisearch/tests/dumps/mod.rs | 39 ++-- meilisearch/tests/search/mod.rs | 26 ++- meilisearch/tests/settings/get_settings.rs | 3 +- milli/src/error.rs | 8 +- .../extract/extract_vector_points.rs | 197 +++++++++++------- .../src/update/index_documents/extract/mod.rs | 1 + milli/src/update/index_documents/mod.rs | 10 +- .../src/update/index_documents/typed_chunk.rs | 26 ++- milli/src/vector/manual.rs | 34 +++ milli/src/vector/mod.rs | 4 + 13 files changed, 292 insertions(+), 140 deletions(-) create mode 100644 milli/src/vector/manual.rs diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 5df1ae106..9df41b68f 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -305,6 +305,7 @@ NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENT PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ; TaskNotFound , InvalidRequest , NOT_FOUND ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; +TooManyVectors , InvalidRequest , BAD_REQUEST ; UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; @@ -362,7 +363,9 @@ impl ErrorCode for milli::Error { UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, + UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType, UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, + UserError::TooManyVectors(_, _) => Code::TooManyVectors, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index ec4825661..c057d4809 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -235,38 +235,42 @@ pub async fn embed( index_scheduler: &IndexScheduler, index: &milli::Index, ) -> Result<(), ResponseError> { - if let Some(VectorQuery::String(prompt)) = query.vector.take() { - let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedder = index_scheduler.embedders(embedder_configs)?; + match query.vector.take() { + Some(VectorQuery::String(prompt)) => { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedder = index_scheduler.embedders(embedder_configs)?; - let embedder_name = if let Some(HybridQuery { - semantic_ratio: _, - embedder: Some(embedder), - }) = &query.hybrid - { - embedder - } else { - "default" - }; + let embedder_name = + if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) = + &query.hybrid + { + embedder + } else { + "default" + }; - let embeddings = embedder - .get(embedder_name) - .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) - .map_err(milli::Error::from)? - .0 - .embed(vec![prompt]) - .await - .map_err(milli::vector::Error::from) - .map_err(milli::Error::from)? - .pop() - .expect("No vector returned from embedding"); + let embeddings = embedder + .get(embedder_name) + .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) + .map_err(milli::Error::from)? + .0 + .embed(vec![prompt]) + .await + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); - if embeddings.iter().nth(1).is_some() { - warn!("Ignoring embeddings past the first one in long search query"); - query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); - } else { - query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = + Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); + } else { + query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); + } } + Some(vector) => query.vector = Some(vector), + None => {} }; Ok(()) } diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index c1e667570..d496da1a3 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -13,7 +13,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -562,8 +562,17 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - /// FIXME: remove this or set to value from the score details - let semantic_score = None; + let mut semantic_score = None; + for details in &score { + if let ScoreDetails::Vector(score_details::Vector { + target_vector: _, + value_similarity: Some((_matching_vector, similarity)), + }) = details + { + semantic_score = Some(*similarity); + break; + } + } let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -648,8 +657,10 @@ pub fn perform_search( hits: documents, hits_info, query: query.q.unwrap_or_default(), - // FIXME: display input vector - vector: None, + vector: match query.vector { + Some(VectorQuery::Vector(vector)) => Some(vector), + _ => None, + }, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, diff --git a/meilisearch/tests/dumps/mod.rs b/meilisearch/tests/dumps/mod.rs index 9e949436a..07cfddd37 100644 --- a/meilisearch/tests/dumps/mod.rs +++ b/meilisearch/tests/dumps/mod.rs @@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1896,7 +1908,8 @@ async fn import_dump_v6_containing_experimental_features() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "###); diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index 00678f7d4..fa97beaaf 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -876,7 +876,31 @@ async fn experimental_feature_vector_store() { })) .await; meili_snap::snapshot!(code, @"200 OK"); - meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); + // vector search returns all documents that don't have vectors in the last bucket, like all sorts + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Shazam!", + "id": "287947" + }, + { + "title": "Captain Marvel", + "id": "299537" + }, + { + "title": "Escape Room", + "id": "522681" + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "id": "166428" + }, + { + "title": "Gläss", + "id": "450465" + } + ] + "###); } #[cfg(feature = "default")] diff --git a/meilisearch/tests/settings/get_settings.rs b/meilisearch/tests/settings/get_settings.rs index 0ea556b94..9ab53c51e 100644 --- a/meilisearch/tests/settings/get_settings.rs +++ b/meilisearch/tests/settings/get_settings.rs @@ -54,7 +54,7 @@ async fn get_settings() { let (response, code) = index.settings().await; assert_eq!(code, 200); let settings = response.as_object().unwrap(); - assert_eq!(settings.keys().len(), 15); + assert_eq!(settings.keys().len(), 16); assert_eq!(settings["displayedAttributes"], json!(["*"])); assert_eq!(settings["searchableAttributes"], json!(["*"])); assert_eq!(settings["filterableAttributes"], json!([])); @@ -83,6 +83,7 @@ async fn get_settings() { "maxTotalHits": 1000, }) ); + assert_eq!(settings["embedders"], json!({})); } #[actix_rt::test] diff --git a/milli/src/error.rs b/milli/src/error.rs index 95a0aba6d..9c5d8f416 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -114,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidGeoField(#[from] GeoError), #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] InvalidVectorDimensions { expected: usize, found: usize }, - #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] - InvalidVectorsType { document_id: Value, value: Value }, + #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] + InvalidVectorsType { document_id: Value, value: Value, subfield: String }, + #[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")] + InvalidVectorsMapType { document_id: Value, value: Value }, #[error("{0}")] InvalidFilter(String), #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] @@ -196,6 +198,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco TooManyEmbedders(usize), #[error("Cannot find embedder with name {0}.")] InvalidEmbedder(String), + #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] + TooManyVectors(String, usize), } impl From for Error { diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 6edde98fb..3a0376511 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -73,6 +73,7 @@ pub fn extract_vector_points( indexer: GrenadParameters, field_id_map: &FieldsIdsMap, prompt: &Prompt, + embedder_name: &str, ) -> Result { puffin::profile_function!(); @@ -115,89 +116,87 @@ pub fn extract_vector_points( // lazily get it when needed let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; - let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) { - let vectors_obkv = KvReaderDelAdd::new(value); - match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { - (Some(old), Some(new)) => { - // no autogeneration - let del_vectors = extract_vectors(old, document_id)?; - let add_vectors = extract_vectors(new, document_id)?; + let vectors_field = vectors_fid + .and_then(|vectors_fid| obkv.get(vectors_fid)) + .map(KvReaderDelAdd::new) + .map(|obkv| to_vector_maps(obkv, document_id)) + .transpose()?; - VectorStateDelta::ManualDelta( - del_vectors.unwrap_or_default(), - add_vectors.unwrap_or_default(), - ) - } - (None, Some(new)) => { - // was possibly autogenerated, remove all vectors for that document - let add_vectors = extract_vectors(new, document_id)?; + let (del_map, add_map) = vectors_field.unzip(); + let del_map = del_map.flatten(); + let add_map = add_map.flatten(); - VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) - } - (Some(_old), None) => { - // Do we keep this document? - let document_is_kept = obkv - .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) - .any(|deladd| deladd.get(DelAdd::Addition).is_some()); - if document_is_kept { - // becomes autogenerated - VectorStateDelta::NowGenerated(prompt.render( - obkv, - DelAdd::Addition, - field_id_map, - )?) - } else { - VectorStateDelta::NowRemoved - } - } - (None, None) => { - // Do we keep this document? - let document_is_kept = obkv - .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) - .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + let del_value = del_map.and_then(|mut map| map.remove(embedder_name)); + let add_value = add_map.and_then(|mut map| map.remove(embedder_name)); - if document_is_kept { - // Don't give up if the old prompt was failing - let old_prompt = - prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); - let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; - if old_prompt != new_prompt { - log::trace!( - "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" - ); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange - } - } else { - VectorStateDelta::NowRemoved - } + let delta = match (del_value, add_value) { + (Some(old), Some(new)) => { + // no autogeneration + let del_vectors = extract_vectors(old, document_id, embedder_name)?; + let add_vectors = extract_vectors(new, document_id, embedder_name)?; + + if add_vectors.len() > u8::MAX.into() { + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + document_id().to_string(), + add_vectors.len(), + ))); + } + + VectorStateDelta::ManualDelta(del_vectors, add_vectors) + } + (Some(_old), None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + if document_is_kept { + // becomes autogenerated + VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + field_id_map, + )?) + } else { + VectorStateDelta::NowRemoved } } - } else { - // Do we keep this document? - let document_is_kept = obkv - .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) - .any(|deladd| deladd.get(DelAdd::Addition).is_some()); - - if document_is_kept { - // Don't give up if the old prompt was failing - let old_prompt = - prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); - let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; - if old_prompt != new_prompt { - log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange + (None, Some(new)) => { + // was possibly autogenerated, remove all vectors for that document + let add_vectors = extract_vectors(new, document_id, embedder_name)?; + if add_vectors.len() > u8::MAX.into() { + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + document_id().to_string(), + add_vectors.len(), + ))); + } + + VectorStateDelta::WasGeneratedNowManual(add_vectors) + } + (None, None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange + } + } else { + VectorStateDelta::NowRemoved } - } else { - VectorStateDelta::NowRemoved } }; @@ -221,6 +220,34 @@ pub fn extract_vector_points( }) } +fn to_vector_maps( + obkv: KvReaderDelAdd, + document_id: impl Fn() -> Value, +) -> Result<(Option>, Option>)> { + let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?; + let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?; + Ok((del, add)) +} + +fn to_vector_map( + obkv: KvReaderDelAdd, + side: DelAdd, + document_id: &impl Fn() -> Value, +) -> Result>> { + Ok(if let Some(value) = obkv.get(side) { + let Ok(value) = from_slice(value) else { + let value = from_slice(value).map_err(InternalError::SerdeJson)?; + return Err(crate::Error::UserError(UserError::InvalidVectorsMapType { + document_id: document_id(), + value, + })); + }; + Some(value) + } else { + None + }) +} + /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. fn push_vectors_diff( @@ -286,12 +313,20 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering { } /// Extracts the vectors from a JSON value. -fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result>>> { - match from_slice(value) { - Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), +fn extract_vectors( + value: Value, + document_id: impl Fn() -> Value, + name: &str, +) -> Result>> { + // FIXME: ugly clone of the vectors here + match serde_json::from_value(value.clone()) { + Ok(vectors) => { + Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default()) + } Err(_) => Err(UserError::InvalidVectorsType { document_id: document_id(), - value: from_slice(value).map_err(InternalError::SerdeJson)?, + value, + subfield: name.to_owned(), } .into()), } diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 4831cc69d..a852b035b 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -298,6 +298,7 @@ fn send_original_documents_data( indexer, &field_id_map, &prompt, + &name, ); match result { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index c3c39b90f..075dcd184 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -514,16 +514,18 @@ where // We write the primary key field id into the main database self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; + let mut rng = rand::rngs::StdRng::from_entropy(); for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; let vector_arroy = self.index.vector_arroy; - /// FIXME: unwrap - let embedder_index = - self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); + + let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, + )?; + pool.install(|| { let writer_index = (embedder_index as u16) << 8; - let mut rng = rand::rngs::StdRng::from_entropy(); for k in 0..=u8::MAX { let writer = arroy::Writer::prepare( wtxn, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index dde2124ed..f8fb30c7b 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -22,7 +22,9 @@ use crate::index::db_name::DOCUMENTS; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; +use crate::{ + lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError, +}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -363,8 +365,9 @@ pub(crate) fn write_typed_chunk_into_index( expected_dimension, embedder_name, } => { - /// FIXME: unwrap - let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); + let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, + )?; let writer_index = (embedder_index as u16) << 8; // FIXME: allow customizing distance let writers: std::result::Result, _> = (0..=u8::MAX) @@ -404,7 +407,20 @@ pub(crate) fn write_typed_chunk_into_index( // code error if we somehow got the wrong dimension .unwrap(); - /// FIXME: detect overflow + if embeddings.embedding_count() > u8::MAX.into() { + let external_docid = if let Ok(Some(Ok(index))) = index + .external_id_of(wtxn, std::iter::once(docid)) + .map(|it| it.into_iter().next()) + { + index + } else { + format!("internal docid={docid}") + }; + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + external_docid, + embeddings.embedding_count(), + ))); + } for (embedding, writer) in embeddings.iter().zip(&writers) { writer.add_item(wtxn, docid, embedding)?; } @@ -455,7 +471,7 @@ pub(crate) fn write_typed_chunk_into_index( if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { let vector = pod_collect_to_vec(value); - /// FIXME: detect overflow + // overflow was detected during vector extraction. for writer in &writers { if !writer.contains_item(wtxn, docid)? { writer.add_item(wtxn, docid, &vector)?; diff --git a/milli/src/vector/manual.rs b/milli/src/vector/manual.rs new file mode 100644 index 000000000..7ed48a251 --- /dev/null +++ b/milli/src/vector/manual.rs @@ -0,0 +1,34 @@ +use super::error::EmbedError; +use super::Embeddings; + +#[derive(Debug, Clone, Copy)] +pub struct Embedder { + dimensions: usize, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub dimensions: usize, +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Self { + Self { dimensions: options.dimensions } + } + + pub fn embed(&self, mut texts: Vec) -> Result>, EmbedError> { + let Some(text) = texts.pop() else { return Ok(Default::default()) }; + Err(EmbedError::embed_on_manual_embedder(text)) + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() + } +} diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 7185e56b1..fa39c20a2 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -31,6 +31,10 @@ impl Embeddings { Ok(this) } + pub fn embedding_count(&self) -> usize { + self.data.len() / self.dimension + } + pub fn dimension(&self) -> usize { self.dimension }