diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 3a0376511..3c4754e5d 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -339,9 +339,7 @@ pub fn extract_embeddings( indexer: GrenadParameters, embedder: Arc, ) -> Result>> { - let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; - - let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism + let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk // docid, state with embedding @@ -375,11 +373,8 @@ pub fn extract_embeddings( current_chunk_ids.push(docid); if chunks.len() == chunks.capacity() { - let chunked_embeds = rt - .block_on( - embedder - .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), - ) + let chunked_embeds = embedder + .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; @@ -396,8 +391,8 @@ pub fn extract_embeddings( // send last chunk if !chunks.is_empty() { - let chunked_embeds = rt - .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) + let chunked_embeds = embedder + .embed_chunks(std::mem::take(&mut chunks)) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -410,13 +405,15 @@ pub fn extract_embeddings( } if !current_chunk.is_empty() { - let embeds = rt - .block_on(embedder.embed(std::mem::take(&mut current_chunk))) + let embeds = embedder + .embed_chunks(vec![std::mem::take(&mut current_chunk)]) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; - for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { - state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + if let Some(embeds) = embeds.first() { + for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + } } } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index c5cce622d..3673c85e3 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -67,6 +67,10 @@ pub enum EmbedErrorKind { OpenAiUnhandledStatusCode(u16), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] ManualEmbed(String), + #[error("could not initialize asynchronous runtime: {0}")] + OpenAiRuntimeInit(std::io::Error), + #[error("initializing web client for sending embedding requests failed: {0}")] + InitWebClient(reqwest::Error), } impl EmbedError { @@ -117,6 +121,14 @@ impl EmbedError { pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } } + + pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } + } + + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { + Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } + } } #[derive(Debug, thiserror::Error)] @@ -183,10 +195,6 @@ impl NewEmbedderError { } } - pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { - Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } - } - pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } } @@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind { #[error("loading model failed: {0}")] LoadModel(candle_core::Error), // openai - #[error("initializing web client for sending embedding requests failed: {0}")] - InitWebClient(reqwest::Error), #[error("The API key passed to Authorization error was in an invalid format: {0}")] InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), } diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 0a6bcbe93..08804e515 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -145,7 +145,8 @@ impl Embedder { let token_ids = tokens .iter() .map(|tokens| { - let tokens = tokens.get_ids().to_vec(); + let mut tokens = tokens.get_ids().to_vec(); + tokens.truncate(512); Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) }) .collect::, EmbedError>>()?; diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 81c4cf4a1..99b7bff7e 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -163,18 +163,24 @@ impl Embedder { ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => embedder.embed(texts).await, + Embedder::OpenAi(embedder) => { + let client = embedder.new_client()?; + embedder.embed(texts, &client).await + } Embedder::UserProvided(embedder) => embedder.embed(texts), } } - pub async fn embed_chunks( + /// # Panics + /// + /// - if called from an asynchronous context + pub fn embed_chunks( &self, text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), - Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 53e8a041b..524f83b80 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings}; #[derive(Debug)] pub struct Embedder { - client: reqwest::Client, + headers: reqwest::header::HeaderMap, tokenizer: tiktoken_rs::CoreBPE, options: EmbedderOptions, } @@ -95,6 +95,13 @@ impl EmbedderOptions { } impl Embedder { + pub fn new_client(&self) -> Result { + reqwest::ClientBuilder::new() + .default_headers(self.headers.clone()) + .build() + .map_err(EmbedError::openai_initialize_web_client) + } + pub fn new(options: EmbedderOptions) -> Result { let mut headers = reqwest::header::HeaderMap::new(); let mut inferred_api_key = Default::default(); @@ -111,25 +118,25 @@ impl Embedder { reqwest::header::CONTENT_TYPE, reqwest::header::HeaderValue::from_static("application/json"), ); - let client = reqwest::ClientBuilder::new() - .default_headers(headers) - .build() - .map_err(NewEmbedderError::openai_initialize_web_client)?; // looking at the code it is very unclear that this can actually fail. let tokenizer = tiktoken_rs::cl100k_base().unwrap(); - Ok(Self { options, client, tokenizer }) + Ok(Self { options, headers, tokenizer }) } - pub async fn embed(&self, texts: Vec) -> Result>, EmbedError> { + pub async fn embed( + &self, + texts: Vec, + client: &reqwest::Client, + ) -> Result>, EmbedError> { let mut tokenized = false; for attempt in 0..7 { let result = if tokenized { - self.try_embed_tokenized(&texts).await + self.try_embed_tokenized(&texts, client).await } else { - self.try_embed(&texts).await + self.try_embed(&texts, client).await }; let retry_duration = match result { @@ -145,9 +152,9 @@ impl Embedder { } let result = if tokenized { - self.try_embed_tokenized(&texts).await + self.try_embed_tokenized(&texts, client).await } else { - self.try_embed(&texts).await + self.try_embed(&texts, client).await }; result.map_err(Retry::into_error) @@ -225,13 +232,13 @@ impl Embedder { async fn try_embed + serde::Serialize>( &self, texts: &[S], + client: &reqwest::Client, ) -> Result>, Retry> { for text in texts { log::trace!("Received prompt: {}", text.as_ref()) } let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; - let response = self - .client + let response = client .post(OPENAI_EMBEDDINGS_URL) .json(&request) .send() @@ -256,7 +263,11 @@ impl Embedder { .collect()) } - async fn try_embed_tokenized(&self, text: &[String]) -> Result>, Retry> { + async fn try_embed_tokenized( + &self, + text: &[String], + client: &reqwest::Client, + ) -> Result>, Retry> { pub const OVERLAP_SIZE: usize = 200; let mut all_embeddings = Vec::with_capacity(text.len()); for text in text { @@ -264,7 +275,7 @@ impl Embedder { let encoded = self.tokenizer.encode_ordinary(text.as_str()); let len = encoded.len(); if len < max_token_count { - all_embeddings.append(&mut self.try_embed(&[text]).await?); + all_embeddings.append(&mut self.try_embed(&[text], client).await?); continue; } @@ -273,22 +284,26 @@ impl Embedder { Embeddings::new(self.options.embedding_model.dimensions()); while tokens.len() > max_token_count { let window = &tokens[..max_token_count]; - embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); + embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); tokens = &tokens[max_token_count - OVERLAP_SIZE..]; } // end of text - embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); + embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap(); all_embeddings.push(embeddings_for_prompt); } Ok(all_embeddings) } - async fn embed_tokens(&self, tokens: &[usize]) -> Result { + async fn embed_tokens( + &self, + tokens: &[usize], + client: &reqwest::Client, + ) -> Result { for attempt in 0..9 { - let duration = match self.try_embed_tokens(tokens).await { + let duration = match self.try_embed_tokens(tokens, client).await { Ok(embedding) => return Ok(embedding), Err(retry) => retry.into_duration(attempt), } @@ -297,14 +312,19 @@ impl Embedder { tokio::time::sleep(duration).await; } - self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) + self.try_embed_tokens(tokens, client) + .await + .map_err(|retry| Retry::give_up(retry.into_error())) } - async fn try_embed_tokens(&self, tokens: &[usize]) -> Result { + async fn try_embed_tokens( + &self, + tokens: &[usize], + client: &reqwest::Client, + ) -> Result { let request = OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; - let response = self - .client + let response = client .post(OPENAI_EMBEDDINGS_URL) .json(&request) .send() @@ -322,12 +342,19 @@ impl Embedder { Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) } - pub async fn embed_chunks( + pub fn embed_chunks( &self, text_chunks: Vec>, ) -> Result>>, EmbedError> { - futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) - .await + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .map_err(EmbedError::openai_runtime_init)?; + let client = self.new_client()?; + rt.block_on(futures::future::try_join_all( + text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), + )) } pub fn chunk_count_hint(&self) -> usize {