diff --git a/index-scheduler/src/batch.rs b/index-scheduler/src/batch.rs index fdf213a6b..60393e51d 100644 --- a/index-scheduler/src/batch.rs +++ b/index-scheduler/src/batch.rs @@ -1300,6 +1300,8 @@ impl IndexScheduler { let mut content_files_iter = content_files.iter(); let mut indexer = indexer::DocumentOperation::new(method); + let embedders = index.embedding_configs(index_wtxn)?; + let embedders = self.embedders(embedders)?; for (operation, task) in operations.into_iter().zip(tasks.iter_mut()) { match operation { DocumentOperation::Add(_content_uuid) => { @@ -1374,6 +1376,7 @@ impl IndexScheduler { primary_key_has_been_set.then_some(primary_key), &pool, &document_changes, + embedders, )?; // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); @@ -1460,6 +1463,8 @@ impl IndexScheduler { let indexer = UpdateByFunction::new(candidates, context.clone(), code.clone()); let document_changes = indexer.into_changes(&primary_key)?; + let embedders = index.embedding_configs(index_wtxn)?; + let embedders = self.embedders(embedders)?; indexer::index( index_wtxn, @@ -1469,6 +1474,7 @@ impl IndexScheduler { None, // cannot change primary key in DocumentEdition &pool, &document_changes, + embedders, )?; // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); @@ -1596,6 +1602,8 @@ impl IndexScheduler { let mut indexer = indexer::DocumentDeletion::new(); indexer.delete_documents_by_docids(to_delete); let document_changes = indexer.into_changes(&indexer_alloc, primary_key); + let embedders = index.embedding_configs(index_wtxn)?; + let embedders = self.embedders(embedders)?; indexer::index( index_wtxn, @@ -1605,6 +1613,7 @@ impl IndexScheduler { None, // document deletion never changes primary key &pool, &document_changes, + embedders, )?; // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); diff --git a/milli/src/update/new/channel.rs b/milli/src/update/new/channel.rs index 657c00141..92f692a88 100644 --- a/milli/src/update/new/channel.rs +++ b/milli/src/update/new/channel.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use crossbeam_channel::{IntoIter, Receiver, SendError, Sender}; use grenad::Merger; +use hashbrown::HashMap; use heed::types::Bytes; use memmap2::Mmap; use roaring::RoaringBitmap; @@ -124,7 +125,32 @@ impl DocumentDeletionEntry { } } -pub struct WriterOperation { +pub enum WriterOperation { + DbOperation(DbOperation), + ArroyOperation(ArroyOperation), +} + +pub enum ArroyOperation { + /// TODO: call when deleting regular documents + DeleteVectors { + docid: DocumentId, + }, + SetVectors { + docid: DocumentId, + embedder_id: u8, + embeddings: Vec, + }, + SetVector { + docid: DocumentId, + embedder_id: u8, + embedding: Embedding, + }, + Finish { + user_provided: HashMap, + }, +} + +pub struct DbOperation { database: Database, entry: EntryOperation, } @@ -180,7 +206,7 @@ impl From for Database { } } -impl WriterOperation { +impl DbOperation { pub fn database(&self, index: &Index) -> heed::Database { self.database.database(index) } @@ -246,13 +272,13 @@ impl MergerSender { DOCUMENTS_IDS_KEY.as_bytes(), documents_ids, )); - match self.send(WriterOperation { database: Database::Main, entry }) { + match self.send_db_operation(DbOperation { database: Database::Main, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } } - fn send(&self, op: WriterOperation) -> StdResult<(), SendError<()>> { + fn send_db_operation(&self, op: DbOperation) -> StdResult<(), SendError<()>> { if self.sender.is_full() { self.writer_contentious_count.set(self.writer_contentious_count.get() + 1); } @@ -260,7 +286,7 @@ impl MergerSender { self.merger_contentious_count.set(self.merger_contentious_count.get() + 1); } self.send_count.set(self.send_count.get() + 1); - match self.sender.send(op) { + match self.sender.send(WriterOperation::DbOperation(op)) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -275,7 +301,7 @@ impl MainSender<'_> { WORDS_FST_KEY.as_bytes(), value, )); - match self.0.send(WriterOperation { database: Database::Main, entry }) { + match self.0.send_db_operation(DbOperation { database: Database::Main, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -286,7 +312,7 @@ impl MainSender<'_> { WORDS_PREFIXES_FST_KEY.as_bytes(), value, )); - match self.0.send(WriterOperation { database: Database::Main, entry }) { + match self.0.send_db_operation(DbOperation { database: Database::Main, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -294,7 +320,7 @@ impl MainSender<'_> { pub fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> { let entry = EntryOperation::Delete(KeyEntry::from_key(key)); - match self.0.send(WriterOperation { database: Database::Main, entry }) { + match self.0.send_db_operation(DbOperation { database: Database::Main, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -396,7 +422,7 @@ pub struct WordDocidsSender<'a, D> { impl DocidsSender for WordDocidsSender<'_, D> { fn write(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>> { let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)); - match self.sender.send(WriterOperation { database: D::DATABASE, entry }) { + match self.sender.send_db_operation(DbOperation { database: D::DATABASE, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -404,7 +430,7 @@ impl DocidsSender for WordDocidsSender<'_, D> { fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> { let entry = EntryOperation::Delete(KeyEntry::from_key(key)); - match self.sender.send(WriterOperation { database: D::DATABASE, entry }) { + match self.sender.send_db_operation(DbOperation { database: D::DATABASE, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -429,7 +455,7 @@ impl DocidsSender for FacetDocidsSender<'_> { } _ => EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)), }; - match self.sender.send(WriterOperation { database, entry }) { + match self.sender.send_db_operation(DbOperation { database, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -439,7 +465,7 @@ impl DocidsSender for FacetDocidsSender<'_> { let (facet_kind, key) = FacetKind::extract_from_key(key); let database = Database::from(facet_kind); let entry = EntryOperation::Delete(KeyEntry::from_key(key)); - match self.sender.send(WriterOperation { database, entry }) { + match self.sender.send_db_operation(DbOperation { database, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -460,7 +486,7 @@ impl DocumentsSender<'_> { &docid.to_be_bytes(), document.as_bytes(), )); - match self.0.send(WriterOperation { database: Database::Documents, entry }) { + match self.0.send_db_operation(DbOperation { database: Database::Documents, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), }?; @@ -469,7 +495,10 @@ impl DocumentsSender<'_> { external_id.as_bytes(), &docid.to_be_bytes(), )); - match self.0.send(WriterOperation { database: Database::ExternalDocumentsIds, entry }) { + match self + .0 + .send_db_operation(DbOperation { database: Database::ExternalDocumentsIds, entry }) + { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } @@ -477,33 +506,38 @@ impl DocumentsSender<'_> { pub fn delete(&self, docid: DocumentId, external_id: String) -> StdResult<(), SendError<()>> { let entry = EntryOperation::Delete(KeyEntry::from_key(&docid.to_be_bytes())); - match self.0.send(WriterOperation { database: Database::Documents, entry }) { + match self.0.send_db_operation(DbOperation { database: Database::Documents, entry }) { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), }?; let entry = EntryOperation::Delete(KeyEntry::from_key(external_id.as_bytes())); - match self.0.send(WriterOperation { database: Database::ExternalDocumentsIds, entry }) { + match self + .0 + .send_db_operation(DbOperation { database: Database::ExternalDocumentsIds, entry }) + { Ok(()) => Ok(()), Err(SendError(_)) => Err(SendError(())), } } } -pub struct EmbeddingSender<'a>(Option<&'a Sender>); +pub struct EmbeddingSender<'a>(&'a Sender); impl EmbeddingSender<'_> { - pub fn delete(&self, docid: DocumentId, embedder_id: u8) -> StdResult<(), SendError<()>> { - todo!() - } - pub fn set_vectors( &self, docid: DocumentId, embedder_id: u8, embeddings: Vec, ) -> StdResult<(), SendError<()>> { - todo!() + self.0 + .send(WriterOperation::ArroyOperation(ArroyOperation::SetVectors { + docid, + embedder_id, + embeddings, + })) + .map_err(|_| SendError(())) } pub fn set_vector( @@ -512,19 +546,24 @@ impl EmbeddingSender<'_> { embedder_id: u8, embedding: Embedding, ) -> StdResult<(), SendError<()>> { - todo!() + self.0 + .send(WriterOperation::ArroyOperation(ArroyOperation::SetVector { + docid, + embedder_id, + embedding, + })) + .map_err(|_| SendError(())) } - pub fn set_user_provided( - &self, - docid: DocumentId, - embedder_id: u8, - regenerate: bool, + /// Marks all embedders as "to be built" + pub fn finish( + self, + user_provided: HashMap, ) -> StdResult<(), SendError<()>> { - todo!() + self.0 + .send(WriterOperation::ArroyOperation(ArroyOperation::Finish { user_provided })) + .map_err(|_| SendError(())) } - - pub fn finish(self, embedder_id: u8) {} } pub enum MergerOperation { ExactWordDocidsMerger(Merger), diff --git a/milli/src/update/new/document.rs b/milli/src/update/new/document.rs index 0a5172d36..068268c4e 100644 --- a/milli/src/update/new/document.rs +++ b/milli/src/update/new/document.rs @@ -4,7 +4,7 @@ use heed::RoTxn; use raw_collections::RawMap; use serde_json::value::RawValue; -use super::vector_document::{VectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions}; +use super::vector_document::VectorDocument; use super::{KvReaderFieldId, KvWriterFieldId}; use crate::documents::FieldIdMapper; use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; diff --git a/milli/src/update/new/extract/cache.rs b/milli/src/update/new/extract/cache.rs index 2fbe427f3..cbb42af8b 100644 --- a/milli/src/update/new/extract/cache.rs +++ b/milli/src/update/new/extract/cache.rs @@ -267,7 +267,7 @@ impl Stats { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct DelAddRoaringBitmap { pub(crate) del: Option, pub(crate) add: Option, diff --git a/milli/src/update/new/extract/mod.rs b/milli/src/update/new/extract/mod.rs index 5a63dccfa..8a18eb074 100644 --- a/milli/src/update/new/extract/mod.rs +++ b/milli/src/update/new/extract/mod.rs @@ -11,6 +11,7 @@ use bumpalo::Bump; pub use faceted::*; use grenad::Merger; pub use searchable::*; +pub use vectors::EmbeddingExtractor; use super::indexer::document_changes::{DocumentChanges, FullySend, IndexingContext, ThreadLocal}; use crate::update::{GrenadParameters, MergeDeladdCboRoaringBitmaps}; diff --git a/milli/src/update/new/extract/vectors/mod.rs b/milli/src/update/new/extract/vectors/mod.rs index a2762ae7a..96b03a25b 100644 --- a/milli/src/update/new/extract/vectors/mod.rs +++ b/milli/src/update/new/extract/vectors/mod.rs @@ -1,3 +1,10 @@ +use std::cell::RefCell; + +use bumpalo::collections::Vec as BVec; +use bumpalo::Bump; +use hashbrown::HashMap; + +use super::cache::DelAddRoaringBitmap; use crate::error::FaultSource; use crate::prompt::Prompt; use crate::update::new::channel::EmbeddingSender; @@ -5,26 +12,34 @@ use crate::update::new::indexer::document_changes::{Extractor, FullySend}; use crate::update::new::vector_document::VectorDocument; use crate::update::new::DocumentChange; use crate::vector::error::EmbedErrorKind; -use crate::vector::Embedder; -use crate::{DocumentId, Result, ThreadPoolNoAbort, UserError}; +use crate::vector::{Embedder, Embedding, EmbeddingConfigs}; +use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, UserError}; pub struct EmbeddingExtractor<'a> { - embedder: &'a Embedder, - prompt: &'a Prompt, - embedder_id: u8, - embedder_name: &'a str, + embedders: &'a EmbeddingConfigs, sender: &'a EmbeddingSender<'a>, threads: &'a ThreadPoolNoAbort, } +impl<'a> EmbeddingExtractor<'a> { + pub fn new( + embedders: &'a EmbeddingConfigs, + sender: &'a EmbeddingSender<'a>, + threads: &'a ThreadPoolNoAbort, + ) -> Self { + Self { embedders, sender, threads } + } +} + impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { - type Data = FullySend<()>; + type Data = FullySend>>; fn init_data<'doc>( &'doc self, _extractor_alloc: raw_collections::alloc::RefBump<'extractor>, ) -> crate::Result { - Ok(FullySend(())) + /// TODO: use the extractor_alloc in the hashbrown once you merge the branch where it is no longer a RefBump + Ok(FullySend(Default::default())) } fn process<'doc>( @@ -34,63 +49,90 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { Self::Data, >, ) -> crate::Result<()> { - let embedder_name: &str = self.embedder_name; - let embedder: &Embedder = self.embedder; - let prompt: &Prompt = self.prompt; + let embedders = self.embedders.inner_as_ref(); - let mut chunks = Chunks::new( - embedder, - self.embedder_id, - embedder_name, - self.threads, - self.sender, - &context.doc_alloc, - ); + let mut all_chunks = BVec::with_capacity_in(embedders.len(), &context.doc_alloc); + for (embedder_name, (embedder, prompt, _is_quantized)) in embedders { + let embedder_id = + context.index.embedder_category_id.get(&context.txn, embedder_name)?.ok_or_else( + || InternalError::DatabaseMissingEntry { + db_name: "embedder_category_id", + key: None, + }, + )?; + all_chunks.push(Chunks::new( + embedder, + embedder_id, + embedder_name, + prompt, + &context.data.0, + self.threads, + self.sender, + &context.doc_alloc, + )) + } for change in changes { let change = change?; match change { - DocumentChange::Deletion(deletion) => { - self.sender.delete(deletion.docid(), self.embedder_id).unwrap(); + DocumentChange::Deletion(_deletion) => { + // handled by document sender } DocumentChange::Update(update) => { - /// FIXME: this will force the parsing/retrieval of VectorDocument once per embedder - /// consider doing all embedders at once? let old_vectors = update.current_vectors( &context.txn, context.index, context.db_fields_ids_map, &context.doc_alloc, )?; - let old_vectors = old_vectors.vectors_for_key(embedder_name)?.unwrap(); let new_vectors = update.updated_vectors(&context.doc_alloc)?; - if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { - new_vectors.vectors_for_key(embedder_name).transpose() - }) { - let new_vectors = new_vectors?; - match (old_vectors.regenerate, new_vectors.regenerate) { - (true, true) | (false, false) => todo!(), - _ => { - self.sender - .set_user_provided( - update.docid(), - self.embedder_id, - !new_vectors.regenerate, - ) - .unwrap(); + + for chunks in &mut all_chunks { + let embedder_name = chunks.embedder_name(); + let prompt = chunks.prompt(); + + let old_vectors = old_vectors.vectors_for_key(embedder_name)?.unwrap(); + if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { + new_vectors.vectors_for_key(embedder_name).transpose() + }) { + let new_vectors = new_vectors?; + match (old_vectors.regenerate, new_vectors.regenerate) { + (true, true) | (false, false) => todo!(), + _ => { + chunks.set_regenerate(update.docid(), new_vectors.regenerate); + } } - } - // do we have set embeddings? - if let Some(embeddings) = new_vectors.embeddings { - self.sender - .set_vectors( + // do we have set embeddings? + if let Some(embeddings) = new_vectors.embeddings { + chunks.set_vectors( update.docid(), - self.embedder_id, embeddings.into_vec().map_err(UserError::SerdeJson)?, - ) - .unwrap(); - } else if new_vectors.regenerate { - let new_rendered = prompt.render_document( + ); + } else if new_vectors.regenerate { + let new_rendered = prompt.render_document( + update.current( + &context.txn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + let old_rendered = prompt.render_document( + update.merged( + &context.txn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + if new_rendered != old_rendered { + chunks.set_autogenerated(update.docid(), new_rendered)?; + } + } + } else if old_vectors.regenerate { + let old_rendered = prompt.render_document( update.current( &context.txn, context.index, @@ -99,7 +141,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { context.new_fields_ids_map, &context.doc_alloc, )?; - let old_rendered = prompt.render_document( + let new_rendered = prompt.render_document( update.merged( &context.txn, context.index, @@ -109,82 +151,55 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { &context.doc_alloc, )?; if new_rendered != old_rendered { - chunks.push(update.docid(), new_rendered)?; + chunks.set_autogenerated(update.docid(), new_rendered)?; } } - } else if old_vectors.regenerate { - let old_rendered = prompt.render_document( - update.current( - &context.txn, - context.index, - context.db_fields_ids_map, - )?, - context.new_fields_ids_map, - &context.doc_alloc, - )?; - let new_rendered = prompt.render_document( - update.merged( - &context.txn, - context.index, - context.db_fields_ids_map, - )?, - context.new_fields_ids_map, - &context.doc_alloc, - )?; - if new_rendered != old_rendered { - chunks.push(update.docid(), new_rendered)?; - } } } DocumentChange::Insertion(insertion) => { - // if no inserted vectors, then regenerate: true + no embeddings => autogenerate - let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?; - if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { - new_vectors.vectors_for_key(embedder_name).transpose() - }) { - let new_vectors = new_vectors?; - self.sender - .set_user_provided( - insertion.docid(), - self.embedder_id, - !new_vectors.regenerate, - ) - .unwrap(); - if let Some(embeddings) = new_vectors.embeddings { - self.sender - .set_vectors( + for chunks in &mut all_chunks { + let embedder_name = chunks.embedder_name(); + let prompt = chunks.prompt(); + // if no inserted vectors, then regenerate: true + no embeddings => autogenerate + let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?; + if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { + new_vectors.vectors_for_key(embedder_name).transpose() + }) { + let new_vectors = new_vectors?; + chunks.set_regenerate(insertion.docid(), new_vectors.regenerate); + if let Some(embeddings) = new_vectors.embeddings { + chunks.set_vectors( insertion.docid(), - self.embedder_id, embeddings.into_vec().map_err(UserError::SerdeJson)?, - ) - .unwrap(); - } else if new_vectors.regenerate { + ); + } else if new_vectors.regenerate { + let rendered = prompt.render_document( + insertion.inserted(), + context.new_fields_ids_map, + &context.doc_alloc, + )?; + chunks.set_autogenerated(insertion.docid(), rendered)?; + } + } else { let rendered = prompt.render_document( insertion.inserted(), context.new_fields_ids_map, &context.doc_alloc, )?; - chunks.push(insertion.docid(), rendered)?; + chunks.set_autogenerated(insertion.docid(), rendered)?; } - } else { - let rendered = prompt.render_document( - insertion.inserted(), - context.new_fields_ids_map, - &context.doc_alloc, - )?; - chunks.push(insertion.docid(), rendered)?; } } } } - chunks.drain() + for chunk in all_chunks { + chunk.drain()?; + } + Ok(()) } } -use bumpalo::collections::Vec as BVec; -use bumpalo::Bump; - // **Warning**: the destructor of this struct is not normally run, make sure that all its fields: // 1. don't have side effects tied to they destructors // 2. if allocated, are allocated inside of the bumpalo @@ -199,15 +214,21 @@ struct Chunks<'a> { embedder: &'a Embedder, embedder_id: u8, embedder_name: &'a str, + prompt: &'a Prompt, + + user_provided: &'a RefCell>, threads: &'a ThreadPoolNoAbort, sender: &'a EmbeddingSender<'a>, } impl<'a> Chunks<'a> { + #[allow(clippy::too_many_arguments)] pub fn new( embedder: &'a Embedder, embedder_id: u8, embedder_name: &'a str, + prompt: &'a Prompt, + user_provided: &'a RefCell>, threads: &'a ThreadPoolNoAbort, sender: &'a EmbeddingSender<'a>, doc_alloc: &'a Bump, @@ -215,10 +236,20 @@ impl<'a> Chunks<'a> { let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint(); let texts = BVec::with_capacity_in(capacity, doc_alloc); let ids = BVec::with_capacity_in(capacity, doc_alloc); - Self { texts, ids, embedder, threads, sender, embedder_id, embedder_name } + Self { + texts, + ids, + embedder, + prompt, + threads, + sender, + embedder_id, + embedder_name, + user_provided, + } } - pub fn push(&mut self, docid: DocumentId, rendered: &'a str) -> Result<()> { + pub fn set_autogenerated(&mut self, docid: DocumentId, rendered: &'a str) -> Result<()> { if self.texts.len() < self.texts.capacity() { self.texts.push(rendered); self.ids.push(docid); @@ -316,4 +347,28 @@ impl<'a> Chunks<'a> { ids.clear(); res } + + pub fn prompt(&self) -> &'a Prompt { + self.prompt + } + + pub fn embedder_name(&self) -> &'a str { + self.embedder_name + } + + fn set_regenerate(&self, docid: DocumentId, regenerate: bool) { + let mut user_provided = self.user_provided.borrow_mut(); + let user_provided = + user_provided.entry_ref(self.embedder_name).or_insert(Default::default()); + if regenerate { + // regenerate == !user_provided + user_provided.del.get_or_insert(Default::default()).insert(docid); + } else { + user_provided.add.get_or_insert(Default::default()).insert(docid); + } + } + + fn set_vectors(&self, docid: DocumentId, embeddings: Vec) { + self.sender.set_vectors(docid, self.embedder_id, embeddings).unwrap(); + } } diff --git a/milli/src/update/new/indexer/mod.rs b/milli/src/update/new/indexer/mod.rs index b316cbc34..d0be88e34 100644 --- a/milli/src/update/new/indexer/mod.rs +++ b/milli/src/update/new/indexer/mod.rs @@ -1,5 +1,5 @@ use std::cell::RefCell; -use std::sync::RwLock; +use std::sync::{OnceLock, RwLock}; use std::thread::{self, Builder}; use big_s::S; @@ -10,9 +10,13 @@ use document_changes::{ }; pub use document_deletion::DocumentDeletion; pub use document_operation::DocumentOperation; +use hashbrown::HashMap; use heed::{RoTxn, RwTxn}; +use itertools::{EitherOrBoth, Itertools}; pub use partial_dump::PartialDump; +use rand::SeedableRng as _; use rayon::ThreadPool; +use roaring::RoaringBitmap; use time::OffsetDateTime; pub use update_by_function::UpdateByFunction; @@ -31,10 +35,15 @@ use crate::facet::FacetType; use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; use crate::proximity::ProximityPrecision; use crate::update::new::channel::ExtractorSender; +use crate::update::new::extract::EmbeddingExtractor; use crate::update::new::words_prefix_docids::compute_exact_word_prefix_docids; use crate::update::settings::InnerIndexSettings; use crate::update::{FacetsUpdateBulk, GrenadParameters}; -use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, Result, UserError}; +use crate::vector::{ArroyWrapper, EmbeddingConfigs}; +use crate::{ + FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result, ThreadPoolNoAbort, + ThreadPoolNoAbortBuilder, UserError, +}; pub(crate) mod de; pub mod document_changes; @@ -119,6 +128,7 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentExtractor<'a> { /// Give it the output of the [`Indexer::document_changes`] method and it will execute it in the [`rayon::ThreadPool`]. /// /// TODO return stats +#[allow(clippy::too_many_arguments)] // clippy: 😝 pub fn index<'pl, 'indexer, 'index, DC>( wtxn: &mut RwTxn, index: &'index Index, @@ -127,6 +137,7 @@ pub fn index<'pl, 'indexer, 'index, DC>( new_primary_key: Option>, pool: &ThreadPool, document_changes: &DC, + embedders: EmbeddingConfigs, ) -> Result<()> where DC: DocumentChanges<'pl>, @@ -153,8 +164,9 @@ where fields_ids_map_store: &fields_ids_map_store, }; - thread::scope(|s| { + thread::scope(|s| -> Result<()> { let indexer_span = tracing::Span::current(); + let embedders = &embedders; // TODO manage the errors correctly let handle = Builder::new().name(S("indexer-extractors")).spawn_scoped(s, move || { pool.in_place_scope(|_s| { @@ -238,9 +250,29 @@ where if index_embeddings.is_empty() { break 'vectors; } - for index_embedding in index_embeddings { + /// FIXME: need access to `merger_sender` + let embedding_sender = todo!(); + let extractor = EmbeddingExtractor::new(&embedders, &embedding_sender, request_threads()); + let datastore = ThreadLocal::with_capacity(pool.current_num_threads()); + for_each_document_change(document_changes, &extractor, indexing_context, &mut extractor_allocs, &datastore)?; + + + let mut user_provided = HashMap::new(); + for data in datastore { + let data = data.0.into_inner(); + for (embedder, deladd) in data.into_iter() { + let user_provided = user_provided.entry(embedder).or_insert(Default::default()); + if let Some(del) = deladd.del { + *user_provided -= del; + } + if let Some(add) = deladd.add { + *user_provided |= add; + } + } } + + embedding_sender.finish(user_provided).unwrap(); } { @@ -285,15 +317,137 @@ where ) })?; + let vector_arroy = index.vector_arroy; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let indexer_span = tracing::Span::current(); + let arroy_writers: Result> = embedders + .inner_as_ref() + .iter() + .map(|(embedder_name, (embedder, _, was_quantized))| { + let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { + db_name: "embedder_category_id", + key: None, + }, + )?; + + let dimensions = embedder.dimensions(); + + let writers: Vec<_> = crate::vector::arroy_db_range_for_embedder(embedder_index) + .map(|k| ArroyWrapper::new(vector_arroy, k, *was_quantized)) + .collect(); + + Ok(( + embedder_index, + (embedder_name.as_str(), embedder.as_ref(), writers, dimensions), + )) + }) + .collect(); + + let mut arroy_writers = arroy_writers?; for operation in writer_receiver { - let database = operation.database(index); - match operation.entry() { - EntryOperation::Delete(e) => { - if !database.delete(wtxn, e.entry())? { - unreachable!("We tried to delete an unknown key") + match operation { + WriterOperation::DbOperation(db_operation) => { + let database = db_operation.database(index); + match db_operation.entry() { + EntryOperation::Delete(e) => { + if !database.delete(wtxn, e.entry())? { + unreachable!("We tried to delete an unknown key") + } + } + EntryOperation::Write(e) => database.put(wtxn, e.key(), e.value())?, } } - EntryOperation::Write(e) => database.put(wtxn, e.key(), e.value())?, + WriterOperation::ArroyOperation(arroy_operation) => match arroy_operation { + ArroyOperation::DeleteVectors { docid } => { + for (_embedder_index, (_embedder_name, _embedder, writers, dimensions)) in + &mut arroy_writers + { + let dimensions = *dimensions; + for writer in writers { + // Uses invariant: vectors are packed in the first writers. + if !writer.del_item(wtxn, dimensions, docid)? { + break; + } + } + } + } + ArroyOperation::SetVectors { docid, embedder_id, embeddings } => { + let (_, _, writers, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + for res in writers.iter().zip_longest(&embeddings) { + match res { + EitherOrBoth::Both(writer, embedding) => { + writer.add_item(wtxn, *dimensions, docid, embedding)?; + } + EitherOrBoth::Left(writer) => { + let deleted = writer.del_item(wtxn, *dimensions, docid)?; + if !deleted { + break; + } + } + EitherOrBoth::Right(_embedding) => { + let external_document_id = index + .external_id_of(wtxn, std::iter::once(docid))? + .into_iter() + .next() + .unwrap()?; + return Err(UserError::TooManyVectors( + external_document_id, + embeddings.len(), + ) + .into()); + } + } + } + } + ArroyOperation::SetVector { docid, embedder_id, embedding } => { + let (_, _, writers, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + for res in writers.iter().zip_longest(std::iter::once(&embedding)) { + match res { + EitherOrBoth::Both(writer, embedding) => { + writer.add_item(wtxn, *dimensions, docid, embedding)?; + } + EitherOrBoth::Left(writer) => { + let deleted = writer.del_item(wtxn, *dimensions, docid)?; + if !deleted { + break; + } + } + EitherOrBoth::Right(_embedding) => { + unreachable!("1 vs 256 vectors") + } + } + } + } + ArroyOperation::Finish { mut user_provided } => { + let span = tracing::trace_span!(target: "indexing::vectors", parent: &indexer_span, "build"); + let _entered = span.enter(); + for (_embedder_index, (_embedder_name, _embedder, writers, dimensions)) in + &mut arroy_writers + { + let dimensions = *dimensions; + for writer in writers { + if writer.need_build(wtxn, dimensions)? { + writer.build(wtxn, &mut rng, dimensions)?; + } else if writer.is_empty(wtxn, dimensions)? { + break; + } + } + } + + let mut configs = index.embedding_configs(wtxn)?; + + for config in &mut configs { + if let Some(user_provided) = user_provided.remove(&config.name) { + config.user_provided = user_provided; + } + } + + index.put_embedding_configs(wtxn, configs)?; + } + }, } } @@ -483,3 +637,15 @@ pub fn retrieve_or_guess_primary_key<'a>( Err(err) => Ok(Err(err)), } } + +fn request_threads() -> &'static ThreadPoolNoAbort { + static REQUEST_THREADS: OnceLock = OnceLock::new(); + + REQUEST_THREADS.get_or_init(|| { + ThreadPoolNoAbortBuilder::new() + .num_threads(crate::vector::REQUEST_PARALLELISM) + .thread_name(|index| format!("embedding-request-{index}")) + .build() + .unwrap() + }) +} diff --git a/milli/src/update/new/merger.rs b/milli/src/update/new/merger.rs index 6183beb63..14e947686 100644 --- a/milli/src/update/new/merger.rs +++ b/milli/src/update/new/merger.rs @@ -149,6 +149,7 @@ pub fn merge_grenad_entries( } } MergerOperation::DeleteDocument { docid, external_id } => { + /// TODO: delete vectors let span = tracing::trace_span!(target: "indexing::documents::merge", "delete_document"); let _entered = span.enter();