Merge cd58a71f576f26efddf97794a3b88eee3ee264fc into 9bcb271f0021dd4f8651458b6f573df1068bc396

This commit is contained in:
Clément Renault 2025-01-29 11:58:30 +01:00 committed by GitHub
commit db4798f4f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 158 additions and 90 deletions

View File

@ -839,12 +839,14 @@ impl IndexScheduler {
for document in for document in
serde_json::de::Deserializer::from_reader(content_file).into_iter() serde_json::de::Deserializer::from_reader(content_file).into_iter()
{ {
let document = document.map_err(|e| { let document = document
Error::from_milli( .map_err(|e| {
milli::InternalError::SerdeJson(e).into(), Error::from_milli(
None, milli::InternalError::SerdeJson(e).into(),
) None,
})?; )
})
.unwrap();
dump_content_file.push_document(&document)?; dump_content_file.push_document(&document)?;
} }

View File

@ -1,5 +1,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::env::VarError;
use std::path::Path; use std::path::Path;
use std::str::FromStr;
use std::time::Duration; use std::time::Duration;
use meilisearch_types::heed::{EnvClosingEvent, EnvFlags, EnvOpenOptions}; use meilisearch_types::heed::{EnvClosingEvent, EnvFlags, EnvOpenOptions};
@ -302,7 +304,15 @@ fn create_or_open_index(
) -> Result<Index> { ) -> Result<Index> {
let mut options = EnvOpenOptions::new(); let mut options = EnvOpenOptions::new();
options.map_size(clamp_to_page_size(map_size)); options.map_size(clamp_to_page_size(map_size));
options.max_readers(1024);
let max_readers = match std::env::var("MEILI_EXPERIMENTAL_INDEX_MAX_READERS") {
Ok(value) => u32::from_str(&value).unwrap(),
Err(VarError::NotPresent) => 1024,
Err(VarError::NotUnicode(value)) => panic!(
"Invalid unicode for the `MEILI_EXPERIMENTAL_INDEX_MAX_READERS` env var: {value:?}"
),
};
options.max_readers(max_readers);
if enable_mdb_writemap { if enable_mdb_writemap {
unsafe { options.flags(EnvFlags::WRITE_MAP) }; unsafe { options.flags(EnvFlags::WRITE_MAP) };
} }

View File

@ -2024,9 +2024,11 @@ impl<'a> Dump<'a> {
let mut writer = io::BufWriter::new(file); let mut writer = io::BufWriter::new(file);
for doc in content_file { for doc in content_file {
let doc = doc?; let doc = doc?;
serde_json::to_writer(&mut writer, &doc).map_err(|e| { serde_json::to_writer(&mut writer, &doc)
Error::from_milli(milli::InternalError::SerdeJson(e).into(), None) .map_err(|e| {
})?; Error::from_milli(milli::InternalError::SerdeJson(e).into(), None)
})
.unwrap();
} }
let file = writer.into_inner().map_err(|e| e.into_error())?; let file = writer.into_inner().map_err(|e| e.into_error())?;
file.persist()?; file.persist()?;

View File

@ -1337,7 +1337,7 @@ impl<'a> HitMaker<'a> {
ExplicitVectors { embeddings: Some(vector.into()), regenerate: !user_provided }; ExplicitVectors { embeddings: Some(vector.into()), regenerate: !user_provided };
vectors.insert( vectors.insert(
name, name,
serde_json::to_value(embeddings).map_err(InternalError::SerdeJson)?, serde_json::to_value(embeddings).map_err(InternalError::SerdeJson).unwrap(),
); );
} }
document.insert("_vectors".into(), vectors.into()); document.insert("_vectors".into(), vectors.into());
@ -1717,7 +1717,7 @@ fn make_document(
// recreate the original json // recreate the original json
for (key, value) in obkv.iter() { for (key, value) in obkv.iter() {
let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap();
let key = field_ids_map.name(key).expect("Missing field name").to_string(); let key = field_ids_map.name(key).expect("Missing field name").to_string();
document.insert(key, value); document.insert(key, value);

View File

@ -33,7 +33,7 @@ pub fn obkv_to_object(obkv: &KvReader<FieldId>, index: &DocumentsBatchIndex) ->
let field_name = index let field_name = index
.name(field_id) .name(field_id)
.ok_or(FieldIdMapMissingEntry::FieldId { field_id, process: "obkv_to_object" })?; .ok_or(FieldIdMapMissingEntry::FieldId { field_id, process: "obkv_to_object" })?;
let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap();
Ok((field_name.to_string(), value)) Ok((field_name.to_string(), value))
}) })
.collect() .collect()
@ -84,7 +84,8 @@ impl DocumentsBatchIndex {
let key = let key =
self.0.get_by_left(&k).ok_or(crate::error::InternalError::DatabaseClosing)?.clone(); self.0.get_by_left(&k).ok_or(crate::error::InternalError::DatabaseClosing)?.clone();
let value = serde_json::from_slice::<serde_json::Value>(v) let value = serde_json::from_slice::<serde_json::Value>(v)
.map_err(crate::error::InternalError::SerdeJson)?; .map_err(crate::error::InternalError::SerdeJson)
.unwrap();
map.insert(key, value); map.insert(key, value);
} }

