This commit is contained in:
Louis Dureuil 2024-10-29 17:43:36 +01:00
parent 7058959a46
commit 1075dd34bb
No known key found for this signature in database
8 changed files with 420 additions and 149 deletions

View File

@ -1300,6 +1300,8 @@ impl IndexScheduler {
let mut content_files_iter = content_files.iter(); let mut content_files_iter = content_files.iter();
let mut indexer = indexer::DocumentOperation::new(method); let mut indexer = indexer::DocumentOperation::new(method);
let embedders = index.embedding_configs(index_wtxn)?;
let embedders = self.embedders(embedders)?;
for (operation, task) in operations.into_iter().zip(tasks.iter_mut()) { for (operation, task) in operations.into_iter().zip(tasks.iter_mut()) {
match operation { match operation {
DocumentOperation::Add(_content_uuid) => { DocumentOperation::Add(_content_uuid) => {
@ -1374,6 +1376,7 @@ impl IndexScheduler {
primary_key_has_been_set.then_some(primary_key), primary_key_has_been_set.then_some(primary_key),
&pool, &pool,
&document_changes, &document_changes,
embedders,
)?; )?;
// tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done");
@ -1460,6 +1463,8 @@ impl IndexScheduler {
let indexer = UpdateByFunction::new(candidates, context.clone(), code.clone()); let indexer = UpdateByFunction::new(candidates, context.clone(), code.clone());
let document_changes = indexer.into_changes(&primary_key)?; let document_changes = indexer.into_changes(&primary_key)?;
let embedders = index.embedding_configs(index_wtxn)?;
let embedders = self.embedders(embedders)?;
indexer::index( indexer::index(
index_wtxn, index_wtxn,
@ -1469,6 +1474,7 @@ impl IndexScheduler {
None, // cannot change primary key in DocumentEdition None, // cannot change primary key in DocumentEdition
&pool, &pool,
&document_changes, &document_changes,
embedders,
)?; )?;
// tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done");
@ -1596,6 +1602,8 @@ impl IndexScheduler {
let mut indexer = indexer::DocumentDeletion::new(); let mut indexer = indexer::DocumentDeletion::new();
indexer.delete_documents_by_docids(to_delete); indexer.delete_documents_by_docids(to_delete);
let document_changes = indexer.into_changes(&indexer_alloc, primary_key); let document_changes = indexer.into_changes(&indexer_alloc, primary_key);
let embedders = index.embedding_configs(index_wtxn)?;
let embedders = self.embedders(embedders)?;
indexer::index( indexer::index(
index_wtxn, index_wtxn,
@ -1605,6 +1613,7 @@ impl IndexScheduler {
None, // document deletion never changes primary key None, // document deletion never changes primary key
&pool, &pool,
&document_changes, &document_changes,
embedders,
)?; )?;
// tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done");

View File

@ -3,6 +3,7 @@ use std::marker::PhantomData;
use crossbeam_channel::{IntoIter, Receiver, SendError, Sender}; use crossbeam_channel::{IntoIter, Receiver, SendError, Sender};
use grenad::Merger; use grenad::Merger;
use hashbrown::HashMap;
use heed::types::Bytes; use heed::types::Bytes;
use memmap2::Mmap; use memmap2::Mmap;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
@ -124,7 +125,32 @@ impl DocumentDeletionEntry {
} }
} }
pub struct WriterOperation { pub enum WriterOperation {
DbOperation(DbOperation),
ArroyOperation(ArroyOperation),
}
pub enum ArroyOperation {
/// TODO: call when deleting regular documents
DeleteVectors {
docid: DocumentId,
},
SetVectors {
docid: DocumentId,
embedder_id: u8,
embeddings: Vec<Embedding>,
},
SetVector {
docid: DocumentId,
embedder_id: u8,
embedding: Embedding,
},
Finish {
user_provided: HashMap<String, RoaringBitmap>,
},
}
pub struct DbOperation {
database: Database, database: Database,
entry: EntryOperation, entry: EntryOperation,
} }
@ -180,7 +206,7 @@ impl From<FacetKind> for Database {
} }
} }
impl WriterOperation { impl DbOperation {
pub fn database(&self, index: &Index) -> heed::Database<Bytes, Bytes> { pub fn database(&self, index: &Index) -> heed::Database<Bytes, Bytes> {
self.database.database(index) self.database.database(index)
} }
@ -246,13 +272,13 @@ impl MergerSender {
DOCUMENTS_IDS_KEY.as_bytes(), DOCUMENTS_IDS_KEY.as_bytes(),
documents_ids, documents_ids,
)); ));
match self.send(WriterOperation { database: Database::Main, entry }) { match self.send_db_operation(DbOperation { database: Database::Main, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
} }
fn send(&self, op: WriterOperation) -> StdResult<(), SendError<()>> { fn send_db_operation(&self, op: DbOperation) -> StdResult<(), SendError<()>> {
if self.sender.is_full() { if self.sender.is_full() {
self.writer_contentious_count.set(self.writer_contentious_count.get() + 1); self.writer_contentious_count.set(self.writer_contentious_count.get() + 1);
} }
@ -260,7 +286,7 @@ impl MergerSender {
self.merger_contentious_count.set(self.merger_contentious_count.get() + 1); self.merger_contentious_count.set(self.merger_contentious_count.get() + 1);
} }
self.send_count.set(self.send_count.get() + 1); self.send_count.set(self.send_count.get() + 1);
match self.sender.send(op) { match self.sender.send(WriterOperation::DbOperation(op)) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -275,7 +301,7 @@ impl MainSender<'_> {
WORDS_FST_KEY.as_bytes(), WORDS_FST_KEY.as_bytes(),
value, value,
)); ));
match self.0.send(WriterOperation { database: Database::Main, entry }) { match self.0.send_db_operation(DbOperation { database: Database::Main, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -286,7 +312,7 @@ impl MainSender<'_> {
WORDS_PREFIXES_FST_KEY.as_bytes(), WORDS_PREFIXES_FST_KEY.as_bytes(),
value, value,
)); ));
match self.0.send(WriterOperation { database: Database::Main, entry }) { match self.0.send_db_operation(DbOperation { database: Database::Main, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -294,7 +320,7 @@ impl MainSender<'_> {
pub fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> { pub fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> {
let entry = EntryOperation::Delete(KeyEntry::from_key(key)); let entry = EntryOperation::Delete(KeyEntry::from_key(key));
match self.0.send(WriterOperation { database: Database::Main, entry }) { match self.0.send_db_operation(DbOperation { database: Database::Main, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -396,7 +422,7 @@ pub struct WordDocidsSender<'a, D> {
impl<D: DatabaseType> DocidsSender for WordDocidsSender<'_, D> { impl<D: DatabaseType> DocidsSender for WordDocidsSender<'_, D> {
fn write(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>> { fn write(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>> {
let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)); let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value));
match self.sender.send(WriterOperation { database: D::DATABASE, entry }) { match self.sender.send_db_operation(DbOperation { database: D::DATABASE, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -404,7 +430,7 @@ impl<D: DatabaseType> DocidsSender for WordDocidsSender<'_, D> {
fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> { fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> {
let entry = EntryOperation::Delete(KeyEntry::from_key(key)); let entry = EntryOperation::Delete(KeyEntry::from_key(key));
match self.sender.send(WriterOperation { database: D::DATABASE, entry }) { match self.sender.send_db_operation(DbOperation { database: D::DATABASE, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -429,7 +455,7 @@ impl DocidsSender for FacetDocidsSender<'_> {
} }
_ => EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)), _ => EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)),
}; };
match self.sender.send(WriterOperation { database, entry }) { match self.sender.send_db_operation(DbOperation { database, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -439,7 +465,7 @@ impl DocidsSender for FacetDocidsSender<'_> {
let (facet_kind, key) = FacetKind::extract_from_key(key); let (facet_kind, key) = FacetKind::extract_from_key(key);
let database = Database::from(facet_kind); let database = Database::from(facet_kind);
let entry = EntryOperation::Delete(KeyEntry::from_key(key)); let entry = EntryOperation::Delete(KeyEntry::from_key(key));
match self.sender.send(WriterOperation { database, entry }) { match self.sender.send_db_operation(DbOperation { database, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -460,7 +486,7 @@ impl DocumentsSender<'_> {
&docid.to_be_bytes(), &docid.to_be_bytes(),
document.as_bytes(), document.as_bytes(),
)); ));
match self.0.send(WriterOperation { database: Database::Documents, entry }) { match self.0.send_db_operation(DbOperation { database: Database::Documents, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
}?; }?;
@ -469,7 +495,10 @@ impl DocumentsSender<'_> {
external_id.as_bytes(), external_id.as_bytes(),
&docid.to_be_bytes(), &docid.to_be_bytes(),
)); ));
match self.0.send(WriterOperation { database: Database::ExternalDocumentsIds, entry }) { match self
.0
.send_db_operation(DbOperation { database: Database::ExternalDocumentsIds, entry })
{
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
@ -477,33 +506,38 @@ impl DocumentsSender<'_> {
pub fn delete(&self, docid: DocumentId, external_id: String) -> StdResult<(), SendError<()>> { pub fn delete(&self, docid: DocumentId, external_id: String) -> StdResult<(), SendError<()>> {
let entry = EntryOperation::Delete(KeyEntry::from_key(&docid.to_be_bytes())); let entry = EntryOperation::Delete(KeyEntry::from_key(&docid.to_be_bytes()));
match self.0.send(WriterOperation { database: Database::Documents, entry }) { match self.0.send_db_operation(DbOperation { database: Database::Documents, entry }) {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
}?; }?;
let entry = EntryOperation::Delete(KeyEntry::from_key(external_id.as_bytes())); let entry = EntryOperation::Delete(KeyEntry::from_key(external_id.as_bytes()));
match self.0.send(WriterOperation { database: Database::ExternalDocumentsIds, entry }) { match self
.0
.send_db_operation(DbOperation { database: Database::ExternalDocumentsIds, entry })
{
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(SendError(_)) => Err(SendError(())), Err(SendError(_)) => Err(SendError(())),
} }
} }
} }
pub struct EmbeddingSender<'a>(Option<&'a Sender<MergerOperation>>); pub struct EmbeddingSender<'a>(&'a Sender<WriterOperation>);
impl EmbeddingSender<'_> { impl EmbeddingSender<'_> {
pub fn delete(&self, docid: DocumentId, embedder_id: u8) -> StdResult<(), SendError<()>> {
todo!()
}
pub fn set_vectors( pub fn set_vectors(
&self, &self,
docid: DocumentId, docid: DocumentId,
embedder_id: u8, embedder_id: u8,
embeddings: Vec<Embedding>, embeddings: Vec<Embedding>,
) -> StdResult<(), SendError<()>> { ) -> StdResult<(), SendError<()>> {
todo!() self.0
.send(WriterOperation::ArroyOperation(ArroyOperation::SetVectors {
docid,
embedder_id,
embeddings,
}))
.map_err(|_| SendError(()))
} }
pub fn set_vector( pub fn set_vector(
@ -512,19 +546,24 @@ impl EmbeddingSender<'_> {
embedder_id: u8, embedder_id: u8,
embedding: Embedding, embedding: Embedding,
) -> StdResult<(), SendError<()>> { ) -> StdResult<(), SendError<()>> {
todo!() self.0
.send(WriterOperation::ArroyOperation(ArroyOperation::SetVector {
docid,
embedder_id,
embedding,
}))
.map_err(|_| SendError(()))
} }
pub fn set_user_provided( /// Marks all embedders as "to be built"
&self, pub fn finish(
docid: DocumentId, self,
embedder_id: u8, user_provided: HashMap<String, RoaringBitmap>,
regenerate: bool,
) -> StdResult<(), SendError<()>> { ) -> StdResult<(), SendError<()>> {
todo!() self.0
.send(WriterOperation::ArroyOperation(ArroyOperation::Finish { user_provided }))
.map_err(|_| SendError(()))
} }
pub fn finish(self, embedder_id: u8) {}
} }
pub enum MergerOperation { pub enum MergerOperation {
ExactWordDocidsMerger(Merger<File, MergeDeladdCboRoaringBitmaps>), ExactWordDocidsMerger(Merger<File, MergeDeladdCboRoaringBitmaps>),

View File

@ -4,7 +4,7 @@ use heed::RoTxn;
use raw_collections::RawMap; use raw_collections::RawMap;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use super::vector_document::{VectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions}; use super::vector_document::VectorDocument;
use super::{KvReaderFieldId, KvWriterFieldId}; use super::{KvReaderFieldId, KvWriterFieldId};
use crate::documents::FieldIdMapper; use crate::documents::FieldIdMapper;
use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME;

View File

@ -267,7 +267,7 @@ impl Stats {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Default)]
pub struct DelAddRoaringBitmap { pub struct DelAddRoaringBitmap {
pub(crate) del: Option<RoaringBitmap>, pub(crate) del: Option<RoaringBitmap>,
pub(crate) add: Option<RoaringBitmap>, pub(crate) add: Option<RoaringBitmap>,

View File

@ -11,6 +11,7 @@ use bumpalo::Bump;
pub use faceted::*; pub use faceted::*;
use grenad::Merger; use grenad::Merger;
pub use searchable::*; pub use searchable::*;
pub use vectors::EmbeddingExtractor;
use super::indexer::document_changes::{DocumentChanges, FullySend, IndexingContext, ThreadLocal}; use super::indexer::document_changes::{DocumentChanges, FullySend, IndexingContext, ThreadLocal};
use crate::update::{GrenadParameters, MergeDeladdCboRoaringBitmaps}; use crate::update::{GrenadParameters, MergeDeladdCboRoaringBitmaps};

View File

@ -1,3 +1,10 @@
use std::cell::RefCell;
use bumpalo::collections::Vec as BVec;
use bumpalo::Bump;
use hashbrown::HashMap;
use super::cache::DelAddRoaringBitmap;
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::prompt::Prompt; use crate::prompt::Prompt;
use crate::update::new::channel::EmbeddingSender; use crate::update::new::channel::EmbeddingSender;
@ -5,26 +12,34 @@ use crate::update::new::indexer::document_changes::{Extractor, FullySend};
use crate::update::new::vector_document::VectorDocument; use crate::update::new::vector_document::VectorDocument;
use crate::update::new::DocumentChange; use crate::update::new::DocumentChange;
use crate::vector::error::EmbedErrorKind; use crate::vector::error::EmbedErrorKind;
use crate::vector::Embedder; use crate::vector::{Embedder, Embedding, EmbeddingConfigs};
use crate::{DocumentId, Result, ThreadPoolNoAbort, UserError}; use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, UserError};
pub struct EmbeddingExtractor<'a> { pub struct EmbeddingExtractor<'a> {
embedder: &'a Embedder, embedders: &'a EmbeddingConfigs,
prompt: &'a Prompt,
embedder_id: u8,
embedder_name: &'a str,
sender: &'a EmbeddingSender<'a>, sender: &'a EmbeddingSender<'a>,
threads: &'a ThreadPoolNoAbort, threads: &'a ThreadPoolNoAbort,
} }
impl<'a> EmbeddingExtractor<'a> {
pub fn new(
embedders: &'a EmbeddingConfigs,
sender: &'a EmbeddingSender<'a>,
threads: &'a ThreadPoolNoAbort,
) -> Self {
Self { embedders, sender, threads }
}
}
impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
type Data = FullySend<()>; type Data = FullySend<RefCell<HashMap<String, DelAddRoaringBitmap>>>;
fn init_data<'doc>( fn init_data<'doc>(
&'doc self, &'doc self,
_extractor_alloc: raw_collections::alloc::RefBump<'extractor>, _extractor_alloc: raw_collections::alloc::RefBump<'extractor>,
) -> crate::Result<Self::Data> { ) -> crate::Result<Self::Data> {
Ok(FullySend(())) /// TODO: use the extractor_alloc in the hashbrown once you merge the branch where it is no longer a RefBump
Ok(FullySend(Default::default()))
} }
fn process<'doc>( fn process<'doc>(
@ -34,36 +49,49 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
Self::Data, Self::Data,
>, >,
) -> crate::Result<()> { ) -> crate::Result<()> {
let embedder_name: &str = self.embedder_name; let embedders = self.embedders.inner_as_ref();
let embedder: &Embedder = self.embedder;
let prompt: &Prompt = self.prompt;
let mut chunks = Chunks::new( let mut all_chunks = BVec::with_capacity_in(embedders.len(), &context.doc_alloc);
for (embedder_name, (embedder, prompt, _is_quantized)) in embedders {
let embedder_id =
context.index.embedder_category_id.get(&context.txn, embedder_name)?.ok_or_else(
|| InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
},
)?;
all_chunks.push(Chunks::new(
embedder, embedder,
self.embedder_id, embedder_id,
embedder_name, embedder_name,
prompt,
&context.data.0,
self.threads, self.threads,
self.sender, self.sender,
&context.doc_alloc, &context.doc_alloc,
); ))
}
for change in changes { for change in changes {
let change = change?; let change = change?;
match change { match change {
DocumentChange::Deletion(deletion) => { DocumentChange::Deletion(_deletion) => {
self.sender.delete(deletion.docid(), self.embedder_id).unwrap(); // handled by document sender
} }
DocumentChange::Update(update) => { DocumentChange::Update(update) => {
/// FIXME: this will force the parsing/retrieval of VectorDocument once per embedder
/// consider doing all embedders at once?
let old_vectors = update.current_vectors( let old_vectors = update.current_vectors(
&context.txn, &context.txn,
context.index, context.index,
context.db_fields_ids_map, context.db_fields_ids_map,
&context.doc_alloc, &context.doc_alloc,
)?; )?;
let old_vectors = old_vectors.vectors_for_key(embedder_name)?.unwrap();
let new_vectors = update.updated_vectors(&context.doc_alloc)?; let new_vectors = update.updated_vectors(&context.doc_alloc)?;
for chunks in &mut all_chunks {
let embedder_name = chunks.embedder_name();
let prompt = chunks.prompt();
let old_vectors = old_vectors.vectors_for_key(embedder_name)?.unwrap();
if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| {
new_vectors.vectors_for_key(embedder_name).transpose() new_vectors.vectors_for_key(embedder_name).transpose()
}) { }) {
@ -71,24 +99,15 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
match (old_vectors.regenerate, new_vectors.regenerate) { match (old_vectors.regenerate, new_vectors.regenerate) {
(true, true) | (false, false) => todo!(), (true, true) | (false, false) => todo!(),
_ => { _ => {
self.sender chunks.set_regenerate(update.docid(), new_vectors.regenerate);
.set_user_provided(
update.docid(),
self.embedder_id,
!new_vectors.regenerate,
)
.unwrap();
} }
} }
// do we have set embeddings? // do we have set embeddings?
if let Some(embeddings) = new_vectors.embeddings { if let Some(embeddings) = new_vectors.embeddings {
self.sender chunks.set_vectors(
.set_vectors(
update.docid(), update.docid(),
self.embedder_id,
embeddings.into_vec().map_err(UserError::SerdeJson)?, embeddings.into_vec().map_err(UserError::SerdeJson)?,
) );
.unwrap();
} else if new_vectors.regenerate { } else if new_vectors.regenerate {
let new_rendered = prompt.render_document( let new_rendered = prompt.render_document(
update.current( update.current(
@ -109,7 +128,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
&context.doc_alloc, &context.doc_alloc,
)?; )?;
if new_rendered != old_rendered { if new_rendered != old_rendered {
chunks.push(update.docid(), new_rendered)?; chunks.set_autogenerated(update.docid(), new_rendered)?;
} }
} }
} else if old_vectors.regenerate { } else if old_vectors.regenerate {
@ -132,39 +151,34 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
&context.doc_alloc, &context.doc_alloc,
)?; )?;
if new_rendered != old_rendered { if new_rendered != old_rendered {
chunks.push(update.docid(), new_rendered)?; chunks.set_autogenerated(update.docid(), new_rendered)?;
}
} }
} }
} }
DocumentChange::Insertion(insertion) => { DocumentChange::Insertion(insertion) => {
for chunks in &mut all_chunks {
let embedder_name = chunks.embedder_name();
let prompt = chunks.prompt();
// if no inserted vectors, then regenerate: true + no embeddings => autogenerate // if no inserted vectors, then regenerate: true + no embeddings => autogenerate
let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?; let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?;
if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| {
new_vectors.vectors_for_key(embedder_name).transpose() new_vectors.vectors_for_key(embedder_name).transpose()
}) { }) {
let new_vectors = new_vectors?; let new_vectors = new_vectors?;
self.sender chunks.set_regenerate(insertion.docid(), new_vectors.regenerate);
.set_user_provided(
insertion.docid(),
self.embedder_id,
!new_vectors.regenerate,
)
.unwrap();
if let Some(embeddings) = new_vectors.embeddings { if let Some(embeddings) = new_vectors.embeddings {
self.sender chunks.set_vectors(
.set_vectors(
insertion.docid(), insertion.docid(),
self.embedder_id,
embeddings.into_vec().map_err(UserError::SerdeJson)?, embeddings.into_vec().map_err(UserError::SerdeJson)?,
) );
.unwrap();
} else if new_vectors.regenerate { } else if new_vectors.regenerate {
let rendered = prompt.render_document( let rendered = prompt.render_document(
insertion.inserted(), insertion.inserted(),
context.new_fields_ids_map, context.new_fields_ids_map,
&context.doc_alloc, &context.doc_alloc,
)?; )?;
chunks.push(insertion.docid(), rendered)?; chunks.set_autogenerated(insertion.docid(), rendered)?;
} }
} else { } else {
let rendered = prompt.render_document( let rendered = prompt.render_document(
@ -172,19 +186,20 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
context.new_fields_ids_map, context.new_fields_ids_map,
&context.doc_alloc, &context.doc_alloc,
)?; )?;
chunks.push(insertion.docid(), rendered)?; chunks.set_autogenerated(insertion.docid(), rendered)?;
}
} }
} }
} }
} }
chunks.drain() for chunk in all_chunks {
chunk.drain()?;
}
Ok(())
} }
} }
use bumpalo::collections::Vec as BVec;
use bumpalo::Bump;
// **Warning**: the destructor of this struct is not normally run, make sure that all its fields: // **Warning**: the destructor of this struct is not normally run, make sure that all its fields:
// 1. don't have side effects tied to they destructors // 1. don't have side effects tied to they destructors
// 2. if allocated, are allocated inside of the bumpalo // 2. if allocated, are allocated inside of the bumpalo
@ -199,15 +214,21 @@ struct Chunks<'a> {
embedder: &'a Embedder, embedder: &'a Embedder,
embedder_id: u8, embedder_id: u8,
embedder_name: &'a str, embedder_name: &'a str,
prompt: &'a Prompt,
user_provided: &'a RefCell<HashMap<String, DelAddRoaringBitmap>>,
threads: &'a ThreadPoolNoAbort, threads: &'a ThreadPoolNoAbort,
sender: &'a EmbeddingSender<'a>, sender: &'a EmbeddingSender<'a>,
} }
impl<'a> Chunks<'a> { impl<'a> Chunks<'a> {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
embedder: &'a Embedder, embedder: &'a Embedder,
embedder_id: u8, embedder_id: u8,
embedder_name: &'a str, embedder_name: &'a str,
prompt: &'a Prompt,
user_provided: &'a RefCell<HashMap<String, DelAddRoaringBitmap>>,
threads: &'a ThreadPoolNoAbort, threads: &'a ThreadPoolNoAbort,
sender: &'a EmbeddingSender<'a>, sender: &'a EmbeddingSender<'a>,
doc_alloc: &'a Bump, doc_alloc: &'a Bump,
@ -215,10 +236,20 @@ impl<'a> Chunks<'a> {
let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint(); let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint();
let texts = BVec::with_capacity_in(capacity, doc_alloc); let texts = BVec::with_capacity_in(capacity, doc_alloc);
let ids = BVec::with_capacity_in(capacity, doc_alloc); let ids = BVec::with_capacity_in(capacity, doc_alloc);
Self { texts, ids, embedder, threads, sender, embedder_id, embedder_name } Self {
texts,
ids,
embedder,
prompt,
threads,
sender,
embedder_id,
embedder_name,
user_provided,
}
} }
pub fn push(&mut self, docid: DocumentId, rendered: &'a str) -> Result<()> { pub fn set_autogenerated(&mut self, docid: DocumentId, rendered: &'a str) -> Result<()> {
if self.texts.len() < self.texts.capacity() { if self.texts.len() < self.texts.capacity() {
self.texts.push(rendered); self.texts.push(rendered);
self.ids.push(docid); self.ids.push(docid);
@ -316,4 +347,28 @@ impl<'a> Chunks<'a> {
ids.clear(); ids.clear();
res res
} }
pub fn prompt(&self) -> &'a Prompt {
self.prompt
}
pub fn embedder_name(&self) -> &'a str {
self.embedder_name
}
fn set_regenerate(&self, docid: DocumentId, regenerate: bool) {
let mut user_provided = self.user_provided.borrow_mut();
let user_provided =
user_provided.entry_ref(self.embedder_name).or_insert(Default::default());
if regenerate {
// regenerate == !user_provided
user_provided.del.get_or_insert(Default::default()).insert(docid);
} else {
user_provided.add.get_or_insert(Default::default()).insert(docid);
}
}
fn set_vectors(&self, docid: DocumentId, embeddings: Vec<Embedding>) {
self.sender.set_vectors(docid, self.embedder_id, embeddings).unwrap();
}
} }

