Add deadline of 3 seconds to embedding requests made in the context of hybrid search

This commit is contained in:
Louis Dureuil 2024-11-06 09:24:51 +01:00
parent 3753f87fd8
commit 37a4fd7f99
No known key found for this signature in database
7 changed files with 81 additions and 29 deletions

View File

@ -5201,9 +5201,10 @@ mod tests {
let configs = index_scheduler.embedders(configs).unwrap(); let configs = index_scheduler.embedders(configs).unwrap();
let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap(); let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap();
let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap(); let beagle_embed =
let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap(); hf_embedder.embed_one(S("Intel the beagle best doggo"), None).unwrap();
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap(); let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo"), None).unwrap();
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo"), None).unwrap();
(fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed) (fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed)
}; };

View File

@ -796,8 +796,10 @@ fn prepare_search<'t>(
let span = tracing::trace_span!(target: "search::vector", "embed_one"); let span = tracing::trace_span!(target: "search::vector", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
embedder embedder
.embed_one(query.q.clone().unwrap()) .embed_one(query.q.clone().unwrap(), Some(deadline))
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
} }

View File

@ -201,7 +201,9 @@ impl<'a> Search<'a> {
let span = tracing::trace_span!(target: "search::hybrid", "embed_one"); let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
match embedder.embed_one(query) { let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_one(query, Some(deadline)) {
Ok(embedding) => embedding, Ok(embedding) => embedding,
Err(error) => { Err(error) => {
tracing::error!(error=%error, "Embedding failed"); tracing::error!(error=%error, "Embedding failed");

View File

@ -1,5 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
use arroy::distances::{BinaryQuantizedCosine, Cosine}; use arroy::distances::{BinaryQuantizedCosine, Cosine};
use arroy::ItemId; use arroy::ItemId;
@ -594,18 +595,23 @@ impl Embedder {
pub fn embed( pub fn embed(
&self, &self,
texts: Vec<String>, texts: Vec<String>,
deadline: Option<Instant>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts), Embedder::OpenAi(embedder) => embedder.embed(texts, deadline),
Embedder::Ollama(embedder) => embedder.embed(texts), Embedder::Ollama(embedder) => embedder.embed(texts, deadline),
Embedder::UserProvided(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts),
Embedder::Rest(embedder) => embedder.embed(texts), Embedder::Rest(embedder) => embedder.embed(texts, deadline),
} }
} }
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> { pub fn embed_one(
let mut embeddings = self.embed(vec![text])?; &self,
text: String,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
let mut embeddings = self.embed(vec![text], deadline)?;
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?; let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
Ok(if embeddings.iter().nth(1).is_some() { Ok(if embeddings.iter().nth(1).is_some() {
tracing::warn!("Ignoring embeddings past the first one in long search query"); tracing::warn!("Ignoring embeddings past the first one in long search query");

View File

@ -1,3 +1,5 @@
use std::time::Instant;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
@ -75,8 +77,12 @@ impl Embedder {
Ok(Self { rest_embedder }) Ok(Self { rest_embedder })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(
match self.rest_embedder.embed(texts) { &self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed(texts, deadline) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error)) Err(EmbedError::ollama_model_not_found(error))
@ -92,7 +98,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),

View File

@ -1,3 +1,5 @@
use std::time::Instant;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
@ -206,32 +208,40 @@ impl Embedder {
Ok(Self { options, rest_embedder, tokenizer }) Ok(Self { options, rest_embedder, tokenizer })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(
match self.rest_embedder.embed_ref(&texts) { &self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed_ref(&texts, deadline) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
self.try_embed_tokenized(&texts) self.try_embed_tokenized(&texts, deadline)
} }
Err(error) => Err(error), Err(error) => Err(error),
} }
} }
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> { fn try_embed_tokenized(
&self,
text: &[String],
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut all_embeddings = Vec::with_capacity(text.len()); let mut all_embeddings = Vec::with_capacity(text.len());
for text in text { for text in text {
let max_token_count = self.options.embedding_model.max_token(); let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str()); let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len(); let len = encoded.len();
if len < max_token_count { if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
continue; continue;
} }
let tokens = &encoded.as_slice()[0..max_token_count]; let tokens = &encoded.as_slice()[0..max_token_count];
let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
let embedding = self.rest_embedder.embed_tokens(tokens)?; let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len()) EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
})?; })?;
@ -248,7 +258,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),

View File

@ -1,4 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::time::Instant;
use deserr::Deserr; use deserr::Deserr;
use rand::Rng; use rand::Rng;
@ -154,19 +155,31 @@ impl Embedder {
Ok(Self { data, dimensions, distribution: options.distribution }) Ok(Self { data, dimensions, distribution: options.distribution })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) &self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
} }
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError> pub fn embed_ref<S>(
&self,
texts: &[S],
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError>
where where
S: AsRef<str> + Serialize, S: AsRef<str> + Serialize,
{ {
embed(&self.data, texts, texts.len(), Some(self.dimensions)) embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
} }
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> { pub fn embed_tokens(
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; &self,
tokens: &[usize],
deadline: Option<Instant>,
) -> Result<Embeddings<f32>, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap()) Ok(embeddings.pop().unwrap())
} }
@ -178,7 +191,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -207,7 +220,7 @@ impl Embedder {
} }
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> { fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
let v = embed(data, ["test"].as_slice(), 1, None) let v = embed(data, ["test"].as_slice(), 1, None, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().dimension()) Ok(v.first().unwrap().dimension())
@ -218,6 +231,7 @@ fn embed<S>(
inputs: &[S], inputs: &[S],
expected_count: usize, expected_count: usize,
expected_dimension: Option<usize>, expected_dimension: Option<usize>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> ) -> Result<Vec<Embeddings<f32>>, EmbedError>
where where
S: Serialize, S: Serialize,
@ -245,8 +259,19 @@ where
} }
Err(retry) => { Err(retry) => {
tracing::warn!("Failed: {}", retry.error); tracing::warn!("Failed: {}", retry.error);
if let Some(deadline) = deadline {
let now = std::time::Instant::now();
if now > deadline {
tracing::warn!("Could not embed due to deadline");
return Err(retry.into_error());
}
let duration_to_deadline = deadline - now;
retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
} else {
retry.into_duration(attempt) retry.into_duration(attempt)
} }
}
}?; }?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute