diff --git a/crates/milli/src/update/new/channel.rs b/crates/milli/src/update/new/channel.rs index 26e375a5a..7eaa50df1 100644 --- a/crates/milli/src/update/new/channel.rs +++ b/crates/milli/src/update/new/channel.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::io::{self, BufWriter}; use std::marker::PhantomData; use std::mem; use std::num::NonZeroU16; @@ -9,7 +10,7 @@ use bytemuck::{checked, CheckedBitPattern, NoUninit}; use crossbeam_channel::SendError; use heed::types::Bytes; use heed::BytesDecode; -use memmap2::Mmap; +use memmap2::{Mmap, MmapMut}; use roaring::RoaringBitmap; use super::extract::FacetKind; @@ -98,20 +99,63 @@ pub struct WriterBbqueueReceiver<'a> { pub enum ReceiverAction { /// Wake up, you have frames to read for the BBQueue buffers. WakeUp, - /// An entry that cannot fit in the BBQueue buffers has been - /// written to disk, memory-mapped and must be written in the - /// database. - LargeEntry { - /// The database where the entry must be written. - database: Database, - /// The key of the entry that must be written in the database. - key: Box<[u8]>, - /// The large value that must be written. - /// - /// Note: We can probably use a `File` here and - /// use `Database::put_reserved` instead of memory-mapping. - value: Mmap, - }, + LargeEntry(LargeEntry), + LargeVector(LargeVector), + LargeVectors(LargeVectors), +} + +/// An entry that cannot fit in the BBQueue buffers has been +/// written to disk, memory-mapped and must be written in the +/// database. +#[derive(Debug)] +pub struct LargeEntry { + /// The database where the entry must be written. + pub database: Database, + /// The key of the entry that must be written in the database. + pub key: Box<[u8]>, + /// The large value that must be written. + /// + /// Note: We can probably use a `File` here and + /// use `Database::put_reserved` instead of memory-mapping. + pub value: Mmap, +} + +/// When an embedding is larger than the available +/// BBQueue space it arrives here. +#[derive(Debug)] +pub struct LargeVector { + /// The document id associated to the large embedding. + pub docid: DocumentId, + /// The embedder id in which to insert the large embedding. + pub embedder_id: u8, + /// The large embedding that must be written. + pub embedding: Mmap, +} + +impl LargeVector { + pub fn read_embedding(&self) -> &[f32] { + bytemuck::cast_slice(&self.embedding) + } +} + +/// When embeddings are larger than the available +/// BBQueue space it arrives here. +#[derive(Debug)] +pub struct LargeVectors { + /// The document id associated to the large embedding. + pub docid: DocumentId, + /// The embedder id in which to insert the large embedding. + pub embedder_id: u8, + /// The dimensions of the embeddings in this payload. + pub dimensions: u16, + /// The large embedding that must be written. + pub embeddings: Mmap, +} + +impl LargeVectors { + pub fn read_embeddings(&self) -> impl Iterator { + self.embeddings.chunks_exact(self.dimensions as usize).map(bytemuck::cast_slice) + } } impl<'a> WriterBbqueueReceiver<'a> { @@ -209,12 +253,55 @@ impl ArroySetVector { } } +#[derive(Debug, Clone, Copy, NoUninit, CheckedBitPattern)] +#[repr(C)] +/// The embeddings are in the remaining space and represents +/// non-aligned [f32] each with dimensions f32s. +pub struct ArroySetVectors { + pub docid: DocumentId, + pub dimensions: u16, + pub embedder_id: u8, + _padding: u8, +} + +impl ArroySetVectors { + fn remaining_bytes<'a>(frame: &'a FrameGrantR<'_>) -> &'a [u8] { + let skip = EntryHeader::variant_size() + mem::size_of::(); + &frame[skip..] + } + + // /// The number of embeddings in this payload. + // pub fn embedding_count(&self, frame: &FrameGrantR<'_>) -> usize { + // let bytes = Self::remaining_bytes(frame); + // bytes.len().checked_div(self.dimensions as usize).unwrap() + // } + + /// Read the embedding at `index` or `None` if out of bounds. + pub fn read_embedding_into_vec<'v>( + &self, + frame: &FrameGrantR<'_>, + index: usize, + vec: &'v mut Vec, + ) -> Option<&'v [f32]> { + vec.clear(); + let bytes = Self::remaining_bytes(frame); + let embedding_size = self.dimensions as usize * mem::size_of::(); + let embedding_bytes = bytes.chunks_exact(embedding_size).nth(index)?; + embedding_bytes.chunks_exact(mem::size_of::()).for_each(|bytes| { + let f = bytes.try_into().map(f32::from_ne_bytes).unwrap(); + vec.push(f); + }); + Some(&vec[..]) + } +} + #[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum EntryHeader { DbOperation(DbOperation), ArroyDeleteVector(ArroyDeleteVector), ArroySetVector(ArroySetVector), + ArroySetVectors(ArroySetVectors), } impl EntryHeader { @@ -227,6 +314,7 @@ impl EntryHeader { EntryHeader::DbOperation(_) => 0, EntryHeader::ArroyDeleteVector(_) => 1, EntryHeader::ArroySetVector(_) => 2, + EntryHeader::ArroySetVectors(_) => 3, } } @@ -245,11 +333,15 @@ impl EntryHeader { Self::variant_size() + mem::size_of::() } - /// The `embedding_length` corresponds to the number of `f32` in the embedding. - fn total_set_vector_size(embedding_length: usize) -> usize { - Self::variant_size() - + mem::size_of::() - + embedding_length * mem::size_of::() + /// The `dimensions` corresponds to the number of `f32` in the embedding. + fn total_set_vector_size(dimensions: usize) -> usize { + Self::variant_size() + mem::size_of::() + dimensions * mem::size_of::() + } + + /// The `dimensions` corresponds to the number of `f32` in the embedding. + fn total_set_vectors_size(count: usize, dimensions: usize) -> usize { + let embedding_size = dimensions * mem::size_of::(); + Self::variant_size() + mem::size_of::() + embedding_size * count } fn header_size(&self) -> usize { @@ -257,6 +349,7 @@ impl EntryHeader { EntryHeader::DbOperation(op) => mem::size_of_val(op), EntryHeader::ArroyDeleteVector(adv) => mem::size_of_val(adv), EntryHeader::ArroySetVector(asv) => mem::size_of_val(asv), + EntryHeader::ArroySetVectors(asvs) => mem::size_of_val(asvs), }; Self::variant_size() + payload_size } @@ -279,6 +372,11 @@ impl EntryHeader { let header = checked::pod_read_unaligned(header_bytes); EntryHeader::ArroySetVector(header) } + 3 => { + let header_bytes = &remaining[..mem::size_of::()]; + let header = checked::pod_read_unaligned(header_bytes); + EntryHeader::ArroySetVectors(header) + } id => panic!("invalid variant id: {id}"), } } @@ -289,6 +387,7 @@ impl EntryHeader { EntryHeader::DbOperation(op) => bytemuck::bytes_of(op), EntryHeader::ArroyDeleteVector(adv) => bytemuck::bytes_of(adv), EntryHeader::ArroySetVector(asv) => bytemuck::bytes_of(asv), + EntryHeader::ArroySetVectors(asvs) => bytemuck::bytes_of(asvs), }; *first = self.variant_id(); remaining.copy_from_slice(payload_bytes); @@ -405,7 +504,7 @@ impl<'b> ExtractorBbqueueSender<'b> { let payload_header = EntryHeader::ArroyDeleteVector(ArroyDeleteVector { docid }); let total_length = EntryHeader::total_delete_vector_size(); if total_length > capacity { - unreachable!("entry larger that the BBQueue capacity"); + panic!("The entry is larger ({total_length} bytes) than the BBQueue capacity ({capacity} bytes)"); } // Spin loop to have a frame the size we requested. @@ -441,11 +540,21 @@ impl<'b> ExtractorBbqueueSender<'b> { let refcell = self.producers.get().unwrap(); let mut producer = refcell.0.borrow_mut_or_yield(); - let payload_header = - EntryHeader::ArroySetVector(ArroySetVector { docid, embedder_id, _padding: [0; 3] }); + let arroy_set_vector = ArroySetVector { docid, embedder_id, _padding: [0; 3] }; + let payload_header = EntryHeader::ArroySetVector(arroy_set_vector); let total_length = EntryHeader::total_set_vector_size(embedding.len()); if total_length > capacity { - unreachable!("entry larger that the BBQueue capacity"); + let mut embedding_bytes = bytemuck::cast_slice(embedding); + let mut value_file = tempfile::tempfile().map(BufWriter::new)?; + io::copy(&mut embedding_bytes, &mut value_file)?; + let value_file = value_file.into_inner().map_err(|ie| ie.into_error())?; + value_file.sync_all()?; + let embedding = unsafe { Mmap::map(&value_file)? }; + + let large_vector = LargeVector { docid, embedder_id, embedding }; + self.sender.send(ReceiverAction::LargeVector(large_vector)).unwrap(); + + return Ok(()); } // Spin loop to have a frame the size we requested. @@ -457,7 +566,6 @@ impl<'b> ExtractorBbqueueSender<'b> { } }; - // payload_header.serialize_into(&mut grant); let header_size = payload_header.header_size(); let (header_bytes, remaining) = grant.split_at_mut(header_size); payload_header.serialize_into(header_bytes); @@ -475,6 +583,83 @@ impl<'b> ExtractorBbqueueSender<'b> { Ok(()) } + fn set_vectors( + &self, + docid: u32, + embedder_id: u8, + embeddings: &[Vec], + ) -> crate::Result<()> { + let capacity = self.capacity; + let refcell = self.producers.get().unwrap(); + let mut producer = refcell.0.borrow_mut_or_yield(); + + let dimensions = match embeddings.first() { + Some(embedding) => embedding.len(), + None => return Ok(()), + }; + + let arroy_set_vector = ArroySetVectors { + docid, + dimensions: dimensions.try_into().unwrap(), + embedder_id, + _padding: 0, + }; + + let payload_header = EntryHeader::ArroySetVectors(arroy_set_vector); + let total_length = EntryHeader::total_set_vectors_size(embeddings.len(), dimensions); + if total_length > capacity { + let mut value_file = tempfile::tempfile().map(BufWriter::new)?; + for embedding in embeddings { + let mut embedding_bytes = bytemuck::cast_slice(embedding); + io::copy(&mut embedding_bytes, &mut value_file)?; + } + + let value_file = value_file.into_inner().map_err(|ie| ie.into_error())?; + value_file.sync_all()?; + let embeddings = unsafe { Mmap::map(&value_file)? }; + + let large_vectors = LargeVectors { + docid, + embedder_id, + dimensions: dimensions.try_into().unwrap(), + embeddings, + }; + + self.sender.send(ReceiverAction::LargeVectors(large_vectors)).unwrap(); + + return Ok(()); + } + + // Spin loop to have a frame the size we requested. + let mut grant = loop { + match producer.grant(total_length) { + Ok(grant) => break grant, + Err(bbqueue::Error::InsufficientSize) => continue, + Err(e) => unreachable!("{e:?}"), + } + }; + + let header_size = payload_header.header_size(); + let (header_bytes, remaining) = grant.split_at_mut(header_size); + payload_header.serialize_into(header_bytes); + + let output_iter = remaining.chunks_exact_mut(dimensions * mem::size_of::()); + for (embedding, output) in embeddings.iter().zip(output_iter) { + output.copy_from_slice(bytemuck::cast_slice(embedding)); + } + + // We could commit only the used memory. + grant.commit(total_length); + + // We only send a wake up message when the channel is empty + // so that we don't fill the channel with too many WakeUps. + if self.sender.is_empty() { + self.sender.send(ReceiverAction::WakeUp).unwrap(); + } + + Ok(()) + } + fn write_key_value(&self, database: Database, key: &[u8], value: &[u8]) -> crate::Result<()> { let key_length = NonZeroU16::new(key.len().try_into().unwrap()).unwrap(); self.write_key_value_with(database, key_length, value.len(), |key_buffer, value_buffer| { @@ -502,7 +687,22 @@ impl<'b> ExtractorBbqueueSender<'b> { let payload_header = EntryHeader::DbOperation(operation); let total_length = EntryHeader::total_key_value_size(key_length, value_length); if total_length > capacity { - unreachable!("entry larger that the BBQueue capacity"); + let mut key_buffer = vec![0; key_length.get() as usize].into_boxed_slice(); + let value_file = tempfile::tempfile()?; + value_file.set_len(value_length.try_into().unwrap())?; + let mut mmap_mut = unsafe { MmapMut::map_mut(&value_file)? }; + + key_value_writer(&mut key_buffer, &mut mmap_mut)?; + + self.sender + .send(ReceiverAction::LargeEntry(LargeEntry { + database, + key: key_buffer, + value: mmap_mut.make_read_only()?, + })) + .unwrap(); + + return Ok(()); } // Spin loop to have a frame the size we requested. @@ -559,7 +759,7 @@ impl<'b> ExtractorBbqueueSender<'b> { let payload_header = EntryHeader::DbOperation(operation); let total_length = EntryHeader::total_key_size(key_length); if total_length > capacity { - unreachable!("entry larger that the BBQueue capacity"); + panic!("The entry is larger ({total_length} bytes) than the BBQueue capacity ({capacity} bytes)"); } // Spin loop to have a frame the size we requested. @@ -763,10 +963,7 @@ impl EmbeddingSender<'_, '_> { embedder_id: u8, embeddings: Vec, ) -> crate::Result<()> { - for embedding in embeddings { - self.set_vector(docid, embedder_id, embedding)?; - } - Ok(()) + self.0.set_vectors(docid, embedder_id, &embeddings[..]) } pub fn set_vector( @@ -786,11 +983,11 @@ impl GeoSender<'_, '_> { pub fn set_rtree(&self, value: Mmap) -> StdResult<(), SendError<()>> { self.0 .sender - .send(ReceiverAction::LargeEntry { + .send(ReceiverAction::LargeEntry(LargeEntry { database: Database::Main, key: GEO_RTREE_KEY.to_string().into_bytes().into_boxed_slice(), value, - }) + })) .map_err(|_| SendError(())) } diff --git a/crates/milli/src/update/new/indexer/mod.rs b/crates/milli/src/update/new/indexer/mod.rs index 3a4406aef..9ad7a8f0b 100644 --- a/crates/milli/src/update/new/indexer/mod.rs +++ b/crates/milli/src/update/new/indexer/mod.rs @@ -16,6 +16,7 @@ use rand::SeedableRng as _; use raw_collections::RawMap; use time::OffsetDateTime; pub use update_by_function::UpdateByFunction; +use {LargeEntry, LargeVector}; use super::channel::*; use super::extract::*; @@ -40,7 +41,7 @@ use crate::update::new::words_prefix_docids::compute_exact_word_prefix_docids; use crate::update::new::{merge_and_send_docids, merge_and_send_facet_docids, FacetDatabases}; use crate::update::settings::InnerIndexSettings; use crate::update::{FacetsUpdateBulk, GrenadParameters}; -use crate::vector::{ArroyWrapper, EmbeddingConfigs}; +use crate::vector::{ArroyWrapper, EmbeddingConfigs, Embeddings}; use crate::{ Error, FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder, UserError, @@ -132,7 +133,8 @@ where { let span = tracing::trace_span!(target: "indexing::documents::extract", parent: &indexer_span, "documents"); let _entered = span.enter(); - extract(document_changes, + extract( + document_changes, &document_extractor, indexing_context, &mut extractor_allocs, @@ -416,7 +418,7 @@ where match action { ReceiverAction::WakeUp => (), - ReceiverAction::LargeEntry { database, key, value } => { + ReceiverAction::LargeEntry(LargeEntry { database, key, value }) => { let database_name = database.database_name(); let database = database.database(index); if let Err(error) = database.put(wtxn, &key, &value) { @@ -428,6 +430,24 @@ where })); } } + ReceiverAction::LargeVector(large_vector) => { + let embedding = large_vector.read_embedding(); + let LargeVector { docid, embedder_id, .. } = large_vector; + let (_, _, writer, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + writer.del_items(wtxn, *dimensions, docid)?; + writer.add_item(wtxn, docid, embedding)?; + } + ReceiverAction::LargeVectors(large_vectors) => { + let LargeVectors { docid, embedder_id, .. } = large_vectors; + let (_, _, writer, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + writer.del_items(wtxn, *dimensions, docid)?; + let mut embeddings = Embeddings::new(*dimensions); + for embedding in large_vectors.read_embeddings() { + embeddings.push(embedding.to_vec()).unwrap(); + } + } } // Every time the is a message in the channel we search @@ -582,6 +602,19 @@ fn write_from_bbqueue( writer.del_items(wtxn, *dimensions, docid)?; writer.add_item(wtxn, docid, embedding)?; } + EntryHeader::ArroySetVectors(asvs) => { + let ArroySetVectors { docid, embedder_id, .. } = asvs; + let frame = frame_with_header.frame(); + let (_, _, writer, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + writer.del_items(wtxn, *dimensions, docid)?; + for index in 0.. { + match asvs.read_embedding_into_vec(frame, index, aligned_embedding) { + Some(embedding) => writer.add_item(wtxn, docid, embedding)?, + None => break, + } + } + } } }