From 7058959a4644a6ea49c482277f3bfa37f3784c71 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 28 Oct 2024 16:18:48 +0100 Subject: [PATCH] Write into documents --- milli/src/update/new/document.rs | 55 +++++++------ milli/src/update/new/document_change.rs | 22 ++++- .../new/extract/faceted/extract_facets.rs | 4 +- .../extract/searchable/extract_word_docids.rs | 4 +- .../extract_word_pair_proximity_docids.rs | 4 +- milli/src/update/new/extract/vectors/mod.rs | 12 ++- milli/src/update/new/indexer/mod.rs | 32 +++++--- milli/src/update/new/vector_document.rs | 81 ++++++++++++++++--- 8 files changed, 154 insertions(+), 60 deletions(-) diff --git a/milli/src/update/new/document.rs b/milli/src/update/new/document.rs index be09feb5a..0a5172d36 100644 --- a/milli/src/update/new/document.rs +++ b/milli/src/update/new/document.rs @@ -1,13 +1,14 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use heed::RoTxn; use raw_collections::RawMap; use serde_json::value::RawValue; +use super::vector_document::{VectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions}; use super::{KvReaderFieldId, KvWriterFieldId}; use crate::documents::FieldIdMapper; use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; -use crate::{DocumentId, Index, InternalError, Result}; +use crate::{DocumentId, GlobalFieldsIdsMap, Index, InternalError, Result, UserError}; /// A view into a document that can represent either the current version from the DB, /// the update data from payload or other means, or the merged updated version. @@ -69,17 +70,22 @@ impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> { std::iter::from_fn(move || { let (fid, value) = it.next()?; - let res = (|| { - let value = - serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?; - + let res = (|| loop { let name = self.fields_ids_map.name(fid).ok_or( InternalError::FieldIdMapMissingEntry(crate::FieldIdMapMissingEntry::FieldId { field_id: fid, process: "getting current document", }), )?; - Ok((name, value)) + + if name == RESERVED_VECTORS_FIELD_NAME || name == "_geo" { + continue; + } + + let value = + serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?; + + return Ok((name, value)); })(); Some(res) @@ -164,13 +170,6 @@ pub struct MergedDocument<'a, 'doc, 't, Mapper: FieldIdMapper> { } impl<'a, 'doc, 't, Mapper: FieldIdMapper> MergedDocument<'a, 'doc, 't, Mapper> { - pub fn new( - new_doc: DocumentFromVersions<'a, 'doc>, - db: Option>, - ) -> Self { - Self { new_doc, db } - } - pub fn with_db( docid: DocumentId, rtxn: &'t RoTxn, @@ -287,15 +286,14 @@ where /// /// - If the document contains a top-level field that is not present in `fields_ids_map`. /// -pub fn write_to_obkv<'s, 'a, 'b>( +pub fn write_to_obkv<'s, 'a, 'map>( document: &'s impl Document<'s>, - vector_document: Option<()>, - fields_ids_map: &'a impl FieldIdMapper, + vector_document: Option<&'s impl VectorDocument<'s>>, + fields_ids_map: &'a mut GlobalFieldsIdsMap<'map>, mut document_buffer: &'a mut Vec, ) -> Result<&'a KvReaderFieldId> where 's: 'a, - 's: 'b, { // will be used in 'inject_vectors let vectors_value: Box; @@ -308,19 +306,21 @@ where for res in document.iter_top_level_fields() { let (field_name, value) = res?; - let field_id = fields_ids_map.id(field_name).unwrap(); + let field_id = + fields_ids_map.id_or_insert(field_name).ok_or(UserError::AttributeLimitReached)?; unordered_field_buffer.push((field_id, value)); } 'inject_vectors: { let Some(vector_document) = vector_document else { break 'inject_vectors }; - let Some(vectors_fid) = fields_ids_map.id(RESERVED_VECTORS_FIELD_NAME) else { - break 'inject_vectors; - }; - /* + let vectors_fid = fields_ids_map + .id_or_insert(RESERVED_VECTORS_FIELD_NAME) + .ok_or(UserError::AttributeLimitReached)?; + let mut vectors = BTreeMap::new(); - for (name, entry) in vector_document.iter_vectors() { + for res in vector_document.iter_vectors() { + let (name, entry) = res?; if entry.has_configured_embedder { continue; // we don't write vectors with configured embedder in documents } @@ -335,7 +335,7 @@ where } vectors_value = serde_json::value::to_raw_value(&vectors).unwrap(); - unordered_field_buffer.push((vectors_fid, &vectors_value));*/ + unordered_field_buffer.push((vectors_fid, &vectors_value)); } unordered_field_buffer.sort_by_key(|(fid, _)| *fid); @@ -373,9 +373,8 @@ impl<'doc> Versions<'doc> { Self { data: version } } - pub fn iter_top_level_fields(&self) -> raw_collections::map::iter::Iter<'doc, '_> { - /// FIXME: ignore vectors and _geo - self.data.iter() + pub fn iter_top_level_fields(&self) -> impl Iterator + '_ { + self.data.iter().filter(|(k, _)| *k != RESERVED_VECTORS_FIELD_NAME && *k != "_geo") } pub fn vectors_field(&self) -> Option<&'doc RawValue> { diff --git a/milli/src/update/new/document_change.rs b/milli/src/update/new/document_change.rs index c55113b74..bb1fc9441 100644 --- a/milli/src/update/new/document_change.rs +++ b/milli/src/update/new/document_change.rs @@ -2,7 +2,9 @@ use bumpalo::Bump; use heed::RoTxn; use super::document::{DocumentFromDb, DocumentFromVersions, MergedDocument, Versions}; -use super::vector_document::{VectorDocumentFromDb, VectorDocumentFromVersions}; +use super::vector_document::{ + MergedVectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions, +}; use crate::documents::FieldIdMapper; use crate::{DocumentId, Index, Result}; @@ -85,7 +87,7 @@ impl<'doc> Insertion<'doc> { pub fn external_document_id(&self) -> &'doc str { self.external_document_id } - pub fn new(&self) -> DocumentFromVersions<'_, 'doc> { + pub fn inserted(&self) -> DocumentFromVersions<'_, 'doc> { DocumentFromVersions::new(&self.new) } @@ -141,7 +143,7 @@ impl<'doc> Update<'doc> { DocumentFromVersions::new(&self.new) } - pub fn new<'t, Mapper: FieldIdMapper>( + pub fn merged<'t, Mapper: FieldIdMapper>( &self, rtxn: &'t RoTxn, index: &'t Index, @@ -166,4 +168,18 @@ impl<'doc> Update<'doc> { ) -> Result>> { VectorDocumentFromVersions::new(&self.new, doc_alloc) } + + pub fn merged_vectors( + &self, + rtxn: &'doc RoTxn, + index: &'doc Index, + mapper: &'doc Mapper, + doc_alloc: &'doc Bump, + ) -> Result>> { + if self.has_deletion { + MergedVectorDocument::without_db(&self.new, doc_alloc) + } else { + MergedVectorDocument::with_db(self.docid, index, rtxn, mapper, &self.new, doc_alloc) + } + } } diff --git a/milli/src/update/new/extract/faceted/extract_facets.rs b/milli/src/update/new/extract/faceted/extract_facets.rs index 9fae1839e..f2cbad6ff 100644 --- a/milli/src/update/new/extract/faceted/extract_facets.rs +++ b/milli/src/update/new/extract/faceted/extract_facets.rs @@ -120,7 +120,7 @@ impl FacetedDocidsExtractor { extract_document_facets( attributes_to_extract, - inner.new(rtxn, index, context.db_fields_ids_map)?, + inner.merged(rtxn, index, context.db_fields_ids_map)?, new_fields_ids_map.deref_mut(), &mut |fid, value| { Self::facet_fn_with_options( @@ -136,7 +136,7 @@ impl FacetedDocidsExtractor { } DocumentChange::Insertion(inner) => extract_document_facets( attributes_to_extract, - inner.new(), + inner.inserted(), new_fields_ids_map.deref_mut(), &mut |fid, value| { Self::facet_fn_with_options( diff --git a/milli/src/update/new/extract/searchable/extract_word_docids.rs b/milli/src/update/new/extract/searchable/extract_word_docids.rs index 5eb9692d6..80f36b01d 100644 --- a/milli/src/update/new/extract/searchable/extract_word_docids.rs +++ b/milli/src/update/new/extract/searchable/extract_word_docids.rs @@ -481,7 +481,7 @@ impl WordDocidsExtractors { .map_err(crate::Error::from) }; document_tokenizer.tokenize_document( - inner.new(rtxn, index, context.db_fields_ids_map)?, + inner.merged(rtxn, index, context.db_fields_ids_map)?, new_fields_ids_map, &mut token_fn, )?; @@ -500,7 +500,7 @@ impl WordDocidsExtractors { .map_err(crate::Error::from) }; document_tokenizer.tokenize_document( - inner.new(), + inner.inserted(), new_fields_ids_map, &mut token_fn, )?; diff --git a/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs b/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs index 53e6515a9..1bd3aee36 100644 --- a/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs +++ b/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs @@ -80,7 +80,7 @@ impl SearchableExtractor for WordPairProximityDocidsExtractor { del_word_pair_proximity.push(((w1, w2), prox)); }, )?; - let document = inner.new(rtxn, index, context.db_fields_ids_map)?; + let document = inner.merged(rtxn, index, context.db_fields_ids_map)?; process_document_tokens( document, document_tokenizer, @@ -92,7 +92,7 @@ impl SearchableExtractor for WordPairProximityDocidsExtractor { )?; } DocumentChange::Insertion(inner) => { - let document = inner.new(); + let document = inner.inserted(); process_document_tokens( document, document_tokenizer, diff --git a/milli/src/update/new/extract/vectors/mod.rs b/milli/src/update/new/extract/vectors/mod.rs index 87b126207..a2762ae7a 100644 --- a/milli/src/update/new/extract/vectors/mod.rs +++ b/milli/src/update/new/extract/vectors/mod.rs @@ -100,7 +100,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { &context.doc_alloc, )?; let old_rendered = prompt.render_document( - update.new( + update.merged( &context.txn, context.index, context.db_fields_ids_map, @@ -123,7 +123,11 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { &context.doc_alloc, )?; let new_rendered = prompt.render_document( - update.new(&context.txn, context.index, context.db_fields_ids_map)?, + update.merged( + &context.txn, + context.index, + context.db_fields_ids_map, + )?, context.new_fields_ids_map, &context.doc_alloc, )?; @@ -156,7 +160,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { .unwrap(); } else if new_vectors.regenerate { let rendered = prompt.render_document( - insertion.new(), + insertion.inserted(), context.new_fields_ids_map, &context.doc_alloc, )?; @@ -164,7 +168,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { } } else { let rendered = prompt.render_document( - insertion.new(), + insertion.inserted(), context.new_fields_ids_map, &context.doc_alloc, )?; diff --git a/milli/src/update/new/indexer/mod.rs b/milli/src/update/new/indexer/mod.rs index dd2506ef9..b316cbc34 100644 --- a/milli/src/update/new/indexer/mod.rs +++ b/milli/src/update/new/indexer/mod.rs @@ -64,9 +64,7 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentExtractor<'a> { ) -> Result<()> { let mut document_buffer = Vec::new(); - let new_fields_ids_map = context.new_fields_ids_map.borrow_or_yield(); - let new_fields_ids_map = &*new_fields_ids_map; - let new_fields_ids_map = new_fields_ids_map.local_map(); + let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); for change in changes { let change = change?; @@ -78,20 +76,34 @@ impl<'a, 'extractor> Extractor<'extractor> for DocumentExtractor<'a> { let docid = deletion.docid(); self.document_sender.delete(docid, external_docid).unwrap(); } - /// TODO: change NONE by SOME(vector) when implemented DocumentChange::Update(update) => { let docid = update.docid(); let content = - update.new(&context.txn, context.index, &context.db_fields_ids_map)?; - let content = - write_to_obkv(&content, None, new_fields_ids_map, &mut document_buffer)?; + update.merged(&context.txn, context.index, &context.db_fields_ids_map)?; + let vector_content = update.merged_vectors( + &context.txn, + context.index, + &context.db_fields_ids_map, + &context.doc_alloc, + )?; + let content = write_to_obkv( + &content, + vector_content.as_ref(), + &mut new_fields_ids_map, + &mut document_buffer, + )?; self.document_sender.insert(docid, external_docid, content.boxed()).unwrap(); } DocumentChange::Insertion(insertion) => { let docid = insertion.docid(); - let content = insertion.new(); - let content = - write_to_obkv(&content, None, new_fields_ids_map, &mut document_buffer)?; + let content = insertion.inserted(); + let inserted_vectors = insertion.inserted_vectors(&context.doc_alloc)?; + let content = write_to_obkv( + &content, + inserted_vectors.as_ref(), + &mut new_fields_ids_map, + &mut document_buffer, + )?; self.document_sender.insert(docid, external_docid, content.boxed()).unwrap(); // extracted_dictionary_sender.send(self, dictionary: &[u8]); } diff --git a/milli/src/update/new/vector_document.rs b/milli/src/update/new/vector_document.rs index 782076716..a5519a025 100644 --- a/milli/src/update/new/vector_document.rs +++ b/milli/src/update/new/vector_document.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeSet; + use bumpalo::Bump; use heed::RoTxn; use raw_collections::RawMap; @@ -106,14 +108,9 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { let config_name = self.doc_alloc.alloc_str(config.name.as_str()); Ok((&*config_name, entry)) }) - .chain(self.vectors_field.iter().map(|map| map.iter()).flatten().map( - |(name, value)| { - Ok(( - name.as_ref(), - entry_from_raw_value(value).map_err(InternalError::SerdeJson)?, - )) - }, - )) + .chain(self.vectors_field.iter().flat_map(|map| map.iter()).map(|(name, value)| { + Ok((name, entry_from_raw_value(value).map_err(InternalError::SerdeJson)?)) + })) } fn vectors_for_key(&self, key: &str) -> Result>> { @@ -139,7 +136,7 @@ fn entry_from_raw_value( let value: RawVectors = serde_json::from_str(value.get())?; Ok(VectorEntry { has_configured_embedder: false, - embeddings: value.embeddings().map(|embeddings| Embeddings::FromJson(embeddings)), + embeddings: value.embeddings().map(Embeddings::FromJson), regenerate: value.must_regenerate(), }) } @@ -175,3 +172,69 @@ impl<'doc> VectorDocument<'doc> for VectorDocumentFromVersions<'doc> { Ok(Some(vectors)) } } + +pub struct MergedVectorDocument<'doc> { + new_doc: Option>, + db: Option>, +} + +impl<'doc> MergedVectorDocument<'doc> { + pub fn with_db( + docid: DocumentId, + index: &'doc Index, + rtxn: &'doc RoTxn, + db_fields_ids_map: &'doc Mapper, + versions: &Versions<'doc>, + doc_alloc: &'doc Bump, + ) -> Result> { + let db = VectorDocumentFromDb::new(docid, index, rtxn, db_fields_ids_map, doc_alloc)?; + let new_doc = VectorDocumentFromVersions::new(versions, doc_alloc)?; + Ok(if db.is_none() && new_doc.is_none() { None } else { Some(Self { new_doc, db }) }) + } + + pub fn without_db(versions: &Versions<'doc>, doc_alloc: &'doc Bump) -> Result> { + let Some(new_doc) = VectorDocumentFromVersions::new(versions, doc_alloc)? else { + return Ok(None); + }; + Ok(Some(Self { new_doc: Some(new_doc), db: None })) + } +} + +impl<'doc> VectorDocument<'doc> for MergedVectorDocument<'doc> { + fn iter_vectors(&self) -> impl Iterator)>> { + let mut new_doc_it = self.new_doc.iter().flat_map(|new_doc| new_doc.iter_vectors()); + let mut db_it = self.db.iter().flat_map(|db| db.iter_vectors()); + let mut seen_fields = BTreeSet::new(); + + std::iter::from_fn(move || { + if let Some(next) = new_doc_it.next() { + if let Ok((name, _)) = next { + seen_fields.insert(name); + } + return Some(next); + } + loop { + match db_it.next()? { + Ok((name, value)) => { + if seen_fields.contains(name) { + continue; + } + return Some(Ok((name, value))); + } + Err(err) => return Some(Err(err)), + } + } + }) + } + + fn vectors_for_key(&self, key: &str) -> Result>> { + if let Some(new_doc) = &self.new_doc { + if let Some(entry) = new_doc.vectors_for_key(key)? { + return Ok(Some(entry)); + } + } + + let Some(db) = self.db.as_ref() else { return Ok(None) }; + db.vectors_for_key(key) + } +}