diff --git a/crates/milli/src/vector/error.rs b/crates/milli/src/vector/error.rs index 41765f6ab..97bbe5d68 100644 --- a/crates/milli/src/vector/error.rs +++ b/crates/milli/src/vector/error.rs @@ -60,7 +60,7 @@ pub enum EmbedErrorKind { ManualEmbed(String), #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))] OllamaModelNotFoundError(Option), - #[error("error deserialization the response body as JSON:\n - {0}")] + #[error("error deserializing the response body as JSON:\n - {0}")] RestResponseDeserialization(std::io::Error), #[error("expected a response containing {0} embeddings, got only {1}")] RestResponseEmbeddingCount(usize, usize), diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index eeb5b16af..81ca6598d 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -257,12 +257,12 @@ where for attempt in 0..10 { let response = request.clone().send_json(&body); - let result = check_response(response, data.configuration_source); + let result = check_response(response, data.configuration_source).and_then(|response| { + response_to_embedding(response, data, expected_count, expected_dimension) + }); let retry_duration = match result { - Ok(response) => { - return response_to_embedding(response, data, expected_count, expected_dimension) - } + Ok(response) => return Ok(response), Err(retry) => { tracing::warn!("Failed: {}", retry.error); retry.into_duration(attempt) @@ -283,6 +283,7 @@ where let result = check_response(response, data.configuration_source); result.map_err(Retry::into_error).and_then(|response| { response_to_embedding(response, data, expected_count, expected_dimension) + .map_err(Retry::into_error) }) } @@ -324,20 +325,28 @@ fn response_to_embedding( data: &EmbedderData, expected_count: usize, expected_dimensions: Option, -) -> Result, EmbedError> { - let response: serde_json::Value = - response.into_json().map_err(EmbedError::rest_response_deserialization)?; +) -> Result, Retry> { + let response: serde_json::Value = response + .into_json() + .map_err(EmbedError::rest_response_deserialization) + .map_err(Retry::retry_later)?; - let embeddings = data.response.extract_embeddings(response)?; + let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?; if embeddings.len() != expected_count { - return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); + return Err(Retry::give_up(EmbedError::rest_response_embedding_count( + expected_count, + embeddings.len(), + ))); } if let Some(dimensions) = expected_dimensions { for embedding in &embeddings { if embedding.len() != dimensions { - return Err(EmbedError::rest_unexpected_dimension(dimensions, embedding.len())); + return Err(Retry::give_up(EmbedError::rest_unexpected_dimension( + dimensions, + embedding.len(), + ))); } } }