diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 1f5edeeeb..317a9aec3 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,15 +1,25 @@ +use std::cmp::Ordering; use std::convert::TryFrom; use std::fs::File; -use std::io::{self, BufReader}; +use std::io::{self, BufReader, BufWriter}; +use std::mem::size_of; +use std::str::from_utf8; use bytemuck::cast_slice; +use grenad::Writer; +use itertools::EitherOrBoth; +use ordered_float::OrderedFloat; use serde_json::{from_slice, Value}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::error::UserError; +use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; +/// The length of the elements that are always in the buffer when inserting new values. +const TRUNCATE_SIZE: usize = size_of::(); + /// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec @@ -27,45 +37,112 @@ pub fn extract_vector_points( tempfile::tempfile()?, ); + let mut key_buffer = Vec::new(); let mut cursor = obkv_documents.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // this must always be serialized as (docid, external_docid); let (docid_bytes, external_id_bytes) = try_split_at(key, std::mem::size_of::()).unwrap(); - debug_assert!(std::str::from_utf8(external_id_bytes).is_ok()); + debug_assert!(from_utf8(external_id_bytes).is_ok()); let obkv = obkv::KvReader::new(value); + key_buffer.clear(); + key_buffer.extend_from_slice(docid_bytes); // since we only needs the primary key when we throw an error we create this getter to // lazily get it when needed - let document_id = || -> Value { std::str::from_utf8(external_id_bytes).unwrap().into() }; + let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; // first we retrieve the _vectors field - if let Some(vectors) = obkv.get(vectors_fid) { - // extract the vectors - let vectors = match from_slice(vectors) { - Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors), - Err(_) => { - return Err(UserError::InvalidVectorsType { - document_id: document_id(), - value: from_slice(vectors).map_err(InternalError::SerdeJson)?, - } - .into()) - } - }; + if let Some(value) = obkv.get(vectors_fid) { + let vectors_obkv = KvReaderDelAdd::new(value); - if let Some(vectors) = vectors { - for (i, vector) in vectors.into_iter().enumerate().take(u16::MAX as usize) { - let index = u16::try_from(i).unwrap(); - let mut key = docid_bytes.to_vec(); - key.extend_from_slice(&index.to_be_bytes()); - let bytes = cast_slice(&vector); - writer.insert(key, bytes)?; - } - } + // then we extract the values + let del_vectors = vectors_obkv + .get(DelAdd::Deletion) + .map(|vectors| extract_vectors(vectors, document_id)) + .transpose()? + .flatten(); + let add_vectors = vectors_obkv + .get(DelAdd::Addition) + .map(|vectors| extract_vectors(vectors, document_id)) + .transpose()? + .flatten(); + + // and we finally push the unique vectors into the writer + push_vectors_diff( + &mut writer, + &mut key_buffer, + del_vectors.unwrap_or_default(), + add_vectors.unwrap_or_default(), + )?; } - // else => the `_vectors` object was `null`, there is nothing to do } writer_into_reader(writer) } + +/// Computes the diff between both Del and Add numbers and +/// only inserts the parts that differ in the sorter. +fn push_vectors_diff( + writer: &mut Writer>, + key_buffer: &mut Vec, + mut del_vectors: Vec>, + mut add_vectors: Vec>, +) -> Result<()> { + // We sort and dedup the vectors + del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); + add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); + del_vectors.dedup_by(|a, b| compare_vectors(a, b).is_eq()); + add_vectors.dedup_by(|a, b| compare_vectors(a, b).is_eq()); + + let merged_vectors_iter = + itertools::merge_join_by(del_vectors, add_vectors, |del, add| compare_vectors(del, add)); + + // insert vectors into the writer + for (i, eob) in merged_vectors_iter.into_iter().enumerate().take(u16::MAX as usize) { + // Generate the key by extending the unique index to it. + key_buffer.truncate(TRUNCATE_SIZE); + let index = u16::try_from(i).unwrap(); + key_buffer.extend_from_slice(&index.to_be_bytes()); + + match eob { + EitherOrBoth::Both(_, _) => (), // no need to touch anything + EitherOrBoth::Left(vector) => { + // We insert only the Del part of the Obkv to inform + // that we only want to remove all those vectors. + let mut obkv = KvWriterDelAdd::memory(); + obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; + let bytes = obkv.into_inner()?; + writer.insert(&key_buffer, bytes)?; + } + EitherOrBoth::Right(vector) => { + // We insert only the Add part of the Obkv to inform + // that we only want to remove all those vectors. + let mut obkv = KvWriterDelAdd::memory(); + obkv.insert(DelAdd::Addition, cast_slice(&vector))?; + let bytes = obkv.into_inner()?; + writer.insert(&key_buffer, bytes)?; + } + } + } + + Ok(()) +} + +/// Compares two vectors by using the OrderingFloat helper. +fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering { + a.iter().copied().map(OrderedFloat).cmp(b.iter().copied().map(OrderedFloat)) +} + +/// Extracts the vectors from a JSON value. +fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result>>> { + match from_slice(value) { + Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), + Err(_) => Err(UserError::InvalidVectorsType { + document_id: document_id(), + value: from_slice(value).map_err(InternalError::SerdeJson)?, + } + .into()), + } +} diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 7c3f587d2..80671e39f 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::convert::TryInto; use std::fs::File; use std::io::{self, BufReader}; @@ -8,7 +8,9 @@ use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; +use log::error; use obkv::{KvReader, KvWriter}; +use ordered_float::OrderedFloat; use roaring::RoaringBitmap; use super::helpers::{self, merge_ignore_values, valid_lmdb_key, CursorClonableMmap}; @@ -22,10 +24,9 @@ use crate::index::Hnsw; use crate::update::del_add::{DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::update::index_documents::validate_document_id_value; use crate::{ - lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, FieldId, GeoPoint, Index, InternalError, - Result, SerializationError, BEU32, + lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, FieldId, GeoPoint, Index, Result, + SerializationError, BEU32, }; pub(crate) enum TypedChunk { @@ -366,44 +367,70 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } TypedChunk::VectorPoints(vector_points) => { - let (pids, mut points): (Vec<_>, Vec<_>) = match index.vector_hnsw(wtxn)? { - Some(hnsw) => hnsw.iter().map(|(pid, point)| (pid, point.clone())).unzip(), - None => Default::default(), - }; - - // Convert the PointIds into DocumentIds - let mut docids = Vec::new(); - for pid in pids { - let docid = - index.vector_id_docid.get(wtxn, &BEU32::new(pid.into_inner()))?.unwrap(); - docids.push(docid.get()); + let mut vectors_set = HashSet::new(); + // We extract and store the previous vectors + if let Some(hnsw) = index.vector_hnsw(wtxn)? { + for (pid, point) in hnsw.iter() { + let pid_key = BEU32::new(pid.into_inner()); + let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap().get(); + let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); + vectors_set.insert((docid, vector)); + } } - let mut expected_dimensions = points.get(0).map(|p| p.len()); let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let (left, _index) = try_split_array_at(key).unwrap(); let docid = DocumentId::from_be_bytes(left); - // convert the vector back to a Vec - let vector: Vec = pod_collect_to_vec(value); - // TODO Inform the user about the document that has a wrong `_vectors` - let found = vector.len(); - let expected = *expected_dimensions.get_or_insert(found); - if expected != found { - return Err(UserError::InvalidVectorDimensions { expected, found }.into()); + let vector_deladd_obkv = KvReaderDelAdd::new(value); + if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { + // convert the vector back to a Vec + let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + let key = (docid, vector); + if !vectors_set.remove(&key) { + error!("Unable to delete the vector: {:?}", key.1); + } + } + if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { + // convert the vector back to a Vec + let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + vectors_set.insert((docid, vector)); } - - points.push(NDotProductPoint::new(vector)); - docids.push(docid); } - assert_eq!(docids.len(), points.len()); + // Extract the most common vector dimension + let expected_dimension_size = { + let mut dims = HashMap::new(); + vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); + dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) + }; + + // Ensure that the vector lenghts are correct and + // prepare the vectors before inserting them in the HNSW. + let mut points = Vec::new(); + let mut docids = Vec::new(); + for (docid, vector) in vectors_set { + if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { + return Err(UserError::InvalidVectorDimensions { + expected: expected_dimension_size.unwrap_or(vector.len()), + found: vector.len(), + } + .into()); + } else { + let vector = vector.into_iter().map(OrderedFloat::into_inner).collect(); + points.push(NDotProductPoint::new(vector)); + docids.push(docid); + } + } let hnsw_length = points.len(); let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); + assert_eq!(docids.len(), pids.len()); + + // Store the vectors in the point-docid relation database index.vector_id_docid.clear(wtxn)?; for (docid, pid) in docids.into_iter().zip(pids) { index.vector_id_docid.put(