View File

@ -1,5 +1,5 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::sync::RwLock; use std::sync::{OnceLock, RwLock};
use std::thread::{self, Builder}; use std::thread::{self, Builder};
use big_s::S; use big_s::S;
@ -10,9 +10,13 @@ use document_changes::{
}; };
pub use document_deletion::DocumentDeletion; pub use document_deletion::DocumentDeletion;
pub use document_operation::DocumentOperation; pub use document_operation::DocumentOperation;
use hashbrown::HashMap;
use heed::{RoTxn, RwTxn}; use heed::{RoTxn, RwTxn};
use itertools::{EitherOrBoth, Itertools};
pub use partial_dump::PartialDump; pub use partial_dump::PartialDump;
use rand::SeedableRng as _;
use rayon::ThreadPool; use rayon::ThreadPool;
use roaring::RoaringBitmap;
use time::OffsetDateTime; use time::OffsetDateTime;
pub use update_by_function::UpdateByFunction; pub use update_by_function::UpdateByFunction;
@ -31,10 +35,15 @@ use crate::facet::FacetType;
use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder};
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::update::new::channel::ExtractorSender; use crate::update::new::channel::ExtractorSender;
use crate::update::new::extract::EmbeddingExtractor;
use crate::update::new::words_prefix_docids::compute_exact_word_prefix_docids; use crate::update::new::words_prefix_docids::compute_exact_word_prefix_docids;
use crate::update::settings::InnerIndexSettings; use crate::update::settings::InnerIndexSettings;
use crate::update::{FacetsUpdateBulk, GrenadParameters}; use crate::update::{FacetsUpdateBulk, GrenadParameters};
use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, Result, UserError}; use crate::vector::{ArroyWrapper, EmbeddingConfigs};
use crate::{
FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result, ThreadPoolNoAbort,
ThreadPoolNoAbortBuilder, UserError,
};
pub(crate) mod de; pub(crate) mod de;
pub mod document_changes; pub mod document_changes;
@ -119,6 +128,7 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentExtractor<'a> {
/// Give it the output of the [`Indexer::document_changes`] method and it will execute it in the [`rayon::ThreadPool`]. /// Give it the output of the [`Indexer::document_changes`] method and it will execute it in the [`rayon::ThreadPool`].
/// ///
/// TODO return stats /// TODO return stats
#[allow(clippy::too_many_arguments)] // clippy: 😝
pub fn index<'pl, 'indexer, 'index, DC>( pub fn index<'pl, 'indexer, 'index, DC>(
wtxn: &mut RwTxn, wtxn: &mut RwTxn,
index: &'index Index, index: &'index Index,
@ -127,6 +137,7 @@ pub fn index<'pl, 'indexer, 'index, DC>(
new_primary_key: Option<PrimaryKey<'pl>>, new_primary_key: Option<PrimaryKey<'pl>>,
pool: &ThreadPool, pool: &ThreadPool,
document_changes: &DC, document_changes: &DC,
embedders: EmbeddingConfigs,
) -> Result<()> ) -> Result<()>
where where
DC: DocumentChanges<'pl>, DC: DocumentChanges<'pl>,
@ -153,8 +164,9 @@ where
fields_ids_map_store: &fields_ids_map_store, fields_ids_map_store: &fields_ids_map_store,
}; };
thread::scope(|s| { thread::scope(|s| -> Result<()> {
let indexer_span = tracing::Span::current(); let indexer_span = tracing::Span::current();
let embedders = &embedders;
// TODO manage the errors correctly // TODO manage the errors correctly
let handle = Builder::new().name(S("indexer-extractors")).spawn_scoped(s, move || { let handle = Builder::new().name(S("indexer-extractors")).spawn_scoped(s, move || {
pool.in_place_scope(|_s| { pool.in_place_scope(|_s| {
@ -238,9 +250,29 @@ where
if index_embeddings.is_empty() { if index_embeddings.is_empty() {
break 'vectors; break 'vectors;
} }
for index_embedding in index_embeddings { /// FIXME: need access to `merger_sender`
let embedding_sender = todo!();
let extractor = EmbeddingExtractor::new(&embedders, &embedding_sender, request_threads());
let datastore = ThreadLocal::with_capacity(pool.current_num_threads());
for_each_document_change(document_changes, &extractor, indexing_context, &mut extractor_allocs, &datastore)?;
let mut user_provided = HashMap::new();
for data in datastore {
let data = data.0.into_inner();
for (embedder, deladd) in data.into_iter() {
let user_provided = user_provided.entry(embedder).or_insert(Default::default());
if let Some(del) = deladd.del {
*user_provided -= del;
} }
if let Some(add) = deladd.add {
*user_provided |= add;
}
}
}
embedding_sender.finish(user_provided).unwrap();
} }
{ {
@ -285,9 +317,39 @@ where
) )
})?; })?;
let vector_arroy = index.vector_arroy;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let indexer_span = tracing::Span::current();
let arroy_writers: Result<HashMap<_, _>> = embedders
.inner_as_ref()
.iter()
.map(|(embedder_name, (embedder, _, was_quantized))| {
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry {
db_name: "embedder_category_id",
key: None,
},
)?;
let dimensions = embedder.dimensions();
let writers: Vec<_> = crate::vector::arroy_db_range_for_embedder(embedder_index)
.map(|k| ArroyWrapper::new(vector_arroy, k, *was_quantized))
.collect();
Ok((
embedder_index,
(embedder_name.as_str(), embedder.as_ref(), writers, dimensions),
))
})
.collect();
let mut arroy_writers = arroy_writers?;
for operation in writer_receiver { for operation in writer_receiver {
let database = operation.database(index); match operation {
match operation.entry() { WriterOperation::DbOperation(db_operation) => {
let database = db_operation.database(index);
match db_operation.entry() {
EntryOperation::Delete(e) => { EntryOperation::Delete(e) => {
if !database.delete(wtxn, e.entry())? { if !database.delete(wtxn, e.entry())? {
unreachable!("We tried to delete an unknown key") unreachable!("We tried to delete an unknown key")
@ -296,6 +358,98 @@ where
EntryOperation::Write(e) => database.put(wtxn, e.key(), e.value())?, EntryOperation::Write(e) => database.put(wtxn, e.key(), e.value())?,
} }
} }
WriterOperation::ArroyOperation(arroy_operation) => match arroy_operation {
ArroyOperation::DeleteVectors { docid } => {
for (_embedder_index, (_embedder_name, _embedder, writers, dimensions)) in
&mut arroy_writers
{
let dimensions = *dimensions;
for writer in writers {
// Uses invariant: vectors are packed in the first writers.
if !writer.del_item(wtxn, dimensions, docid)? {
break;
}
}
}
}
ArroyOperation::SetVectors { docid, embedder_id, embeddings } => {
let (_, _, writers, dimensions) =
arroy_writers.get(&embedder_id).expect("requested a missing embedder");
for res in writers.iter().zip_longest(&embeddings) {
match res {
EitherOrBoth::Both(writer, embedding) => {
writer.add_item(wtxn, *dimensions, docid, embedding)?;
}
EitherOrBoth::Left(writer) => {
let deleted = writer.del_item(wtxn, *dimensions, docid)?;
if !deleted {
break;
}
}
EitherOrBoth::Right(_embedding) => {
let external_document_id = index
.external_id_of(wtxn, std::iter::once(docid))?
.into_iter()
.next()
.unwrap()?;
return Err(UserError::TooManyVectors(
external_document_id,
embeddings.len(),
)
.into());
}
}
}
}
ArroyOperation::SetVector { docid, embedder_id, embedding } => {
let (_, _, writers, dimensions) =
arroy_writers.get(&embedder_id).expect("requested a missing embedder");
for res in writers.iter().zip_longest(std::iter::once(&embedding)) {
match res {
EitherOrBoth::Both(writer, embedding) => {
writer.add_item(wtxn, *dimensions, docid, embedding)?;
}
EitherOrBoth::Left(writer) => {
let deleted = writer.del_item(wtxn, *dimensions, docid)?;
if !deleted {
break;
}
}
EitherOrBoth::Right(_embedding) => {
unreachable!("1 vs 256 vectors")
}
}
}
}
ArroyOperation::Finish { mut user_provided } => {
let span = tracing::trace_span!(target: "indexing::vectors", parent: &indexer_span, "build");
let _entered = span.enter();
for (_embedder_index, (_embedder_name, _embedder, writers, dimensions)) in
&mut arroy_writers
{
let dimensions = *dimensions;
for writer in writers {
if writer.need_build(wtxn, dimensions)? {
writer.build(wtxn, &mut rng, dimensions)?;
} else if writer.is_empty(wtxn, dimensions)? {
break;
}
}
}
let mut configs = index.embedding_configs(wtxn)?;
for config in &mut configs {
if let Some(user_provided) = user_provided.remove(&config.name) {
config.user_provided = user_provided;
}
}
index.put_embedding_configs(wtxn, configs)?;
}
},
}
}
/// TODO handle the panicking threads /// TODO handle the panicking threads
handle.join().unwrap()?; handle.join().unwrap()?;
@ -483,3 +637,15 @@ pub fn retrieve_or_guess_primary_key<'a>(
Err(err) => Ok(Err(err)), Err(err) => Ok(Err(err)),
} }
} }
fn request_threads() -> &'static ThreadPoolNoAbort {
static REQUEST_THREADS: OnceLock<ThreadPoolNoAbort> = OnceLock::new();
REQUEST_THREADS.get_or_init(|| {
ThreadPoolNoAbortBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}"))
.build()
.unwrap()
})
}

View File

@ -149,6 +149,7 @@ pub fn merge_grenad_entries(
} }
} }
MergerOperation::DeleteDocument { docid, external_id } => { MergerOperation::DeleteDocument { docid, external_id } => {
/// TODO: delete vectors
let span = let span =
tracing::trace_span!(target: "indexing::documents::merge", "delete_document"); tracing::trace_span!(target: "indexing::documents::merge", "delete_document");
let _entered = span.enter(); let _entered = span.enter();