Fix vector parsing

This commit is contained in:
Louis Dureuil 2024-11-07 22:35:06 +01:00
parent d97af4d8e6
commit 4706a0eb49
No known key found for this signature in database
2 changed files with 90 additions and 5 deletions

View File

@ -167,7 +167,7 @@ fn entry_from_raw_value(
value: &RawValue, value: &RawValue,
has_configured_embedder: bool, has_configured_embedder: bool,
) -> std::result::Result<VectorEntry<'_>, serde_json::Error> { ) -> std::result::Result<VectorEntry<'_>, serde_json::Error> {
let value: RawVectors = serde_json::from_str(value.get())?; let value: RawVectors = RawVectors::from_raw_value(value)?;
Ok(match value { Ok(match value {
RawVectors::Explicit(raw_explicit_vectors) => VectorEntry { RawVectors::Explicit(raw_explicit_vectors) => VectorEntry {
@ -177,7 +177,7 @@ fn entry_from_raw_value(
}, },
RawVectors::ImplicitlyUserProvided(value) => VectorEntry { RawVectors::ImplicitlyUserProvided(value) => VectorEntry {
has_configured_embedder, has_configured_embedder,
embeddings: Some(Embeddings::FromJsonImplicityUserProvided(value)), embeddings: value.map(Embeddings::FromJsonImplicityUserProvided),
regenerate: false, regenerate: false,
}, },
}) })

View File

@ -12,11 +12,96 @@ use crate::{DocumentId, FieldId, InternalError, UserError};
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors"; pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
#[derive(serde::Serialize, serde::Deserialize, Debug)] #[derive(serde::Serialize, Debug)]
#[serde(untagged)] #[serde(untagged)]
pub enum RawVectors<'doc> { pub enum RawVectors<'doc> {
Explicit(#[serde(borrow)] RawExplicitVectors<'doc>), Explicit(#[serde(borrow)] RawExplicitVectors<'doc>),
ImplicitlyUserProvided(#[serde(borrow)] &'doc RawValue), ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>),
}
impl<'doc> RawVectors<'doc> {
pub fn from_raw_value(raw: &'doc RawValue) -> Result<Self, serde_json::Error> {
use serde::de::Deserializer as _;
Ok(match raw.deserialize_any(RawVectorsVisitor)? {
RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None),
RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)),
RawVectorsVisitorValue::Explicit { regenerate, embeddings } => {
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate })
}
})
}
}
struct RawVectorsVisitor;
enum RawVectorsVisitorValue<'doc> {
ImplicitNone,
Implicit,
Explicit { regenerate: bool, embeddings: Option<&'doc RawValue> },
}
impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
type Value = RawVectorsVisitorValue<'doc>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a map containing at least `regenerate`, or an array of floats`")
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(RawVectorsVisitorValue::ImplicitNone)
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'doc>,
{
deserializer.deserialize_any(self)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(RawVectorsVisitorValue::ImplicitNone)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'doc>,
{
// must consume all elements or parsing fails
while let Some(_) = seq.next_element::<&RawValue>()? {}
Ok(RawVectorsVisitorValue::Implicit)
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'doc>,
{
use serde::de::Error as _;
let mut regenerate = None;
let mut embeddings = None;
while let Some(s) = map.next_key()? {
match s {
"regenerate" => {
let value: bool = map.next_value()?;
regenerate = Some(value);
}
"embeddings" => {
let value: &RawValue = map.next_value()?;
embeddings = Some(value);
}
other => return Err(A::Error::unknown_field(other, &["regenerate", "embeddings"])),
}
}
let Some(regenerate) = regenerate else {
return Err(A::Error::missing_field("regenerate"));
};
Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings })
}
} }
#[derive(serde::Serialize, Debug)] #[derive(serde::Serialize, Debug)]
@ -86,7 +171,7 @@ impl<'doc> RawVectors<'doc> {
} }
pub fn embeddings(&self) -> Option<&'doc RawValue> { pub fn embeddings(&self) -> Option<&'doc RawValue> {
match self { match self {
RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings), RawVectors::ImplicitlyUserProvided(embeddings) => *embeddings,
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings, RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings,
} }
} }