From 096a28656ee3c1bba1900f2335e33a8a88677070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 28 Nov 2024 15:15:06 +0100 Subject: [PATCH] Fix a bug around deleting all the vectors of a doc --- crates/milli/src/update/new/channel.rs | 68 ++++++--------------- crates/milli/src/update/new/indexer/mod.rs | 7 ++- crates/milli/src/update/new/ref_cell_ext.rs | 1 + 3 files changed, 23 insertions(+), 53 deletions(-) diff --git a/crates/milli/src/update/new/channel.rs b/crates/milli/src/update/new/channel.rs index 237c19a5c..38f436837 100644 --- a/crates/milli/src/update/new/channel.rs +++ b/crates/milli/src/update/new/channel.rs @@ -146,15 +146,13 @@ pub struct LargeVectors { pub docid: DocumentId, /// The embedder id in which to insert the large embedding. pub embedder_id: u8, - /// The dimensions of the embeddings in this payload. - pub dimensions: u16, /// The large embedding that must be written. pub embeddings: Mmap, } impl LargeVectors { - pub fn read_embeddings(&self) -> impl Iterator { - self.embeddings.chunks_exact(self.dimensions as usize).map(bytemuck::cast_slice) + pub fn read_embeddings(&self, dimensions: usize) -> impl Iterator { + self.embeddings.chunks_exact(dimensions).map(bytemuck::cast_slice) } } @@ -241,15 +239,18 @@ impl ArroySetVector { &self, frame: &FrameGrantR<'_>, vec: &'v mut Vec, - ) -> &'v [f32] { + ) -> Option<&'v [f32]> { vec.clear(); let skip = EntryHeader::variant_size() + mem::size_of::(); let bytes = &frame[skip..]; + if bytes.is_empty() { + return None; + } bytes.chunks_exact(mem::size_of::()).for_each(|bytes| { let f = bytes.try_into().map(f32::from_ne_bytes).unwrap(); vec.push(f); }); - &vec[..] + Some(&vec[..]) } } @@ -259,9 +260,8 @@ impl ArroySetVector { /// non-aligned [f32] each with dimensions f32s. pub struct ArroySetVectors { pub docid: DocumentId, - pub dimensions: u16, pub embedder_id: u8, - _padding: u8, + _padding: [u8; 3], } impl ArroySetVectors { @@ -270,30 +270,6 @@ impl ArroySetVectors { &frame[skip..] } - // /// The number of embeddings in this payload. - // pub fn embedding_count(&self, frame: &FrameGrantR<'_>) -> usize { - // let bytes = Self::remaining_bytes(frame); - // bytes.len().checked_div(self.dimensions as usize).unwrap() - // } - - /// Read the embedding at `index` or `None` if out of bounds. - pub fn read_embedding_into_vec<'v>( - &self, - frame: &FrameGrantR<'_>, - index: usize, - vec: &'v mut Vec, - ) -> Option<&'v [f32]> { - vec.clear(); - let bytes = Self::remaining_bytes(frame); - let embedding_size = self.dimensions as usize * mem::size_of::(); - let embedding_bytes = bytes.chunks_exact(embedding_size).nth(index)?; - embedding_bytes.chunks_exact(mem::size_of::()).for_each(|bytes| { - let f = bytes.try_into().map(f32::from_ne_bytes).unwrap(); - vec.push(f); - }); - Some(&vec[..]) - } - /// Read all the embeddings and write them into an aligned `f32` Vec. pub fn read_all_embeddings_into_vec<'v>( &self, @@ -607,18 +583,14 @@ impl<'b> ExtractorBbqueueSender<'b> { let refcell = self.producers.get().unwrap(); let mut producer = refcell.0.borrow_mut_or_yield(); + // If there are no vector we specify the dimensions + // to zero to allocate no extra space at all let dimensions = match embeddings.first() { Some(embedding) => embedding.len(), - None => return Ok(()), - }; - - let arroy_set_vector = ArroySetVectors { - docid, - dimensions: dimensions.try_into().unwrap(), - embedder_id, - _padding: 0, + None => 0, }; + let arroy_set_vector = ArroySetVectors { docid, embedder_id, _padding: [0; 3] }; let payload_header = EntryHeader::ArroySetVectors(arroy_set_vector); let total_length = EntryHeader::total_set_vectors_size(embeddings.len(), dimensions); if total_length > capacity { @@ -632,13 +604,7 @@ impl<'b> ExtractorBbqueueSender<'b> { value_file.sync_all()?; let embeddings = unsafe { Mmap::map(&value_file)? }; - let large_vectors = LargeVectors { - docid, - embedder_id, - dimensions: dimensions.try_into().unwrap(), - embeddings, - }; - + let large_vectors = LargeVectors { docid, embedder_id, embeddings }; self.sender.send(ReceiverAction::LargeVectors(large_vectors)).unwrap(); return Ok(()); @@ -657,9 +623,11 @@ impl<'b> ExtractorBbqueueSender<'b> { let (header_bytes, remaining) = grant.split_at_mut(header_size); payload_header.serialize_into(header_bytes); - let output_iter = remaining.chunks_exact_mut(dimensions * mem::size_of::()); - for (embedding, output) in embeddings.iter().zip(output_iter) { - output.copy_from_slice(bytemuck::cast_slice(embedding)); + if dimensions != 0 { + let output_iter = remaining.chunks_exact_mut(dimensions * mem::size_of::()); + for (embedding, output) in embeddings.iter().zip(output_iter) { + output.copy_from_slice(bytemuck::cast_slice(embedding)); + } } // We could commit only the used memory. diff --git a/crates/milli/src/update/new/indexer/mod.rs b/crates/milli/src/update/new/indexer/mod.rs index a8a94cb7c..07cb9d69e 100644 --- a/crates/milli/src/update/new/indexer/mod.rs +++ b/crates/milli/src/update/new/indexer/mod.rs @@ -443,7 +443,7 @@ where let (_, _, writer, dimensions) = arroy_writers.get(&embedder_id).expect("requested a missing embedder"); let mut embeddings = Embeddings::new(*dimensions); - for embedding in large_vectors.read_embeddings() { + for embedding in large_vectors.read_embeddings(*dimensions) { embeddings.push(embedding.to_vec()).unwrap(); } writer.del_items(wtxn, *dimensions, docid)?; @@ -597,11 +597,12 @@ fn write_from_bbqueue( EntryHeader::ArroySetVector(asv) => { let ArroySetVector { docid, embedder_id, .. } = asv; let frame = frame_with_header.frame(); - let embedding = asv.read_embedding_into_vec(frame, aligned_embedding); let (_, _, writer, dimensions) = arroy_writers.get(&embedder_id).expect("requested a missing embedder"); writer.del_items(wtxn, *dimensions, docid)?; - writer.add_item(wtxn, docid, embedding)?; + if let Some(embedding) = asv.read_embedding_into_vec(frame, aligned_embedding) { + writer.add_item(wtxn, docid, embedding)?; + } } EntryHeader::ArroySetVectors(asvs) => { let ArroySetVectors { docid, embedder_id, .. } = asvs; diff --git a/crates/milli/src/update/new/ref_cell_ext.rs b/crates/milli/src/update/new/ref_cell_ext.rs index c66f4af0a..77f5fa800 100644 --- a/crates/milli/src/update/new/ref_cell_ext.rs +++ b/crates/milli/src/update/new/ref_cell_ext.rs @@ -5,6 +5,7 @@ pub trait RefCellExt { &self, ) -> std::result::Result, std::cell::BorrowMutError>; + #[track_caller] fn borrow_mut_or_yield(&self) -> RefMut<'_, T> { self.try_borrow_mut_or_yield().unwrap() }