From e83534a4305963c857423cf03c3612e4e31a2b07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 27 Nov 2024 16:27:43 +0100 Subject: [PATCH] Fix the indexer::index to correctly use the rayon::ThreadPool --- crates/milli/src/update/new/channel.rs | 49 +++++----------------- crates/milli/src/update/new/indexer/mod.rs | 17 ++++---- 2 files changed, 19 insertions(+), 47 deletions(-) diff --git a/crates/milli/src/update/new/channel.rs b/crates/milli/src/update/new/channel.rs index beba80ac8..70c4a6042 100644 --- a/crates/milli/src/update/new/channel.rs +++ b/crates/milli/src/update/new/channel.rs @@ -4,6 +4,7 @@ use std::mem; use std::num::NonZeroU16; use bbqueue::framed::{FrameGrantR, FrameProducer}; +use bbqueue::BBBuffer; use bytemuck::{checked, CheckedBitPattern, NoUninit}; use crossbeam_channel::SendError; use heed::types::Bytes; @@ -25,6 +26,9 @@ use crate::{CboRoaringBitmapCodec, DocumentId, Index}; /// Creates a tuple of senders/receiver to be used by /// the extractors and the writer loop. /// +/// The `bbqueue_capacity` represent the number of bytes allocated +/// to each BBQueue buffer and is not the sum of all of them. +/// /// The `channel_capacity` parameter defines the number of /// too-large-to-fit-in-BBQueue entries that can be sent through /// a crossbeam channel. This parameter must stay low to make @@ -40,14 +44,11 @@ use crate::{CboRoaringBitmapCodec, DocumentId, Index}; /// Panics if the number of provided BBQueues is not exactly equal /// to the number of available threads in the rayon threadpool. pub fn extractor_writer_bbqueue( - bbbuffers: &[bbqueue::BBBuffer], + bbbuffers: &mut Vec, + bbbuffer_capacity: usize, channel_capacity: usize, ) -> (ExtractorBbqueueSender, WriterBbqueueReceiver) { - assert_eq!( - bbbuffers.len(), - rayon::current_num_threads(), - "You must provide as many BBBuffer as the available number of threads to extract" - ); + bbbuffers.resize_with(rayon::current_num_threads(), || BBBuffer::new(bbbuffer_capacity)); let capacity = bbbuffers.first().unwrap().capacity(); // Read the field description to understand this @@ -55,12 +56,6 @@ pub fn extractor_writer_bbqueue( let producers = ThreadLocal::with_capacity(bbbuffers.len()); let consumers = rayon::broadcast(|bi| { - eprintln!( - "hello thread #{:?} (#{:?}, #{:?})", - bi.index(), - std::thread::current().name(), - std::thread::current().id(), - ); let bbqueue = &bbbuffers[bi.index()]; let (producer, consumer) = bbqueue.try_split_framed().unwrap(); producers.get_or(|| FullySend(RefCell::new(producer))); @@ -405,15 +400,7 @@ impl<'b> ExtractorBbqueueSender<'b> { fn delete_vector(&self, docid: DocumentId) -> crate::Result<()> { let capacity = self.capacity; - let refcell = match self.producers.get() { - Some(refcell) => refcell, - None => panic!( - "hello thread #{:?} (#{:?}, #{:?})", - rayon::current_thread_index(), - std::thread::current().name(), - std::thread::current().id() - ), - }; + let refcell = self.producers.get().unwrap(); let mut producer = refcell.0.borrow_mut_or_yield(); let payload_header = EntryHeader::ArroyDeleteVector(ArroyDeleteVector { docid }); @@ -452,15 +439,7 @@ impl<'b> ExtractorBbqueueSender<'b> { embedding: &[f32], ) -> crate::Result<()> { let capacity = self.capacity; - let refcell = match self.producers.get() { - Some(refcell) => refcell, - None => panic!( - "hello thread #{:?} (#{:?}, #{:?})", - rayon::current_thread_index(), - std::thread::current().name(), - std::thread::current().id() - ), - }; + let refcell = self.producers.get().unwrap(); let mut producer = refcell.0.borrow_mut_or_yield(); let payload_header = @@ -518,15 +497,7 @@ impl<'b> ExtractorBbqueueSender<'b> { F: FnOnce(&mut [u8]) -> crate::Result<()>, { let capacity = self.capacity; - let refcell = match self.producers.get() { - Some(refcell) => refcell, - None => panic!( - "hello thread #{:?} (#{:?}, #{:?})", - rayon::current_thread_index(), - std::thread::current().name(), - std::thread::current().id() - ), - }; + let refcell = self.producers.get().unwrap(); let mut producer = refcell.0.borrow_mut_or_yield(); let operation = DbOperation { database, key_length: Some(key_length) }; diff --git a/crates/milli/src/update/new/indexer/mod.rs b/crates/milli/src/update/new/indexer/mod.rs index b7d5431b4..3a4406aef 100644 --- a/crates/milli/src/update/new/indexer/mod.rs +++ b/crates/milli/src/update/new/indexer/mod.rs @@ -77,17 +77,18 @@ where MSP: Fn() -> bool + Sync, SP: Fn(Progress) + Sync, { - /// TODO restrict memory and remove this memory from the extractors bump allocators - let bbbuffers: Vec<_> = pool + let mut bbbuffers = Vec::new(); + let finished_extraction = AtomicBool::new(false); + let (extractor_sender, mut writer_receiver) = pool .install(|| { - (0..rayon::current_num_threads()) - .map(|_| bbqueue::BBBuffer::new(100 * 1024 * 1024)) // 100 MiB by thread - .collect() + /// TODO restrict memory and remove this memory from the extractors bump allocators + extractor_writer_bbqueue( + &mut bbbuffers, + 100 * 1024 * 1024, // 100 MiB + 1000, + ) }) .unwrap(); - let (extractor_sender, mut writer_receiver) = - pool.install(|| extractor_writer_bbqueue(&bbbuffers, 1000)).unwrap(); - let finished_extraction = AtomicBool::new(false); let metadata_builder = MetadataBuilder::from_index(index, wtxn)?; let new_fields_ids_map = FieldIdMapWithMetadata::new(new_fields_ids_map, metadata_builder);