mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-01-19 01:18:31 +08:00
Merge #4593
4593: Stop crashing when panic occurs in thread pool r=ManyTheFish a=Kerollmops This PR fixes #4362 by introducing a new boolean to catch panics in the rayon thread pool. The boolean is read after performing the operations in rayon, and the indexation process is stopped. This first version doesn't expose the panic message but marks the task as failed. The current implementation exposes a `ThreadPoolNoAbort` wrapper. The `rayon::ThreadPool` has been wrapped to check that nothing went wrong after running the `ThreadPool::install` function. An atomic boolean and some `store/load` logic make the system work efficiently. Before, Meilisearch was completely crashing... <img width="1563" alt="Capture d’écran 2024-04-22 à 15 49 02" src="https://github.com/meilisearch/meilisearch/assets/3610253/ce114917-a881-4fbb-85df-c195fcf0c7cb"> Now, it handles the panics correctly and marks the task as failed. <img width="1558" alt="Capture d’écran 2024-04-22 à 15 42 14" src="https://github.com/meilisearch/meilisearch/assets/3610253/8bd031ef-5e8f-4a12-a91e-c823597a2344"> Co-authored-by: Clément Renault <clement@meilisearch.com>
This commit is contained in:
commit
3e5cd027a5
@ -13,6 +13,7 @@ use byte_unit::{Byte, ByteError};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use meilisearch_types::features::InstanceTogglableFeatures;
|
use meilisearch_types::features::InstanceTogglableFeatures;
|
||||||
use meilisearch_types::milli::update::IndexerConfig;
|
use meilisearch_types::milli::update::IndexerConfig;
|
||||||
|
use meilisearch_types::milli::ThreadPoolNoAbortBuilder;
|
||||||
use rustls::server::{
|
use rustls::server::{
|
||||||
AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache,
|
AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache,
|
||||||
};
|
};
|
||||||
@ -666,7 +667,7 @@ impl TryFrom<&IndexerOpts> for IndexerConfig {
|
|||||||
type Error = anyhow::Error;
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
fn try_from(other: &IndexerOpts) -> Result<Self, Self::Error> {
|
fn try_from(other: &IndexerOpts) -> Result<Self, Self::Error> {
|
||||||
let thread_pool = rayon::ThreadPoolBuilder::new()
|
let thread_pool = ThreadPoolNoAbortBuilder::new()
|
||||||
.thread_name(|index| format!("indexing-thread:{index}"))
|
.thread_name(|index| format!("indexing-thread:{index}"))
|
||||||
.num_threads(*other.max_indexing_threads)
|
.num_threads(*other.max_indexing_threads)
|
||||||
.build()?;
|
.build()?;
|
||||||
|
@ -9,6 +9,7 @@ use serde_json::Value;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::documents::{self, DocumentsBatchCursorError};
|
use crate::documents::{self, DocumentsBatchCursorError};
|
||||||
|
use crate::thread_pool_no_abort::PanicCatched;
|
||||||
use crate::{CriterionError, DocumentId, FieldId, Object, SortError};
|
use crate::{CriterionError, DocumentId, FieldId, Object, SortError};
|
||||||
|
|
||||||
pub fn is_reserved_keyword(keyword: &str) -> bool {
|
pub fn is_reserved_keyword(keyword: &str) -> bool {
|
||||||
@ -39,17 +40,19 @@ pub enum InternalError {
|
|||||||
Fst(#[from] fst::Error),
|
Fst(#[from] fst::Error),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
DocumentsError(#[from] documents::Error),
|
DocumentsError(#[from] documents::Error),
|
||||||
#[error("Invalid compression type have been specified to grenad.")]
|
#[error("Invalid compression type have been specified to grenad")]
|
||||||
GrenadInvalidCompressionType,
|
GrenadInvalidCompressionType,
|
||||||
#[error("Invalid grenad file with an invalid version format.")]
|
#[error("Invalid grenad file with an invalid version format")]
|
||||||
GrenadInvalidFormatVersion,
|
GrenadInvalidFormatVersion,
|
||||||
#[error("Invalid merge while processing {process}.")]
|
#[error("Invalid merge while processing {process}")]
|
||||||
IndexingMergingKeys { process: &'static str },
|
IndexingMergingKeys { process: &'static str },
|
||||||
#[error("{}", HeedError::InvalidDatabaseTyping)]
|
#[error("{}", HeedError::InvalidDatabaseTyping)]
|
||||||
InvalidDatabaseTyping,
|
InvalidDatabaseTyping,
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
RayonThreadPool(#[from] ThreadPoolBuildError),
|
RayonThreadPool(#[from] ThreadPoolBuildError),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
|
PanicInThreadPool(#[from] PanicCatched),
|
||||||
|
#[error(transparent)]
|
||||||
SerdeJson(#[from] serde_json::Error),
|
SerdeJson(#[from] serde_json::Error),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Serialization(#[from] SerializationError),
|
Serialization(#[from] SerializationError),
|
||||||
@ -57,9 +60,9 @@ pub enum InternalError {
|
|||||||
Store(#[from] MdbError),
|
Store(#[from] MdbError),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Utf8(#[from] str::Utf8Error),
|
Utf8(#[from] str::Utf8Error),
|
||||||
#[error("An indexation process was explicitly aborted.")]
|
#[error("An indexation process was explicitly aborted")]
|
||||||
AbortedIndexation,
|
AbortedIndexation,
|
||||||
#[error("The matching words list contains at least one invalid member.")]
|
#[error("The matching words list contains at least one invalid member")]
|
||||||
InvalidMatchingWords,
|
InvalidMatchingWords,
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
ArroyError(#[from] arroy::Error),
|
ArroyError(#[from] arroy::Error),
|
||||||
|
@ -21,6 +21,7 @@ pub mod prompt;
|
|||||||
pub mod proximity;
|
pub mod proximity;
|
||||||
pub mod score_details;
|
pub mod score_details;
|
||||||
mod search;
|
mod search;
|
||||||
|
mod thread_pool_no_abort;
|
||||||
pub mod update;
|
pub mod update;
|
||||||
pub mod vector;
|
pub mod vector;
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ pub use search::new::{
|
|||||||
SearchLogger, VisualSearchLogger,
|
SearchLogger, VisualSearchLogger,
|
||||||
};
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
|
||||||
pub use {charabia as tokenizer, heed};
|
pub use {charabia as tokenizer, heed};
|
||||||
|
|
||||||
pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError};
|
pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError};
|
||||||
|
69
milli/src/thread_pool_no_abort.rs
Normal file
69
milli/src/thread_pool_no_abort.rs
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use rayon::{ThreadPool, ThreadPoolBuilder};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// A rayon ThreadPool wrapper that can catch panics in the pool
|
||||||
|
/// and modifies the install function accordingly.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ThreadPoolNoAbort {
|
||||||
|
thread_pool: ThreadPool,
|
||||||
|
/// Set to true if the thread pool catched a panic.
|
||||||
|
pool_catched_panic: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadPoolNoAbort {
|
||||||
|
pub fn install<OP, R>(&self, op: OP) -> Result<R, PanicCatched>
|
||||||
|
where
|
||||||
|
OP: FnOnce() -> R + Send,
|
||||||
|
R: Send,
|
||||||
|
{
|
||||||
|
let output = self.thread_pool.install(op);
|
||||||
|
// While reseting the pool panic catcher we return an error if we catched one.
|
||||||
|
if self.pool_catched_panic.swap(false, Ordering::SeqCst) {
|
||||||
|
Err(PanicCatched)
|
||||||
|
} else {
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_num_threads(&self) -> usize {
|
||||||
|
self.thread_pool.current_num_threads()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
#[error("A panic occured. Read the logs to find more information about it")]
|
||||||
|
pub struct PanicCatched;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder);
|
||||||
|
|
||||||
|
impl ThreadPoolNoAbortBuilder {
|
||||||
|
pub fn new() -> ThreadPoolNoAbortBuilder {
|
||||||
|
ThreadPoolNoAbortBuilder::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn thread_name<F>(mut self, closure: F) -> Self
|
||||||
|
where
|
||||||
|
F: FnMut(usize) -> String + 'static,
|
||||||
|
{
|
||||||
|
self.0 = self.0.thread_name(closure);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder {
|
||||||
|
self.0 = self.0.num_threads(num_threads);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(mut self) -> Result<ThreadPoolNoAbort, rayon::ThreadPoolBuildError> {
|
||||||
|
let pool_catched_panic = Arc::new(AtomicBool::new(false));
|
||||||
|
self.0 = self.0.panic_handler({
|
||||||
|
let catched_panic = pool_catched_panic.clone();
|
||||||
|
move |_result| catched_panic.store(true, Ordering::SeqCst)
|
||||||
|
});
|
||||||
|
Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic })
|
||||||
|
}
|
||||||
|
}
|
@ -19,7 +19,7 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
|
|||||||
use crate::update::index_documents::helpers::try_split_at;
|
use crate::update::index_documents::helpers::try_split_at;
|
||||||
use crate::update::settings::InnerIndexSettingsDiff;
|
use crate::update::settings::InnerIndexSettingsDiff;
|
||||||
use crate::vector::Embedder;
|
use crate::vector::Embedder;
|
||||||
use crate::{DocumentId, InternalError, Result, VectorOrArrayOfVectors};
|
use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, VectorOrArrayOfVectors};
|
||||||
|
|
||||||
/// The length of the elements that are always in the buffer when inserting new values.
|
/// The length of the elements that are always in the buffer when inserting new values.
|
||||||
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
||||||
@ -347,7 +347,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
prompt_reader: grenad::Reader<R>,
|
prompt_reader: grenad::Reader<R>,
|
||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
embedder: Arc<Embedder>,
|
embedder: Arc<Embedder>,
|
||||||
request_threads: &rayon::ThreadPool,
|
request_threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||||
puffin::profile_function!();
|
puffin::profile_function!();
|
||||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||||
|
@ -31,7 +31,7 @@ use self::extract_word_position_docids::extract_word_position_docids;
|
|||||||
use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters};
|
use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters};
|
||||||
use super::{helpers, TypedChunk};
|
use super::{helpers, TypedChunk};
|
||||||
use crate::update::settings::InnerIndexSettingsDiff;
|
use crate::update::settings::InnerIndexSettingsDiff;
|
||||||
use crate::{FieldId, Result};
|
use crate::{FieldId, Result, ThreadPoolNoAbortBuilder};
|
||||||
|
|
||||||
/// Extract data for each databases from obkv documents in parallel.
|
/// Extract data for each databases from obkv documents in parallel.
|
||||||
/// Send data in grenad file over provided Sender.
|
/// Send data in grenad file over provided Sender.
|
||||||
@ -229,7 +229,7 @@ fn send_original_documents_data(
|
|||||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||||
|
|
||||||
let request_threads = rayon::ThreadPoolBuilder::new()
|
let request_threads = ThreadPoolNoAbortBuilder::new()
|
||||||
.num_threads(crate::vector::REQUEST_PARALLELISM)
|
.num_threads(crate::vector::REQUEST_PARALLELISM)
|
||||||
.thread_name(|index| format!("embedding-request-{index}"))
|
.thread_name(|index| format!("embedding-request-{index}"))
|
||||||
.build()?;
|
.build()?;
|
||||||
|
@ -33,6 +33,7 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
|
|||||||
pub use self::transform::{Transform, TransformOutput};
|
pub use self::transform::{Transform, TransformOutput};
|
||||||
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
||||||
use crate::error::{Error, InternalError, UserError};
|
use crate::error::{Error, InternalError, UserError};
|
||||||
|
use crate::thread_pool_no_abort::ThreadPoolNoAbortBuilder;
|
||||||
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
||||||
use crate::update::{
|
use crate::update::{
|
||||||
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
||||||
@ -298,18 +299,18 @@ where
|
|||||||
let backup_pool;
|
let backup_pool;
|
||||||
let pool = match self.indexer_config.thread_pool {
|
let pool = match self.indexer_config.thread_pool {
|
||||||
Some(ref pool) => pool,
|
Some(ref pool) => pool,
|
||||||
#[cfg(not(test))]
|
|
||||||
None => {
|
None => {
|
||||||
// We initialize a bakcup pool with the default
|
// We initialize a backup pool with the default
|
||||||
// settings if none have already been set.
|
// settings if none have already been set.
|
||||||
backup_pool = rayon::ThreadPoolBuilder::new().build()?;
|
#[allow(unused_mut)]
|
||||||
&backup_pool
|
let mut pool_builder = ThreadPoolNoAbortBuilder::new();
|
||||||
}
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
None => {
|
{
|
||||||
// We initialize a bakcup pool with the default
|
pool_builder = pool_builder.num_threads(1);
|
||||||
// settings if none have already been set.
|
}
|
||||||
backup_pool = rayon::ThreadPoolBuilder::new().num_threads(1).build()?;
|
|
||||||
|
backup_pool = pool_builder.build()?;
|
||||||
&backup_pool
|
&backup_pool
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -533,7 +534,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
}).map_err(InternalError::from)??;
|
||||||
|
|
||||||
// We write the field distribution into the main database
|
// We write the field distribution into the main database
|
||||||
self.index.put_field_distribution(self.wtxn, &field_distribution)?;
|
self.index.put_field_distribution(self.wtxn, &field_distribution)?;
|
||||||
@ -562,7 +563,8 @@ where
|
|||||||
writer.build(wtxn, &mut rng, None)?;
|
writer.build(wtxn, &mut rng, None)?;
|
||||||
}
|
}
|
||||||
Result::Ok(())
|
Result::Ok(())
|
||||||
})?;
|
})
|
||||||
|
.map_err(InternalError::from)??;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.execute_prefix_databases(
|
self.execute_prefix_databases(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use grenad::CompressionType;
|
use grenad::CompressionType;
|
||||||
use rayon::ThreadPool;
|
|
||||||
|
use crate::thread_pool_no_abort::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct IndexerConfig {
|
pub struct IndexerConfig {
|
||||||
@ -9,7 +10,7 @@ pub struct IndexerConfig {
|
|||||||
pub max_memory: Option<usize>,
|
pub max_memory: Option<usize>,
|
||||||
pub chunk_compression_type: CompressionType,
|
pub chunk_compression_type: CompressionType,
|
||||||
pub chunk_compression_level: Option<u32>,
|
pub chunk_compression_level: Option<u32>,
|
||||||
pub thread_pool: Option<ThreadPool>,
|
pub thread_pool: Option<ThreadPoolNoAbort>,
|
||||||
pub max_positions_per_attributes: Option<u32>,
|
pub max_positions_per_attributes: Option<u32>,
|
||||||
pub skip_index_budget: bool,
|
pub skip_index_budget: bool,
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ use std::path::PathBuf;
|
|||||||
use hf_hub::api::sync::ApiError;
|
use hf_hub::api::sync::ApiError;
|
||||||
|
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
|
use crate::PanicCatched;
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("Error while generating embeddings: {inner}")]
|
#[error("Error while generating embeddings: {inner}")]
|
||||||
@ -80,6 +81,8 @@ pub enum EmbedErrorKind {
|
|||||||
OpenAiUnexpectedDimension(usize, usize),
|
OpenAiUnexpectedDimension(usize, usize),
|
||||||
#[error("no embedding was produced")]
|
#[error("no embedding was produced")]
|
||||||
MissingEmbedding,
|
MissingEmbedding,
|
||||||
|
#[error(transparent)]
|
||||||
|
PanicInThreadPool(#[from] PanicCatched),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedError {
|
impl EmbedError {
|
||||||
|
@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use self::error::{EmbedError, NewEmbedderError};
|
use self::error::{EmbedError, NewEmbedderError};
|
||||||
use crate::prompt::{Prompt, PromptData};
|
use crate::prompt::{Prompt, PromptData};
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod hf;
|
pub mod hf;
|
||||||
@ -254,7 +255,7 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
|
@ -3,6 +3,8 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
|||||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::{DistributionShift, Embeddings};
|
||||||
|
use crate::error::FaultSource;
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
@ -71,11 +73,16 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads.install(move || {
|
threads
|
||||||
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
})
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
@ -4,7 +4,9 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
|||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::{DistributionShift, Embeddings};
|
||||||
|
use crate::error::FaultSource;
|
||||||
use crate::vector::error::EmbedErrorKind;
|
use crate::vector::error::EmbedErrorKind;
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
@ -241,11 +243,16 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads.install(move || {
|
threads
|
||||||
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
})
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
@ -2,9 +2,12 @@ use deserr::Deserr;
|
|||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::error::EmbedErrorKind;
|
||||||
use super::{
|
use super::{
|
||||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||||
};
|
};
|
||||||
|
use crate::error::FaultSource;
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
// retrying in case of failure
|
// retrying in case of failure
|
||||||
|
|
||||||
@ -158,11 +161,16 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads.install(move || {
|
threads
|
||||||
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
})
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
@ -217,9 +217,7 @@ fn add_memory_samples(
|
|||||||
memory_counters: &mut Option<MemoryCounterHandles>,
|
memory_counters: &mut Option<MemoryCounterHandles>,
|
||||||
last_memory: &mut MemoryStats,
|
last_memory: &mut MemoryStats,
|
||||||
) -> Option<MemoryStats> {
|
) -> Option<MemoryStats> {
|
||||||
let Some(stats) = memory else {
|
let stats = memory?;
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let memory_counters =
|
let memory_counters =
|
||||||
memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main));
|
memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main));
|
||||||
|
Loading…
Reference in New Issue
Block a user