mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-22 18:17:39 +08:00
Add embed_chunks_ref
This commit is contained in:
parent
50de3fba7b
commit
c22dc55694
@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType};
|
|||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
pub use super::error::{EmbedError, Error, NewEmbedderError};
|
pub use super::error::{EmbedError, Error, NewEmbedderError};
|
||||||
use super::{DistributionShift, Embedding, Embeddings};
|
use super::{DistributionShift, Embedding};
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
Debug,
|
Debug,
|
||||||
@ -139,15 +139,12 @@ impl Embedder {
|
|||||||
let embeddings = this
|
let embeddings = this
|
||||||
.embed(vec!["test".into()])
|
.embed(vec!["test".into()])
|
||||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
this.dimensions = embeddings.first().unwrap().dimension();
|
this.dimensions = embeddings.first().unwrap().len();
|
||||||
|
|
||||||
Ok(this)
|
Ok(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(
|
pub fn embed(&self, mut texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
|
||||||
&self,
|
|
||||||
mut texts: Vec<String>,
|
|
||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
|
||||||
let tokens = match texts.len() {
|
let tokens = match texts.len() {
|
||||||
1 => vec![self
|
1 => vec![self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
@ -177,13 +174,31 @@ impl Embedder {
|
|||||||
.map_err(EmbedError::tensor_shape)?;
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
|
||||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||||
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
|
||||||
|
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
|
||||||
|
let token_ids = tokens.get_ids();
|
||||||
|
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
|
||||||
|
let token_ids =
|
||||||
|
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
||||||
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embeddings =
|
||||||
|
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
|
let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||||
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
|
||||||
|
Ok(embedding)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||||
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,4 +226,8 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_chunks_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
|
texts.iter().map(|text| self.embed_one(text)).collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use super::error::EmbedError;
|
use super::error::EmbedError;
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::DistributionShift;
|
||||||
|
use crate::vector::Embedding;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
@ -18,11 +19,13 @@ impl Embedder {
|
|||||||
Self { dimensions: options.dimensions, distribution: options.distribution }
|
Self { dimensions: options.dimensions, distribution: options.distribution }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed<S: AsRef<str>>(&self, texts: &[S]) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
let Some(text) = texts.pop() else { return Ok(Default::default()) };
|
texts.as_ref().iter().map(|text| self.embed_one(text)).collect()
|
||||||
Err(EmbedError::embed_on_manual_embedder(text.chars().take(250).collect()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn embed_one<S: AsRef<str>>(&self, text: S) -> Result<Embedding, EmbedError> {
|
||||||
|
Err(EmbedError::embed_on_manual_embedder(text.as_ref().chars().take(250).collect()))
|
||||||
|
}
|
||||||
pub fn dimensions(&self) -> usize {
|
pub fn dimensions(&self) -> usize {
|
||||||
self.dimensions
|
self.dimensions
|
||||||
}
|
}
|
||||||
@ -30,11 +33,15 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||||
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
text_chunks.into_iter().map(|prompts| self.embed(&prompts)).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
self.distribution
|
self.distribution
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_chunks_ref(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
|
texts.iter().map(|text| self.embed_one(text)).collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -376,28 +376,20 @@ impl Embedder {
|
|||||||
/// Embed one or multiple texts.
|
/// Embed one or multiple texts.
|
||||||
///
|
///
|
||||||
/// Each text can be embedded as one or multiple embeddings.
|
/// Each text can be embedded as one or multiple embeddings.
|
||||||
pub fn embed(
|
pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
|
||||||
&self,
|
|
||||||
texts: Vec<String>,
|
|
||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
Embedder::OpenAi(embedder) => embedder.embed(&texts),
|
||||||
Embedder::Ollama(embedder) => embedder.embed(texts),
|
Embedder::Ollama(embedder) => embedder.embed(&texts),
|
||||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
Embedder::UserProvided(embedder) => embedder.embed(&texts),
|
||||||
Embedder::Rest(embedder) => embedder.embed(texts),
|
Embedder::Rest(embedder) => embedder.embed(texts),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
|
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
|
||||||
let mut embeddings = self.embed(vec![text])?;
|
let mut embedding = self.embed(vec![text])?;
|
||||||
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
|
let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?;
|
||||||
Ok(if embeddings.iter().nth(1).is_some() {
|
Ok(embedding)
|
||||||
tracing::warn!("Ignoring embeddings past the first one in long search query");
|
|
||||||
embeddings.iter().next().unwrap().to_vec()
|
|
||||||
} else {
|
|
||||||
embeddings.into_inner()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Embed multiple chunks of texts.
|
/// Embed multiple chunks of texts.
|
||||||
@ -407,7 +399,7 @@ impl Embedder {
|
|||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &ThreadPoolNoAbort,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||||
@ -417,6 +409,20 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn embed_chunks_ref(
|
||||||
|
&self,
|
||||||
|
texts: &[&str],
|
||||||
|
threads: &ThreadPoolNoAbort,
|
||||||
|
) -> std::result::Result<Vec<Embedding>, EmbedError> {
|
||||||
|
match self {
|
||||||
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks_ref(texts),
|
||||||
|
Embedder::OpenAi(embedder) => embedder.embed_chunks_ref(texts, threads),
|
||||||
|
Embedder::Ollama(embedder) => embedder.embed_chunks_ref(texts, threads),
|
||||||
|
Embedder::UserProvided(embedder) => embedder.embed_chunks_ref(texts),
|
||||||
|
Embedder::Rest(embedder) => embedder.embed_chunks_ref(texts, threads),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
|
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
|
use rayon::slice::ParallelSlice as _;
|
||||||
|
|
||||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::DistributionShift;
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
|
use crate::vector::Embedding;
|
||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -75,8 +77,11 @@ impl Embedder {
|
|||||||
Ok(Self { rest_embedder })
|
Ok(Self { rest_embedder })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed<S: AsRef<str> + serde::Serialize>(
|
||||||
match self.rest_embedder.embed(texts) {
|
&self,
|
||||||
|
texts: &[S],
|
||||||
|
) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
|
match self.rest_embedder.embed_ref(texts) {
|
||||||
Ok(embeddings) => Ok(embeddings),
|
Ok(embeddings) => Ok(embeddings),
|
||||||
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
||||||
Err(EmbedError::ollama_model_not_found(error))
|
Err(EmbedError::ollama_model_not_found(error))
|
||||||
@ -89,10 +94,31 @@ impl Embedder {
|
|||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &ThreadPoolNoAbort,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||||
threads
|
threads
|
||||||
.install(move || {
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect()
|
||||||
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_chunks_ref(
|
||||||
|
&self,
|
||||||
|
texts: &[&str],
|
||||||
|
threads: &ThreadPoolNoAbort,
|
||||||
|
) -> Result<Vec<Vec<f32>>, EmbedError> {
|
||||||
|
threads
|
||||||
|
.install(move || {
|
||||||
|
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
|
||||||
|
.par_chunks(self.chunk_count_hint())
|
||||||
|
.map(move |chunk| self.embed(chunk))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let embeddings = embeddings?;
|
||||||
|
Ok(embeddings.into_iter().flatten().collect())
|
||||||
})
|
})
|
||||||
.map_err(|error| EmbedError {
|
.map_err(|error| EmbedError {
|
||||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||||
|
use rayon::slice::ParallelSlice as _;
|
||||||
|
|
||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::DistributionShift;
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::vector::error::EmbedErrorKind;
|
use crate::vector::error::EmbedErrorKind;
|
||||||
|
use crate::vector::Embedding;
|
||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
@ -206,22 +208,26 @@ impl Embedder {
|
|||||||
Ok(Self { options, rest_embedder, tokenizer })
|
Ok(Self { options, rest_embedder, tokenizer })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed<S: AsRef<str> + serde::Serialize>(
|
||||||
match self.rest_embedder.embed_ref(&texts) {
|
&self,
|
||||||
|
texts: &[S],
|
||||||
|
) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
|
match self.rest_embedder.embed_ref(texts) {
|
||||||
Ok(embeddings) => Ok(embeddings),
|
Ok(embeddings) => Ok(embeddings),
|
||||||
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
|
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
|
||||||
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||||
self.try_embed_tokenized(&texts)
|
self.try_embed_tokenized(texts)
|
||||||
}
|
}
|
||||||
Err(error) => Err(error),
|
Err(error) => Err(error),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
fn try_embed_tokenized<S: AsRef<str>>(&self, text: &[S]) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||||
for text in text {
|
for text in text {
|
||||||
|
let text = text.as_ref();
|
||||||
let max_token_count = self.options.embedding_model.max_token();
|
let max_token_count = self.options.embedding_model.max_token();
|
||||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
let encoded = self.tokenizer.encode_ordinary(text);
|
||||||
let len = encoded.len();
|
let len = encoded.len();
|
||||||
if len < max_token_count {
|
if len < max_token_count {
|
||||||
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
|
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
|
||||||
@ -229,14 +235,10 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let tokens = &encoded.as_slice()[0..max_token_count];
|
let tokens = &encoded.as_slice()[0..max_token_count];
|
||||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
|
||||||
|
|
||||||
let embedding = self.rest_embedder.embed_tokens(tokens)?;
|
let embedding = self.rest_embedder.embed_tokens(tokens)?;
|
||||||
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
|
|
||||||
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
all_embeddings.push(embeddings_for_prompt);
|
all_embeddings.push(embedding);
|
||||||
}
|
}
|
||||||
Ok(all_embeddings)
|
Ok(all_embeddings)
|
||||||
}
|
}
|
||||||
@ -245,10 +247,31 @@ impl Embedder {
|
|||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &ThreadPoolNoAbort,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||||
threads
|
threads
|
||||||
.install(move || {
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect()
|
||||||
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_chunks_ref(
|
||||||
|
&self,
|
||||||
|
texts: &[&str],
|
||||||
|
threads: &ThreadPoolNoAbort,
|
||||||
|
) -> Result<Vec<Vec<f32>>, EmbedError> {
|
||||||
|
threads
|
||||||
|
.install(move || {
|
||||||
|
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
|
||||||
|
.par_chunks(self.chunk_count_hint())
|
||||||
|
.map(move |chunk| self.embed(chunk))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let embeddings = embeddings?;
|
||||||
|
Ok(embeddings.into_iter().flatten().collect())
|
||||||
})
|
})
|
||||||
.map_err(|error| EmbedError {
|
.map_err(|error| EmbedError {
|
||||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
@ -3,13 +3,12 @@ use std::collections::BTreeMap;
|
|||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
|
use rayon::slice::ParallelSlice as _;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::error::EmbedErrorKind;
|
use super::error::EmbedErrorKind;
|
||||||
use super::json_template::ValueTemplate;
|
use super::json_template::ValueTemplate;
|
||||||
use super::{
|
use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM};
|
||||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
|
||||||
};
|
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
@ -154,18 +153,18 @@ impl Embedder {
|
|||||||
Ok(Self { data, dimensions, distribution: options.distribution })
|
Ok(Self { data, dimensions, distribution: options.distribution })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
|
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embedding>, EmbedError>
|
||||||
where
|
where
|
||||||
S: AsRef<str> + Serialize,
|
S: AsRef<str> + Serialize,
|
||||||
{
|
{
|
||||||
embed(&self.data, texts, texts.len(), Some(self.dimensions))
|
embed(&self.data, texts, texts.len(), Some(self.dimensions))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, EmbedError> {
|
||||||
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
|
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
|
||||||
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(embeddings.pop().unwrap())
|
Ok(embeddings.pop().unwrap())
|
||||||
@ -175,7 +174,7 @@ impl Embedder {
|
|||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &ThreadPoolNoAbort,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
|
||||||
threads
|
threads
|
||||||
.install(move || {
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
@ -186,6 +185,27 @@ impl Embedder {
|
|||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_chunks_ref(
|
||||||
|
&self,
|
||||||
|
texts: &[&str],
|
||||||
|
threads: &ThreadPoolNoAbort,
|
||||||
|
) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
|
threads
|
||||||
|
.install(move || {
|
||||||
|
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
|
||||||
|
.par_chunks(self.chunk_count_hint())
|
||||||
|
.map(move |chunk| self.embed_ref(chunk))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let embeddings = embeddings?;
|
||||||
|
Ok(embeddings.into_iter().flatten().collect())
|
||||||
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
super::REQUEST_PARALLELISM
|
super::REQUEST_PARALLELISM
|
||||||
}
|
}
|
||||||
@ -210,7 +230,7 @@ fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
|||||||
let v = embed(data, ["test"].as_slice(), 1, None)
|
let v = embed(data, ["test"].as_slice(), 1, None)
|
||||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(v.first().unwrap().dimension())
|
Ok(v.first().unwrap().len())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embed<S>(
|
fn embed<S>(
|
||||||
@ -218,7 +238,7 @@ fn embed<S>(
|
|||||||
inputs: &[S],
|
inputs: &[S],
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
expected_dimension: Option<usize>,
|
expected_dimension: Option<usize>,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
) -> Result<Vec<Embedding>, EmbedError>
|
||||||
where
|
where
|
||||||
S: Serialize,
|
S: Serialize,
|
||||||
{
|
{
|
||||||
@ -304,7 +324,7 @@ fn response_to_embedding(
|
|||||||
data: &EmbedderData,
|
data: &EmbedderData,
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
expected_dimensions: Option<usize>,
|
expected_dimensions: Option<usize>,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
let response: serde_json::Value =
|
let response: serde_json::Value =
|
||||||
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
||||||
|
|
||||||
@ -316,11 +336,8 @@ fn response_to_embedding(
|
|||||||
|
|
||||||
if let Some(dimensions) = expected_dimensions {
|
if let Some(dimensions) = expected_dimensions {
|
||||||
for embedding in &embeddings {
|
for embedding in &embeddings {
|
||||||
if embedding.dimension() != dimensions {
|
if embedding.len() != dimensions {
|
||||||
return Err(EmbedError::rest_unexpected_dimension(
|
return Err(EmbedError::rest_unexpected_dimension(dimensions, embedding.len()));
|
||||||
dimensions,
|
|
||||||
embedding.dimension(),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -394,7 +411,7 @@ impl Response {
|
|||||||
pub fn extract_embeddings(
|
pub fn extract_embeddings(
|
||||||
&self,
|
&self,
|
||||||
response: serde_json::Value,
|
response: serde_json::Value,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
||||||
Ok(extracted_values) => extracted_values,
|
Ok(extracted_values) => extracted_values,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
@ -403,8 +420,7 @@ impl Response {
|
|||||||
return Err(EmbedError::rest_extraction_error(error_message));
|
return Err(EmbedError::rest_extraction_error(error_message));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let embeddings: Vec<Embeddings<f32>> =
|
let embeddings: Vec<Embedding> = extracted_values.into_iter().collect();
|
||||||
extracted_values.into_iter().map(Embeddings::from_single_embedding).collect();
|
|
||||||
|
|
||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user