diff --git a/milli/src/update/new/mod.rs b/milli/src/update/new/mod.rs index 37ccc75cd..6b59b5b59 100644 --- a/milli/src/update/new/mod.rs +++ b/milli/src/update/new/mod.rs @@ -12,6 +12,7 @@ pub mod indexer; mod merger; mod parallel_iterator_ext; mod top_level_map; +pub mod vector_document; mod word_fst_builder; mod words_prefix_docids; diff --git a/milli/src/update/new/vector_document.rs b/milli/src/update/new/vector_document.rs new file mode 100644 index 000000000..375d4f2ce --- /dev/null +++ b/milli/src/update/new/vector_document.rs @@ -0,0 +1,134 @@ +use bumpalo::Bump; +use heed::RoTxn; +use raw_collections::RawMap; +use serde::Serialize; +use serde_json::value::RawValue; + +use super::document::{Document, DocumentFromDb}; +use crate::documents::FieldIdMapper; +use crate::index::IndexEmbeddingConfig; +use crate::vector::parsed_vectors::RawVectors; +use crate::vector::Embedding; +use crate::{DocumentId, Index, InternalError, Result}; + +#[derive(Serialize)] +#[serde(untagged)] +pub enum Embeddings<'doc> { + FromJson(&'doc RawValue), + FromDb(Vec), +} + +pub struct VectorEntry<'doc> { + pub has_configured_embedder: bool, + pub embeddings: Option>, + pub regenerate: bool, +} + +pub trait VectorDocument<'doc> { + fn iter_vectors(&self) -> impl Iterator)>>; + + fn vectors_for_key(&self, key: &str) -> Result>>; +} + +pub struct VectorDocumentFromDb<'t> { + docid: DocumentId, + embedding_config: Vec, + index: &'t Index, + vectors_field: Option>, + rtxn: &'t RoTxn<'t>, + doc_alloc: &'t Bump, +} + +impl<'t> VectorDocumentFromDb<'t> { + pub fn new( + docid: DocumentId, + index: &'t Index, + rtxn: &'t RoTxn, + db_fields_ids_map: &'t Mapper, + doc_alloc: &'t Bump, + ) -> Result { + let document = DocumentFromDb::new(docid, rtxn, index, db_fields_ids_map)?.unwrap(); + let vectors = document.vectors_field()?; + let vectors_field = match vectors { + Some(vectors) => { + Some(RawMap::from_raw_value(vectors, doc_alloc).map_err(InternalError::SerdeJson)?) + } + None => None, + }; + + let embedding_config = index.embedding_configs(rtxn)?; + + Ok(Self { docid, embedding_config, index, vectors_field, rtxn, doc_alloc }) + } + + fn entry_from_db( + &self, + embedder_id: u8, + config: &IndexEmbeddingConfig, + ) -> Result> { + let readers = self.index.arroy_readers(self.rtxn, embedder_id, config.config.quantized()); + let mut vectors = Vec::new(); + for reader in readers { + let reader = reader?; + let Some(vector) = reader.item_vector(self.rtxn, self.docid)? else { + break; + }; + + vectors.push(vector); + } + Ok(VectorEntry { + has_configured_embedder: true, + embeddings: Some(Embeddings::FromDb(vectors)), + regenerate: !config.user_provided.contains(self.docid), + }) + } +} + +impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { + fn iter_vectors(&self) -> impl Iterator)>> { + self.embedding_config + .iter() + .map(|config| { + let embedder_id = + self.index.embedder_category_id.get(self.rtxn, &config.name)?.unwrap(); + let entry = self.entry_from_db(embedder_id, config)?; + let config_name = self.doc_alloc.alloc_str(config.name.as_str()); + Ok((&*config_name, entry)) + }) + .chain(self.vectors_field.iter().map(|map| map.iter()).flatten().map( + |(name, value)| { + Ok(( + name.as_ref(), + entry_from_raw_value(value).map_err(InternalError::SerdeJson)?, + )) + }, + )) + } + + fn vectors_for_key(&self, key: &str) -> Result>> { + Ok(match self.index.embedder_category_id.get(self.rtxn, key)? { + Some(embedder_id) => { + let config = + self.embedding_config.iter().find(|config| config.name == key).unwrap(); + Some(self.entry_from_db(embedder_id, config)?) + } + None => match self.vectors_field.as_ref().and_then(|obkv| obkv.get(key)) { + Some(embedding_from_doc) => Some( + entry_from_raw_value(embedding_from_doc).map_err(InternalError::SerdeJson)?, + ), + None => None, + }, + }) + } +} + +fn entry_from_raw_value( + value: &RawValue, +) -> std::result::Result, serde_json::Error> { + let value: RawVectors = serde_json::from_str(value.get())?; + Ok(VectorEntry { + has_configured_embedder: false, + embeddings: value.embeddings().map(|embeddings| Embeddings::FromJson(embeddings)), + regenerate: value.must_regenerate(), + }) +}