diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 3c8cb4b06..45ed5f3a5 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -58,7 +58,7 @@ pub enum EmbedErrorKind { ManualEmbed(String), #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))] OllamaModelNotFoundError(Option), - #[error("error deserialization the response body as JSON:\n - {0}")] + #[error("error deserializing the response body as JSON:\n - {0}")] RestResponseDeserialization(std::io::Error), #[error("expected a response containing {0} embeddings, got only {1}")] RestResponseEmbeddingCount(usize, usize), diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 05578ef1f..27bbc6ae0 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -252,12 +252,12 @@ where for attempt in 0..10 { let response = request.clone().send_json(&body); - let result = check_response(response, data.configuration_source); + let result = check_response(response, data.configuration_source).and_then(|response| { + response_to_embedding(response, data, expected_count, expected_dimension) + }); let retry_duration = match result { - Ok(response) => { - return response_to_embedding(response, data, expected_count, expected_dimension) - } + Ok(response) => return Ok(response), Err(retry) => { tracing::warn!("Failed: {}", retry.error); if let Some(deadline) = deadline { @@ -289,6 +289,7 @@ where let result = check_response(response, data.configuration_source); result.map_err(Retry::into_error).and_then(|response| { response_to_embedding(response, data, expected_count, expected_dimension) + .map_err(Retry::into_error) }) } @@ -330,23 +331,28 @@ fn response_to_embedding( data: &EmbedderData, expected_count: usize, expected_dimensions: Option, -) -> Result>, EmbedError> { - let response: serde_json::Value = - response.into_json().map_err(EmbedError::rest_response_deserialization)?; +) -> Result>, Retry> { + let response: serde_json::Value = response + .into_json() + .map_err(EmbedError::rest_response_deserialization) + .map_err(Retry::retry_later)?; - let embeddings = data.response.extract_embeddings(response)?; + let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?; if embeddings.len() != expected_count { - return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); + return Err(Retry::give_up(EmbedError::rest_response_embedding_count( + expected_count, + embeddings.len(), + ))); } if let Some(dimensions) = expected_dimensions { for embedding in &embeddings { if embedding.dimension() != dimensions { - return Err(EmbedError::rest_unexpected_dimension( + return Err(Retry::give_up(EmbedError::rest_unexpected_dimension( dimensions, embedding.dimension(), - )); + ))); } } }