From 4706a0eb49e4e1861f3ed37773b17abf33c93067 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 7 Nov 2024 22:35:06 +0100 Subject: [PATCH] Fix vector parsing --- .../milli/src/update/new/vector_document.rs | 4 +- crates/milli/src/vector/parsed_vectors.rs | 91 ++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/crates/milli/src/update/new/vector_document.rs b/crates/milli/src/update/new/vector_document.rs index dc73c5268..4a27361a9 100644 --- a/crates/milli/src/update/new/vector_document.rs +++ b/crates/milli/src/update/new/vector_document.rs @@ -167,7 +167,7 @@ fn entry_from_raw_value( value: &RawValue, has_configured_embedder: bool, ) -> std::result::Result, serde_json::Error> { - let value: RawVectors = serde_json::from_str(value.get())?; + let value: RawVectors = RawVectors::from_raw_value(value)?; Ok(match value { RawVectors::Explicit(raw_explicit_vectors) => VectorEntry { @@ -177,7 +177,7 @@ fn entry_from_raw_value( }, RawVectors::ImplicitlyUserProvided(value) => VectorEntry { has_configured_embedder, - embeddings: Some(Embeddings::FromJsonImplicityUserProvided(value)), + embeddings: value.map(Embeddings::FromJsonImplicityUserProvided), regenerate: false, }, }) diff --git a/crates/milli/src/vector/parsed_vectors.rs b/crates/milli/src/vector/parsed_vectors.rs index 40e823f17..6ae6c1c9e 100644 --- a/crates/milli/src/vector/parsed_vectors.rs +++ b/crates/milli/src/vector/parsed_vectors.rs @@ -12,11 +12,96 @@ use crate::{DocumentId, FieldId, InternalError, UserError}; pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors"; -#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[derive(serde::Serialize, Debug)] #[serde(untagged)] pub enum RawVectors<'doc> { Explicit(#[serde(borrow)] RawExplicitVectors<'doc>), - ImplicitlyUserProvided(#[serde(borrow)] &'doc RawValue), + ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>), +} + +impl<'doc> RawVectors<'doc> { + pub fn from_raw_value(raw: &'doc RawValue) -> Result { + use serde::de::Deserializer as _; + Ok(match raw.deserialize_any(RawVectorsVisitor)? { + RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None), + RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)), + RawVectorsVisitorValue::Explicit { regenerate, embeddings } => { + RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate }) + } + }) + } +} + +struct RawVectorsVisitor; + +enum RawVectorsVisitorValue<'doc> { + ImplicitNone, + Implicit, + Explicit { regenerate: bool, embeddings: Option<&'doc RawValue> }, +} + +impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor { + type Value = RawVectorsVisitorValue<'doc>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a map containing at least `regenerate`, or an array of floats`") + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(RawVectorsVisitorValue::ImplicitNone) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'doc>, + { + deserializer.deserialize_any(self) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(RawVectorsVisitorValue::ImplicitNone) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'doc>, + { + // must consume all elements or parsing fails + while let Some(_) = seq.next_element::<&RawValue>()? {} + Ok(RawVectorsVisitorValue::Implicit) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'doc>, + { + use serde::de::Error as _; + let mut regenerate = None; + let mut embeddings = None; + while let Some(s) = map.next_key()? { + match s { + "regenerate" => { + let value: bool = map.next_value()?; + regenerate = Some(value); + } + "embeddings" => { + let value: &RawValue = map.next_value()?; + embeddings = Some(value); + } + other => return Err(A::Error::unknown_field(other, &["regenerate", "embeddings"])), + } + } + let Some(regenerate) = regenerate else { + return Err(A::Error::missing_field("regenerate")); + }; + Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings }) + } } #[derive(serde::Serialize, Debug)] @@ -86,7 +171,7 @@ impl<'doc> RawVectors<'doc> { } pub fn embeddings(&self) -> Option<&'doc RawValue> { match self { - RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings), + RawVectors::ImplicitlyUserProvided(embeddings) => *embeddings, RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings, } }