diff --git a/milli/src/vector/parsed_vectors.rs b/milli/src/vector/parsed_vectors.rs index 8e5ccf690..526516fef 100644 --- a/milli/src/vector/parsed_vectors.rs +++ b/milli/src/vector/parsed_vectors.rs @@ -2,6 +2,7 @@ use std::collections::{BTreeMap, BTreeSet}; use deserr::{take_cf_content, DeserializeError, Deserr, Sequence}; use obkv::KvReader; +use serde_json::value::RawValue; use serde_json::{from_slice, Value}; use super::Embedding; @@ -11,6 +12,13 @@ use crate::{DocumentId, FieldId, InternalError, UserError}; pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors"; +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(untagged)] +pub enum RawVectors<'doc> { + Explicit(#[serde(borrow)] RawExplicitVectors<'doc>), + ImplicitlyUserProvided(#[serde(borrow)] &'doc RawValue), +} + #[derive(serde::Serialize, Debug)] #[serde(untagged)] pub enum Vectors { @@ -69,6 +77,22 @@ impl Vectors { } } +impl<'doc> RawVectors<'doc> { + pub fn must_regenerate(&self) -> bool { + match self { + RawVectors::ImplicitlyUserProvided(_) => false, + RawVectors::Explicit(RawExplicitVectors { regenerate, .. }) => *regenerate, + } + } + + pub fn embeddings(&self) -> Option<&'doc RawValue> { + match self { + RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings), + RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings, + } + } +} + #[derive(serde::Serialize, Deserr, Debug)] #[serde(rename_all = "camelCase")] pub struct ExplicitVectors { @@ -78,6 +102,15 @@ pub struct ExplicitVectors { pub regenerate: bool, } +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct RawExplicitVectors<'doc> { + #[serde(borrow)] + #[serde(default)] + pub embeddings: Option<&'doc RawValue>, + pub regenerate: bool, +} + pub enum VectorState { Inline(Vectors), Manual,