From c9082130c85252da6ff08eedd2efeb158a00a19f Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 30 Oct 2024 13:50:51 +0100 Subject: [PATCH] support vectors or array of vectors --- milli/src/update/new/document_change.rs | 14 +- milli/src/update/new/extract/vectors/mod.rs | 21 +- milli/src/update/new/indexer/de.rs | 291 ++++++++++++++++++++ milli/src/update/new/vector_document.rs | 98 +++++-- milli/src/vector/mod.rs | 4 + milli/src/vector/parsed_vectors.rs | 1 - 6 files changed, 401 insertions(+), 28 deletions(-) diff --git a/milli/src/update/new/document_change.rs b/milli/src/update/new/document_change.rs index bb1fc9441..4a61c110d 100644 --- a/milli/src/update/new/document_change.rs +++ b/milli/src/update/new/document_change.rs @@ -6,6 +6,7 @@ use super::vector_document::{ MergedVectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions, }; use crate::documents::FieldIdMapper; +use crate::vector::EmbeddingConfigs; use crate::{DocumentId, Index, Result}; pub enum DocumentChange<'doc> { @@ -94,8 +95,9 @@ impl<'doc> Insertion<'doc> { pub fn inserted_vectors( &self, doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, ) -> Result>> { - VectorDocumentFromVersions::new(&self.new, doc_alloc) + VectorDocumentFromVersions::new(&self.new, doc_alloc, embedders) } } @@ -165,8 +167,9 @@ impl<'doc> Update<'doc> { pub fn updated_vectors( &self, doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, ) -> Result>> { - VectorDocumentFromVersions::new(&self.new, doc_alloc) + VectorDocumentFromVersions::new(&self.new, doc_alloc, embedders) } pub fn merged_vectors( @@ -175,11 +178,14 @@ impl<'doc> Update<'doc> { index: &'doc Index, mapper: &'doc Mapper, doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, ) -> Result>> { if self.has_deletion { - MergedVectorDocument::without_db(&self.new, doc_alloc) + MergedVectorDocument::without_db(&self.new, doc_alloc, embedders) } else { - MergedVectorDocument::with_db(self.docid, index, rtxn, mapper, &self.new, doc_alloc) + MergedVectorDocument::with_db( + self.docid, index, rtxn, mapper, &self.new, doc_alloc, embedders, + ) } } } diff --git a/milli/src/update/new/extract/vectors/mod.rs b/milli/src/update/new/extract/vectors/mod.rs index 92c355710..70bd4d42d 100644 --- a/milli/src/update/new/extract/vectors/mod.rs +++ b/milli/src/update/new/extract/vectors/mod.rs @@ -93,7 +93,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { context.db_fields_ids_map, &context.doc_alloc, )?; - let new_vectors = update.updated_vectors(&context.doc_alloc)?; + let new_vectors = update.updated_vectors(&context.doc_alloc, self.embedders)?; if let Some(new_vectors) = &new_vectors { unused_vectors_distribution.append(new_vectors); @@ -118,7 +118,12 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { if let Some(embeddings) = new_vectors.embeddings { chunks.set_vectors( update.docid(), - embeddings.into_vec().map_err(UserError::SerdeJson)?, + embeddings + .into_vec(&context.doc_alloc, embedder_name) + .map_err(|error| UserError::InvalidVectorsEmbedderConf { + document_id: update.external_document_id().to_string(), + error, + })?, ); } else if new_vectors.regenerate { let new_rendered = prompt.render_document( @@ -177,7 +182,8 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { } } DocumentChange::Insertion(insertion) => { - let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?; + let new_vectors = + insertion.inserted_vectors(&context.doc_alloc, self.embedders)?; if let Some(new_vectors) = &new_vectors { unused_vectors_distribution.append(new_vectors); } @@ -194,7 +200,14 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { if let Some(embeddings) = new_vectors.embeddings { chunks.set_vectors( insertion.docid(), - embeddings.into_vec().map_err(UserError::SerdeJson)?, + embeddings + .into_vec(&context.doc_alloc, embedder_name) + .map_err(|error| UserError::InvalidVectorsEmbedderConf { + document_id: insertion + .external_document_id() + .to_string(), + error, + })?, ); } else if new_vectors.regenerate { let rendered = prompt.render_document( diff --git a/milli/src/update/new/indexer/de.rs b/milli/src/update/new/indexer/de.rs index 3da4fc239..94ab4c2c1 100644 --- a/milli/src/update/new/indexer/de.rs +++ b/milli/src/update/new/indexer/de.rs @@ -326,3 +326,294 @@ pub fn match_component<'de, 'indexer: 'de>( } ControlFlow::Continue(()) } + +pub struct DeserrRawValue<'a> { + value: &'a RawValue, + alloc: &'a Bump, +} + +impl<'a> DeserrRawValue<'a> { + pub fn new_in(value: &'a RawValue, alloc: &'a Bump) -> Self { + Self { value, alloc } + } +} + +pub struct DeserrRawVec<'a> { + vec: raw_collections::RawVec<'a>, + alloc: &'a Bump, +} + +impl<'a> deserr::Sequence for DeserrRawVec<'a> { + type Value = DeserrRawValue<'a>; + + type Iter = DeserrRawVecIter<'a>; + + fn len(&self) -> usize { + self.vec.len() + } + + fn into_iter(self) -> Self::Iter { + DeserrRawVecIter { it: self.vec.into_iter(), alloc: self.alloc } + } +} + +pub struct DeserrRawVecIter<'a> { + it: raw_collections::vec::iter::IntoIter<'a>, + alloc: &'a Bump, +} + +impl<'a> Iterator for DeserrRawVecIter<'a> { + type Item = DeserrRawValue<'a>; + + fn next(&mut self) -> Option { + let next = self.it.next()?; + Some(DeserrRawValue { value: next, alloc: self.alloc }) + } +} + +pub struct DeserrRawMap<'a> { + map: raw_collections::RawMap<'a>, + alloc: &'a Bump, +} + +impl<'a> deserr::Map for DeserrRawMap<'a> { + type Value = DeserrRawValue<'a>; + + type Iter = DeserrRawMapIter<'a>; + + fn len(&self) -> usize { + self.map.len() + } + + fn remove(&mut self, _key: &str) -> Option { + unimplemented!() + } + + fn into_iter(self) -> Self::Iter { + DeserrRawMapIter { it: self.map.into_iter(), alloc: self.alloc } + } +} + +pub struct DeserrRawMapIter<'a> { + it: raw_collections::map::iter::IntoIter<'a>, + alloc: &'a Bump, +} + +impl<'a> Iterator for DeserrRawMapIter<'a> { + type Item = (String, DeserrRawValue<'a>); + + fn next(&mut self) -> Option { + let (name, value) = self.it.next()?; + Some((name.to_string(), DeserrRawValue { value, alloc: self.alloc })) + } +} + +impl<'a> deserr::IntoValue for DeserrRawValue<'a> { + type Sequence = DeserrRawVec<'a>; + + type Map = DeserrRawMap<'a>; + + fn kind(&self) -> deserr::ValueKind { + self.value.deserialize_any(DeserrKindVisitor).unwrap() + } + + fn into_value(self) -> deserr::Value { + self.value.deserialize_any(DeserrRawValueVisitor { alloc: self.alloc }).unwrap() + } +} + +pub struct DeserrKindVisitor; + +impl<'de> Visitor<'de> for DeserrKindVisitor { + type Value = deserr::ValueKind; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "any value") + } + + fn visit_bool(self, _v: bool) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Boolean) + } + + fn visit_i64(self, _v: i64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::NegativeInteger) + } + + fn visit_u64(self, _v: u64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Integer) + } + + fn visit_f64(self, _v: f64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Float) + } + + fn visit_str(self, _v: &str) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::String) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Null) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Null) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_seq(self, _seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + Ok(deserr::ValueKind::Sequence) + } + + fn visit_map(self, _map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + Ok(deserr::ValueKind::Map) + } +} + +pub struct DeserrRawValueVisitor<'a> { + alloc: &'a Bump, +} + +impl<'de> Visitor<'de> for DeserrRawValueVisitor<'de> { + type Value = deserr::Value>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "any value") + } + + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Boolean(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::NegativeInteger(v)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Integer(v)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Float(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::String(v.to_string())) + } + + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::String(v)) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Null) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Null) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut raw_vec = raw_collections::RawVec::new_in(&self.alloc); + while let Some(next) = seq.next_element()? { + raw_vec.push(next); + } + Ok(deserr::Value::Sequence(DeserrRawVec { vec: raw_vec, alloc: self.alloc })) + } + + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let _ = map; + Err(serde::de::Error::invalid_type(serde::de::Unexpected::Map, &self)) + } + + fn visit_enum(self, data: A) -> Result + where + A: serde::de::EnumAccess<'de>, + { + let _ = data; + Err(serde::de::Error::invalid_type(serde::de::Unexpected::Enum, &self)) + } +} diff --git a/milli/src/update/new/vector_document.rs b/milli/src/update/new/vector_document.rs index a5519a025..6796134db 100644 --- a/milli/src/update/new/vector_document.rs +++ b/milli/src/update/new/vector_document.rs @@ -1,29 +1,67 @@ use std::collections::BTreeSet; use bumpalo::Bump; +use deserr::{Deserr, IntoValue}; use heed::RoTxn; use raw_collections::RawMap; use serde::Serialize; use serde_json::value::RawValue; use super::document::{Document, DocumentFromDb, DocumentFromVersions, Versions}; +use super::indexer::de::DeserrRawValue; use crate::documents::FieldIdMapper; use crate::index::IndexEmbeddingConfig; -use crate::vector::parsed_vectors::RawVectors; -use crate::vector::Embedding; +use crate::vector::parsed_vectors::{ + RawVectors, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME, +}; +use crate::vector::{Embedding, EmbeddingConfigs}; use crate::{DocumentId, Index, InternalError, Result, UserError}; #[derive(Serialize)] #[serde(untagged)] pub enum Embeddings<'doc> { - FromJson(&'doc RawValue), + FromJsonExplicit(&'doc RawValue), + FromJsonImplicityUserProvided(&'doc RawValue), FromDb(Vec), } impl<'doc> Embeddings<'doc> { - pub fn into_vec(self) -> std::result::Result, serde_json::Error> { + pub fn into_vec( + self, + doc_alloc: &'doc Bump, + embedder_name: &str, + ) -> std::result::Result, deserr::errors::JsonError> { match self { - /// FIXME: this should be a VecOrArrayOfVec - Embeddings::FromJson(value) => serde_json::from_str(value.get()), + Embeddings::FromJsonExplicit(value) => { + let vectors_ref = deserr::ValuePointerRef::Key { + key: RESERVED_VECTORS_FIELD_NAME, + prev: &deserr::ValuePointerRef::Origin, + }; + let embedders_ref = + deserr::ValuePointerRef::Key { key: embedder_name, prev: &vectors_ref }; + + let embeddings_ref = + deserr::ValuePointerRef::Key { key: "embeddings", prev: &embedders_ref }; + + let v: VectorOrArrayOfVectors = VectorOrArrayOfVectors::deserialize_from_value( + DeserrRawValue::new_in(value, doc_alloc).into_value(), + embeddings_ref, + )?; + Ok(v.into_array_of_vectors().unwrap_or_default()) + } + Embeddings::FromJsonImplicityUserProvided(value) => { + let vectors_ref = deserr::ValuePointerRef::Key { + key: RESERVED_VECTORS_FIELD_NAME, + prev: &deserr::ValuePointerRef::Origin, + }; + let embedders_ref = + deserr::ValuePointerRef::Key { key: embedder_name, prev: &vectors_ref }; + + let v: VectorOrArrayOfVectors = VectorOrArrayOfVectors::deserialize_from_value( + DeserrRawValue::new_in(value, doc_alloc).into_value(), + embedders_ref, + )?; + Ok(v.into_array_of_vectors().unwrap_or_default()) + } Embeddings::FromDb(vec) => Ok(vec), } } @@ -109,7 +147,7 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { Ok((&*config_name, entry)) }) .chain(self.vectors_field.iter().flat_map(|map| map.iter()).map(|(name, value)| { - Ok((name, entry_from_raw_value(value).map_err(InternalError::SerdeJson)?)) + Ok((name, entry_from_raw_value(value, false).map_err(InternalError::SerdeJson)?)) })) } @@ -122,7 +160,8 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { } None => match self.vectors_field.as_ref().and_then(|obkv| obkv.get(key)) { Some(embedding_from_doc) => Some( - entry_from_raw_value(embedding_from_doc).map_err(InternalError::SerdeJson)?, + entry_from_raw_value(embedding_from_doc, false) + .map_err(InternalError::SerdeJson)?, ), None => None, }, @@ -132,26 +171,40 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { fn entry_from_raw_value( value: &RawValue, + has_configured_embedder: bool, ) -> std::result::Result, serde_json::Error> { let value: RawVectors = serde_json::from_str(value.get())?; - Ok(VectorEntry { - has_configured_embedder: false, - embeddings: value.embeddings().map(Embeddings::FromJson), - regenerate: value.must_regenerate(), + + Ok(match value { + RawVectors::Explicit(raw_explicit_vectors) => VectorEntry { + has_configured_embedder, + embeddings: raw_explicit_vectors.embeddings.map(Embeddings::FromJsonExplicit), + regenerate: raw_explicit_vectors.regenerate, + }, + RawVectors::ImplicitlyUserProvided(value) => VectorEntry { + has_configured_embedder, + embeddings: Some(Embeddings::FromJsonImplicityUserProvided(value)), + regenerate: false, + }, }) } pub struct VectorDocumentFromVersions<'doc> { vectors: RawMap<'doc>, + embedders: &'doc EmbeddingConfigs, } impl<'doc> VectorDocumentFromVersions<'doc> { - pub fn new(versions: &Versions<'doc>, bump: &'doc Bump) -> Result> { + pub fn new( + versions: &Versions<'doc>, + bump: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result> { let document = DocumentFromVersions::new(versions); if let Some(vectors_field) = document.vectors_field()? { let vectors = RawMap::from_raw_value(vectors_field, bump).map_err(UserError::SerdeJson)?; - Ok(Some(Self { vectors })) + Ok(Some(Self { vectors, embedders })) } else { Ok(None) } @@ -161,14 +214,16 @@ impl<'doc> VectorDocumentFromVersions<'doc> { impl<'doc> VectorDocument<'doc> for VectorDocumentFromVersions<'doc> { fn iter_vectors(&self) -> impl Iterator)>> { self.vectors.iter().map(|(embedder, vectors)| { - let vectors = entry_from_raw_value(vectors).map_err(UserError::SerdeJson)?; + let vectors = entry_from_raw_value(vectors, self.embedders.contains(embedder)) + .map_err(UserError::SerdeJson)?; Ok((embedder, vectors)) }) } fn vectors_for_key(&self, key: &str) -> Result>> { let Some(vectors) = self.vectors.get(key) else { return Ok(None) }; - let vectors = entry_from_raw_value(vectors).map_err(UserError::SerdeJson)?; + let vectors = entry_from_raw_value(vectors, self.embedders.contains(key)) + .map_err(UserError::SerdeJson)?; Ok(Some(vectors)) } } @@ -186,14 +241,19 @@ impl<'doc> MergedVectorDocument<'doc> { db_fields_ids_map: &'doc Mapper, versions: &Versions<'doc>, doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, ) -> Result> { let db = VectorDocumentFromDb::new(docid, index, rtxn, db_fields_ids_map, doc_alloc)?; - let new_doc = VectorDocumentFromVersions::new(versions, doc_alloc)?; + let new_doc = VectorDocumentFromVersions::new(versions, doc_alloc, embedders)?; Ok(if db.is_none() && new_doc.is_none() { None } else { Some(Self { new_doc, db }) }) } - pub fn without_db(versions: &Versions<'doc>, doc_alloc: &'doc Bump) -> Result> { - let Some(new_doc) = VectorDocumentFromVersions::new(versions, doc_alloc)? else { + pub fn without_db( + versions: &Versions<'doc>, + doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result> { + let Some(new_doc) = VectorDocumentFromVersions::new(versions, doc_alloc, embedders)? else { return Ok(None); }; Ok(Some(Self { new_doc: Some(new_doc), db: None })) diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 2e9a498c0..a21e9e2ca 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -316,6 +316,10 @@ impl EmbeddingConfigs { Self(data) } + pub fn contains(&self, name: &str) -> bool { + self.0.contains_key(name) + } + /// Get an embedder configuration and template from its name. pub fn get(&self, name: &str) -> Option<(Arc, Arc, bool)> { self.0.get(name).cloned() diff --git a/milli/src/vector/parsed_vectors.rs b/milli/src/vector/parsed_vectors.rs index 526516fef..40e823f17 100644 --- a/milli/src/vector/parsed_vectors.rs +++ b/milli/src/vector/parsed_vectors.rs @@ -84,7 +84,6 @@ impl<'doc> RawVectors<'doc> { RawVectors::Explicit(RawExplicitVectors { regenerate, .. }) => *regenerate, } } - pub fn embeddings(&self) -> Option<&'doc RawValue> { match self { RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings),