Refactor vector indexing

- use the parsed_vectors module
- only parse `_vectors` once per document, instead of once per embedder per document
This commit is contained in:
Louis Dureuil 2024-05-14 11:42:26 +02:00
parent 261de888b7
commit 52d9cb6e5a
No known key found for this signature in database
5 changed files with 218 additions and 236 deletions

View File

@ -120,7 +120,7 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
#[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
InvalidVectorsType { document_id: Value, value: Value, subfield: String }, InvalidVectorsType { document_id: Value, value: Value, subfield: String },
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")] #[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
InvalidVectorsMapType { document_id: Value, value: Value }, InvalidVectorsMapType { document_id: String, value: Value },
#[error("{0}")] #[error("{0}")]
InvalidFilter(String), InvalidFilter(String),
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]

View File

@ -362,35 +362,6 @@ pub fn normalize_facet(original: &str) -> String {
CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase()
} }
/// Represents either a vector or an array of multiple vectors.
#[derive(serde::Serialize, serde::Deserialize, Debug)]
#[serde(transparent)]
pub struct VectorOrArrayOfVectors {
#[serde(with = "either::serde_untagged_optional")]
inner: Option<either::Either<Vec<f32>, Vec<Vec<f32>>>>,
}
impl VectorOrArrayOfVectors {
pub fn into_array_of_vectors(self) -> Option<Vec<Vec<f32>>> {
match self.inner? {
either::Either::Left(vector) => Some(vec![vector]),
either::Either::Right(vectors) => Some(vectors),
}
}
}
/// Normalize a vector by dividing the dimensions by the length of it.
pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
let squared: f32 = vector.iter().map(|x| x * x).sum();
let length = squared.sqrt();
if length <= f32::EPSILON {
vector
} else {
vector.iter_mut().for_each(|x| *x /= length);
vector
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json; use serde_json::json;

View File

@ -10,16 +10,16 @@ use bytemuck::cast_slice;
use grenad::Writer; use grenad::Writer;
use itertools::EitherOrBoth; use itertools::EitherOrBoth;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use serde_json::{from_slice, Value}; use serde_json::Value;
use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::error::UserError;
use crate::prompt::Prompt; use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at; use crate::update::index_documents::helpers::try_split_at;
use crate::update::settings::InnerIndexSettingsDiff; use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::parsed_vectors::{ParsedVectorsDiff, RESERVED_VECTORS_FIELD_NAME};
use crate::vector::Embedder; use crate::vector::Embedder;
use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, VectorOrArrayOfVectors}; use crate::{DocumentId, Result, ThreadPoolNoAbort};
/// The length of the elements that are always in the buffer when inserting new values. /// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
@ -31,6 +31,10 @@ pub struct ExtractedVectorPoints {
pub remove_vectors: grenad::Reader<BufReader<File>>, pub remove_vectors: grenad::Reader<BufReader<File>>,
// docid -> prompt // docid -> prompt
pub prompts: grenad::Reader<BufReader<File>>, pub prompts: grenad::Reader<BufReader<File>>,
// embedder
pub embedder_name: String,
pub embedder: Arc<Embedder>,
} }
enum VectorStateDelta { enum VectorStateDelta {
@ -65,6 +69,19 @@ impl VectorStateDelta {
} }
} }
struct EmbedderVectorExtractor {
embedder_name: String,
embedder: Arc<Embedder>,
prompt: Arc<Prompt>,
// (docid, _index) -> KvWriterDelAdd -> Vector
manual_vectors_writer: Writer<BufWriter<File>>,
// (docid) -> (prompt)
prompts_writer: Writer<BufWriter<File>>,
// (docid) -> ()
remove_vectors_writer: Writer<BufWriter<File>>,
}
/// Extracts the embedding vector contained in each document under the `_vectors` field. /// Extracts the embedding vector contained in each document under the `_vectors` field.
/// ///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32> /// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
@ -72,36 +89,56 @@ impl VectorStateDelta {
pub fn extract_vector_points<R: io::Read + io::Seek>( pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>, obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters, indexer: GrenadParameters,
settings_diff: &InnerIndexSettingsDiff, settings_diff: Arc<InnerIndexSettingsDiff>,
prompt: &Prompt, ) -> Result<Vec<ExtractedVectorPoints>> {
embedder_name: &str,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!(); puffin::profile_function!();
let reindex_vectors = settings_diff.reindex_vectors();
let old_fields_ids_map = &settings_diff.old.fields_ids_map; let old_fields_ids_map = &settings_diff.old.fields_ids_map;
let new_fields_ids_map = &settings_diff.new.fields_ids_map; let new_fields_ids_map = &settings_diff.new.fields_ids_map;
// the vector field id may have changed
let old_vectors_fid = old_fields_ids_map.id(RESERVED_VECTORS_FIELD_NAME);
// filter the old vector fid if the settings has been changed forcing reindexing.
let old_vectors_fid = old_vectors_fid.filter(|_| !reindex_vectors);
let new_vectors_fid = new_fields_ids_map.id(RESERVED_VECTORS_FIELD_NAME);
let mut extractors = Vec::new();
for (embedder_name, (embedder, prompt)) in
settings_diff.new.embedding_configs.clone().into_iter()
{
// (docid, _index) -> KvWriterDelAdd -> Vector // (docid, _index) -> KvWriterDelAdd -> Vector
let mut manual_vectors_writer = create_writer( let manual_vectors_writer = create_writer(
indexer.chunk_compression_type, indexer.chunk_compression_type,
indexer.chunk_compression_level, indexer.chunk_compression_level,
tempfile::tempfile()?, tempfile::tempfile()?,
); );
// (docid) -> (prompt) // (docid) -> (prompt)
let mut prompts_writer = create_writer( let prompts_writer = create_writer(
indexer.chunk_compression_type, indexer.chunk_compression_type,
indexer.chunk_compression_level, indexer.chunk_compression_level,
tempfile::tempfile()?, tempfile::tempfile()?,
); );
// (docid) -> () // (docid) -> ()
let mut remove_vectors_writer = create_writer( let remove_vectors_writer = create_writer(
indexer.chunk_compression_type, indexer.chunk_compression_type,
indexer.chunk_compression_level, indexer.chunk_compression_level,
tempfile::tempfile()?, tempfile::tempfile()?,
); );
extractors.push(EmbedderVectorExtractor {
embedder_name,
embedder,
prompt,
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
});
}
let mut key_buffer = Vec::new(); let mut key_buffer = Vec::new();
let mut cursor = obkv_documents.into_cursor()?; let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? { while let Some((key, value)) = cursor.move_on_next()? {
@ -114,42 +151,27 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
key_buffer.clear(); key_buffer.clear();
key_buffer.extend_from_slice(docid_bytes); key_buffer.extend_from_slice(docid_bytes);
// since we only needs the primary key when we throw an error we create this getter to // since we only need the primary key when we throw an error we create this getter to
// lazily get it when needed // lazily get it when needed
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
// the vector field id may have changed let mut parsed_vectors = ParsedVectorsDiff::new(obkv, old_vectors_fid, new_vectors_fid)
let old_vectors_fid = old_fields_ids_map.id("_vectors"); .map_err(|error| error.to_crate_error(document_id().to_string()))?;
// filter the old vector fid if the settings has been changed forcing reindexing.
let old_vectors_fid = old_vectors_fid.filter(|_| !settings_diff.reindex_vectors());
let new_vectors_fid = new_fields_ids_map.id("_vectors"); for EmbedderVectorExtractor {
let vectors_field = { embedder_name,
let del = old_vectors_fid embedder: _,
.and_then(|vectors_fid| obkv.get(vectors_fid)) prompt,
.map(KvReaderDelAdd::new) manual_vectors_writer,
.map(|obkv| to_vector_map(obkv, DelAdd::Deletion, &document_id)) prompts_writer,
.transpose()? remove_vectors_writer,
.flatten(); } in extractors.iter_mut()
let add = new_vectors_fid {
.and_then(|vectors_fid| obkv.get(vectors_fid)) let delta = match parsed_vectors.remove(embedder_name) {
.map(KvReaderDelAdd::new)
.map(|obkv| to_vector_map(obkv, DelAdd::Addition, &document_id))
.transpose()?
.flatten();
(del, add)
};
let (del_map, add_map) = vectors_field;
let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
let delta = match (del_value, add_value) {
(Some(old), Some(new)) => { (Some(old), Some(new)) => {
// no autogeneration // no autogeneration
let del_vectors = extract_vectors(old, document_id, embedder_name)?; let del_vectors = old.into_array_of_vectors();
let add_vectors = extract_vectors(new, document_id, embedder_name)?; let add_vectors = new.into_array_of_vectors();
if add_vectors.len() > usize::from(u8::MAX) { if add_vectors.len() > usize::from(u8::MAX) {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors( return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
@ -179,7 +201,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
} }
(None, Some(new)) => { (None, Some(new)) => {
// was possibly autogenerated, remove all vectors for that document // was possibly autogenerated, remove all vectors for that document
let add_vectors = extract_vectors(new, document_id, embedder_name)?; let add_vectors = new.into_array_of_vectors();
if add_vectors.len() > usize::from(u8::MAX) { if add_vectors.len() > usize::from(u8::MAX) {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors( return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(), document_id().to_string(),
@ -198,14 +220,16 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
if document_is_kept { if document_is_kept {
// Don't give up if the old prompt was failing // Don't give up if the old prompt was failing
let old_prompt = Some(prompt) let old_prompt = Some(&prompt)
// TODO: this filter works because we erase the vec database when a embedding setting changes. // TODO: this filter works because we erase the vec database when a embedding setting changes.
// When vector pipeline will be optimized, this should be removed. // When vector pipeline will be optimized, this should be removed.
.filter(|_| !settings_diff.reindex_vectors()) .filter(|_| !settings_diff.reindex_vectors())
.map(|p| { .map(|p| {
p.render(obkv, DelAdd::Deletion, old_fields_ids_map).unwrap_or_default() p.render(obkv, DelAdd::Deletion, old_fields_ids_map)
.unwrap_or_default()
}); });
let new_prompt = prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; let new_prompt =
prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?;
if old_prompt.as_ref() != Some(&new_prompt) { if old_prompt.as_ref() != Some(&new_prompt) {
let old_prompt = old_prompt.unwrap_or_default(); let old_prompt = old_prompt.unwrap_or_default();
tracing::trace!( tracing::trace!(
@ -224,42 +248,43 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
// and we finally push the unique vectors into the writer // and we finally push the unique vectors into the writer
push_vectors_diff( push_vectors_diff(
&mut remove_vectors_writer, remove_vectors_writer,
&mut prompts_writer, prompts_writer,
&mut manual_vectors_writer, manual_vectors_writer,
&mut key_buffer, &mut key_buffer,
delta, delta,
settings_diff, reindex_vectors,
)?; )?;
} }
}
Ok(ExtractedVectorPoints { /////
let mut results = Vec::new();
for EmbedderVectorExtractor {
embedder_name,
embedder,
prompt: _,
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
} in extractors
{
results.push(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector // docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?, manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> () // docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?, remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt // docid -> prompt
prompts: writer_into_reader(prompts_writer)?, prompts: writer_into_reader(prompts_writer)?,
})
}
fn to_vector_map( embedder,
obkv: KvReaderDelAdd, embedder_name,
side: DelAdd,
document_id: &impl Fn() -> Value,
) -> Result<Option<serde_json::Map<String, Value>>> {
Ok(if let Some(value) = obkv.get(side) {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
document_id: document_id(),
value,
}));
};
Some(value)
} else {
None
}) })
}
Ok(results)
} }
/// Computes the diff between both Del and Add numbers and /// Computes the diff between both Del and Add numbers and
@ -270,14 +295,14 @@ fn push_vectors_diff(
manual_vectors_writer: &mut Writer<BufWriter<File>>, manual_vectors_writer: &mut Writer<BufWriter<File>>,
key_buffer: &mut Vec<u8>, key_buffer: &mut Vec<u8>,
delta: VectorStateDelta, delta: VectorStateDelta,
settings_diff: &InnerIndexSettingsDiff, reindex_vectors: bool,
) -> Result<()> { ) -> Result<()> {
puffin::profile_function!(); puffin::profile_function!();
let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values();
if must_remove if must_remove
// TODO: the below condition works because we erase the vec database when a embedding setting changes. // TODO: the below condition works because we erase the vec database when a embedding setting changes.
// When vector pipeline will be optimized, this should be removed. // When vector pipeline will be optimized, this should be removed.
&& !settings_diff.reindex_vectors() && !reindex_vectors
{ {
key_buffer.truncate(TRUNCATE_SIZE); key_buffer.truncate(TRUNCATE_SIZE);
remove_vectors_writer.insert(&key_buffer, [])?; remove_vectors_writer.insert(&key_buffer, [])?;
@ -308,7 +333,7 @@ fn push_vectors_diff(
EitherOrBoth::Left(vector) => { EitherOrBoth::Left(vector) => {
// TODO: the below condition works because we erase the vec database when a embedding setting changes. // TODO: the below condition works because we erase the vec database when a embedding setting changes.
// When vector pipeline will be optimized, this should be removed. // When vector pipeline will be optimized, this should be removed.
if !settings_diff.reindex_vectors() { if !reindex_vectors {
// We insert only the Del part of the Obkv to inform // We insert only the Del part of the Obkv to inform
// that we only want to remove all those vectors. // that we only want to remove all those vectors.
let mut obkv = KvWriterDelAdd::memory(); let mut obkv = KvWriterDelAdd::memory();
@ -336,26 +361,6 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
a.iter().copied().map(OrderedFloat).cmp(b.iter().copied().map(OrderedFloat)) a.iter().copied().map(OrderedFloat).cmp(b.iter().copied().map(OrderedFloat))
} }
/// Extracts the vectors from a JSON value.
fn extract_vectors(
value: Value,
document_id: impl Fn() -> Value,
name: &str,
) -> Result<Vec<Vec<f32>>> {
// FIXME: ugly clone of the vectors here
match serde_json::from_value(value.clone()) {
Ok(vectors) => {
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
}
Err(_) => Err(UserError::InvalidVectorsType {
document_id: document_id(),
value,
subfield: name.to_owned(),
}
.into()),
}
}
#[tracing::instrument(level = "trace", skip_all, target = "indexing::extract")] #[tracing::instrument(level = "trace", skip_all, target = "indexing::extract")]
pub fn extract_embeddings<R: io::Read + io::Seek>( pub fn extract_embeddings<R: io::Read + io::Seek>(
// docid, prompt // docid, prompt

View File

@ -226,27 +226,31 @@ fn send_original_documents_data(
let original_documents_chunk = let original_documents_chunk =
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
let request_threads = ThreadPoolNoAbortBuilder::new() let request_threads = ThreadPoolNoAbortBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM) .num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}")) .thread_name(|index| format!("embedding-request-{index}"))
.build()?; .build()?;
if settings_diff.reindex_vectors() || !settings_diff.settings_update_only() { let index_vectors = (settings_diff.reindex_vectors() || !settings_diff.settings_update_only())
// no point in indexing vectors without embedders
&& (!settings_diff.new.embedding_configs.inner_as_ref().is_empty());
if index_vectors {
let settings_diff = settings_diff.clone(); let settings_diff = settings_diff.clone();
let original_documents_chunk = original_documents_chunk.clone();
let lmdb_writer_sx = lmdb_writer_sx.clone();
rayon::spawn(move || { rayon::spawn(move || {
for (name, (embedder, prompt)) in settings_diff.new.embedding_configs.clone() { match extract_vector_points(original_documents_chunk.clone(), indexer, settings_diff) {
let result = extract_vector_points( Ok(extracted_vectors) => {
documents_chunk_cloned.clone(), for ExtractedVectorPoints {
indexer, manual_vectors,
&settings_diff, remove_vectors,
&prompt, prompts,
&name, embedder_name,
); embedder,
match result { } in extracted_vectors
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { {
let embeddings = match extract_embeddings( let embeddings = match extract_embeddings(
prompts, prompts,
indexer, indexer,
@ -255,28 +259,26 @@ fn send_original_documents_data(
) { ) {
Ok(results) => Some(results), Ok(results) => Some(results),
Err(error) => { Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error)); let _ = lmdb_writer_sx.send(Err(error));
None None
} }
}; };
if !(remove_vectors.is_empty() if !(remove_vectors.is_empty()
&& manual_vectors.is_empty() && manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty())) && embeddings.as_ref().map_or(true, |e| e.is_empty()))
{ {
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { let _ = lmdb_writer_sx.send(Ok(TypedChunk::VectorPoints {
remove_vectors, remove_vectors,
embeddings, embeddings,
expected_dimension: embedder.dimensions(), expected_dimension: embedder.dimensions(),
manual_vectors, manual_vectors,
embedder_name: name, embedder_name,
})); }));
} }
} }
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
} }
Err(error) => {
let _ = lmdb_writer_sx.send(Err(error));
} }
} }
}); });

View File

@ -148,6 +148,10 @@ impl EmbeddingConfigs {
self.get(self.get_default_embedder_name()) self.get(self.get_default_embedder_name())
} }
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>)> {
&self.0
}
/// Get the name of the default embedder configuration. /// Get the name of the default embedder configuration.
/// ///
/// The default embedder is determined as follows: /// The default embedder is determined as follows: