diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 886a0fe30..ff9b06d85 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -218,6 +218,7 @@ MissingDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ; InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ; +InvalidVectorsType , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; @@ -337,6 +338,7 @@ impl ErrorCode for milli::Error { UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, + UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance diff --git a/milli/src/error.rs b/milli/src/error.rs index a12334f90..3df599b61 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -112,6 +112,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidGeoField(#[from] GeoError), #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] InvalidVectorDimensions { expected: usize, found: usize }, + #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] + InvalidVectorsType { document_id: Value, value: Value }, #[error("{0}")] InvalidFilter(String), #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index c2a08b320..e78dbc080 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -3,9 +3,10 @@ use std::fs::File; use std::io; use bytemuck::cast_slice; -use serde_json::from_slice; +use serde_json::{from_slice, Value}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; +use crate::error::UserError; use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; /// Extracts the embedding vector contained in each document under the `_vectors` field. @@ -15,6 +16,7 @@ use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, + primary_key_id: FieldId, vectors_fid: FieldId, ) -> Result> { let mut writer = create_writer( @@ -27,14 +29,23 @@ pub fn extract_vector_points( while let Some((docid_bytes, value)) = cursor.move_on_next()? { let obkv = obkv::KvReader::new(value); + // since we only needs the primary key when we throw an error we create this getter to + // lazily get it when needed + let document_id = || -> Value { + let document_id = obkv.get(primary_key_id).unwrap(); + serde_json::from_slice(document_id).unwrap() + }; + // first we retrieve the _vectors field if let Some(vectors) = obkv.get(vectors_fid) { // extract the vectors - // TODO return a user error before unwrapping - let vectors = from_slice(vectors) - .map_err(InternalError::SerdeJson) - .map(VectorOrArrayOfVectors::into_array_of_vectors) - .unwrap(); + let vectors = match from_slice(vectors) { + Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors), + Err(_) => return Err(UserError::InvalidVectorsType { + document_id: document_id(), + value: from_slice(vectors).map_err(InternalError::SerdeJson)?, + }.into()), + }; for (i, vector) in vectors.into_iter().enumerate() { match u16::try_from(i) { diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 14f08b106..6259c7272 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -316,7 +316,12 @@ fn send_and_extract_flattened_documents_data( let documents_chunk_cloned = flattened_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); + let result = extract_vector_points( + documents_chunk_cloned, + indexer, + primary_key_id, + vectors_field_id, + ); let _ = match result { Ok(vector_points) => { lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))