diff --git a/crates/index-scheduler/src/batch.rs b/crates/index-scheduler/src/batch.rs index 89c7f6f45..449d2aaed 100644 --- a/crates/index-scheduler/src/batch.rs +++ b/crates/index-scheduler/src/batch.rs @@ -839,12 +839,14 @@ impl IndexScheduler { for document in serde_json::de::Deserializer::from_reader(content_file).into_iter() { - let document = document.map_err(|e| { - Error::from_milli( - milli::InternalError::SerdeJson(e).into(), - None, - ) - })?; + let document = document + .map_err(|e| { + Error::from_milli( + milli::InternalError::SerdeJson(e).into(), + None, + ) + }) + .unwrap(); dump_content_file.push_document(&document)?; } diff --git a/crates/index-scheduler/src/index_mapper/index_map.rs b/crates/index-scheduler/src/index_mapper/index_map.rs index 480dafa7c..931cff162 100644 --- a/crates/index-scheduler/src/index_mapper/index_map.rs +++ b/crates/index-scheduler/src/index_mapper/index_map.rs @@ -1,5 +1,7 @@ use std::collections::BTreeMap; +use std::env::VarError; use std::path::Path; +use std::str::FromStr; use std::time::Duration; use meilisearch_types::heed::{EnvClosingEvent, EnvFlags, EnvOpenOptions}; @@ -302,7 +304,15 @@ fn create_or_open_index( ) -> Result { let mut options = EnvOpenOptions::new(); options.map_size(clamp_to_page_size(map_size)); - options.max_readers(1024); + + let max_readers = match std::env::var("MEILI_EXPERIMENTAL_INDEX_MAX_READERS") { + Ok(value) => u32::from_str(&value).unwrap(), + Err(VarError::NotPresent) => 1024, + Err(VarError::NotUnicode(value)) => panic!( + "Invalid unicode for the `MEILI_EXPERIMENTAL_INDEX_MAX_READERS` env var: {value:?}" + ), + }; + options.max_readers(max_readers); if enable_mdb_writemap { unsafe { options.flags(EnvFlags::WRITE_MAP) }; } diff --git a/crates/index-scheduler/src/lib.rs b/crates/index-scheduler/src/lib.rs index ac51e584a..9f2594f31 100644 --- a/crates/index-scheduler/src/lib.rs +++ b/crates/index-scheduler/src/lib.rs @@ -2024,9 +2024,11 @@ impl<'a> Dump<'a> { let mut writer = io::BufWriter::new(file); for doc in content_file { let doc = doc?; - serde_json::to_writer(&mut writer, &doc).map_err(|e| { - Error::from_milli(milli::InternalError::SerdeJson(e).into(), None) - })?; + serde_json::to_writer(&mut writer, &doc) + .map_err(|e| { + Error::from_milli(milli::InternalError::SerdeJson(e).into(), None) + }) + .unwrap(); } let file = writer.into_inner().map_err(|e| e.into_error())?; file.persist()?; diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 674ae226b..2aec59d2e 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -1337,7 +1337,7 @@ impl<'a> HitMaker<'a> { ExplicitVectors { embeddings: Some(vector.into()), regenerate: !user_provided }; vectors.insert( name, - serde_json::to_value(embeddings).map_err(InternalError::SerdeJson)?, + serde_json::to_value(embeddings).map_err(InternalError::SerdeJson).unwrap(), ); } document.insert("_vectors".into(), vectors.into()); @@ -1717,7 +1717,7 @@ fn make_document( // recreate the original json for (key, value) in obkv.iter() { - let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap(); let key = field_ids_map.name(key).expect("Missing field name").to_string(); document.insert(key, value); diff --git a/crates/milli/src/documents/mod.rs b/crates/milli/src/documents/mod.rs index 88fa38d30..91dcd348e 100644 --- a/crates/milli/src/documents/mod.rs +++ b/crates/milli/src/documents/mod.rs @@ -33,7 +33,7 @@ pub fn obkv_to_object(obkv: &KvReader, index: &DocumentsBatchIndex) -> let field_name = index .name(field_id) .ok_or(FieldIdMapMissingEntry::FieldId { field_id, process: "obkv_to_object" })?; - let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap(); Ok((field_name.to_string(), value)) }) .collect() @@ -84,7 +84,8 @@ impl DocumentsBatchIndex { let key = self.0.get_by_left(&k).ok_or(crate::error::InternalError::DatabaseClosing)?.clone(); let value = serde_json::from_slice::(v) - .map_err(crate::error::InternalError::SerdeJson)?; + .map_err(crate::error::InternalError::SerdeJson) + .unwrap(); map.insert(key, value); } diff --git a/crates/milli/src/documents/primary_key.rs b/crates/milli/src/documents/primary_key.rs index c1dd9a9b8..b8d151922 100644 --- a/crates/milli/src/documents/primary_key.rs +++ b/crates/milli/src/documents/primary_key.rs @@ -92,7 +92,8 @@ impl<'a> PrimaryKey<'a> { PrimaryKey::Flat { name: _, field_id } => match document.get(*field_id) { Some(document_id_bytes) => { let document_id = serde_json::from_slice(document_id_bytes) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); match validate_document_id_value(document_id) { Ok(document_id) => Ok(Ok(document_id)), Err(user_error) => { @@ -108,7 +109,8 @@ impl<'a> PrimaryKey<'a> { if let Some(field_id) = fields.id(first_level_name) { if let Some(value_bytes) = document.get(field_id) { let object = serde_json::from_slice(value_bytes) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); fetch_matching_values(object, right, &mut matching_documents_ids); if matching_documents_ids.len() >= 2 { @@ -151,11 +153,12 @@ impl<'a> PrimaryKey<'a> { }; let document_id: &RawValue = - serde_json::from_slice(document_id).map_err(InternalError::SerdeJson)?; + serde_json::from_slice(document_id).map_err(InternalError::SerdeJson).unwrap(); let document_id = document_id .deserialize_any(crate::update::new::indexer::de::DocumentIdVisitor(indexer)) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); let external_document_id = match document_id { Ok(document_id) => Ok(document_id), @@ -173,7 +176,7 @@ impl<'a> PrimaryKey<'a> { let Some(value) = document.get(fid) else { continue }; let value: &RawValue = - serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap(); match match_component(first_level, right, value, indexer, &mut docid) { ControlFlow::Continue(()) => continue, ControlFlow::Break(Ok(_)) => { @@ -183,7 +186,7 @@ impl<'a> PrimaryKey<'a> { .into()) } ControlFlow::Break(Err(err)) => { - return Err(InternalError::SerdeJson(err).into()) + panic!("{err}"); } } } diff --git a/crates/milli/src/lib.rs b/crates/milli/src/lib.rs index 3ae0bfdb9..1e90b50e5 100644 --- a/crates/milli/src/lib.rs +++ b/crates/milli/src/lib.rs @@ -228,7 +228,8 @@ pub fn obkv_to_json( field_id: id, process: "obkv_to_json", })?; - let value = serde_json::from_slice(value).map_err(error::InternalError::SerdeJson)?; + let value = + serde_json::from_slice(value).map_err(error::InternalError::SerdeJson).unwrap(); Ok((name.to_owned(), value)) }) .collect() diff --git a/crates/milli/src/thread_pool_no_abort.rs b/crates/milli/src/thread_pool_no_abort.rs index 14e5b0491..b57050a63 100644 --- a/crates/milli/src/thread_pool_no_abort.rs +++ b/crates/milli/src/thread_pool_no_abort.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use rayon::{ThreadPool, ThreadPoolBuilder}; @@ -9,6 +9,8 @@ use thiserror::Error; #[derive(Debug)] pub struct ThreadPoolNoAbort { thread_pool: ThreadPool, + /// The number of active operations. + active_operations: AtomicUsize, /// Set to true if the thread pool catched a panic. pool_catched_panic: Arc, } @@ -19,7 +21,9 @@ impl ThreadPoolNoAbort { OP: FnOnce() -> R + Send, R: Send, { + self.active_operations.fetch_add(1, Ordering::Relaxed); let output = self.thread_pool.install(op); + self.active_operations.fetch_sub(1, Ordering::Relaxed); // While reseting the pool panic catcher we return an error if we catched one. if self.pool_catched_panic.swap(false, Ordering::SeqCst) { Err(PanicCatched) @@ -31,6 +35,11 @@ impl ThreadPoolNoAbort { pub fn current_num_threads(&self) -> usize { self.thread_pool.current_num_threads() } + + /// The number of active operations. + pub fn active_operations(&self) -> usize { + self.active_operations.load(Ordering::Relaxed) + } } #[derive(Error, Debug)] @@ -64,6 +73,10 @@ impl ThreadPoolNoAbortBuilder { let catched_panic = pool_catched_panic.clone(); move |_result| catched_panic.store(true, Ordering::SeqCst) }); - Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic }) + Ok(ThreadPoolNoAbort { + thread_pool: self.0.build()?, + active_operations: AtomicUsize::new(0), + pool_catched_panic, + }) } } diff --git a/crates/milli/src/update/index_documents/enrich.rs b/crates/milli/src/update/index_documents/enrich.rs index 85f871830..1626adcd9 100644 --- a/crates/milli/src/update/index_documents/enrich.rs +++ b/crates/milli/src/update/index_documents/enrich.rs @@ -123,7 +123,8 @@ pub fn enrich_documents_batch( } } - let document_id = serde_json::to_vec(&document_id).map_err(InternalError::SerdeJson)?; + let document_id = + serde_json::to_vec(&document_id).map_err(InternalError::SerdeJson).unwrap(); external_ids.insert(count.to_be_bytes(), document_id)?; count += 1; @@ -237,7 +238,7 @@ pub fn validate_geo_from_json(id: &DocumentId, bytes: &[u8]) -> Result match (object.remove("lat"), object.remove("lng")) { (Some(lat), Some(lng)) => { match (extract_finite_float_from_value(lat), extract_finite_float_from_value(lng)) { diff --git a/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs b/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs index 606ae6b54..9a63e5299 100644 --- a/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs +++ b/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs @@ -206,7 +206,7 @@ fn tokens_from_document<'a>( if let Some(field_bytes) = KvReaderDelAdd::from_slice(field_bytes).get(del_add) { // parse json. let value = - serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?; + serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson).unwrap(); // prepare writing destination. buffers.obkv_positions_buffer.clear(); diff --git a/crates/milli/src/update/new/document.rs b/crates/milli/src/update/new/document.rs index 930b0c078..2cce03d87 100644 --- a/crates/milli/src/update/new/document.rs +++ b/crates/milli/src/update/new/document.rs @@ -86,7 +86,7 @@ impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> { let res = (|| { let value = - serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?; + serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson).unwrap(); Ok((name, value)) })(); @@ -139,7 +139,7 @@ impl<'t, Mapper: FieldIdMapper> DocumentFromDb<'t, Mapper> { return Ok(None); }; let Some(value) = self.content.get(fid) else { return Ok(None) }; - Ok(Some(serde_json::from_slice(value).map_err(InternalError::SerdeJson)?)) + Ok(Some(serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap())) } } diff --git a/crates/milli/src/update/new/extract/faceted/facet_document.rs b/crates/milli/src/update/new/extract/faceted/facet_document.rs index eff529120..35475e7cc 100644 --- a/crates/milli/src/update/new/extract/faceted/facet_document.rs +++ b/crates/milli/src/update/new/extract/faceted/facet_document.rs @@ -27,7 +27,7 @@ pub fn extract_document_facets<'doc>( let selection = perm_json_p::select_field(field_name, Some(attributes_to_extract), &[]); if selection != perm_json_p::Selection::Skip { // parse json. - match serde_json::value::to_value(value).map_err(InternalError::SerdeJson)? { + match serde_json::value::to_value(value).map_err(InternalError::SerdeJson).unwrap() { Value::Object(object) => { perm_json_p::seek_leaf_values_in_object( &object, diff --git a/crates/milli/src/update/new/extract/geo/mod.rs b/crates/milli/src/update/new/extract/geo/mod.rs index a3820609d..438d64d31 100644 --- a/crates/milli/src/update/new/extract/geo/mod.rs +++ b/crates/milli/src/update/new/extract/geo/mod.rs @@ -256,15 +256,16 @@ pub fn extract_geo_coordinates( external_id: &str, raw_value: &RawValue, ) -> Result> { - let mut geo = match serde_json::from_str(raw_value.get()).map_err(InternalError::SerdeJson)? { - Value::Null => return Ok(None), - Value::Object(map) => map, - value => { - return Err( - GeoError::NotAnObject { document_id: Value::from(external_id), value }.into() - ) - } - }; + let mut geo = + match serde_json::from_str(raw_value.get()).map_err(InternalError::SerdeJson).unwrap() { + Value::Null => return Ok(None), + Value::Object(map) => map, + value => { + return Err( + GeoError::NotAnObject { document_id: Value::from(external_id), value }.into() + ) + } + }; let [lat, lng] = match (geo.remove("lat"), geo.remove("lng")) { (Some(lat), Some(lng)) => { diff --git a/crates/milli/src/update/new/extract/searchable/tokenize_document.rs b/crates/milli/src/update/new/extract/searchable/tokenize_document.rs index 1c1605b66..e8c858c9a 100644 --- a/crates/milli/src/update/new/extract/searchable/tokenize_document.rs +++ b/crates/milli/src/update/new/extract/searchable/tokenize_document.rs @@ -94,7 +94,7 @@ impl<'a> DocumentTokenizer<'a> { }; // parse json. - match serde_json::to_value(value).map_err(InternalError::SerdeJson)? { + match serde_json::to_value(value).map_err(InternalError::SerdeJson).unwrap() { Value::Object(object) => seek_leaf_values_in_object( &object, None, diff --git a/crates/milli/src/update/new/indexer/document_operation.rs b/crates/milli/src/update/new/indexer/document_operation.rs index 5ccac4297..13894c407 100644 --- a/crates/milli/src/update/new/indexer/document_operation.rs +++ b/crates/milli/src/update/new/indexer/document_operation.rs @@ -158,7 +158,7 @@ fn extract_addition_payload_changes<'r, 'pl: 'r>( let mut previous_offset = 0; let mut iter = Deserializer::from_slice(payload).into_iter::<&RawValue>(); - while let Some(doc) = iter.next().transpose().map_err(InternalError::SerdeJson)? { + while let Some(doc) = iter.next().transpose().map_err(InternalError::SerdeJson).unwrap() { *bytes = previous_offset as u64; // Only guess the primary key if it is the first document diff --git a/crates/milli/src/update/new/indexer/partial_dump.rs b/crates/milli/src/update/new/indexer/partial_dump.rs index 6e4abd898..7fdcdae75 100644 --- a/crates/milli/src/update/new/indexer/partial_dump.rs +++ b/crates/milli/src/update/new/indexer/partial_dump.rs @@ -78,7 +78,8 @@ where let external_document_id = external_document_id.to_de(); let document = RawMap::from_raw_value_and_hasher(document, FxBuildHasher, doc_alloc) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); let insertion = Insertion::create(docid, external_document_id, Versions::single(document)); Ok(Some(DocumentChange::Insertion(insertion))) diff --git a/crates/milli/src/update/new/indexer/update_by_function.rs b/crates/milli/src/update/new/indexer/update_by_function.rs index 3001648e6..3423bde5e 100644 --- a/crates/milli/src/update/new/indexer/update_by_function.rs +++ b/crates/milli/src/update/new/indexer/update_by_function.rs @@ -58,9 +58,9 @@ impl UpdateByFunction { let ast = engine.compile(code).map_err(UserError::DocumentEditionCompilationError)?; let context = match context { - Some(context) => { - Some(serde_json::from_value(context.into()).map_err(InternalError::SerdeJson)?) - } + Some(context) => Some( + serde_json::from_value(context.into()).map_err(InternalError::SerdeJson).unwrap(), + ), None => None, }; @@ -137,9 +137,11 @@ impl<'index> DocumentChanges<'index> for UpdateByFunctionChanges<'index> { Some(new_rhai_document) => { let mut buffer = bumpalo::collections::Vec::new_in(doc_alloc); serde_json::to_writer(&mut buffer, &new_rhai_document) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); let raw_new_doc = serde_json::from_slice(buffer.into_bump_slice()) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); // Note: This condition is not perfect. Sometimes it detect changes // like with floating points numbers and consider updating @@ -166,7 +168,8 @@ impl<'index> DocumentChanges<'index> for UpdateByFunctionChanges<'index> { FxBuildHasher, doc_alloc, ) - .map_err(InternalError::SerdeJson)?; + .map_err(InternalError::SerdeJson) + .unwrap(); Ok(Some(DocumentChange::Update(Update::create( docid, @@ -200,7 +203,7 @@ fn obkv_to_rhaimap(obkv: &KvReaderFieldId, fields_ids_map: &FieldsIdsMap) -> Res field_id: id, process: "all_obkv_to_rhaimap", })?; - let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap(); Ok((name.into(), value)) }) .collect(); diff --git a/crates/milli/src/update/new/vector_document.rs b/crates/milli/src/update/new/vector_document.rs index 8d14a749d..e49cffa57 100644 --- a/crates/milli/src/update/new/vector_document.rs +++ b/crates/milli/src/update/new/vector_document.rs @@ -105,7 +105,8 @@ impl<'t> VectorDocumentFromDb<'t> { let vectors_field = match vectors { Some(vectors) => Some( RawMap::from_raw_value_and_hasher(vectors, FxBuildHasher, doc_alloc) - .map_err(InternalError::SerdeJson)?, + .map_err(InternalError::SerdeJson) + .unwrap(), ), None => None, }; diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index 7ee775cbf..a0698c5d0 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::DistributionShift; +use super::{DistributionShift, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::vector::Embedding; use crate::ThreadPoolNoAbort; @@ -113,20 +113,30 @@ impl Embedder { texts: &[&str], threads: &ThreadPoolNoAbort, ) -> Result>, EmbedError> { - threads - .install(move || { - let embeddings: Result>, _> = texts - .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None)) - .collect(); + if threads.active_operations() >= REQUEST_PARALLELISM { + let embeddings: Result>, _> = texts + .chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); - let embeddings = embeddings?; - Ok(embeddings.into_iter().flatten().collect()) - }) - .map_err(|error| EmbedError { - kind: EmbedErrorKind::PanicInThreadPool(error), - fault: FaultSource::Bug, - })? + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + } else { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } } pub fn chunk_count_hint(&self) -> usize { diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index 7262bfef8..b1af381b1 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -6,7 +6,7 @@ use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::DistributionShift; +use super::{DistributionShift, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; use crate::vector::Embedding; @@ -270,20 +270,29 @@ impl Embedder { texts: &[&str], threads: &ThreadPoolNoAbort, ) -> Result>, EmbedError> { - threads - .install(move || { - let embeddings: Result>, _> = texts - .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None)) - .collect(); + if threads.active_operations() >= REQUEST_PARALLELISM { + let embeddings: Result>, _> = texts + .chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + } else { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); - let embeddings = embeddings?; - Ok(embeddings.into_iter().flatten().collect()) - }) - .map_err(|error| EmbedError { - kind: EmbedErrorKind::PanicInThreadPool(error), - fault: FaultSource::Bug, - })? + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } } pub fn chunk_count_hint(&self) -> usize { diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index 98be311d4..736dc3b2f 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -203,20 +203,30 @@ impl Embedder { texts: &[&str], threads: &ThreadPoolNoAbort, ) -> Result, EmbedError> { - threads - .install(move || { - let embeddings: Result>, _> = texts - .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed_ref(chunk, None)) - .collect(); + if threads.active_operations() >= REQUEST_PARALLELISM { + let embeddings: Result>, _> = texts + .chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed_ref(chunk, None)) + .collect(); - let embeddings = embeddings?; - Ok(embeddings.into_iter().flatten().collect()) - }) - .map_err(|error| EmbedError { - kind: EmbedErrorKind::PanicInThreadPool(error), - fault: FaultSource::Bug, - })? + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + } else { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed_ref(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } } pub fn chunk_count_hint(&self) -> usize {