diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 306c1c1e9..d3d05a1c1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -13,6 +13,7 @@ pub mod error; pub mod hf; pub mod manual; pub mod openai; +pub mod parsed_vectors; pub mod settings; pub mod ollama; diff --git a/milli/src/vector/parsed_vectors.rs b/milli/src/vector/parsed_vectors.rs new file mode 100644 index 000000000..bf4b9ea83 --- /dev/null +++ b/milli/src/vector/parsed_vectors.rs @@ -0,0 +1,149 @@ +use std::collections::BTreeMap; + +use obkv::KvReader; +use serde_json::{from_slice, Value}; + +use super::Embedding; +use crate::update::del_add::{DelAdd, KvReaderDelAdd}; +use crate::{FieldId, InternalError, UserError}; + +pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors"; + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(untagged)] +pub enum Vectors { + ImplicitlyUserProvided(VectorOrArrayOfVectors), + Explicit(ExplicitVectors), +} + +impl Vectors { + pub fn into_array_of_vectors(self) -> Vec { + match self { + Vectors::ImplicitlyUserProvided(embeddings) + | Vectors::Explicit(ExplicitVectors { embeddings, user_provided: _ }) => { + embeddings.into_array_of_vectors().unwrap_or_default() + } + } + } +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ExplicitVectors { + pub embeddings: VectorOrArrayOfVectors, + pub user_provided: bool, +} + +pub struct ParsedVectorsDiff { + pub old: Option>, + pub new: Option>, +} + +impl ParsedVectorsDiff { + pub fn new( + documents_diff: KvReader<'_, FieldId>, + old_vectors_fid: Option, + new_vectors_fid: Option, + ) -> Result { + let old = match old_vectors_fid + .and_then(|vectors_fid| documents_diff.get(vectors_fid)) + .map(KvReaderDelAdd::new) + .map(|obkv| to_vector_map(obkv, DelAdd::Deletion)) + .transpose() + { + Ok(del) => del, + // ignore wrong shape for old version of documents, use an empty map in this case + Err(Error::InvalidMap(value)) => { + tracing::warn!(%value, "Previous version of the `_vectors` field had a wrong shape"); + Default::default() + } + Err(error) => { + return Err(error); + } + } + .flatten(); + let new = new_vectors_fid + .and_then(|vectors_fid| documents_diff.get(vectors_fid)) + .map(KvReaderDelAdd::new) + .map(|obkv| to_vector_map(obkv, DelAdd::Addition)) + .transpose()? + .flatten(); + Ok(Self { old, new }) + } + + pub fn remove(&mut self, embedder_name: &str) -> (Option, Option) { + let old = self.old.as_mut().and_then(|old| old.remove(embedder_name)); + let new = self.new.as_mut().and_then(|new| new.remove(embedder_name)); + (old, new) + } +} + +pub struct ParsedVectors(pub BTreeMap); + +impl ParsedVectors { + pub fn from_bytes(value: &[u8]) -> Result { + let Ok(value) = from_slice(value) else { + let value = from_slice(value).map_err(Error::InternalSerdeJson)?; + return Err(Error::InvalidMap(value)); + }; + Ok(ParsedVectors(value)) + } + + pub fn retain_user_provided_vectors(&mut self) { + self.0.retain(|_k, v| match v { + Vectors::ImplicitlyUserProvided(_) => true, + Vectors::Explicit(ExplicitVectors { embeddings: _, user_provided }) => *user_provided, + }); + } +} + +pub enum Error { + InvalidMap(Value), + InternalSerdeJson(serde_json::Error), +} + +impl Error { + pub fn to_crate_error(self, document_id: String) -> crate::Error { + match self { + Error::InvalidMap(value) => { + crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value }) + } + Error::InternalSerdeJson(error) => { + crate::Error::InternalError(InternalError::SerdeJson(error)) + } + } + } +} + +fn to_vector_map( + obkv: KvReaderDelAdd, + side: DelAdd, +) -> Result>, Error> { + Ok(if let Some(value) = obkv.get(side) { + let ParsedVectors(parsed_vectors) = ParsedVectors::from_bytes(value)?; + Some(parsed_vectors) + } else { + None + }) +} + +/// Represents either a vector or an array of multiple vectors. +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(transparent)] +pub struct VectorOrArrayOfVectors { + #[serde(with = "either::serde_untagged_optional")] + inner: Option>>, +} + +impl VectorOrArrayOfVectors { + pub fn into_array_of_vectors(self) -> Option> { + match self.inner? { + either::Either::Left(vector) => Some(vec![vector]), + either::Either::Right(vectors) => Some(vectors), + } + } + + pub fn from_array_of_vectors(array_of_vec: Vec) -> Self { + Self { inner: Some(either::Either::Right(array_of_vec)) } + } +}