From aaefbfae1f445d50d9eb856ae09283c34fe7109d Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 28 Jan 2025 16:53:34 +0100 Subject: [PATCH] Do not create too many rayon tasks --- crates/milli/src/thread_pool_no_abort.rs | 17 +++++++++-- crates/milli/src/vector/ollama.rs | 38 +++++++++++++++--------- crates/milli/src/vector/openai.rs | 37 ++++++++++++++--------- crates/milli/src/vector/rest.rs | 36 ++++++++++++++-------- 4 files changed, 85 insertions(+), 43 deletions(-) diff --git a/crates/milli/src/thread_pool_no_abort.rs b/crates/milli/src/thread_pool_no_abort.rs index 14e5b0491..b57050a63 100644 --- a/crates/milli/src/thread_pool_no_abort.rs +++ b/crates/milli/src/thread_pool_no_abort.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use rayon::{ThreadPool, ThreadPoolBuilder}; @@ -9,6 +9,8 @@ use thiserror::Error; #[derive(Debug)] pub struct ThreadPoolNoAbort { thread_pool: ThreadPool, + /// The number of active operations. + active_operations: AtomicUsize, /// Set to true if the thread pool catched a panic. pool_catched_panic: Arc, } @@ -19,7 +21,9 @@ impl ThreadPoolNoAbort { OP: FnOnce() -> R + Send, R: Send, { + self.active_operations.fetch_add(1, Ordering::Relaxed); let output = self.thread_pool.install(op); + self.active_operations.fetch_sub(1, Ordering::Relaxed); // While reseting the pool panic catcher we return an error if we catched one. if self.pool_catched_panic.swap(false, Ordering::SeqCst) { Err(PanicCatched) @@ -31,6 +35,11 @@ impl ThreadPoolNoAbort { pub fn current_num_threads(&self) -> usize { self.thread_pool.current_num_threads() } + + /// The number of active operations. + pub fn active_operations(&self) -> usize { + self.active_operations.load(Ordering::Relaxed) + } } #[derive(Error, Debug)] @@ -64,6 +73,10 @@ impl ThreadPoolNoAbortBuilder { let catched_panic = pool_catched_panic.clone(); move |_result| catched_panic.store(true, Ordering::SeqCst) }); - Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic }) + Ok(ThreadPoolNoAbort { + thread_pool: self.0.build()?, + active_operations: AtomicUsize::new(0), + pool_catched_panic, + }) } } diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index cc70e2c47..2276bbd3e 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::DistributionShift; +use super::{DistributionShift, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::vector::Embedding; use crate::ThreadPoolNoAbort; @@ -133,20 +133,30 @@ impl Embedder { texts: &[&str], threads: &ThreadPoolNoAbort, ) -> Result>, EmbedError> { - threads - .install(move || { - let embeddings: Result>, _> = texts - .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None)) - .collect(); + if threads.active_operations() >= REQUEST_PARALLELISM { + let embeddings: Result>, _> = texts + .chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); - let embeddings = embeddings?; - Ok(embeddings.into_iter().flatten().collect()) - }) - .map_err(|error| EmbedError { - kind: EmbedErrorKind::PanicInThreadPool(error), - fault: FaultSource::Bug, - })? + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + } else { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } } pub fn chunk_count_hint(&self) -> usize { diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index 938c04fe3..c9da3d2da 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -7,7 +7,7 @@ use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::DistributionShift; +use super::{DistributionShift, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; use crate::vector::Embedding; @@ -270,20 +270,29 @@ impl Embedder { texts: &[&str], threads: &ThreadPoolNoAbort, ) -> Result>, EmbedError> { - threads - .install(move || { - let embeddings: Result>, _> = texts - .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None)) - .collect(); + if threads.active_operations() >= REQUEST_PARALLELISM { + let embeddings: Result>, _> = texts + .chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + } else { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); - let embeddings = embeddings?; - Ok(embeddings.into_iter().flatten().collect()) - }) - .map_err(|error| EmbedError { - kind: EmbedErrorKind::PanicInThreadPool(error), - fault: FaultSource::Bug, - })? + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } } pub fn chunk_count_hint(&self) -> usize { diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index eb05bac64..0abb98315 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -203,20 +203,30 @@ impl Embedder { texts: &[&str], threads: &ThreadPoolNoAbort, ) -> Result, EmbedError> { - threads - .install(move || { - let embeddings: Result>, _> = texts - .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed_ref(chunk, None)) - .collect(); + if threads.active_operations() >= REQUEST_PARALLELISM { + let embeddings: Result>, _> = texts + .chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed_ref(chunk, None)) + .collect(); - let embeddings = embeddings?; - Ok(embeddings.into_iter().flatten().collect()) - }) - .map_err(|error| EmbedError { - kind: EmbedErrorKind::PanicInThreadPool(error), - fault: FaultSource::Bug, - })? + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + } else { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed_ref(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } } pub fn chunk_count_hint(&self) -> usize {