View File

@ -92,7 +92,8 @@ impl<'a> PrimaryKey<'a> {
PrimaryKey::Flat { name: _, field_id } => match document.get(*field_id) { PrimaryKey::Flat { name: _, field_id } => match document.get(*field_id) {
Some(document_id_bytes) => { Some(document_id_bytes) => {
let document_id = serde_json::from_slice(document_id_bytes) let document_id = serde_json::from_slice(document_id_bytes)
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
match validate_document_id_value(document_id) { match validate_document_id_value(document_id) {
Ok(document_id) => Ok(Ok(document_id)), Ok(document_id) => Ok(Ok(document_id)),
Err(user_error) => { Err(user_error) => {
@ -108,7 +109,8 @@ impl<'a> PrimaryKey<'a> {
if let Some(field_id) = fields.id(first_level_name) { if let Some(field_id) = fields.id(first_level_name) {
if let Some(value_bytes) = document.get(field_id) { if let Some(value_bytes) = document.get(field_id) {
let object = serde_json::from_slice(value_bytes) let object = serde_json::from_slice(value_bytes)
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
fetch_matching_values(object, right, &mut matching_documents_ids); fetch_matching_values(object, right, &mut matching_documents_ids);
if matching_documents_ids.len() >= 2 { if matching_documents_ids.len() >= 2 {
@ -151,11 +153,12 @@ impl<'a> PrimaryKey<'a> {
}; };
let document_id: &RawValue = let document_id: &RawValue =
serde_json::from_slice(document_id).map_err(InternalError::SerdeJson)?; serde_json::from_slice(document_id).map_err(InternalError::SerdeJson).unwrap();
let document_id = document_id let document_id = document_id
.deserialize_any(crate::update::new::indexer::de::DocumentIdVisitor(indexer)) .deserialize_any(crate::update::new::indexer::de::DocumentIdVisitor(indexer))
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
let external_document_id = match document_id { let external_document_id = match document_id {
Ok(document_id) => Ok(document_id), Ok(document_id) => Ok(document_id),
@ -173,7 +176,7 @@ impl<'a> PrimaryKey<'a> {
let Some(value) = document.get(fid) else { continue }; let Some(value) = document.get(fid) else { continue };
let value: &RawValue = let value: &RawValue =
serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap();
match match_component(first_level, right, value, indexer, &mut docid) { match match_component(first_level, right, value, indexer, &mut docid) {
ControlFlow::Continue(()) => continue, ControlFlow::Continue(()) => continue,
ControlFlow::Break(Ok(_)) => { ControlFlow::Break(Ok(_)) => {
@ -183,7 +186,7 @@ impl<'a> PrimaryKey<'a> {
.into()) .into())
} }
ControlFlow::Break(Err(err)) => { ControlFlow::Break(Err(err)) => {
return Err(InternalError::SerdeJson(err).into()) panic!("{err}");
} }
} }
} }

View File

@ -228,7 +228,8 @@ pub fn obkv_to_json(
field_id: id, field_id: id,
process: "obkv_to_json", process: "obkv_to_json",
})?; })?;
let value = serde_json::from_slice(value).map_err(error::InternalError::SerdeJson)?; let value =
serde_json::from_slice(value).map_err(error::InternalError::SerdeJson).unwrap();
Ok((name.to_owned(), value)) Ok((name.to_owned(), value))
}) })
.collect() .collect()

View File

@ -1,4 +1,4 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use rayon::{ThreadPool, ThreadPoolBuilder}; use rayon::{ThreadPool, ThreadPoolBuilder};
@ -9,6 +9,8 @@ use thiserror::Error;
#[derive(Debug)] #[derive(Debug)]
pub struct ThreadPoolNoAbort { pub struct ThreadPoolNoAbort {
thread_pool: ThreadPool, thread_pool: ThreadPool,
/// The number of active operations.
active_operations: AtomicUsize,
/// Set to true if the thread pool catched a panic. /// Set to true if the thread pool catched a panic.
pool_catched_panic: Arc<AtomicBool>, pool_catched_panic: Arc<AtomicBool>,
} }
@ -19,7 +21,9 @@ impl ThreadPoolNoAbort {
OP: FnOnce() -> R + Send, OP: FnOnce() -> R + Send,
R: Send, R: Send,
{ {
self.active_operations.fetch_add(1, Ordering::Relaxed);
let output = self.thread_pool.install(op); let output = self.thread_pool.install(op);
self.active_operations.fetch_sub(1, Ordering::Relaxed);
// While reseting the pool panic catcher we return an error if we catched one. // While reseting the pool panic catcher we return an error if we catched one.
if self.pool_catched_panic.swap(false, Ordering::SeqCst) { if self.pool_catched_panic.swap(false, Ordering::SeqCst) {
Err(PanicCatched) Err(PanicCatched)
@ -31,6 +35,11 @@ impl ThreadPoolNoAbort {
pub fn current_num_threads(&self) -> usize { pub fn current_num_threads(&self) -> usize {
self.thread_pool.current_num_threads() self.thread_pool.current_num_threads()
} }
/// The number of active operations.
pub fn active_operations(&self) -> usize {
self.active_operations.load(Ordering::Relaxed)
}
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -64,6 +73,10 @@ impl ThreadPoolNoAbortBuilder {
let catched_panic = pool_catched_panic.clone(); let catched_panic = pool_catched_panic.clone();
move |_result| catched_panic.store(true, Ordering::SeqCst) move |_result| catched_panic.store(true, Ordering::SeqCst)
}); });
Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic }) Ok(ThreadPoolNoAbort {
thread_pool: self.0.build()?,
active_operations: AtomicUsize::new(0),
pool_catched_panic,
})
} }
} }

View File

@ -123,7 +123,8 @@ pub fn enrich_documents_batch<R: Read + Seek>(
} }
} }
let document_id = serde_json::to_vec(&document_id).map_err(InternalError::SerdeJson)?; let document_id =
serde_json::to_vec(&document_id).map_err(InternalError::SerdeJson).unwrap();
external_ids.insert(count.to_be_bytes(), document_id)?; external_ids.insert(count.to_be_bytes(), document_id)?;
count += 1; count += 1;
@ -237,7 +238,7 @@ pub fn validate_geo_from_json(id: &DocumentId, bytes: &[u8]) -> Result<StdResult
let debug_id = || { let debug_id = || {
serde_json::from_slice(id.value().as_bytes()).unwrap_or_else(|_| Value::from(id.debug())) serde_json::from_slice(id.value().as_bytes()).unwrap_or_else(|_| Value::from(id.debug()))
}; };
match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? { match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson).unwrap() {
Value::Object(mut object) => match (object.remove("lat"), object.remove("lng")) { Value::Object(mut object) => match (object.remove("lat"), object.remove("lng")) {
(Some(lat), Some(lng)) => { (Some(lat), Some(lng)) => {
match (extract_finite_float_from_value(lat), extract_finite_float_from_value(lng)) { match (extract_finite_float_from_value(lat), extract_finite_float_from_value(lng)) {

View File

@ -206,7 +206,7 @@ fn tokens_from_document<'a>(
if let Some(field_bytes) = KvReaderDelAdd::from_slice(field_bytes).get(del_add) { if let Some(field_bytes) = KvReaderDelAdd::from_slice(field_bytes).get(del_add) {
// parse json. // parse json.
let value = let value =
serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?; serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson).unwrap();
// prepare writing destination. // prepare writing destination.
buffers.obkv_positions_buffer.clear(); buffers.obkv_positions_buffer.clear();

View File

@ -86,7 +86,7 @@ impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> {
let res = (|| { let res = (|| {
let value = let value =
serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?; serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson).unwrap();
Ok((name, value)) Ok((name, value))
})(); })();
@ -139,7 +139,7 @@ impl<'t, Mapper: FieldIdMapper> DocumentFromDb<'t, Mapper> {
return Ok(None); return Ok(None);
}; };
let Some(value) = self.content.get(fid) else { return Ok(None) }; let Some(value) = self.content.get(fid) else { return Ok(None) };
Ok(Some(serde_json::from_slice(value).map_err(InternalError::SerdeJson)?)) Ok(Some(serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap()))
} }
} }

View File

@ -27,7 +27,7 @@ pub fn extract_document_facets<'doc>(
let selection = perm_json_p::select_field(field_name, Some(attributes_to_extract), &[]); let selection = perm_json_p::select_field(field_name, Some(attributes_to_extract), &[]);
if selection != perm_json_p::Selection::Skip { if selection != perm_json_p::Selection::Skip {
// parse json. // parse json.
match serde_json::value::to_value(value).map_err(InternalError::SerdeJson)? { match serde_json::value::to_value(value).map_err(InternalError::SerdeJson).unwrap() {
Value::Object(object) => { Value::Object(object) => {
perm_json_p::seek_leaf_values_in_object( perm_json_p::seek_leaf_values_in_object(
&object, &object,

View File

@ -256,15 +256,16 @@ pub fn extract_geo_coordinates(
external_id: &str, external_id: &str,
raw_value: &RawValue, raw_value: &RawValue,
) -> Result<Option<[f64; 2]>> { ) -> Result<Option<[f64; 2]>> {
let mut geo = match serde_json::from_str(raw_value.get()).map_err(InternalError::SerdeJson)? { let mut geo =
Value::Null => return Ok(None), match serde_json::from_str(raw_value.get()).map_err(InternalError::SerdeJson).unwrap() {
Value::Object(map) => map, Value::Null => return Ok(None),
value => { Value::Object(map) => map,
return Err( value => {
GeoError::NotAnObject { document_id: Value::from(external_id), value }.into() return Err(
) GeoError::NotAnObject { document_id: Value::from(external_id), value }.into()
} )
}; }
};
let [lat, lng] = match (geo.remove("lat"), geo.remove("lng")) { let [lat, lng] = match (geo.remove("lat"), geo.remove("lng")) {
(Some(lat), Some(lng)) => { (Some(lat), Some(lng)) => {

View File

@ -94,7 +94,7 @@ impl<'a> DocumentTokenizer<'a> {
}; };
// parse json. // parse json.
match serde_json::to_value(value).map_err(InternalError::SerdeJson)? { match serde_json::to_value(value).map_err(InternalError::SerdeJson).unwrap() {
Value::Object(object) => seek_leaf_values_in_object( Value::Object(object) => seek_leaf_values_in_object(
&object, &object,
None, None,

View File

@ -158,7 +158,7 @@ fn extract_addition_payload_changes<'r, 'pl: 'r>(
let mut previous_offset = 0; let mut previous_offset = 0;
let mut iter = Deserializer::from_slice(payload).into_iter::<&RawValue>(); let mut iter = Deserializer::from_slice(payload).into_iter::<&RawValue>();
while let Some(doc) = iter.next().transpose().map_err(InternalError::SerdeJson)? { while let Some(doc) = iter.next().transpose().map_err(InternalError::SerdeJson).unwrap() {
*bytes = previous_offset as u64; *bytes = previous_offset as u64;
// Only guess the primary key if it is the first document // Only guess the primary key if it is the first document

View File

@ -78,7 +78,8 @@ where
let external_document_id = external_document_id.to_de(); let external_document_id = external_document_id.to_de();
let document = RawMap::from_raw_value_and_hasher(document, FxBuildHasher, doc_alloc) let document = RawMap::from_raw_value_and_hasher(document, FxBuildHasher, doc_alloc)
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
let insertion = Insertion::create(docid, external_document_id, Versions::single(document)); let insertion = Insertion::create(docid, external_document_id, Versions::single(document));
Ok(Some(DocumentChange::Insertion(insertion))) Ok(Some(DocumentChange::Insertion(insertion)))

View File

@ -58,9 +58,9 @@ impl UpdateByFunction {
let ast = engine.compile(code).map_err(UserError::DocumentEditionCompilationError)?; let ast = engine.compile(code).map_err(UserError::DocumentEditionCompilationError)?;
let context = match context { let context = match context {
Some(context) => { Some(context) => Some(
Some(serde_json::from_value(context.into()).map_err(InternalError::SerdeJson)?) serde_json::from_value(context.into()).map_err(InternalError::SerdeJson).unwrap(),
} ),
None => None, None => None,
}; };
@ -137,9 +137,11 @@ impl<'index> DocumentChanges<'index> for UpdateByFunctionChanges<'index> {
Some(new_rhai_document) => { Some(new_rhai_document) => {
let mut buffer = bumpalo::collections::Vec::new_in(doc_alloc); let mut buffer = bumpalo::collections::Vec::new_in(doc_alloc);
serde_json::to_writer(&mut buffer, &new_rhai_document) serde_json::to_writer(&mut buffer, &new_rhai_document)
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
let raw_new_doc = serde_json::from_slice(buffer.into_bump_slice()) let raw_new_doc = serde_json::from_slice(buffer.into_bump_slice())
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
// Note: This condition is not perfect. Sometimes it detect changes // Note: This condition is not perfect. Sometimes it detect changes
// like with floating points numbers and consider updating // like with floating points numbers and consider updating
@ -166,7 +168,8 @@ impl<'index> DocumentChanges<'index> for UpdateByFunctionChanges<'index> {
FxBuildHasher, FxBuildHasher,
doc_alloc, doc_alloc,
) )
.map_err(InternalError::SerdeJson)?; .map_err(InternalError::SerdeJson)
.unwrap();
Ok(Some(DocumentChange::Update(Update::create( Ok(Some(DocumentChange::Update(Update::create(
docid, docid,
@ -200,7 +203,7 @@ fn obkv_to_rhaimap(obkv: &KvReaderFieldId, fields_ids_map: &FieldsIdsMap) -> Res
field_id: id, field_id: id,
process: "all_obkv_to_rhaimap", process: "all_obkv_to_rhaimap",
})?; })?;
let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson).unwrap();
Ok((name.into(), value)) Ok((name.into(), value))
}) })
.collect(); .collect();

View File

@ -105,7 +105,8 @@ impl<'t> VectorDocumentFromDb<'t> {
let vectors_field = match vectors { let vectors_field = match vectors {
Some(vectors) => Some( Some(vectors) => Some(
RawMap::from_raw_value_and_hasher(vectors, FxBuildHasher, doc_alloc) RawMap::from_raw_value_and_hasher(vectors, FxBuildHasher, doc_alloc)
.map_err(InternalError::SerdeJson)?, .map_err(InternalError::SerdeJson)
.unwrap(),
), ),
None => None, None => None,
}; };

View File

@ -5,7 +5,7 @@ use rayon::slice::ParallelSlice as _;
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::DistributionShift; use super::{DistributionShift, REQUEST_PARALLELISM};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::vector::Embedding; use crate::vector::Embedding;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
@ -113,20 +113,30 @@ impl Embedder {
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<f32>>, EmbedError> { ) -> Result<Vec<Vec<f32>>, EmbedError> {
threads if threads.active_operations() >= REQUEST_PARALLELISM {
.install(move || { let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts .chunks(self.prompt_count_in_chunk_hint())
.par_chunks(self.prompt_count_in_chunk_hint()) .map(move |chunk| self.embed(chunk, None))
.map(move |chunk| self.embed(chunk, None)) .collect();
.collect();
let embeddings = embeddings?; let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect()) Ok(embeddings.into_iter().flatten().collect())
}) } else {
.map_err(|error| EmbedError { threads
kind: EmbedErrorKind::PanicInThreadPool(error), .install(move || {
fault: FaultSource::Bug, let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
})? .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None))
.collect();
let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect())
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
} }
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {

View File

@ -6,7 +6,7 @@ use rayon::slice::ParallelSlice as _;
use super::error::{EmbedError, NewEmbedderError}; use super::error::{EmbedError, NewEmbedderError};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::DistributionShift; use super::{DistributionShift, REQUEST_PARALLELISM};
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::vector::error::EmbedErrorKind; use crate::vector::error::EmbedErrorKind;
use crate::vector::Embedding; use crate::vector::Embedding;
@ -270,20 +270,29 @@ impl Embedder {
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<f32>>, EmbedError> { ) -> Result<Vec<Vec<f32>>, EmbedError> {
threads if threads.active_operations() >= REQUEST_PARALLELISM {
.install(move || { let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts .chunks(self.prompt_count_in_chunk_hint())
.par_chunks(self.prompt_count_in_chunk_hint()) .map(move |chunk| self.embed(chunk, None))
.map(move |chunk| self.embed(chunk, None)) .collect();
.collect(); let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect())
} else {
threads
.install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None))
.collect();
let embeddings = embeddings?; let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect()) Ok(embeddings.into_iter().flatten().collect())
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug, fault: FaultSource::Bug,
})? })?
}
} }
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {

View File

@ -203,20 +203,30 @@ impl Embedder {
texts: &[&str], texts: &[&str],
threads: &ThreadPoolNoAbort, threads: &ThreadPoolNoAbort,
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
threads if threads.active_operations() >= REQUEST_PARALLELISM {
.install(move || { let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts .chunks(self.prompt_count_in_chunk_hint())
.par_chunks(self.prompt_count_in_chunk_hint()) .map(move |chunk| self.embed_ref(chunk, None))
.map(move |chunk| self.embed_ref(chunk, None)) .collect();
.collect();
let embeddings = embeddings?; let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect()) Ok(embeddings.into_iter().flatten().collect())
}) } else {
.map_err(|error| EmbedError { threads
kind: EmbedErrorKind::PanicInThreadPool(error), .install(move || {
fault: FaultSource::Bug, let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
})? .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed_ref(chunk, None))
.collect();
let embeddings = embeddings?;
Ok(embeddings.into_iter().flatten().collect())
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
} }
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {