From c22dc556945835cd73af4c6f3d29d0a33a1cf1a4 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 28 Oct 2024 14:08:54 +0100 Subject: [PATCH] Add embed_chunks_ref --- milli/src/vector/hf.rs | 35 +++++++++++++++++++------ milli/src/vector/manual.rs | 19 +++++++++----- milli/src/vector/mod.rs | 38 ++++++++++++++++------------ milli/src/vector/ollama.rs | 36 ++++++++++++++++++++++---- milli/src/vector/openai.rs | 49 +++++++++++++++++++++++++---------- milli/src/vector/rest.rs | 52 +++++++++++++++++++++++++------------- 6 files changed, 163 insertions(+), 66 deletions(-) diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index dc1e7d324..ea892ca57 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; pub use super::error::{EmbedError, Error, NewEmbedderError}; -use super::{DistributionShift, Embedding, Embeddings}; +use super::{DistributionShift, Embedding}; #[derive( Debug, @@ -139,15 +139,12 @@ impl Embedder { let embeddings = this .embed(vec!["test".into()]) .map_err(NewEmbedderError::could_not_determine_dimension)?; - this.dimensions = embeddings.first().unwrap().dimension(); + this.dimensions = embeddings.first().unwrap().len(); Ok(this) } - pub fn embed( - &self, - mut texts: Vec, - ) -> std::result::Result>, EmbedError> { + pub fn embed(&self, mut texts: Vec) -> std::result::Result, EmbedError> { let tokens = match texts.len() { 1 => vec![self .tokenizer @@ -177,13 +174,31 @@ impl Embedder { .map_err(EmbedError::tensor_shape)?; let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; - Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) + Ok(embeddings) + } + + pub fn embed_one(&self, text: &str) -> std::result::Result { + let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; + let token_ids = tokens.get_ids(); + let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids }; + let token_ids = + Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + let embeddings = + self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?; + let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape)?; + let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; + Ok(embedding) } pub fn embed_chunks( &self, text_chunks: Vec>, - ) -> std::result::Result>>, EmbedError> { + ) -> std::result::Result>, EmbedError> { text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() } @@ -211,4 +226,8 @@ impl Embedder { } }) } + + pub(crate) fn embed_chunks_ref(&self, texts: &[&str]) -> Result, EmbedError> { + texts.iter().map(|text| self.embed_one(text)).collect() + } } diff --git a/milli/src/vector/manual.rs b/milli/src/vector/manual.rs index 4cfbb0d3c..8c2ef97b2 100644 --- a/milli/src/vector/manual.rs +++ b/milli/src/vector/manual.rs @@ -1,5 +1,6 @@ use super::error::EmbedError; -use super::{DistributionShift, Embeddings}; +use super::DistributionShift; +use crate::vector::Embedding; #[derive(Debug, Clone, Copy)] pub struct Embedder { @@ -18,11 +19,13 @@ impl Embedder { Self { dimensions: options.dimensions, distribution: options.distribution } } - pub fn embed(&self, mut texts: Vec) -> Result>, EmbedError> { - let Some(text) = texts.pop() else { return Ok(Default::default()) }; - Err(EmbedError::embed_on_manual_embedder(text.chars().take(250).collect())) + pub fn embed>(&self, texts: &[S]) -> Result, EmbedError> { + texts.as_ref().iter().map(|text| self.embed_one(text)).collect() } + pub fn embed_one>(&self, text: S) -> Result { + Err(EmbedError::embed_on_manual_embedder(text.as_ref().chars().take(250).collect())) + } pub fn dimensions(&self) -> usize { self.dimensions } @@ -30,11 +33,15 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - ) -> Result>>, EmbedError> { - text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() + ) -> Result>, EmbedError> { + text_chunks.into_iter().map(|prompts| self.embed(&prompts)).collect() } pub fn distribution(&self) -> Option { self.distribution } + + pub(crate) fn embed_chunks_ref(&self, texts: &[&str]) -> Result, EmbedError> { + texts.iter().map(|text| self.embed_one(text)).collect() + } } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index d52e68bbe..2e9a498c0 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -376,28 +376,20 @@ impl Embedder { /// Embed one or multiple texts. /// /// Each text can be embedded as one or multiple embeddings. - pub fn embed( - &self, - texts: Vec, - ) -> std::result::Result>, EmbedError> { + pub fn embed(&self, texts: Vec) -> std::result::Result, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => embedder.embed(texts), - Embedder::Ollama(embedder) => embedder.embed(texts), - Embedder::UserProvided(embedder) => embedder.embed(texts), + Embedder::OpenAi(embedder) => embedder.embed(&texts), + Embedder::Ollama(embedder) => embedder.embed(&texts), + Embedder::UserProvided(embedder) => embedder.embed(&texts), Embedder::Rest(embedder) => embedder.embed(texts), } } pub fn embed_one(&self, text: String) -> std::result::Result { - let mut embeddings = self.embed(vec![text])?; - let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?; - Ok(if embeddings.iter().nth(1).is_some() { - tracing::warn!("Ignoring embeddings past the first one in long search query"); - embeddings.iter().next().unwrap().to_vec() - } else { - embeddings.into_inner() - }) + let mut embedding = self.embed(vec![text])?; + let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?; + Ok(embedding) } /// Embed multiple chunks of texts. @@ -407,7 +399,7 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> std::result::Result>>, EmbedError> { + ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), @@ -417,6 +409,20 @@ impl Embedder { } } + pub fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> std::result::Result, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed_chunks_ref(texts), + Embedder::OpenAi(embedder) => embedder.embed_chunks_ref(texts, threads), + Embedder::Ollama(embedder) => embedder.embed_chunks_ref(texts, threads), + Embedder::UserProvided(embedder) => embedder.embed_chunks_ref(texts), + Embedder::Rest(embedder) => embedder.embed_chunks_ref(texts, threads), + } + } + /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`] pub fn chunk_count_hint(&self) -> usize { match self { diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 7d41ab4e9..65fd05416 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -1,9 +1,11 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::{DistributionShift, Embeddings}; +use super::DistributionShift; use crate::error::FaultSource; +use crate::vector::Embedding; use crate::ThreadPoolNoAbort; #[derive(Debug)] @@ -75,8 +77,11 @@ impl Embedder { Ok(Self { rest_embedder }) } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - match self.rest_embedder.embed(texts) { + pub fn embed + serde::Serialize>( + &self, + texts: &[S], + ) -> Result, EmbedError> { + match self.rest_embedder.embed_ref(texts) { Ok(embeddings) => Ok(embeddings), Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError::ollama_model_not_found(error)) @@ -89,10 +94,31 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> Result>>, EmbedError> { + ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } + + pub(crate) fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> Result>, EmbedError> { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.chunk_count_hint()) + .map(move |chunk| self.embed(chunk)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 152d1fb7a..466fd1660 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -1,11 +1,13 @@ use ordered_float::OrderedFloat; use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; +use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::{DistributionShift, Embeddings}; +use super::DistributionShift; use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; +use crate::vector::Embedding; use crate::ThreadPoolNoAbort; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] @@ -206,22 +208,26 @@ impl Embedder { Ok(Self { options, rest_embedder, tokenizer }) } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - match self.rest_embedder.embed_ref(&texts) { + pub fn embed + serde::Serialize>( + &self, + texts: &[S], + ) -> Result, EmbedError> { + match self.rest_embedder.embed_ref(texts) { Ok(embeddings) => Ok(embeddings), Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); - self.try_embed_tokenized(&texts) + self.try_embed_tokenized(texts) } Err(error) => Err(error), } } - fn try_embed_tokenized(&self, text: &[String]) -> Result>, EmbedError> { + fn try_embed_tokenized>(&self, text: &[S]) -> Result, EmbedError> { let mut all_embeddings = Vec::with_capacity(text.len()); for text in text { + let text = text.as_ref(); let max_token_count = self.options.embedding_model.max_token(); - let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let encoded = self.tokenizer.encode_ordinary(text); let len = encoded.len(); if len < max_token_count { all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); @@ -229,14 +235,10 @@ impl Embedder { } let tokens = &encoded.as_slice()[0..max_token_count]; - let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); let embedding = self.rest_embedder.embed_tokens(tokens)?; - embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { - EmbedError::rest_unexpected_dimension(self.dimensions(), got.len()) - })?; - all_embeddings.push(embeddings_for_prompt); + all_embeddings.push(embedding); } Ok(all_embeddings) } @@ -245,10 +247,31 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> Result>>, EmbedError> { + ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } + + pub(crate) fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> Result>, EmbedError> { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.chunk_count_hint()) + .map(move |chunk| self.embed(chunk)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 2538f2fff..dc2ab95f9 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -3,13 +3,12 @@ use std::collections::BTreeMap; use deserr::Deserr; use rand::Rng; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use rayon::slice::ParallelSlice as _; use serde::{Deserialize, Serialize}; use super::error::EmbedErrorKind; use super::json_template::ValueTemplate; -use super::{ - DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, -}; +use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::ThreadPoolNoAbort; @@ -154,18 +153,18 @@ impl Embedder { Ok(Self { data, dimensions, distribution: options.distribution }) } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + pub fn embed(&self, texts: Vec) -> Result, EmbedError> { embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) } - pub fn embed_ref(&self, texts: &[S]) -> Result>, EmbedError> + pub fn embed_ref(&self, texts: &[S]) -> Result, EmbedError> where S: AsRef + Serialize, { embed(&self.data, texts, texts.len(), Some(self.dimensions)) } - pub fn embed_tokens(&self, tokens: &[usize]) -> Result, EmbedError> { + pub fn embed_tokens(&self, tokens: &[usize]) -> Result { let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error Ok(embeddings.pop().unwrap()) @@ -175,7 +174,7 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> Result>>, EmbedError> { + ) -> Result>, EmbedError> { threads .install(move || { text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() @@ -186,6 +185,27 @@ impl Embedder { })? } + pub(crate) fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> Result, EmbedError> { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.chunk_count_hint()) + .map(move |chunk| self.embed_ref(chunk)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } + pub fn chunk_count_hint(&self) -> usize { super::REQUEST_PARALLELISM } @@ -210,7 +230,7 @@ fn infer_dimensions(data: &EmbedderData) -> Result { let v = embed(data, ["test"].as_slice(), 1, None) .map_err(NewEmbedderError::could_not_determine_dimension)?; // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error - Ok(v.first().unwrap().dimension()) + Ok(v.first().unwrap().len()) } fn embed( @@ -218,7 +238,7 @@ fn embed( inputs: &[S], expected_count: usize, expected_dimension: Option, -) -> Result>, EmbedError> +) -> Result, EmbedError> where S: Serialize, { @@ -304,7 +324,7 @@ fn response_to_embedding( data: &EmbedderData, expected_count: usize, expected_dimensions: Option, -) -> Result>, EmbedError> { +) -> Result, EmbedError> { let response: serde_json::Value = response.into_json().map_err(EmbedError::rest_response_deserialization)?; @@ -316,11 +336,8 @@ fn response_to_embedding( if let Some(dimensions) = expected_dimensions { for embedding in &embeddings { - if embedding.dimension() != dimensions { - return Err(EmbedError::rest_unexpected_dimension( - dimensions, - embedding.dimension(), - )); + if embedding.len() != dimensions { + return Err(EmbedError::rest_unexpected_dimension(dimensions, embedding.len())); } } } @@ -394,7 +411,7 @@ impl Response { pub fn extract_embeddings( &self, response: serde_json::Value, - ) -> Result>, EmbedError> { + ) -> Result, EmbedError> { let extracted_values: Vec = match self.template.extract(response) { Ok(extracted_values) => extracted_values, Err(error) => { @@ -403,8 +420,7 @@ impl Response { return Err(EmbedError::rest_extraction_error(error_message)); } }; - let embeddings: Vec> = - extracted_values.into_iter().map(Embeddings::from_single_embedding).collect(); + let embeddings: Vec = extracted_values.into_iter().collect(); Ok(embeddings) }