Compare commits

..

1 Commits

Author SHA1 Message Date
Louis Dureuil
e186a250df
Merge 4ff2b3c2ee into 94fb55bb6f 2024-11-14 14:45:11 +00:00
33 changed files with 690 additions and 1089 deletions

View File

@ -32,15 +32,18 @@ use meilisearch_types::error::Code;
use meilisearch_types::heed::{RoTxn, RwTxn}; use meilisearch_types::heed::{RoTxn, RwTxn};
use meilisearch_types::milli::documents::{obkv_to_object, DocumentsBatchReader, PrimaryKey}; use meilisearch_types::milli::documents::{obkv_to_object, DocumentsBatchReader, PrimaryKey};
use meilisearch_types::milli::heed::CompactionOption; use meilisearch_types::milli::heed::CompactionOption;
use meilisearch_types::milli::update::new::indexer::{self, UpdateByFunction}; use meilisearch_types::milli::update::new::indexer::{
self, retrieve_or_guess_primary_key, UpdateByFunction,
};
use meilisearch_types::milli::update::{IndexDocumentsMethod, Settings as MilliSettings}; use meilisearch_types::milli::update::{IndexDocumentsMethod, Settings as MilliSettings};
use meilisearch_types::milli::vector::parsed_vectors::{ use meilisearch_types::milli::vector::parsed_vectors::{
ExplicitVectors, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME, ExplicitVectors, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME,
}; };
use meilisearch_types::milli::{self, Filter, ThreadPoolNoAbortBuilder}; use meilisearch_types::milli::{self, Filter, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
use meilisearch_types::settings::{apply_settings_to_builder, Settings, Unchecked}; use meilisearch_types::settings::{apply_settings_to_builder, Settings, Unchecked};
use meilisearch_types::tasks::{Details, IndexSwap, Kind, KindWithContent, Status, Task}; use meilisearch_types::tasks::{Details, IndexSwap, Kind, KindWithContent, Status, Task};
use meilisearch_types::{compression, Index, VERSION_FILE_NAME}; use meilisearch_types::{compression, Index, VERSION_FILE_NAME};
use raw_collections::RawMap;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use time::macros::format_description; use time::macros::format_description;
use time::OffsetDateTime; use time::OffsetDateTime;
@ -1226,7 +1229,9 @@ impl IndexScheduler {
const PRINT_SECS_DELTA: u64 = 1; const PRINT_SECS_DELTA: u64 = 1;
let processing_tasks = self.processing_tasks.clone(); let processing_tasks = self.processing_tasks.clone();
let must_stop_processing = self.must_stop_processing.clone(); let must_stop_processing = self.must_stop_processing.clone();
let send_progress = |progress| { let send_progress = |progress| {
let now = std::time::Instant::now(); let now = std::time::Instant::now();
let elapsed = secs_since_started_processing_at.load(atomic::Ordering::Relaxed); let elapsed = secs_since_started_processing_at.load(atomic::Ordering::Relaxed);
@ -1275,6 +1280,16 @@ impl IndexScheduler {
// TODO: at some point, for better efficiency we might want to reuse the bumpalo for successive batches. // TODO: at some point, for better efficiency we might want to reuse the bumpalo for successive batches.
// this is made difficult by the fact we're doing private clones of the index scheduler and sending it // this is made difficult by the fact we're doing private clones of the index scheduler and sending it
// to a fresh thread. // to a fresh thread.
/// TODO manage errors correctly
let first_addition_uuid = operations
.iter()
.find_map(|op| match op {
DocumentOperation::Add(content_uuid) => Some(content_uuid),
_ => None,
})
.unwrap();
let mut content_files = Vec::new(); let mut content_files = Vec::new();
for operation in &operations { for operation in &operations {
if let DocumentOperation::Add(content_uuid) = operation { if let DocumentOperation::Add(content_uuid) = operation {
@ -1290,30 +1305,88 @@ impl IndexScheduler {
let db_fields_ids_map = index.fields_ids_map(&rtxn)?; let db_fields_ids_map = index.fields_ids_map(&rtxn)?;
let mut new_fields_ids_map = db_fields_ids_map.clone(); let mut new_fields_ids_map = db_fields_ids_map.clone();
let first_document = match content_files.first() {
Some(mmap) => {
let mut iter = serde_json::Deserializer::from_slice(mmap).into_iter();
iter.next().transpose().map_err(|e| e.into()).map_err(Error::IoError)?
}
None => None,
};
let (primary_key, primary_key_has_been_set) = retrieve_or_guess_primary_key(
&rtxn,
index,
&mut new_fields_ids_map,
primary_key.as_deref(),
first_document
.map(|raw| RawMap::from_raw_value(raw, &indexer_alloc))
.transpose()
.map_err(|error| {
milli::Error::UserError(milli::UserError::SerdeJson(error))
})?,
)?
.map_err(milli::Error::from)?;
let mut content_files_iter = content_files.iter(); let mut content_files_iter = content_files.iter();
let mut indexer = indexer::DocumentOperation::new(method); let mut indexer = indexer::DocumentOperation::new(method);
let embedders = index.embedding_configs(index_wtxn)?; let embedders = index.embedding_configs(index_wtxn)?;
let embedders = self.embedders(embedders)?; let embedders = self.embedders(embedders)?;
for operation in operations { for (operation, task) in operations.into_iter().zip(tasks.iter_mut()) {
match operation { match operation {
DocumentOperation::Add(_content_uuid) => { DocumentOperation::Add(_content_uuid) => {
let mmap = content_files_iter.next().unwrap(); let mmap = content_files_iter.next().unwrap();
indexer.add_documents(mmap)?; let stats = indexer.add_documents(mmap)?;
// builder = builder.with_embedders(embedders.clone()); // builder = builder.with_embedders(embedders.clone());
let received_documents =
if let Some(Details::DocumentAdditionOrUpdate {
received_documents,
..
}) = task.details
{
received_documents
} else {
// In the case of a `documentAdditionOrUpdate` the details MUST be set
unreachable!();
};
task.status = Status::Succeeded;
task.details = Some(Details::DocumentAdditionOrUpdate {
received_documents,
indexed_documents: Some(stats.document_count as u64),
})
} }
DocumentOperation::Delete(document_ids) => { DocumentOperation::Delete(document_ids) => {
let count = document_ids.len();
let document_ids: bumpalo::collections::vec::Vec<_> = document_ids let document_ids: bumpalo::collections::vec::Vec<_> = document_ids
.iter() .iter()
.map(|s| &*indexer_alloc.alloc_str(s)) .map(|s| &*indexer_alloc.alloc_str(s))
.collect_in(&indexer_alloc); .collect_in(&indexer_alloc);
indexer.delete_documents(document_ids.into_bump_slice()); indexer.delete_documents(document_ids.into_bump_slice());
// Uses Invariant: remove documents actually always returns Ok for the inner result
// let count = user_result.unwrap();
let provided_ids =
if let Some(Details::DocumentDeletion { provided_ids, .. }) =
task.details
{
provided_ids
} else {
// In the case of a `documentAdditionOrUpdate` the details MUST be set
unreachable!();
};
task.status = Status::Succeeded;
task.details = Some(Details::DocumentDeletion {
provided_ids,
deleted_documents: Some(count as u64),
});
} }
} }
} }
if tasks.iter().any(|res| res.error.is_none()) {
let local_pool; let local_pool;
let indexer_config = self.index_mapper.indexer_config(); let pool = match &self.index_mapper.indexer_config().thread_pool {
let pool = match &indexer_config.thread_pool {
Some(pool) => pool, Some(pool) => pool,
None => { None => {
local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap(); local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap();
@ -1321,55 +1394,22 @@ impl IndexScheduler {
} }
}; };
let (document_changes, operation_stats, primary_key) = indexer.into_changes( // TODO we want to multithread this
let document_changes = indexer.into_changes(
&indexer_alloc, &indexer_alloc,
index, index,
&rtxn, &rtxn,
primary_key.as_deref(), &primary_key,
&mut new_fields_ids_map, &mut new_fields_ids_map,
)?; )?;
let mut addition = 0;
for (stats, task) in operation_stats.into_iter().zip(&mut tasks) {
addition += stats.document_count;
match stats.error {
Some(error) => {
task.status = Status::Failed;
task.error = Some(milli::Error::UserError(error).into());
}
None => task.status = Status::Succeeded,
}
task.details = match task.details {
Some(Details::DocumentAdditionOrUpdate { received_documents, .. }) => {
Some(Details::DocumentAdditionOrUpdate {
received_documents,
indexed_documents: Some(stats.document_count),
})
}
Some(Details::DocumentDeletion { provided_ids, .. }) => {
Some(Details::DocumentDeletion {
provided_ids,
deleted_documents: Some(stats.document_count),
})
}
_ => {
// In the case of a `documentAdditionOrUpdate` or `DocumentDeletion`
// the details MUST be set to either addition or deletion
unreachable!();
}
}
}
if tasks.iter().any(|res| res.error.is_none()) {
pool.install(|| { pool.install(|| {
indexer::index( indexer::index(
index_wtxn, index_wtxn,
index, index,
indexer_config.grenad_parameters(),
&db_fields_ids_map, &db_fields_ids_map,
new_fields_ids_map, new_fields_ids_map,
primary_key, primary_key_has_been_set.then_some(primary_key),
&document_changes, &document_changes,
embedders, embedders,
&|| must_stop_processing.get(), &|| must_stop_processing.get(),
@ -1378,7 +1418,7 @@ impl IndexScheduler {
}) })
.unwrap()?; .unwrap()?;
tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done");
} }
// else if primary_key_has_been_set { // else if primary_key_has_been_set {
// // Everything failed but we've set a primary key. // // Everything failed but we've set a primary key.
@ -1395,14 +1435,12 @@ impl IndexScheduler {
Ok(tasks) Ok(tasks)
} }
IndexOperation::DocumentEdition { mut task, .. } => { IndexOperation::DocumentEdition { mut task, .. } => {
let (filter, code) = if let KindWithContent::DocumentEdition { let (filter, context, code) =
filter_expr, if let KindWithContent::DocumentEdition {
context: _, filter_expr, context, function, ..
function,
..
} = &task.kind } = &task.kind
{ {
(filter_expr, function) (filter_expr, context, function)
} else { } else {
unreachable!() unreachable!()
}; };
@ -1459,8 +1497,7 @@ impl IndexScheduler {
if task.error.is_none() { if task.error.is_none() {
let local_pool; let local_pool;
let indexer_config = self.index_mapper.indexer_config(); let pool = match &self.index_mapper.indexer_config().thread_pool {
let pool = match &indexer_config.thread_pool {
Some(pool) => pool, Some(pool) => pool,
None => { None => {
local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap(); local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap();
@ -1478,7 +1515,6 @@ impl IndexScheduler {
indexer::index( indexer::index(
index_wtxn, index_wtxn,
index, index,
indexer_config.grenad_parameters(),
&db_fields_ids_map, &db_fields_ids_map,
new_fields_ids_map, new_fields_ids_map,
None, // cannot change primary key in DocumentEdition None, // cannot change primary key in DocumentEdition
@ -1611,8 +1647,7 @@ impl IndexScheduler {
if !tasks.iter().all(|res| res.error.is_some()) { if !tasks.iter().all(|res| res.error.is_some()) {
let local_pool; let local_pool;
let indexer_config = self.index_mapper.indexer_config(); let pool = match &self.index_mapper.indexer_config().thread_pool {
let pool = match &indexer_config.thread_pool {
Some(pool) => pool, Some(pool) => pool,
None => { None => {
local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap(); local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap();
@ -1630,7 +1665,6 @@ impl IndexScheduler {
indexer::index( indexer::index(
index_wtxn, index_wtxn,
index, index,
indexer_config.grenad_parameters(),
&db_fields_ids_map, &db_fields_ids_map,
new_fields_ids_map, new_fields_ids_map,
None, // document deletion never changes primary key None, // document deletion never changes primary key

View File

@ -971,8 +971,6 @@ impl IndexScheduler {
let ProcessingTasks { started_at, processing, progress, .. } = let ProcessingTasks { started_at, processing, progress, .. } =
self.processing_tasks.read().map_err(|_| Error::CorruptedTaskQueue)?.clone(); self.processing_tasks.read().map_err(|_| Error::CorruptedTaskQueue)?.clone();
let _ = progress;
let ret = tasks.into_iter(); let ret = tasks.into_iter();
if processing.is_empty() { if processing.is_empty() {
Ok((ret.collect(), total)) Ok((ret.collect(), total))
@ -4298,11 +4296,11 @@ mod tests {
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "only_first_task_succeed"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "only_first_task_succeed");
// The second batch should fail. // The second batch should fail.
handle.advance_one_successful_batch(); handle.advance_one_failed_batch();
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_task_fails"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_task_fails");
// The second batch should fail. // The second batch should fail.
handle.advance_one_successful_batch(); handle.advance_one_failed_batch();
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "third_task_fails"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "third_task_fails");
// Is the primary key still what we expect? // Is the primary key still what we expect?
@ -4363,7 +4361,7 @@ mod tests {
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "only_first_task_succeed"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "only_first_task_succeed");
// The second batch should fail and contains two tasks. // The second batch should fail and contains two tasks.
handle.advance_one_successful_batch(); handle.advance_one_failed_batch();
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_and_third_tasks_fails"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_and_third_tasks_fails");
// Is the primary key still what we expect? // Is the primary key still what we expect?
@ -4442,8 +4440,7 @@ mod tests {
snapshot!(primary_key, @"id"); snapshot!(primary_key, @"id");
// We're trying to `bork` again, but now there is already a primary key set for this index. // We're trying to `bork` again, but now there is already a primary key set for this index.
// NOTE: it's marked as successful because the batch didn't fails, it's the individual tasks that failed. handle.advance_one_failed_batch();
handle.advance_one_successful_batch();
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "fourth_task_fails"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "fourth_task_fails");
// Finally the last task should succeed since its primary key is the same as the valid one. // Finally the last task should succeed since its primary key is the same as the valid one.
@ -4603,7 +4600,7 @@ mod tests {
snapshot!(primary_key.is_none(), @"false"); snapshot!(primary_key.is_none(), @"false");
// The second batch should contains only one task that fails because it tries to update the primary key to `bork`. // The second batch should contains only one task that fails because it tries to update the primary key to `bork`.
handle.advance_one_successful_batch(); handle.advance_one_failed_batch();
snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_task_fails"); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_task_fails");
// The third batch should succeed and only contains one task. // The third batch should succeed and only contains one task.
@ -5216,10 +5213,9 @@ mod tests {
let configs = index_scheduler.embedders(configs).unwrap(); let configs = index_scheduler.embedders(configs).unwrap();
let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap(); let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap();
let beagle_embed = let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap();
hf_embedder.embed_one(S("Intel the beagle best doggo"), None).unwrap(); let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap();
let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo"), None).unwrap(); let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap();
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo"), None).unwrap();
(fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed) (fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed)
}; };

View File

@ -1,5 +1,5 @@
--- ---
source: crates/index-scheduler/src/lib.rs source: index-scheduler/src/lib.rs
--- ---
### Autobatching Enabled = true ### Autobatching Enabled = true
### Processing Tasks: ### Processing Tasks:
@ -22,7 +22,7 @@ succeeded [0,1,2,]
doggos [0,1,2,] doggos [0,1,2,]
---------------------------------------------------------------------- ----------------------------------------------------------------------
### Index Mapper: ### Index Mapper:
doggos: { number_of_documents: 1, field_distribution: {"breed": 1, "doggo": 1, "id": 1} } doggos: { number_of_documents: 1, field_distribution: {"_vectors": 1, "breed": 1, "doggo": 1, "id": 1} }
---------------------------------------------------------------------- ----------------------------------------------------------------------
### Canceled By: ### Canceled By:

View File

@ -220,12 +220,11 @@ pub fn read_json(input: &File, output: impl io::Write) -> Result<u64> {
let mut out = BufWriter::new(output); let mut out = BufWriter::new(output);
let mut deserializer = serde_json::Deserializer::from_slice(&input); let mut deserializer = serde_json::Deserializer::from_slice(&input);
let res = array_each(&mut deserializer, |obj: &RawValue| { let count = match array_each(&mut deserializer, |obj: &RawValue| {
doc_alloc.reset(); doc_alloc.reset();
let map = RawMap::from_raw_value(obj, &doc_alloc)?; let map = RawMap::from_raw_value(obj, &doc_alloc)?;
to_writer(&mut out, &map) to_writer(&mut out, &map)
}); }) {
let count = match res {
// The json data has been deserialized and does not need to be processed again. // The json data has been deserialized and does not need to be processed again.
// The data has been transferred to the writer during the deserialization process. // The data has been transferred to the writer during the deserialization process.
Ok(Ok(count)) => count, Ok(Ok(count)) => count,

View File

@ -796,10 +796,8 @@ fn prepare_search<'t>(
let span = tracing::trace_span!(target: "search::vector", "embed_one"); let span = tracing::trace_span!(target: "search::vector", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
embedder embedder
.embed_one(query.q.clone().unwrap(), Some(deadline)) .embed_one(query.q.clone().unwrap())
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
} }

View File

@ -585,9 +585,9 @@ async fn federation_two_search_two_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
-100, -100.0,
340, 340.0,
90 90.0
] ]
}, },
"_federation": { "_federation": {
@ -613,9 +613,9 @@ async fn federation_two_search_two_indexes() {
"cattos": "pésti", "cattos": "pésti",
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
3 3.0
] ]
}, },
"_federation": { "_federation": {
@ -640,9 +640,9 @@ async fn federation_two_search_two_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
54 54.0
] ]
}, },
"_federation": { "_federation": {
@ -707,9 +707,9 @@ async fn federation_multiple_search_multiple_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
-100, -100.0,
340, 340.0,
90 90.0
] ]
}, },
"_federation": { "_federation": {
@ -735,9 +735,9 @@ async fn federation_multiple_search_multiple_indexes() {
"cattos": "pésti", "cattos": "pésti",
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
3 3.0
] ]
}, },
"_federation": { "_federation": {
@ -773,9 +773,9 @@ async fn federation_multiple_search_multiple_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
54 54.0
] ]
}, },
"_federation": { "_federation": {
@ -793,9 +793,9 @@ async fn federation_multiple_search_multiple_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
10, 10.0,
-23, -23.0,
32 32.0
] ]
}, },
"_federation": { "_federation": {
@ -824,9 +824,9 @@ async fn federation_multiple_search_multiple_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
10, 10.0,
23, 23.0,
32 32.0
] ]
}, },
"_federation": { "_federation": {
@ -869,9 +869,9 @@ async fn federation_multiple_search_multiple_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
54 54.0
] ]
}, },
"_federation": { "_federation": {
@ -898,9 +898,9 @@ async fn federation_multiple_search_multiple_indexes() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
-100, -100.0,
231, 231.0,
32 32.0
] ]
}, },
"_federation": { "_federation": {
@ -1522,9 +1522,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() {
"cattos": "pésti", "cattos": "pésti",
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
3 3.0
] ]
}, },
"_federation": { "_federation": {
@ -1550,9 +1550,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
54 54.0
] ]
}, },
"_federation": { "_federation": {
@ -1582,9 +1582,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
10, 10.0,
23, 23.0,
32 32.0
] ]
}, },
"_federation": { "_federation": {
@ -1845,9 +1845,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
54 54.0
] ]
}, },
"_federation": { "_federation": {
@ -1874,9 +1874,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() {
"cattos": "pésti", "cattos": "pésti",
"_vectors": { "_vectors": {
"manual": [ "manual": [
1, 1.0,
2, 2.0,
3 3.0
] ]
}, },
"_federation": { "_federation": {
@ -1906,9 +1906,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() {
], ],
"_vectors": { "_vectors": {
"manual": [ "manual": [
10, 10.0,
23, 23.0,
32 32.0
] ]
}, },
"_federation": { "_federation": {

View File

@ -137,14 +137,13 @@ fn long_text() -> &'static str {
} }
async fn create_mock_tokenized() -> (MockServer, Value) { async fn create_mock_tokenized() -> (MockServer, Value) {
create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false, false).await create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false).await
} }
async fn create_mock_with_template( async fn create_mock_with_template(
document_template: &str, document_template: &str,
model_dimensions: ModelDimensions, model_dimensions: ModelDimensions,
fallible: bool, fallible: bool,
slow: bool,
) -> (MockServer, Value) { ) -> (MockServer, Value) {
let mock_server = MockServer::start().await; let mock_server = MockServer::start().await;
const API_KEY: &str = "my-api-key"; const API_KEY: &str = "my-api-key";
@ -155,11 +154,7 @@ async fn create_mock_with_template(
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/")) .and(path("/"))
.respond_with(move |req: &Request| { .respond_with(move |req: &Request| {
// 0. wait for a long time // 0. maybe return 500
if slow {
std::thread::sleep(std::time::Duration::from_secs(1));
}
// 1. maybe return 500
if fallible { if fallible {
let attempt = attempt.fetch_add(1, Ordering::Relaxed); let attempt = attempt.fetch_add(1, Ordering::Relaxed);
let failed = matches!(attempt % 4, 0 | 1 | 3); let failed = matches!(attempt % 4, 0 | 1 | 3);
@ -172,7 +167,7 @@ async fn create_mock_with_template(
})) }))
} }
} }
// 3. check API key // 1. check API key
match req.headers.get("Authorization") { match req.headers.get("Authorization") {
Some(api_key) if api_key == API_KEY_BEARER => { Some(api_key) if api_key == API_KEY_BEARER => {
{} {}
@ -207,7 +202,7 @@ async fn create_mock_with_template(
) )
} }
} }
// 3. parse text inputs // 2. parse text inputs
let query: serde_json::Value = match req.body_json() { let query: serde_json::Value = match req.body_json() {
Ok(query) => query, Ok(query) => query,
Err(_error) => return ResponseTemplate::new(400).set_body_json( Err(_error) => return ResponseTemplate::new(400).set_body_json(
@ -228,7 +223,7 @@ async fn create_mock_with_template(
panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}") panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}")
} }
// 4. for each text, find embedding in responses // 3. for each text, find embedding in responses
let serde_json::Value::Array(inputs) = &query["input"] else { let serde_json::Value::Array(inputs) = &query["input"] else {
panic!("Unexpected `input` value") panic!("Unexpected `input` value")
}; };
@ -288,7 +283,7 @@ async fn create_mock_with_template(
"embedding": embedding, "embedding": embedding,
})).collect(); })).collect();
// 5. produce output from embeddings // 4. produce output from embeddings
ResponseTemplate::new(200).set_body_json(json!({ ResponseTemplate::new(200).set_body_json(json!({
"object": "list", "object": "list",
"data": data, "data": data,
@ -322,27 +317,23 @@ const DOGGO_TEMPLATE: &str = r#"{%- if doc.gender == "F" -%}Une chienne nommée
{%- endif %}, de race {{doc.breed}}."#; {%- endif %}, de race {{doc.breed}}."#;
async fn create_mock() -> (MockServer, Value) { async fn create_mock() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false).await
} }
async fn create_mock_dimensions() -> (MockServer, Value) { async fn create_mock_dimensions() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false).await
} }
async fn create_mock_small_embedding_model() -> (MockServer, Value) { async fn create_mock_small_embedding_model() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false).await
} }
async fn create_mock_legacy_embedding_model() -> (MockServer, Value) { async fn create_mock_legacy_embedding_model() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false).await
} }
async fn create_fallible_mock() -> (MockServer, Value) { async fn create_fallible_mock() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true).await
}
async fn create_slow_mock() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, true).await
} }
// basic test "it works" // basic test "it works"
@ -1882,114 +1873,4 @@ async fn it_still_works() {
] ]
"###); "###);
} }
// test with a server that responds 500 on 3 out of 4 calls
#[actix_rt::test]
async fn timeout() {
let (_mock, setting) = create_slow_mock().await;
let server = get_server_vector().await;
let index = server.index("doggo");
let (response, code) = index
.update_settings(json!({
"embedders": {
"default": setting,
},
}))
.await;
snapshot!(code, @"202 Accepted");
let task = server.wait_task(response.uid()).await;
snapshot!(task["status"], @r###""succeeded""###);
let documents = json!([
{"id": 0, "name": "kefir", "gender": "M", "birthyear": 2023, "breed": "Patou"},
]);
let (value, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
let task = index.wait_task(value.uid()).await;
snapshot!(task, @r###"
{
"uid": "[uid]",
"indexUid": "doggo",
"status": "succeeded",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 1
},
"error": null,
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
let (documents, _code) = index
.get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() })
.await;
snapshot!(json_string!(documents, {".results.*._vectors.default.embeddings" => "[vector]"}), @r###"
{
"results": [
{
"id": 0,
"name": "kefir",
"gender": "M",
"birthyear": 2023,
"breed": "Patou",
"_vectors": {
"default": {
"embeddings": "[vector]",
"regenerate": true
}
}
}
],
"offset": 0,
"limit": 20,
"total": 1
}
"###);
let (response, code) = index
.search_post(json!({
"q": "grand chien de berger des montagnes",
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["semanticHitCount"]), @"0");
snapshot!(json_string!(response["hits"]), @"[]");
let (response, code) = index
.search_post(json!({
"q": "grand chien de berger des montagnes",
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["semanticHitCount"]), @"1");
snapshot!(json_string!(response["hits"]), @r###"
[
{
"id": 0,
"name": "kefir",
"gender": "M",
"birthyear": 2023,
"breed": "Patou"
}
]
"###);
let (response, code) = index
.search_post(json!({
"q": "grand chien de berger des montagnes",
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["semanticHitCount"]), @"0");
snapshot!(json_string!(response["hits"]), @"[]");
}
// test with a server that wrongly responds 400 // test with a server that wrongly responds 400

View File

@ -201,9 +201,7 @@ impl<'a> Search<'a> {
let span = tracing::trace_span!(target: "search::hybrid", "embed_one"); let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); match embedder.embed_one(query) {
match embedder.embed_one(query, Some(deadline)) {
Ok(embedding) => embedding, Ok(embedding) => embedding,
Err(error) => { Err(error) => {
tracing::error!(error=%error, "Embedding failed"); tracing::error!(error=%error, "Embedding failed");

View File

@ -119,8 +119,12 @@ impl GrenadParameters {
/// ///
/// This should be called inside of a rayon thread pool, /// This should be called inside of a rayon thread pool,
/// otherwise, it will take the global number of threads. /// otherwise, it will take the global number of threads.
///
/// The max memory cannot exceed a given reasonable value.
pub fn max_memory_by_thread(&self) -> Option<usize> { pub fn max_memory_by_thread(&self) -> Option<usize> {
self.max_memory.map(|max_memory| (max_memory / rayon::current_num_threads())) self.max_memory.map(|max_memory| {
(max_memory / rayon::current_num_threads()).min(MAX_GRENAD_SORTER_USAGE)
})
} }
} }

View File

@ -1,6 +1,5 @@
use grenad::CompressionType; use grenad::CompressionType;
use super::GrenadParameters;
use crate::thread_pool_no_abort::ThreadPoolNoAbort; use crate::thread_pool_no_abort::ThreadPoolNoAbort;
#[derive(Debug)] #[derive(Debug)]
@ -16,17 +15,6 @@ pub struct IndexerConfig {
pub skip_index_budget: bool, pub skip_index_budget: bool,
} }
impl IndexerConfig {
pub fn grenad_parameters(&self) -> GrenadParameters {
GrenadParameters {
chunk_compression_type: self.chunk_compression_type,
chunk_compression_level: self.chunk_compression_level,
max_memory: self.max_memory,
max_nb_chunks: self.max_nb_chunks,
}
}
}
impl Default for IndexerConfig { impl Default for IndexerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {

View File

@ -70,30 +70,28 @@ impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> {
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'t str, &'t RawValue)>> { fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'t str, &'t RawValue)>> {
let mut it = self.content.iter(); let mut it = self.content.iter();
std::iter::from_fn(move || loop { std::iter::from_fn(move || {
let (fid, value) = it.next()?; let (fid, value) = it.next()?;
let name = match self.fields_ids_map.name(fid).ok_or(
let res = (|| loop {
let name = self.fields_ids_map.name(fid).ok_or(
InternalError::FieldIdMapMissingEntry(crate::FieldIdMapMissingEntry::FieldId { InternalError::FieldIdMapMissingEntry(crate::FieldIdMapMissingEntry::FieldId {
field_id: fid, field_id: fid,
process: "getting current document", process: "getting current document",
}), }),
) { )?;
Ok(name) => name,
Err(error) => return Some(Err(error.into())),
};
if name == RESERVED_VECTORS_FIELD_NAME || name == "_geo" { if name == RESERVED_VECTORS_FIELD_NAME || name == "_geo" {
continue; continue;
} }
let res = (|| {
let value = let value =
serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?; serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?;
Ok((name, value)) return Ok((name, value));
})(); })();
return Some(res); Some(res)
}) })
} }

View File

@ -79,7 +79,7 @@ use roaring::RoaringBitmap;
use rustc_hash::FxBuildHasher; use rustc_hash::FxBuildHasher;
use crate::update::del_add::{DelAdd, KvWriterDelAdd}; use crate::update::del_add::{DelAdd, KvWriterDelAdd};
use crate::update::new::thread_local::MostlySend; use crate::update::new::indexer::document_changes::MostlySend;
use crate::update::new::KvReaderDelAdd; use crate::update::new::KvReaderDelAdd;
use crate::update::MergeDeladdCboRoaringBitmaps; use crate::update::MergeDeladdCboRoaringBitmaps;
use crate::{CboRoaringBitmapCodec, Result}; use crate::{CboRoaringBitmapCodec, Result};

View File

@ -6,9 +6,8 @@ use hashbrown::HashMap;
use super::DelAddRoaringBitmap; use super::DelAddRoaringBitmap;
use crate::update::new::channel::DocumentsSender; use crate::update::new::channel::DocumentsSender;
use crate::update::new::document::{write_to_obkv, Document as _}; use crate::update::new::document::{write_to_obkv, Document as _};
use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor, FullySend};
use crate::update::new::ref_cell_ext::RefCellExt as _; use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::thread_local::FullySend;
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::vector::EmbeddingConfigs; use crate::vector::EmbeddingConfigs;
use crate::Result; use crate::Result;
@ -44,12 +43,10 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentsExtractor<'a> {
let mut document_buffer = bumpalo::collections::Vec::new_in(&context.doc_alloc); let mut document_buffer = bumpalo::collections::Vec::new_in(&context.doc_alloc);
let mut document_extractor_data = context.data.0.borrow_mut_or_yield(); let mut document_extractor_data = context.data.0.borrow_mut_or_yield();
for change in changes {
let change = change?;
// **WARNING**: the exclusive borrow on `new_fields_ids_map` needs to be taken **inside** of the `for change in changes` loop
// Otherwise, `BorrowMutError` will occur for document changes that also need the new_fields_ids_map (e.g.: UpdateByFunction)
let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield();
for change in changes {
let change = change?;
let external_docid = change.external_docid().to_owned(); let external_docid = change.external_docid().to_owned();
// document but we need to create a function that collects and compresses documents. // document but we need to create a function that collects and compresses documents.

View File

@ -15,10 +15,10 @@ use crate::heed_codec::facet::OrderedF64Codec;
use crate::update::del_add::DelAdd; use crate::update::del_add::DelAdd;
use crate::update::new::channel::FieldIdDocidFacetSender; use crate::update::new::channel::FieldIdDocidFacetSender;
use crate::update::new::indexer::document_changes::{ use crate::update::new::indexer::document_changes::{
extract, DocumentChangeContext, DocumentChanges, Extractor, IndexingContext, Progress, extract, DocumentChangeContext, DocumentChanges, Extractor, FullySend, IndexingContext,
Progress, ThreadLocal,
}; };
use crate::update::new::ref_cell_ext::RefCellExt as _; use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::thread_local::{FullySend, ThreadLocal};
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::update::GrenadParameters; use crate::update::GrenadParameters;
use crate::{DocumentId, FieldId, Index, Result, MAX_FACET_VALUE_LENGTH}; use crate::{DocumentId, FieldId, Index, Result, MAX_FACET_VALUE_LENGTH};
@ -36,7 +36,7 @@ impl<'a, 'extractor> Extractor<'extractor> for FacetedExtractorData<'a> {
fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> { fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> {
Ok(RefCell::new(BalancedCaches::new_in( Ok(RefCell::new(BalancedCaches::new_in(
self.buckets, self.buckets,
self.grenad_parameters.max_memory_by_thread(), self.grenad_parameters.max_memory,
extractor_alloc, extractor_alloc,
))) )))
} }
@ -156,7 +156,6 @@ impl FacetedDocidsExtractor {
res res
} }
#[allow(clippy::too_many_arguments)]
fn facet_fn_with_options<'extractor, 'doc>( fn facet_fn_with_options<'extractor, 'doc>(
doc_alloc: &'doc Bump, doc_alloc: &'doc Bump,
cached_sorter: &mut BalancedCaches<'extractor>, cached_sorter: &mut BalancedCaches<'extractor>,
@ -337,7 +336,6 @@ fn truncate_str(s: &str) -> &str {
} }
impl FacetedDocidsExtractor { impl FacetedDocidsExtractor {
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(level = "trace", skip_all, target = "indexing::extract::faceted")] #[tracing::instrument(level = "trace", skip_all, target = "indexing::extract::faceted")]
pub fn run_extraction< pub fn run_extraction<
'pl, 'pl,

View File

@ -11,9 +11,8 @@ use serde_json::Value;
use crate::error::GeoError; use crate::error::GeoError;
use crate::update::new::document::Document; use crate::update::new::document::Document;
use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor, MostlySend};
use crate::update::new::ref_cell_ext::RefCellExt as _; use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::thread_local::MostlySend;
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::update::GrenadParameters; use crate::update::GrenadParameters;
use crate::{lat_lng_to_xyz, DocumentId, GeoPoint, Index, InternalError, Result}; use crate::{lat_lng_to_xyz, DocumentId, GeoPoint, Index, InternalError, Result};
@ -151,7 +150,7 @@ impl<'extractor> Extractor<'extractor> for GeoExtractor {
) -> Result<()> { ) -> Result<()> {
let rtxn = &context.rtxn; let rtxn = &context.rtxn;
let index = context.index; let index = context.index;
let max_memory = self.grenad_parameters.max_memory_by_thread(); let max_memory = self.grenad_parameters.max_memory;
let db_fields_ids_map = context.db_fields_ids_map; let db_fields_ids_map = context.db_fields_ids_map;
let mut data_ref = context.data.borrow_mut_or_yield(); let mut data_ref = context.data.borrow_mut_or_yield();

View File

@ -13,8 +13,9 @@ pub use geo::*;
pub use searchable::*; pub use searchable::*;
pub use vectors::EmbeddingExtractor; pub use vectors::EmbeddingExtractor;
use super::indexer::document_changes::{DocumentChanges, IndexingContext, Progress}; use super::indexer::document_changes::{
use super::thread_local::{FullySend, ThreadLocal}; DocumentChanges, FullySend, IndexingContext, Progress, ThreadLocal,
};
use crate::update::GrenadParameters; use crate::update::GrenadParameters;
use crate::Result; use crate::Result;

View File

@ -11,10 +11,10 @@ use super::tokenize_document::{tokenizer_builder, DocumentTokenizer};
use crate::update::new::extract::cache::BalancedCaches; use crate::update::new::extract::cache::BalancedCaches;
use crate::update::new::extract::perm_json_p::contained_in; use crate::update::new::extract::perm_json_p::contained_in;
use crate::update::new::indexer::document_changes::{ use crate::update::new::indexer::document_changes::{
extract, DocumentChangeContext, DocumentChanges, Extractor, IndexingContext, Progress, extract, DocumentChangeContext, DocumentChanges, Extractor, FullySend, IndexingContext,
MostlySend, Progress, ThreadLocal,
}; };
use crate::update::new::ref_cell_ext::RefCellExt as _; use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::thread_local::{FullySend, MostlySend, ThreadLocal};
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::update::GrenadParameters; use crate::update::GrenadParameters;
use crate::{bucketed_position, DocumentId, FieldId, Index, Result, MAX_POSITION_PER_ATTRIBUTE}; use crate::{bucketed_position, DocumentId, FieldId, Index, Result, MAX_POSITION_PER_ATTRIBUTE};
@ -214,7 +214,7 @@ impl<'a, 'extractor> Extractor<'extractor> for WordDocidsExtractorData<'a> {
fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> { fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> {
Ok(RefCell::new(Some(WordDocidsBalancedCaches::new_in( Ok(RefCell::new(Some(WordDocidsBalancedCaches::new_in(
self.buckets, self.buckets,
self.grenad_parameters.max_memory_by_thread(), self.grenad_parameters.max_memory,
extractor_alloc, extractor_alloc,
)))) ))))
} }

View File

@ -14,9 +14,9 @@ use tokenize_document::{tokenizer_builder, DocumentTokenizer};
use super::cache::BalancedCaches; use super::cache::BalancedCaches;
use super::DocidsExtractor; use super::DocidsExtractor;
use crate::update::new::indexer::document_changes::{ use crate::update::new::indexer::document_changes::{
extract, DocumentChangeContext, DocumentChanges, Extractor, IndexingContext, Progress, extract, DocumentChangeContext, DocumentChanges, Extractor, FullySend, IndexingContext,
Progress, ThreadLocal,
}; };
use crate::update::new::thread_local::{FullySend, ThreadLocal};
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::update::GrenadParameters; use crate::update::GrenadParameters;
use crate::{Index, Result, MAX_POSITION_PER_ATTRIBUTE}; use crate::{Index, Result, MAX_POSITION_PER_ATTRIBUTE};
@ -36,7 +36,7 @@ impl<'a, 'extractor, EX: SearchableExtractor + Sync> Extractor<'extractor>
fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> { fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result<Self::Data> {
Ok(RefCell::new(BalancedCaches::new_in( Ok(RefCell::new(BalancedCaches::new_in(
self.buckets, self.buckets,
self.grenad_parameters.max_memory_by_thread(), self.grenad_parameters.max_memory,
extractor_alloc, extractor_alloc,
))) )))
} }

View File

@ -8,8 +8,7 @@ use super::cache::DelAddRoaringBitmap;
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::prompt::Prompt; use crate::prompt::Prompt;
use crate::update::new::channel::EmbeddingSender; use crate::update::new::channel::EmbeddingSender;
use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor, MostlySend};
use crate::update::new::thread_local::MostlySend;
use crate::update::new::vector_document::VectorDocument; use crate::update::new::vector_document::VectorDocument;
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::vector::error::{ use crate::vector::error::{

View File

@ -8,9 +8,182 @@ use rayon::iter::IndexedParallelIterator;
use super::super::document_change::DocumentChange; use super::super::document_change::DocumentChange;
use crate::fields_ids_map::metadata::FieldIdMapWithMetadata; use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
use crate::update::new::parallel_iterator_ext::ParallelIteratorExt as _; use crate::update::new::parallel_iterator_ext::ParallelIteratorExt as _;
use crate::update::new::thread_local::{FullySend, MostlySend, ThreadLocal};
use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result}; use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result};
/// A trait for types that are **not** [`Send`] only because they would then allow concurrent access to a type that is not [`Sync`].
///
/// The primary example of such a type is `&T`, with `T: !Sync`.
///
/// In the authors' understanding, a type can be `!Send` for two distinct reasons:
///
/// 1. Because it contains data that *genuinely* cannot be moved between threads, such as thread-local data.
/// 2. Because sending the type would allow concurrent access to a `!Sync` type, which is undefined behavior.
///
/// `MostlySend` exists to be used in bounds where you need a type whose data is **not** *attached* to a thread
/// because you might access it from a different thread, but where you will never access the type **concurrently** from
/// multiple threads.
///
/// Like [`Send`], `MostlySend` assumes properties on types that cannot be verified by the compiler, which is why implementing
/// this trait is unsafe.
///
/// # Safety
///
/// Implementers of this trait promises that the following properties hold on the implementing type:
///
/// 1. Its data can be accessed from any thread and will be the same regardless of the thread accessing it.
/// 2. Any operation that can be performed on the type does not depend on the thread that executes it.
///
/// As these properties are subtle and are not generally tracked by the Rust type system, great care should be taken before
/// implementing `MostlySend` on a type, especially a foreign type.
///
/// - An example of a type that verifies (1) and (2) is [`std::rc::Rc`] (when `T` is `Send` and `Sync`).
/// - An example of a type that doesn't verify (1) is thread-local data.
/// - An example of a type that doesn't verify (2) is [`std::sync::MutexGuard`]: a lot of mutex implementations require that
/// a lock is returned to the operating system on the same thread that initially locked the mutex, failing to uphold this
/// invariant will cause Undefined Behavior
/// (see last § in [the nomicon](https://doc.rust-lang.org/nomicon/send-and-sync.html)).
///
/// It is **always safe** to implement this trait on a type that is `Send`, but no placeholder impl is provided due to limitations in
/// coherency. Use the [`FullySend`] wrapper in this situation.
pub unsafe trait MostlySend {}
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct FullySend<T>(pub T);
// SAFETY: a type **fully** send is always mostly send as well.
unsafe impl<T> MostlySend for FullySend<T> where T: Send {}
unsafe impl<T> MostlySend for RefCell<T> where T: MostlySend {}
unsafe impl<T> MostlySend for Option<T> where T: MostlySend {}
impl<T> FullySend<T> {
pub fn into(self) -> T {
self.0
}
}
impl<T> From<T> for FullySend<T> {
fn from(value: T) -> Self {
Self(value)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct MostlySendWrapper<T>(T);
impl<T: MostlySend> MostlySendWrapper<T> {
/// # Safety
///
/// - (P1) Users of this type will never access the type concurrently from multiple threads without synchronization
unsafe fn new(t: T) -> Self {
Self(t)
}
fn as_ref(&self) -> &T {
&self.0
}
fn as_mut(&mut self) -> &mut T {
&mut self.0
}
fn into_inner(self) -> T {
self.0
}
}
/// # Safety
///
/// 1. `T` is [`MostlySend`], so by its safety contract it can be accessed by any thread and all of its operations are available
/// from any thread.
/// 2. (P1) of `MostlySendWrapper::new` forces the user to never access the value from multiple threads concurrently.
unsafe impl<T: MostlySend> Send for MostlySendWrapper<T> {}
/// A wrapper around [`thread_local::ThreadLocal`] that accepts [`MostlySend`] `T`s.
#[derive(Default)]
pub struct ThreadLocal<T: MostlySend> {
inner: thread_local::ThreadLocal<MostlySendWrapper<T>>,
// FIXME: this should be necessary
//_no_send: PhantomData<*mut ()>,
}
impl<T: MostlySend> ThreadLocal<T> {
pub fn new() -> Self {
Self { inner: thread_local::ThreadLocal::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self { inner: thread_local::ThreadLocal::with_capacity(capacity) }
}
pub fn clear(&mut self) {
self.inner.clear()
}
pub fn get(&self) -> Option<&T> {
self.inner.get().map(|t| t.as_ref())
}
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> T,
{
/// TODO: move ThreadLocal, MostlySend, FullySend to a dedicated file
self.inner.get_or(|| unsafe { MostlySendWrapper::new(create()) }).as_ref()
}
pub fn get_or_try<F, E>(&self, create: F) -> std::result::Result<&T, E>
where
F: FnOnce() -> std::result::Result<T, E>,
{
self.inner
.get_or_try(|| unsafe { Ok(MostlySendWrapper::new(create()?)) })
.map(MostlySendWrapper::as_ref)
}
pub fn get_or_default(&self) -> &T
where
T: Default,
{
self.inner.get_or_default().as_ref()
}
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut(self.inner.iter_mut())
}
}
impl<T: MostlySend> IntoIterator for ThreadLocal<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter(self.inner.into_iter())
}
}
pub struct IterMut<'a, T: MostlySend>(thread_local::IterMut<'a, MostlySendWrapper<T>>);
impl<'a, T: MostlySend> Iterator for IterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|t| t.as_mut())
}
}
pub struct IntoIter<T: MostlySend>(thread_local::IntoIter<MostlySendWrapper<T>>);
impl<T: MostlySend> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|t| t.into_inner())
}
}
pub struct DocumentChangeContext< pub struct DocumentChangeContext<
'doc, // covariant lifetime of a single `process` call 'doc, // covariant lifetime of a single `process` call
'extractor: 'doc, // invariant lifetime of the extractor_allocs 'extractor: 'doc, // invariant lifetime of the extractor_allocs

View File

@ -4,9 +4,8 @@ use rayon::iter::IndexedParallelIterator;
use rayon::slice::ParallelSlice as _; use rayon::slice::ParallelSlice as _;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use super::document_changes::{DocumentChangeContext, DocumentChanges}; use super::document_changes::{DocumentChangeContext, DocumentChanges, MostlySend};
use crate::documents::PrimaryKey; use crate::documents::PrimaryKey;
use crate::update::new::thread_local::MostlySend;
use crate::update::new::{Deletion, DocumentChange}; use crate::update::new::{Deletion, DocumentChange};
use crate::{DocumentId, Result}; use crate::{DocumentId, Result};
@ -93,10 +92,9 @@ mod test {
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::index::tests::TempIndex; use crate::index::tests::TempIndex;
use crate::update::new::indexer::document_changes::{ use crate::update::new::indexer::document_changes::{
extract, DocumentChangeContext, Extractor, IndexingContext, extract, DocumentChangeContext, Extractor, IndexingContext, MostlySend, ThreadLocal,
}; };
use crate::update::new::indexer::DocumentDeletion; use crate::update::new::indexer::DocumentDeletion;
use crate::update::new::thread_local::{MostlySend, ThreadLocal};
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::DocumentId; use crate::DocumentId;

View File

@ -1,377 +1,27 @@
use bumpalo::collections::CollectIn; use bumpalo::collections::CollectIn;
use bumpalo::Bump; use bumpalo::Bump;
use hashbrown::hash_map::Entry;
use heed::RoTxn; use heed::RoTxn;
use memmap2::Mmap; use memmap2::Mmap;
use raw_collections::RawMap;
use rayon::slice::ParallelSlice; use rayon::slice::ParallelSlice;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use serde_json::Deserializer; use IndexDocumentsMethod as Idm;
use super::super::document_change::DocumentChange; use super::super::document_change::DocumentChange;
use super::document_changes::{DocumentChangeContext, DocumentChanges}; use super::document_changes::{DocumentChangeContext, DocumentChanges, MostlySend};
use super::retrieve_or_guess_primary_key;
use crate::documents::PrimaryKey; use crate::documents::PrimaryKey;
use crate::update::new::document::Versions; use crate::update::new::document::Versions;
use crate::update::new::thread_local::MostlySend;
use crate::update::new::{Deletion, Insertion, Update}; use crate::update::new::{Deletion, Insertion, Update};
use crate::update::{AvailableIds, IndexDocumentsMethod}; use crate::update::{AvailableIds, IndexDocumentsMethod};
use crate::{DocumentId, Error, FieldsIdsMap, Index, InternalError, Result, UserError}; use crate::{DocumentId, Error, FieldsIdsMap, Index, Result, UserError};
pub struct DocumentOperation<'pl> { pub struct DocumentOperation<'pl> {
operations: Vec<Payload<'pl>>, operations: Vec<Payload<'pl>>,
method: MergeMethod, index_documents_method: IndexDocumentsMethod,
}
impl<'pl> DocumentOperation<'pl> {
pub fn new(method: IndexDocumentsMethod) -> Self {
Self { operations: Default::default(), method: MergeMethod::from(method) }
}
/// TODO please give me a type
/// The payload is expected to be in the grenad format
pub fn add_documents(&mut self, payload: &'pl Mmap) -> Result<()> {
payload.advise(memmap2::Advice::Sequential)?;
self.operations.push(Payload::Addition(&payload[..]));
Ok(())
}
pub fn delete_documents(&mut self, to_delete: &'pl [&'pl str]) {
self.operations.push(Payload::Deletion(to_delete))
}
pub fn into_changes(
self,
indexer: &'pl Bump,
index: &Index,
rtxn: &'pl RoTxn<'pl>,
primary_key_from_op: Option<&'pl str>,
new_fields_ids_map: &mut FieldsIdsMap,
) -> Result<(DocumentOperationChanges<'pl>, Vec<PayloadStats>, Option<PrimaryKey<'pl>>)> {
let Self { operations, method } = self;
let documents_ids = index.documents_ids(rtxn)?;
let mut operations_stats = Vec::new();
let mut available_docids = AvailableIds::new(&documents_ids);
let mut docids_version_offsets = hashbrown::HashMap::new();
let mut primary_key = None;
for operation in operations {
let (bytes, document_count, result) = match operation {
Payload::Addition(payload) => extract_addition_payload_changes(
indexer,
index,
rtxn,
primary_key_from_op,
&mut primary_key,
new_fields_ids_map,
&mut available_docids,
&docids_version_offsets,
method,
payload,
),
Payload::Deletion(to_delete) => extract_deletion_payload_changes(
index,
rtxn,
&mut available_docids,
&docids_version_offsets,
method,
to_delete,
),
};
let error = match result {
Ok(new_docids_version_offsets) => {
// If we don't have any error then we can merge the content of this payload
// into to main payload. Else we just drop this payload extraction.
merge_version_offsets(&mut docids_version_offsets, new_docids_version_offsets);
None
}
Err(Error::UserError(user_error)) => Some(user_error),
Err(e) => return Err(e),
};
operations_stats.push(PayloadStats { document_count, bytes, error });
}
// TODO We must drain the HashMap into a Vec because rayon::hash_map::IntoIter: !Clone
let mut docids_version_offsets: bumpalo::collections::vec::Vec<_> =
docids_version_offsets.drain().collect_in(indexer);
// Reorder the offsets to make sure we iterate on the file sequentially
// And finally sort them
docids_version_offsets.sort_unstable_by_key(|(_, po)| method.sort_key(&po.operations));
let docids_version_offsets = docids_version_offsets.into_bump_slice();
Ok((DocumentOperationChanges { docids_version_offsets }, operations_stats, primary_key))
}
}
#[allow(clippy::too_many_arguments)]
fn extract_addition_payload_changes<'r, 'pl: 'r>(
indexer: &'pl Bump,
index: &Index,
rtxn: &'r RoTxn<'r>,
primary_key_from_op: Option<&'r str>,
primary_key: &mut Option<PrimaryKey<'r>>,
new_fields_ids_map: &mut FieldsIdsMap,
available_docids: &mut AvailableIds,
main_docids_version_offsets: &hashbrown::HashMap<&'pl str, PayloadOperations<'pl>>,
method: MergeMethod,
payload: &'pl [u8],
) -> (u64, u64, Result<hashbrown::HashMap<&'pl str, PayloadOperations<'pl>>>) {
let mut new_docids_version_offsets = hashbrown::HashMap::<&str, PayloadOperations<'pl>>::new();
/// TODO manage the error
let mut previous_offset = 0;
let mut iter = Deserializer::from_slice(payload).into_iter::<&RawValue>();
loop {
let optdoc = match iter.next().transpose() {
Ok(optdoc) => optdoc,
Err(e) => {
return (
payload.len() as u64,
new_docids_version_offsets.len() as u64,
Err(InternalError::SerdeJson(e).into()),
)
}
};
// Only guess the primary key if it is the first document
let retrieved_primary_key = if previous_offset == 0 {
let optdoc = match optdoc {
Some(doc) => match RawMap::from_raw_value(doc, indexer) {
Ok(docmap) => Some(docmap),
Err(error) => {
return (
payload.len() as u64,
new_docids_version_offsets.len() as u64,
Err(Error::UserError(UserError::SerdeJson(error))),
)
}
},
None => None,
};
let result = retrieve_or_guess_primary_key(
rtxn,
index,
new_fields_ids_map,
primary_key_from_op,
optdoc,
);
let (pk, _has_been_changed) = match result {
Ok(Ok(pk)) => pk,
Ok(Err(user_error)) => {
return (
payload.len() as u64,
new_docids_version_offsets.len() as u64,
Err(Error::UserError(user_error)),
)
}
Err(error) => {
return (
payload.len() as u64,
new_docids_version_offsets.len() as u64,
Err(error),
)
}
};
primary_key.get_or_insert(pk)
} else {
primary_key.as_ref().unwrap()
};
let doc = match optdoc {
Some(doc) => doc,
None => break,
};
let external_id = match retrieved_primary_key.extract_fields_and_docid(
doc,
new_fields_ids_map,
indexer,
) {
Ok(edi) => edi,
Err(e) => {
return (payload.len() as u64, new_docids_version_offsets.len() as u64, Err(e))
}
};
let external_id = external_id.to_de();
let current_offset = iter.byte_offset();
let document_offset = DocumentOffset { content: &payload[previous_offset..current_offset] };
match main_docids_version_offsets.get(external_id) {
None => {
let (docid, is_new) = match index.external_documents_ids().get(rtxn, external_id) {
Ok(Some(docid)) => (docid, false),
Ok(None) => (
match available_docids.next() {
Some(docid) => docid,
None => {
return (
payload.len() as u64,
new_docids_version_offsets.len() as u64,
Err(UserError::DocumentLimitReached.into()),
)
}
},
true,
),
Err(e) => {
return (
payload.len() as u64,
new_docids_version_offsets.len() as u64,
Err(e.into()),
)
}
};
match new_docids_version_offsets.entry(external_id) {
Entry::Occupied(mut entry) => entry.get_mut().push_addition(document_offset),
Entry::Vacant(entry) => {
entry.insert(PayloadOperations::new_addition(
method,
docid,
is_new,
document_offset,
));
}
}
}
Some(payload_operations) => match new_docids_version_offsets.entry(external_id) {
Entry::Occupied(mut entry) => entry.get_mut().push_addition(document_offset),
Entry::Vacant(entry) => {
entry.insert(PayloadOperations::new_addition(
method,
payload_operations.docid,
payload_operations.is_new,
document_offset,
));
}
},
}
previous_offset = iter.byte_offset();
}
(payload.len() as u64, new_docids_version_offsets.len() as u64, Ok(new_docids_version_offsets))
}
fn extract_deletion_payload_changes<'s, 'pl: 's>(
index: &Index,
rtxn: &RoTxn,
available_docids: &mut AvailableIds,
main_docids_version_offsets: &hashbrown::HashMap<&'s str, PayloadOperations<'pl>>,
method: MergeMethod,
to_delete: &'pl [&'pl str],
) -> (u64, u64, Result<hashbrown::HashMap<&'s str, PayloadOperations<'pl>>>) {
let mut new_docids_version_offsets = hashbrown::HashMap::<&str, PayloadOperations<'pl>>::new();
let mut document_count = 0;
for external_id in to_delete {
match main_docids_version_offsets.get(external_id) {
None => {
let (docid, is_new) = match index.external_documents_ids().get(rtxn, external_id) {
Ok(Some(docid)) => (docid, false),
Ok(None) => (
match available_docids.next() {
Some(docid) => docid,
None => {
return (
0,
new_docids_version_offsets.len() as u64,
Err(UserError::DocumentLimitReached.into()),
)
}
},
true,
),
Err(e) => return (0, new_docids_version_offsets.len() as u64, Err(e.into())),
};
match new_docids_version_offsets.entry(external_id) {
Entry::Occupied(mut entry) => entry.get_mut().push_deletion(),
Entry::Vacant(entry) => {
entry.insert(PayloadOperations::new_deletion(method, docid, is_new));
}
}
}
Some(payload_operations) => match new_docids_version_offsets.entry(external_id) {
Entry::Occupied(mut entry) => entry.get_mut().push_deletion(),
Entry::Vacant(entry) => {
entry.insert(PayloadOperations::new_deletion(
method,
payload_operations.docid,
payload_operations.is_new,
));
}
},
}
document_count += 1;
}
(0, document_count, Ok(new_docids_version_offsets))
}
fn merge_version_offsets<'s, 'pl>(
main: &mut hashbrown::HashMap<&'s str, PayloadOperations<'pl>>,
new: hashbrown::HashMap<&'s str, PayloadOperations<'pl>>,
) {
// We cannot swap like nothing because documents
// operations must be in the right order.
if main.is_empty() {
return *main = new;
}
for (key, new_payload) in new {
match main.entry(key) {
Entry::Occupied(mut entry) => entry.get_mut().append_operations(new_payload.operations),
Entry::Vacant(entry) => {
entry.insert(new_payload);
}
}
}
}
impl<'pl> DocumentChanges<'pl> for DocumentOperationChanges<'pl> {
type Item = (&'pl str, PayloadOperations<'pl>);
fn iter(
&self,
chunk_size: usize,
) -> impl rayon::prelude::IndexedParallelIterator<Item = impl AsRef<[Self::Item]>> {
self.docids_version_offsets.par_chunks(chunk_size)
}
fn item_to_document_change<'doc, T: MostlySend + 'doc>(
&'doc self,
context: &'doc DocumentChangeContext<T>,
item: &'doc Self::Item,
) -> Result<Option<DocumentChange<'doc>>>
where
'pl: 'doc,
{
let (external_doc, payload_operations) = item;
payload_operations.merge_method.merge(
payload_operations.docid,
external_doc,
payload_operations.is_new,
&context.doc_alloc,
&payload_operations.operations[..],
)
}
fn len(&self) -> usize {
self.docids_version_offsets.len()
}
} }
pub struct DocumentOperationChanges<'pl> { pub struct DocumentOperationChanges<'pl> {
docids_version_offsets: &'pl [(&'pl str, PayloadOperations<'pl>)], docids_version_offsets: &'pl [(&'pl str, ((u32, bool), &'pl [InnerDocOp<'pl>]))],
index_documents_method: IndexDocumentsMethod,
} }
pub enum Payload<'pl> { pub enum Payload<'pl> {
@ -380,57 +30,8 @@ pub enum Payload<'pl> {
} }
pub struct PayloadStats { pub struct PayloadStats {
pub document_count: usize,
pub bytes: u64, pub bytes: u64,
pub document_count: u64,
pub error: Option<UserError>,
}
pub struct PayloadOperations<'pl> {
/// The internal document id of the document.
pub docid: DocumentId,
/// Wether this document is not in the current database (visible by the rtxn).
pub is_new: bool,
/// The operations to perform, in order, on this document.
pub operations: Vec<InnerDocOp<'pl>>,
/// The merge method we are using to merge payloads and documents.
merge_method: MergeMethod,
}
impl<'pl> PayloadOperations<'pl> {
fn new_deletion(merge_method: MergeMethod, docid: DocumentId, is_new: bool) -> Self {
Self { docid, is_new, operations: vec![InnerDocOp::Deletion], merge_method }
}
fn new_addition(
merge_method: MergeMethod,
docid: DocumentId,
is_new: bool,
offset: DocumentOffset<'pl>,
) -> Self {
Self { docid, is_new, operations: vec![InnerDocOp::Addition(offset)], merge_method }
}
}
impl<'pl> PayloadOperations<'pl> {
fn push_addition(&mut self, offset: DocumentOffset<'pl>) {
if self.merge_method.useless_previous_changes() {
self.operations.clear();
}
self.operations.push(InnerDocOp::Addition(offset))
}
fn push_deletion(&mut self) {
self.operations.clear();
self.operations.push(InnerDocOp::Deletion);
}
fn append_operations(&mut self, mut operations: Vec<InnerDocOp<'pl>>) {
debug_assert!(!operations.is_empty());
if self.merge_method.useless_previous_changes() {
self.operations.clear();
}
self.operations.append(&mut operations);
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -447,15 +48,213 @@ pub struct DocumentOffset<'pl> {
pub content: &'pl [u8], pub content: &'pl [u8],
} }
impl<'pl> DocumentOperation<'pl> {
pub fn new(method: IndexDocumentsMethod) -> Self {
Self { operations: Default::default(), index_documents_method: method }
}
/// TODO please give me a type
/// The payload is expected to be in the grenad format
pub fn add_documents(&mut self, payload: &'pl Mmap) -> Result<PayloadStats> {
payload.advise(memmap2::Advice::Sequential)?;
let document_count =
memchr::memmem::find_iter(&payload[..], "}{").count().saturating_add(1);
self.operations.push(Payload::Addition(&payload[..]));
Ok(PayloadStats { bytes: payload.len() as u64, document_count })
}
pub fn delete_documents(&mut self, to_delete: &'pl [&'pl str]) {
self.operations.push(Payload::Deletion(to_delete))
}
pub fn into_changes(
self,
indexer: &'pl Bump,
index: &Index,
rtxn: &RoTxn,
primary_key: &PrimaryKey,
new_fields_ids_map: &mut FieldsIdsMap,
) -> Result<DocumentOperationChanges<'pl>> {
// will contain nodes from the intermediate hashmap
let document_changes_alloc = Bump::with_capacity(1024 * 1024 * 1024); // 1 MiB
let documents_ids = index.documents_ids(rtxn)?;
let mut available_docids = AvailableIds::new(&documents_ids);
let mut docids_version_offsets =
hashbrown::HashMap::<&'pl str, _, _, _>::new_in(&document_changes_alloc);
for operation in self.operations {
match operation {
Payload::Addition(payload) => {
let mut iter =
serde_json::Deserializer::from_slice(payload).into_iter::<&RawValue>();
/// TODO manage the error
let mut previous_offset = 0;
while let Some(document) =
iter.next().transpose().map_err(UserError::SerdeJson)?
{
let external_document_id = primary_key.extract_fields_and_docid(
document,
new_fields_ids_map,
indexer,
)?;
let external_document_id = external_document_id.to_de();
let current_offset = iter.byte_offset();
let document_operation = InnerDocOp::Addition(DocumentOffset {
content: &payload[previous_offset..current_offset],
});
match docids_version_offsets.get_mut(external_document_id) {
None => {
let (docid, is_new) = match index
.external_documents_ids()
.get(rtxn, external_document_id)?
{
Some(docid) => (docid, false),
None => (
available_docids.next().ok_or(Error::UserError(
UserError::DocumentLimitReached,
))?,
true,
),
};
docids_version_offsets.insert(
external_document_id,
(
(docid, is_new),
bumpalo::vec![in indexer; document_operation],
),
);
}
Some((_, offsets)) => {
let useless_previous_addition = match self.index_documents_method {
IndexDocumentsMethod::ReplaceDocuments => {
MergeDocumentForReplacement::USELESS_PREVIOUS_CHANGES
}
IndexDocumentsMethod::UpdateDocuments => {
MergeDocumentForUpdates::USELESS_PREVIOUS_CHANGES
}
};
if useless_previous_addition {
offsets.clear();
}
offsets.push(document_operation);
}
}
previous_offset = iter.byte_offset();
}
}
Payload::Deletion(to_delete) => {
for external_document_id in to_delete {
match docids_version_offsets.get_mut(external_document_id) {
None => {
let (docid, is_new) = match index
.external_documents_ids()
.get(rtxn, external_document_id)?
{
Some(docid) => (docid, false),
None => (
available_docids.next().ok_or(Error::UserError(
UserError::DocumentLimitReached,
))?,
true,
),
};
docids_version_offsets.insert(
external_document_id,
(
(docid, is_new),
bumpalo::vec![in indexer; InnerDocOp::Deletion],
),
);
}
Some((_, offsets)) => {
offsets.clear();
offsets.push(InnerDocOp::Deletion);
}
}
}
}
}
}
// TODO We must drain the HashMap into a Vec because rayon::hash_map::IntoIter: !Clone
let mut docids_version_offsets: bumpalo::collections::vec::Vec<_> = docids_version_offsets
.drain()
.map(|(item, (docid, v))| (item, (docid, v.into_bump_slice())))
.collect_in(indexer);
// Reorder the offsets to make sure we iterate on the file sequentially
let sort_function_key = match self.index_documents_method {
Idm::ReplaceDocuments => MergeDocumentForReplacement::sort_key,
Idm::UpdateDocuments => MergeDocumentForUpdates::sort_key,
};
// And finally sort them
docids_version_offsets.sort_unstable_by_key(|(_, (_, docops))| sort_function_key(docops));
let docids_version_offsets = docids_version_offsets.into_bump_slice();
Ok(DocumentOperationChanges {
docids_version_offsets,
index_documents_method: self.index_documents_method,
})
}
}
impl<'pl> DocumentChanges<'pl> for DocumentOperationChanges<'pl> {
type Item = (&'pl str, ((u32, bool), &'pl [InnerDocOp<'pl>]));
fn iter(
&self,
chunk_size: usize,
) -> impl rayon::prelude::IndexedParallelIterator<Item = impl AsRef<[Self::Item]>> {
self.docids_version_offsets.par_chunks(chunk_size)
}
fn item_to_document_change<'doc, T: MostlySend + 'doc>(
&'doc self,
context: &'doc DocumentChangeContext<T>,
item: &'doc Self::Item,
) -> Result<Option<DocumentChange<'doc>>>
where
'pl: 'doc,
{
let document_merge_function = match self.index_documents_method {
Idm::ReplaceDocuments => MergeDocumentForReplacement::merge,
Idm::UpdateDocuments => MergeDocumentForUpdates::merge,
};
let (external_doc, ((internal_docid, is_new), operations)) = *item;
let change = document_merge_function(
internal_docid,
external_doc,
is_new,
&context.doc_alloc,
operations,
)?;
Ok(change)
}
fn len(&self) -> usize {
self.docids_version_offsets.len()
}
}
trait MergeChanges { trait MergeChanges {
/// Whether the payloads in the list of operations are useless or not. /// Whether the payloads in the list of operations are useless or not.
fn useless_previous_changes(&self) -> bool; const USELESS_PREVIOUS_CHANGES: bool;
/// Returns a key that is used to order the payloads the right way. /// Returns a key that is used to order the payloads the right way.
fn sort_key(&self, docops: &[InnerDocOp]) -> usize; fn sort_key(docops: &[InnerDocOp]) -> usize;
fn merge<'doc>( fn merge<'doc>(
&self,
docid: DocumentId, docid: DocumentId,
external_docid: &'doc str, external_docid: &'doc str,
is_new: bool, is_new: bool,
@ -464,69 +263,13 @@ trait MergeChanges {
) -> Result<Option<DocumentChange<'doc>>>; ) -> Result<Option<DocumentChange<'doc>>>;
} }
#[derive(Debug, Clone, Copy)]
enum MergeMethod {
ForReplacement(MergeDocumentForReplacement),
ForUpdates(MergeDocumentForUpdates),
}
impl MergeChanges for MergeMethod {
fn useless_previous_changes(&self) -> bool {
match self {
MergeMethod::ForReplacement(merge) => merge.useless_previous_changes(),
MergeMethod::ForUpdates(merge) => merge.useless_previous_changes(),
}
}
fn sort_key(&self, docops: &[InnerDocOp]) -> usize {
match self {
MergeMethod::ForReplacement(merge) => merge.sort_key(docops),
MergeMethod::ForUpdates(merge) => merge.sort_key(docops),
}
}
fn merge<'doc>(
&self,
docid: DocumentId,
external_docid: &'doc str,
is_new: bool,
doc_alloc: &'doc Bump,
operations: &'doc [InnerDocOp],
) -> Result<Option<DocumentChange<'doc>>> {
match self {
MergeMethod::ForReplacement(merge) => {
merge.merge(docid, external_docid, is_new, doc_alloc, operations)
}
MergeMethod::ForUpdates(merge) => {
merge.merge(docid, external_docid, is_new, doc_alloc, operations)
}
}
}
}
impl From<IndexDocumentsMethod> for MergeMethod {
fn from(method: IndexDocumentsMethod) -> Self {
match method {
IndexDocumentsMethod::ReplaceDocuments => {
MergeMethod::ForReplacement(MergeDocumentForReplacement)
}
IndexDocumentsMethod::UpdateDocuments => {
MergeMethod::ForUpdates(MergeDocumentForUpdates)
}
}
}
}
#[derive(Debug, Clone, Copy)]
struct MergeDocumentForReplacement; struct MergeDocumentForReplacement;
impl MergeChanges for MergeDocumentForReplacement { impl MergeChanges for MergeDocumentForReplacement {
fn useless_previous_changes(&self) -> bool { const USELESS_PREVIOUS_CHANGES: bool = true;
true
}
/// Reorders to read only the last change. /// Reorders to read only the last change.
fn sort_key(&self, docops: &[InnerDocOp]) -> usize { fn sort_key(docops: &[InnerDocOp]) -> usize {
let f = |ido: &_| match ido { let f = |ido: &_| match ido {
InnerDocOp::Addition(add) => Some(add.content.as_ptr() as usize), InnerDocOp::Addition(add) => Some(add.content.as_ptr() as usize),
InnerDocOp::Deletion => None, InnerDocOp::Deletion => None,
@ -538,7 +281,6 @@ impl MergeChanges for MergeDocumentForReplacement {
/// ///
/// This function is only meant to be used when doing a replacement and not an update. /// This function is only meant to be used when doing a replacement and not an update.
fn merge<'doc>( fn merge<'doc>(
&self,
docid: DocumentId, docid: DocumentId,
external_doc: &'doc str, external_doc: &'doc str,
is_new: bool, is_new: bool,
@ -579,16 +321,13 @@ impl MergeChanges for MergeDocumentForReplacement {
} }
} }
#[derive(Debug, Clone, Copy)]
struct MergeDocumentForUpdates; struct MergeDocumentForUpdates;
impl MergeChanges for MergeDocumentForUpdates { impl MergeChanges for MergeDocumentForUpdates {
fn useless_previous_changes(&self) -> bool { const USELESS_PREVIOUS_CHANGES: bool = false;
false
}
/// Reorders to read the first changes first so that it's faster to read the first one and then the rest. /// Reorders to read the first changes first so that it's faster to read the first one and then the rest.
fn sort_key(&self, docops: &[InnerDocOp]) -> usize { fn sort_key(docops: &[InnerDocOp]) -> usize {
let f = |ido: &_| match ido { let f = |ido: &_| match ido {
InnerDocOp::Addition(add) => Some(add.content.as_ptr() as usize), InnerDocOp::Addition(add) => Some(add.content.as_ptr() as usize),
InnerDocOp::Deletion => None, InnerDocOp::Deletion => None,
@ -601,7 +340,6 @@ impl MergeChanges for MergeDocumentForUpdates {
/// ///
/// This function is only meant to be used when doing an update and not a replacement. /// This function is only meant to be used when doing an update and not a replacement.
fn merge<'doc>( fn merge<'doc>(
&self,
docid: DocumentId, docid: DocumentId,
external_docid: &'doc str, external_docid: &'doc str,
is_new: bool, is_new: bool,

View File

@ -3,9 +3,9 @@ use std::sync::{OnceLock, RwLock};
use std::thread::{self, Builder}; use std::thread::{self, Builder};
use big_s::S; use big_s::S;
use document_changes::{extract, DocumentChanges, IndexingContext, Progress}; use document_changes::{extract, DocumentChanges, IndexingContext, Progress, ThreadLocal};
pub use document_deletion::DocumentDeletion; pub use document_deletion::DocumentDeletion;
pub use document_operation::{DocumentOperation, PayloadStats}; pub use document_operation::DocumentOperation;
use hashbrown::HashMap; use hashbrown::HashMap;
use heed::types::{Bytes, DecodeIgnore, Str}; use heed::types::{Bytes, DecodeIgnore, Str};
use heed::{RoTxn, RwTxn}; use heed::{RoTxn, RwTxn};
@ -20,7 +20,6 @@ use super::channel::*;
use super::extract::*; use super::extract::*;
use super::facet_search_builder::FacetSearchBuilder; use super::facet_search_builder::FacetSearchBuilder;
use super::merger::FacetFieldIdsDelta; use super::merger::FacetFieldIdsDelta;
use super::thread_local::ThreadLocal;
use super::word_fst_builder::{PrefixData, PrefixDelta, WordFstBuilder}; use super::word_fst_builder::{PrefixData, PrefixDelta, WordFstBuilder};
use super::words_prefix_docids::{ use super::words_prefix_docids::{
compute_word_prefix_docids, compute_word_prefix_fid_docids, compute_word_prefix_position_docids, compute_word_prefix_docids, compute_word_prefix_fid_docids, compute_word_prefix_position_docids,
@ -133,7 +132,6 @@ mod steps {
pub fn index<'pl, 'indexer, 'index, DC, MSP, SP>( pub fn index<'pl, 'indexer, 'index, DC, MSP, SP>(
wtxn: &mut RwTxn, wtxn: &mut RwTxn,
index: &'index Index, index: &'index Index,
grenad_parameters: GrenadParameters,
db_fields_ids_map: &'indexer FieldsIdsMap, db_fields_ids_map: &'indexer FieldsIdsMap,
new_fields_ids_map: FieldsIdsMap, new_fields_ids_map: FieldsIdsMap,
new_primary_key: Option<PrimaryKey<'pl>>, new_primary_key: Option<PrimaryKey<'pl>>,
@ -211,6 +209,16 @@ where
field_distribution.retain(|_, v| *v != 0); field_distribution.retain(|_, v| *v != 0);
const TEN_GIB: usize = 10 * 1024 * 1024 * 1024;
let current_num_threads = rayon::current_num_threads();
let max_memory = TEN_GIB / current_num_threads;
eprintln!("A maximum of {max_memory} bytes will be used for each of the {current_num_threads} threads");
let grenad_parameters = GrenadParameters {
max_memory: Some(max_memory),
..GrenadParameters::default()
};
let facet_field_ids_delta; let facet_field_ids_delta;
{ {
@ -220,8 +228,7 @@ where
let (finished_steps, step_name) = steps::extract_facets(); let (finished_steps, step_name) = steps::extract_facets();
facet_field_ids_delta = merge_and_send_facet_docids( facet_field_ids_delta = merge_and_send_facet_docids(
FacetedDocidsExtractor::run_extraction( FacetedDocidsExtractor::run_extraction(grenad_parameters,
grenad_parameters,
document_changes, document_changes,
indexing_context, indexing_context,
&mut extractor_allocs, &mut extractor_allocs,
@ -337,8 +344,7 @@ where
let (finished_steps, step_name) = steps::extract_word_proximity(); let (finished_steps, step_name) = steps::extract_word_proximity();
let caches = <WordPairProximityDocidsExtractor as DocidsExtractor>::run_extraction( let caches = <WordPairProximityDocidsExtractor as DocidsExtractor>::run_extraction(grenad_parameters,
grenad_parameters,
document_changes, document_changes,
indexing_context, indexing_context,
&mut extractor_allocs, &mut extractor_allocs,
@ -392,8 +398,7 @@ where
}; };
let datastore = ThreadLocal::with_capacity(rayon::current_num_threads()); let datastore = ThreadLocal::with_capacity(rayon::current_num_threads());
let (finished_steps, step_name) = steps::extract_geo_points(); let (finished_steps, step_name) = steps::extract_geo_points();
extract( extract(document_changes,
document_changes,
&extractor, &extractor,
indexing_context, indexing_context,
&mut extractor_allocs, &mut extractor_allocs,
@ -772,10 +777,16 @@ pub fn retrieve_or_guess_primary_key<'a>(
match primary_key_from_op { match primary_key_from_op {
// we did, and it is different from the DB one // we did, and it is different from the DB one
Some(primary_key_from_op) if primary_key_from_op != primary_key_from_db => { Some(primary_key_from_op) if primary_key_from_op != primary_key_from_db => {
// is the index empty?
if index.number_of_documents(rtxn)? == 0 {
// change primary key
(primary_key_from_op, true)
} else {
return Ok(Err(UserError::PrimaryKeyCannotBeChanged( return Ok(Err(UserError::PrimaryKeyCannotBeChanged(
primary_key_from_db.to_string(), primary_key_from_db.to_string(),
))); )));
} }
}
_ => (primary_key_from_db, false), _ => (primary_key_from_db, false),
} }
} else { } else {

View File

@ -3,12 +3,11 @@ use std::ops::DerefMut;
use rayon::iter::IndexedParallelIterator; use rayon::iter::IndexedParallelIterator;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use super::document_changes::{DocumentChangeContext, DocumentChanges}; use super::document_changes::{DocumentChangeContext, DocumentChanges, MostlySend};
use crate::documents::PrimaryKey; use crate::documents::PrimaryKey;
use crate::update::concurrent_available_ids::ConcurrentAvailableIds; use crate::update::concurrent_available_ids::ConcurrentAvailableIds;
use crate::update::new::document::Versions; use crate::update::new::document::Versions;
use crate::update::new::ref_cell_ext::RefCellExt as _; use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::thread_local::MostlySend;
use crate::update::new::{DocumentChange, Insertion}; use crate::update::new::{DocumentChange, Insertion};
use crate::{Error, InternalError, Result, UserError}; use crate::{Error, InternalError, Result, UserError};

View File

@ -4,14 +4,13 @@ use rayon::slice::ParallelSlice as _;
use rhai::{Dynamic, Engine, OptimizationLevel, Scope, AST}; use rhai::{Dynamic, Engine, OptimizationLevel, Scope, AST};
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use super::document_changes::DocumentChangeContext; use super::document_changes::{DocumentChangeContext, MostlySend};
use super::DocumentChanges; use super::DocumentChanges;
use crate::documents::Error::InvalidDocumentFormat; use crate::documents::Error::InvalidDocumentFormat;
use crate::documents::PrimaryKey; use crate::documents::PrimaryKey;
use crate::error::{FieldIdMapMissingEntry, InternalError}; use crate::error::{FieldIdMapMissingEntry, InternalError};
use crate::update::new::document::Versions; use crate::update::new::document::Versions;
use crate::update::new::ref_cell_ext::RefCellExt as _; use crate::update::new::ref_cell_ext::RefCellExt as _;
use crate::update::new::thread_local::MostlySend;
use crate::update::new::{Deletion, DocumentChange, KvReaderFieldId, Update}; use crate::update::new::{Deletion, DocumentChange, KvReaderFieldId, Update};
use crate::{all_obkv_to_json, Error, FieldsIdsMap, Object, Result, UserError}; use crate::{all_obkv_to_json, Error, FieldsIdsMap, Object, Result, UserError};

View File

@ -17,7 +17,6 @@ pub mod indexer;
mod merger; mod merger;
mod parallel_iterator_ext; mod parallel_iterator_ext;
mod ref_cell_ext; mod ref_cell_ext;
pub(crate) mod thread_local;
mod top_level_map; mod top_level_map;
pub mod vector_document; pub mod vector_document;
mod word_fst_builder; mod word_fst_builder;

View File

@ -1,16 +1,37 @@
use std::cell::{RefCell, RefMut}; use std::cell::{Ref, RefCell, RefMut};
pub trait RefCellExt<T: ?Sized> { pub trait RefCellExt<T: ?Sized> {
fn try_borrow_or_yield(&self) -> std::result::Result<Ref<'_, T>, std::cell::BorrowError>;
fn try_borrow_mut_or_yield( fn try_borrow_mut_or_yield(
&self, &self,
) -> std::result::Result<RefMut<'_, T>, std::cell::BorrowMutError>; ) -> std::result::Result<RefMut<'_, T>, std::cell::BorrowMutError>;
fn borrow_or_yield(&self) -> Ref<'_, T> {
self.try_borrow_or_yield().unwrap()
}
fn borrow_mut_or_yield(&self) -> RefMut<'_, T> { fn borrow_mut_or_yield(&self) -> RefMut<'_, T> {
self.try_borrow_mut_or_yield().unwrap() self.try_borrow_mut_or_yield().unwrap()
} }
} }
impl<T: ?Sized> RefCellExt<T> for RefCell<T> { impl<T: ?Sized> RefCellExt<T> for RefCell<T> {
fn try_borrow_or_yield(&self) -> std::result::Result<Ref<'_, T>, std::cell::BorrowError> {
/// TODO: move this trait and impl elsewhere
loop {
match self.try_borrow() {
Ok(borrow) => break Ok(borrow),
Err(error) => {
tracing::warn!("dynamic borrow failed, yielding to local tasks");
match rayon::yield_local() {
Some(rayon::Yield::Executed) => continue,
_ => return Err(error),
}
}
}
}
}
fn try_borrow_mut_or_yield( fn try_borrow_mut_or_yield(
&self, &self,
) -> std::result::Result<RefMut<'_, T>, std::cell::BorrowMutError> { ) -> std::result::Result<RefMut<'_, T>, std::cell::BorrowMutError> {

View File

@ -1,174 +0,0 @@
use std::cell::RefCell;
/// A trait for types that are **not** [`Send`] only because they would then allow concurrent access to a type that is not [`Sync`].
///
/// The primary example of such a type is `&T`, with `T: !Sync`.
///
/// In the authors' understanding, a type can be `!Send` for two distinct reasons:
///
/// 1. Because it contains data that *genuinely* cannot be moved between threads, such as thread-local data.
/// 2. Because sending the type would allow concurrent access to a `!Sync` type, which is undefined behavior.
///
/// `MostlySend` exists to be used in bounds where you need a type whose data is **not** *attached* to a thread
/// because you might access it from a different thread, but where you will never access the type **concurrently** from
/// multiple threads.
///
/// Like [`Send`], `MostlySend` assumes properties on types that cannot be verified by the compiler, which is why implementing
/// this trait is unsafe.
///
/// # Safety
///
/// Implementers of this trait promises that the following properties hold on the implementing type:
///
/// 1. Its data can be accessed from any thread and will be the same regardless of the thread accessing it.
/// 2. Any operation that can be performed on the type does not depend on the thread that executes it.
///
/// As these properties are subtle and are not generally tracked by the Rust type system, great care should be taken before
/// implementing `MostlySend` on a type, especially a foreign type.
///
/// - An example of a type that verifies (1) and (2) is [`std::rc::Rc`] (when `T` is `Send` and `Sync`).
/// - An example of a type that doesn't verify (1) is thread-local data.
/// - An example of a type that doesn't verify (2) is [`std::sync::MutexGuard`]: a lot of mutex implementations require that
/// a lock is returned to the operating system on the same thread that initially locked the mutex, failing to uphold this
/// invariant will cause Undefined Behavior
/// (see last § in [the nomicon](https://doc.rust-lang.org/nomicon/send-and-sync.html)).
///
/// It is **always safe** to implement this trait on a type that is `Send`, but no placeholder impl is provided due to limitations in
/// coherency. Use the [`FullySend`] wrapper in this situation.
pub unsafe trait MostlySend {}
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct FullySend<T>(pub T);
// SAFETY: a type **fully** send is always mostly send as well.
unsafe impl<T> MostlySend for FullySend<T> where T: Send {}
unsafe impl<T> MostlySend for RefCell<T> where T: MostlySend {}
unsafe impl<T> MostlySend for Option<T> where T: MostlySend {}
impl<T> FullySend<T> {
pub fn into(self) -> T {
self.0
}
}
impl<T> From<T> for FullySend<T> {
fn from(value: T) -> Self {
Self(value)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct MostlySendWrapper<T>(T);
impl<T: MostlySend> MostlySendWrapper<T> {
/// # Safety
///
/// - (P1) Users of this type will never access the type concurrently from multiple threads without synchronization
unsafe fn new(t: T) -> Self {
Self(t)
}
fn as_ref(&self) -> &T {
&self.0
}
fn as_mut(&mut self) -> &mut T {
&mut self.0
}
fn into_inner(self) -> T {
self.0
}
}
/// # Safety
///
/// 1. `T` is [`MostlySend`], so by its safety contract it can be accessed by any thread and all of its operations are available
/// from any thread.
/// 2. (P1) of `MostlySendWrapper::new` forces the user to never access the value from multiple threads concurrently.
unsafe impl<T: MostlySend> Send for MostlySendWrapper<T> {}
/// A wrapper around [`thread_local::ThreadLocal`] that accepts [`MostlySend`] `T`s.
#[derive(Default)]
pub struct ThreadLocal<T: MostlySend> {
inner: thread_local::ThreadLocal<MostlySendWrapper<T>>,
// FIXME: this should be necessary
//_no_send: PhantomData<*mut ()>,
}
impl<T: MostlySend> ThreadLocal<T> {
pub fn new() -> Self {
Self { inner: thread_local::ThreadLocal::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self { inner: thread_local::ThreadLocal::with_capacity(capacity) }
}
pub fn clear(&mut self) {
self.inner.clear()
}
pub fn get(&self) -> Option<&T> {
self.inner.get().map(|t| t.as_ref())
}
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> T,
{
self.inner.get_or(|| unsafe { MostlySendWrapper::new(create()) }).as_ref()
}
pub fn get_or_try<F, E>(&self, create: F) -> std::result::Result<&T, E>
where
F: FnOnce() -> std::result::Result<T, E>,
{
self.inner
.get_or_try(|| unsafe { Ok(MostlySendWrapper::new(create()?)) })
.map(MostlySendWrapper::as_ref)
}
pub fn get_or_default(&self) -> &T
where
T: Default,
{
self.inner.get_or_default().as_ref()
}
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut(self.inner.iter_mut())
}
}
impl<T: MostlySend> IntoIterator for ThreadLocal<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter(self.inner.into_iter())
}
}
pub struct IterMut<'a, T: MostlySend>(thread_local::IterMut<'a, MostlySendWrapper<T>>);
impl<'a, T: MostlySend> Iterator for IterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|t| t.as_mut())
}
}
pub struct IntoIter<T: MostlySend>(thread_local::IntoIter<MostlySendWrapper<T>>);
impl<T: MostlySend> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.0.next().map(|t| t.into_inner())
}
}

View File

@ -60,7 +60,7 @@ pub enum EmbedErrorKind {
ManualEmbed(String), ManualEmbed(String),
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))] #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))]
OllamaModelNotFoundError(Option<String>), OllamaModelNotFoundError(Option<String>),
#[error("error deserializing the response body as JSON:\n - {0}")] #[error("error deserialization the response body as JSON:\n - {0}")]
RestResponseDeserialization(std::io::Error), RestResponseDeserialization(std::io::Error),
#[error("expected a response containing {0} embeddings, got only {1}")] #[error("expected a response containing {0} embeddings, got only {1}")]
RestResponseEmbeddingCount(usize, usize), RestResponseEmbeddingCount(usize, usize),

View File

@ -1,6 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
use arroy::distances::{BinaryQuantizedCosine, Cosine}; use arroy::distances::{BinaryQuantizedCosine, Cosine};
use arroy::ItemId; use arroy::ItemId;
@ -596,26 +595,18 @@ impl Embedder {
/// Embed one or multiple texts. /// Embed one or multiple texts.
/// ///
/// Each text can be embedded as one or multiple embeddings. /// Each text can be embedded as one or multiple embeddings.
pub fn embed( pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
&self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline), Embedder::OpenAi(embedder) => embedder.embed(&texts),
Embedder::Ollama(embedder) => embedder.embed(&texts, deadline), Embedder::Ollama(embedder) => embedder.embed(&texts),
Embedder::UserProvided(embedder) => embedder.embed(&texts), Embedder::UserProvided(embedder) => embedder.embed(&texts),
Embedder::Rest(embedder) => embedder.embed(texts, deadline), Embedder::Rest(embedder) => embedder.embed(texts),
} }
} }
pub fn embed_one( pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
&self, let mut embedding = self.embed(vec![text])?;
text: String,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
let mut embedding = self.embed(vec![text], deadline)?;
let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?; let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?;
Ok(embedding) Ok(embedding)
} }

View File

@ -1,5 +1,3 @@
use std::time::Instant;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use rayon::slice::ParallelSlice as _; use rayon::slice::ParallelSlice as _;
@ -82,9 +80,8 @@ impl Embedder {
pub fn embed<S: AsRef<str> + serde::Serialize>( pub fn embed<S: AsRef<str> + serde::Serialize>(
&self, &self,
texts: &[S], texts: &[S],
deadline: Option<Instant>,
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline) { match self.rest_embedder.embed_ref(texts) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error)) Err(EmbedError::ollama_model_not_found(error))
@ -100,7 +97,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embedding>>, EmbedError> { ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -117,7 +114,7 @@ impl Embedder {
.install(move || { .install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint()) .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None)) .map(move |chunk| self.embed(chunk))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;

View File

@ -1,5 +1,3 @@
use std::time::Instant;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use rayon::slice::ParallelSlice as _; use rayon::slice::ParallelSlice as _;
@ -213,23 +211,18 @@ impl Embedder {
pub fn embed<S: AsRef<str> + serde::Serialize>( pub fn embed<S: AsRef<str> + serde::Serialize>(
&self, &self,
texts: &[S], texts: &[S],
deadline: Option<Instant>,
) -> Result<Vec<Embedding>, EmbedError> { ) -> Result<Vec<Embedding>, EmbedError> {
match self.rest_embedder.embed_ref(texts, deadline) { match self.rest_embedder.embed_ref(texts) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
self.try_embed_tokenized(texts, deadline) self.try_embed_tokenized(texts)
} }
Err(error) => Err(error), Err(error) => Err(error),
} }
} }
fn try_embed_tokenized<S: AsRef<str>>( fn try_embed_tokenized<S: AsRef<str>>(&self, text: &[S]) -> Result<Vec<Embedding>, EmbedError> {
&self,
text: &[S],
deadline: Option<Instant>,
) -> Result<Vec<Embedding>, EmbedError> {
let mut all_embeddings = Vec::with_capacity(text.len()); let mut all_embeddings = Vec::with_capacity(text.len());
for text in text { for text in text {
let text = text.as_ref(); let text = text.as_ref();
@ -237,13 +230,13 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text); let encoded = self.tokenizer.encode_ordinary(text);
let len = encoded.len(); let len = encoded.len();
if len < max_token_count { if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?); all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
continue; continue;
} }
let tokens = &encoded.as_slice()[0..max_token_count]; let tokens = &encoded.as_slice()[0..max_token_count];
let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?; let embedding = self.rest_embedder.embed_tokens(tokens)?;
all_embeddings.push(embedding); all_embeddings.push(embedding);
} }
@ -257,7 +250,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embedding>>, EmbedError> { ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -274,7 +267,7 @@ impl Embedder {
.install(move || { .install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint()) .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed(chunk, None)) .map(move |chunk| self.embed(chunk))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;

View File

@ -1,5 +1,4 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::time::Instant;
use deserr::Deserr; use deserr::Deserr;
use rand::Rng; use rand::Rng;
@ -154,31 +153,19 @@ impl Embedder {
Ok(Self { data, dimensions, distribution: options.distribution }) Ok(Self { data, dimensions, distribution: options.distribution })
} }
pub fn embed( pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embedding>, EmbedError> {
&self, embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embedding>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
} }
pub fn embed_ref<S>( pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embedding>, EmbedError>
&self,
texts: &[S],
deadline: Option<Instant>,
) -> Result<Vec<Embedding>, EmbedError>
where where
S: AsRef<str> + Serialize, S: AsRef<str> + Serialize,
{ {
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline) embed(&self.data, texts, texts.len(), Some(self.dimensions))
} }
pub fn embed_tokens( pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, EmbedError> {
&self, let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
tokens: &[usize],
deadline: Option<Instant>,
) -> Result<Embedding, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap()) Ok(embeddings.pop().unwrap())
} }
@ -190,7 +177,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embedding>>, EmbedError> { ) -> Result<Vec<Vec<Embedding>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -207,7 +194,7 @@ impl Embedder {
.install(move || { .install(move || {
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
.par_chunks(self.prompt_count_in_chunk_hint()) .par_chunks(self.prompt_count_in_chunk_hint())
.map(move |chunk| self.embed_ref(chunk, None)) .map(move |chunk| self.embed_ref(chunk))
.collect(); .collect();
let embeddings = embeddings?; let embeddings = embeddings?;
@ -240,7 +227,7 @@ impl Embedder {
} }
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> { fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
let v = embed(data, ["test"].as_slice(), 1, None, None) let v = embed(data, ["test"].as_slice(), 1, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().len()) Ok(v.first().unwrap().len())
@ -251,7 +238,6 @@ fn embed<S>(
inputs: &[S], inputs: &[S],
expected_count: usize, expected_count: usize,
expected_dimension: Option<usize>, expected_dimension: Option<usize>,
deadline: Option<Instant>,
) -> Result<Vec<Embedding>, EmbedError> ) -> Result<Vec<Embedding>, EmbedError>
where where
S: Serialize, S: Serialize,
@ -271,27 +257,16 @@ where
for attempt in 0..10 { for attempt in 0..10 {
let response = request.clone().send_json(&body); let response = request.clone().send_json(&body);
let result = check_response(response, data.configuration_source).and_then(|response| { let result = check_response(response, data.configuration_source);
response_to_embedding(response, data, expected_count, expected_dimension)
});
let retry_duration = match result { let retry_duration = match result {
Ok(response) => return Ok(response), Ok(response) => {
return response_to_embedding(response, data, expected_count, expected_dimension)
}
Err(retry) => { Err(retry) => {
tracing::warn!("Failed: {}", retry.error); tracing::warn!("Failed: {}", retry.error);
if let Some(deadline) = deadline {
let now = std::time::Instant::now();
if now > deadline {
tracing::warn!("Could not embed due to deadline");
return Err(retry.into_error());
}
let duration_to_deadline = deadline - now;
retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
} else {
retry.into_duration(attempt) retry.into_duration(attempt)
} }
}
}?; }?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
@ -308,7 +283,6 @@ where
let result = check_response(response, data.configuration_source); let result = check_response(response, data.configuration_source);
result.map_err(Retry::into_error).and_then(|response| { result.map_err(Retry::into_error).and_then(|response| {
response_to_embedding(response, data, expected_count, expected_dimension) response_to_embedding(response, data, expected_count, expected_dimension)
.map_err(Retry::into_error)
}) })
} }
@ -350,28 +324,20 @@ fn response_to_embedding(
data: &EmbedderData, data: &EmbedderData,
expected_count: usize, expected_count: usize,
expected_dimensions: Option<usize>, expected_dimensions: Option<usize>,
) -> Result<Vec<Embedding>, Retry> { ) -> Result<Vec<Embedding>, EmbedError> {
let response: serde_json::Value = response let response: serde_json::Value =
.into_json() response.into_json().map_err(EmbedError::rest_response_deserialization)?;
.map_err(EmbedError::rest_response_deserialization)
.map_err(Retry::retry_later)?;
let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?; let embeddings = data.response.extract_embeddings(response)?;
if embeddings.len() != expected_count { if embeddings.len() != expected_count {
return Err(Retry::give_up(EmbedError::rest_response_embedding_count( return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
expected_count,
embeddings.len(),
)));
} }
if let Some(dimensions) = expected_dimensions { if let Some(dimensions) = expected_dimensions {
for embedding in &embeddings { for embedding in &embeddings {
if embedding.len() != dimensions { if embedding.len() != dimensions {
return Err(Retry::give_up(EmbedError::rest_unexpected_dimension( return Err(EmbedError::rest_unexpected_dimension(dimensions, embedding.len()));
dimensions,
embedding.len(),
)));
} }
} }
} }