diff --git a/crates/index-scheduler/src/lib.rs b/crates/index-scheduler/src/lib.rs index df8870470..16ad3f194 100644 --- a/crates/index-scheduler/src/lib.rs +++ b/crates/index-scheduler/src/lib.rs @@ -5214,9 +5214,10 @@ mod tests { let configs = index_scheduler.embedders(configs).unwrap(); let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap(); - let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap(); - let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap(); - let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap(); + let beagle_embed = + hf_embedder.embed_one(S("Intel the beagle best doggo"), None).unwrap(); + let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo"), None).unwrap(); + let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo"), None).unwrap(); (fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed) }; diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index c873ab387..ec36b01bb 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -796,8 +796,10 @@ fn prepare_search<'t>( let span = tracing::trace_span!(target: "search::vector", "embed_one"); let _entered = span.enter(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + embedder - .embed_one(query.q.clone().unwrap()) + .embed_one(query.q.clone().unwrap(), Some(deadline)) .map_err(milli::vector::Error::from) .map_err(milli::Error::from)? } diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index 8b274804c..5187b572b 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -201,7 +201,9 @@ impl<'a> Search<'a> { let span = tracing::trace_span!(target: "search::hybrid", "embed_one"); let _entered = span.enter(); - match embedder.embed_one(query) { + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); + + match embedder.embed_one(query, Some(deadline)) { Ok(embedding) => embedding, Err(error) => { tracing::error!(error=%error, "Embedding failed"); diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index 24ea77541..3047e6dfc 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::Instant; use arroy::distances::{BinaryQuantizedCosine, Cosine}; use arroy::ItemId; @@ -595,18 +596,26 @@ impl Embedder { /// Embed one or multiple texts. /// /// Each text can be embedded as one or multiple embeddings. - pub fn embed(&self, texts: Vec) -> std::result::Result, EmbedError> { + pub fn embed( + &self, + texts: Vec, + deadline: Option, + ) -> std::result::Result, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => embedder.embed(&texts), - Embedder::Ollama(embedder) => embedder.embed(&texts), + Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline), + Embedder::Ollama(embedder) => embedder.embed(&texts, deadline), Embedder::UserProvided(embedder) => embedder.embed(&texts), - Embedder::Rest(embedder) => embedder.embed(texts), + Embedder::Rest(embedder) => embedder.embed(texts, deadline), } } - pub fn embed_one(&self, text: String) -> std::result::Result { - let mut embedding = self.embed(vec![text])?; + pub fn embed_one( + &self, + text: String, + deadline: Option, + ) -> std::result::Result { + let mut embedding = self.embed(vec![text], deadline)?; let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?; Ok(embedding) } diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index 263d9d3c9..7ee775cbf 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -1,3 +1,5 @@ +use std::time::Instant; + use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use rayon::slice::ParallelSlice as _; @@ -80,8 +82,9 @@ impl Embedder { pub fn embed + serde::Serialize>( &self, texts: &[S], + deadline: Option, ) -> Result, EmbedError> { - match self.rest_embedder.embed_ref(texts) { + match self.rest_embedder.embed_ref(texts, deadline) { Ok(embeddings) => Ok(embeddings), Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError::ollama_model_not_found(error)) @@ -97,7 +100,7 @@ impl Embedder { ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), @@ -114,7 +117,7 @@ impl Embedder { .install(move || { let embeddings: Result>, _> = texts .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk)) + .map(move |chunk| self.embed(chunk, None)) .collect(); let embeddings = embeddings?; diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index 375b2878a..7262bfef8 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -1,3 +1,5 @@ +use std::time::Instant; + use ordered_float::OrderedFloat; use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use rayon::slice::ParallelSlice as _; @@ -211,18 +213,23 @@ impl Embedder { pub fn embed + serde::Serialize>( &self, texts: &[S], + deadline: Option, ) -> Result, EmbedError> { - match self.rest_embedder.embed_ref(texts) { + match self.rest_embedder.embed_ref(texts, deadline) { Ok(embeddings) => Ok(embeddings), Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); - self.try_embed_tokenized(texts) + self.try_embed_tokenized(texts, deadline) } Err(error) => Err(error), } } - fn try_embed_tokenized>(&self, text: &[S]) -> Result, EmbedError> { + fn try_embed_tokenized>( + &self, + text: &[S], + deadline: Option, + ) -> Result, EmbedError> { let mut all_embeddings = Vec::with_capacity(text.len()); for text in text { let text = text.as_ref(); @@ -230,13 +237,13 @@ impl Embedder { let encoded = self.tokenizer.encode_ordinary(text); let len = encoded.len(); if len < max_token_count { - all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); + all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?); continue; } let tokens = &encoded.as_slice()[0..max_token_count]; - let embedding = self.rest_embedder.embed_tokens(tokens)?; + let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?; all_embeddings.push(embedding); } @@ -250,7 +257,7 @@ impl Embedder { ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), @@ -267,7 +274,7 @@ impl Embedder { .install(move || { let embeddings: Result>, _> = texts .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk)) + .map(move |chunk| self.embed(chunk, None)) .collect(); let embeddings = embeddings?; diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index 81ca6598d..98be311d4 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::time::Instant; use deserr::Deserr; use rand::Rng; @@ -153,19 +154,31 @@ impl Embedder { Ok(Self { data, dimensions, distribution: options.distribution }) } - pub fn embed(&self, texts: Vec) -> Result, EmbedError> { - embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) + pub fn embed( + &self, + texts: Vec, + deadline: Option, + ) -> Result, EmbedError> { + embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline) } - pub fn embed_ref(&self, texts: &[S]) -> Result, EmbedError> + pub fn embed_ref( + &self, + texts: &[S], + deadline: Option, + ) -> Result, EmbedError> where S: AsRef + Serialize, { - embed(&self.data, texts, texts.len(), Some(self.dimensions)) + embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline) } - pub fn embed_tokens(&self, tokens: &[usize]) -> Result { - let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; + pub fn embed_tokens( + &self, + tokens: &[usize], + deadline: Option, + ) -> Result { + let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?; // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error Ok(embeddings.pop().unwrap()) } @@ -177,7 +190,7 @@ impl Embedder { ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect() }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), @@ -194,7 +207,7 @@ impl Embedder { .install(move || { let embeddings: Result>, _> = texts .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed_ref(chunk)) + .map(move |chunk| self.embed_ref(chunk, None)) .collect(); let embeddings = embeddings?; @@ -227,7 +240,7 @@ impl Embedder { } fn infer_dimensions(data: &EmbedderData) -> Result { - let v = embed(data, ["test"].as_slice(), 1, None) + let v = embed(data, ["test"].as_slice(), 1, None, None) .map_err(NewEmbedderError::could_not_determine_dimension)?; // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error Ok(v.first().unwrap().len()) @@ -238,6 +251,7 @@ fn embed( inputs: &[S], expected_count: usize, expected_dimension: Option, + deadline: Option, ) -> Result, EmbedError> where S: Serialize, @@ -265,7 +279,18 @@ where Ok(response) => return Ok(response), Err(retry) => { tracing::warn!("Failed: {}", retry.error); - retry.into_duration(attempt) + if let Some(deadline) = deadline { + let now = std::time::Instant::now(); + if now > deadline { + tracing::warn!("Could not embed due to deadline"); + return Err(retry.into_error()); + } + + let duration_to_deadline = deadline - now; + retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline)) + } else { + retry.into_duration(attempt) + } } }?;