2024-05-22 12:24:51 +02:00
|
|
|
use std::collections::{BTreeMap, BTreeSet};
|
2024-05-14 11:22:16 +02:00
|
|
|
|
|
|
|
use obkv::KvReader;
|
|
|
|
use serde_json::{from_slice, Value};
|
|
|
|
|
|
|
|
use super::Embedding;
|
2024-06-03 16:04:14 +02:00
|
|
|
use crate::index::IndexEmbeddingConfig;
|
2024-05-14 11:22:16 +02:00
|
|
|
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
|
2024-06-03 16:04:14 +02:00
|
|
|
use crate::{DocumentId, FieldId, InternalError, UserError};
|
2024-05-14 11:22:16 +02:00
|
|
|
|
|
|
|
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
|
|
|
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
|
|
|
#[serde(untagged)]
|
|
|
|
pub enum Vectors {
|
|
|
|
ImplicitlyUserProvided(VectorOrArrayOfVectors),
|
|
|
|
Explicit(ExplicitVectors),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Vectors {
|
2024-05-22 15:27:09 +02:00
|
|
|
pub fn is_user_provided(&self) -> bool {
|
|
|
|
match self {
|
|
|
|
Vectors::ImplicitlyUserProvided(_) => true,
|
|
|
|
Vectors::Explicit(ExplicitVectors { user_provided, .. }) => *user_provided,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-14 11:22:16 +02:00
|
|
|
pub fn into_array_of_vectors(self) -> Vec<Embedding> {
|
|
|
|
match self {
|
|
|
|
Vectors::ImplicitlyUserProvided(embeddings)
|
|
|
|
| Vectors::Explicit(ExplicitVectors { embeddings, user_provided: _ }) => {
|
|
|
|
embeddings.into_array_of_vectors().unwrap_or_default()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
|
|
pub struct ExplicitVectors {
|
|
|
|
pub embeddings: VectorOrArrayOfVectors,
|
|
|
|
pub user_provided: bool,
|
|
|
|
}
|
|
|
|
|
|
|
|
pub struct ParsedVectorsDiff {
|
2024-06-03 16:04:14 +02:00
|
|
|
pub old: BTreeMap<String, Option<Vectors>>,
|
2024-05-14 11:22:16 +02:00
|
|
|
pub new: Option<BTreeMap<String, Vectors>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ParsedVectorsDiff {
|
|
|
|
pub fn new(
|
2024-06-03 16:04:14 +02:00
|
|
|
docid: DocumentId,
|
|
|
|
embedders_configs: &[IndexEmbeddingConfig],
|
2024-05-14 11:22:16 +02:00
|
|
|
documents_diff: KvReader<'_, FieldId>,
|
|
|
|
old_vectors_fid: Option<FieldId>,
|
|
|
|
new_vectors_fid: Option<FieldId>,
|
|
|
|
) -> Result<Self, Error> {
|
2024-06-03 16:04:14 +02:00
|
|
|
let mut old = match old_vectors_fid
|
2024-05-14 11:22:16 +02:00
|
|
|
.and_then(|vectors_fid| documents_diff.get(vectors_fid))
|
|
|
|
.map(KvReaderDelAdd::new)
|
|
|
|
.map(|obkv| to_vector_map(obkv, DelAdd::Deletion))
|
|
|
|
.transpose()
|
|
|
|
{
|
|
|
|
Ok(del) => del,
|
|
|
|
// ignore wrong shape for old version of documents, use an empty map in this case
|
|
|
|
Err(Error::InvalidMap(value)) => {
|
|
|
|
tracing::warn!(%value, "Previous version of the `_vectors` field had a wrong shape");
|
|
|
|
Default::default()
|
|
|
|
}
|
|
|
|
Err(error) => {
|
|
|
|
return Err(error);
|
|
|
|
}
|
|
|
|
}
|
2024-06-03 16:04:14 +02:00
|
|
|
.flatten().map_or(BTreeMap::default(), |del| del.into_iter().map(|(name, vec)| (name, Some(vec))).collect());
|
|
|
|
for embedding_config in embedders_configs {
|
|
|
|
if embedding_config.user_defined.contains(docid) {
|
|
|
|
old.entry(embedding_config.name.to_string()).or_insert(None);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-14 11:22:16 +02:00
|
|
|
let new = new_vectors_fid
|
|
|
|
.and_then(|vectors_fid| documents_diff.get(vectors_fid))
|
|
|
|
.map(KvReaderDelAdd::new)
|
|
|
|
.map(|obkv| to_vector_map(obkv, DelAdd::Addition))
|
|
|
|
.transpose()?
|
|
|
|
.flatten();
|
|
|
|
Ok(Self { old, new })
|
|
|
|
}
|
|
|
|
|
2024-06-03 16:04:14 +02:00
|
|
|
/// Return (Some(None), _) in case the vector is user defined and contained in the database.
|
|
|
|
pub fn remove(&mut self, embedder_name: &str) -> (Option<Option<Vectors>>, Option<Vectors>) {
|
|
|
|
let old = self.old.remove(embedder_name);
|
2024-05-14 11:22:16 +02:00
|
|
|
let new = self.new.as_mut().and_then(|new| new.remove(embedder_name));
|
|
|
|
(old, new)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
|
|
|
|
|
|
|
|
impl ParsedVectors {
|
|
|
|
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
|
|
|
|
let Ok(value) = from_slice(value) else {
|
|
|
|
let value = from_slice(value).map_err(Error::InternalSerdeJson)?;
|
|
|
|
return Err(Error::InvalidMap(value));
|
|
|
|
};
|
|
|
|
Ok(ParsedVectors(value))
|
|
|
|
}
|
|
|
|
|
2024-05-22 15:27:09 +02:00
|
|
|
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
|
|
|
|
self.0.retain(|k, _v| !embedders.contains(k))
|
2024-05-14 11:22:16 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub enum Error {
|
|
|
|
InvalidMap(Value),
|
|
|
|
InternalSerdeJson(serde_json::Error),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Error {
|
|
|
|
pub fn to_crate_error(self, document_id: String) -> crate::Error {
|
|
|
|
match self {
|
|
|
|
Error::InvalidMap(value) => {
|
|
|
|
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
|
|
|
|
}
|
|
|
|
Error::InternalSerdeJson(error) => {
|
|
|
|
crate::Error::InternalError(InternalError::SerdeJson(error))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn to_vector_map(
|
|
|
|
obkv: KvReaderDelAdd,
|
|
|
|
side: DelAdd,
|
|
|
|
) -> Result<Option<BTreeMap<String, Vectors>>, Error> {
|
|
|
|
Ok(if let Some(value) = obkv.get(side) {
|
|
|
|
let ParsedVectors(parsed_vectors) = ParsedVectors::from_bytes(value)?;
|
|
|
|
Some(parsed_vectors)
|
|
|
|
} else {
|
|
|
|
None
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Represents either a vector or an array of multiple vectors.
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
|
|
|
#[serde(transparent)]
|
|
|
|
pub struct VectorOrArrayOfVectors {
|
|
|
|
#[serde(with = "either::serde_untagged_optional")]
|
2024-05-22 12:25:21 +02:00
|
|
|
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
|
2024-05-14 11:22:16 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
impl VectorOrArrayOfVectors {
|
|
|
|
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
|
|
|
|
match self.inner? {
|
2024-05-22 12:25:21 +02:00
|
|
|
either::Either::Left(vectors) => Some(vectors),
|
|
|
|
either::Either::Right(vector) => Some(vec![vector]),
|
2024-05-14 11:22:16 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn from_array_of_vectors(array_of_vec: Vec<Embedding>) -> Self {
|
2024-05-22 12:25:21 +02:00
|
|
|
Self { inner: Some(either::Either::Left(array_of_vec)) }
|
2024-05-14 11:22:16 +02:00
|
|
|
}
|
|
|
|
}
|
2024-05-16 18:13:27 +02:00
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod test {
|
|
|
|
use super::VectorOrArrayOfVectors;
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn array_of_vectors() {
|
|
|
|
let null: VectorOrArrayOfVectors = serde_json::from_str("null").unwrap();
|
|
|
|
let empty: VectorOrArrayOfVectors = serde_json::from_str("[]").unwrap();
|
|
|
|
let one: VectorOrArrayOfVectors = serde_json::from_str("[0.1]").unwrap();
|
|
|
|
let two: VectorOrArrayOfVectors = serde_json::from_str("[0.1, 0.2]").unwrap();
|
|
|
|
let one_vec: VectorOrArrayOfVectors = serde_json::from_str("[[0.1, 0.2]]").unwrap();
|
|
|
|
let two_vecs: VectorOrArrayOfVectors =
|
|
|
|
serde_json::from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
|
|
|
|
|
|
|
|
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
|
2024-05-22 12:25:21 +02:00
|
|
|
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");
|
2024-05-16 18:13:27 +02:00
|
|
|
insta::assert_json_snapshot!(one.into_array_of_vectors(), @r###"
|
|
|
|
[
|
|
|
|
[
|
|
|
|
0.1
|
|
|
|
]
|
|
|
|
]
|
|
|
|
"###);
|
|
|
|
insta::assert_json_snapshot!(two.into_array_of_vectors(), @r###"
|
|
|
|
[
|
|
|
|
[
|
|
|
|
0.1,
|
|
|
|
0.2
|
|
|
|
]
|
|
|
|
]
|
|
|
|
"###);
|
|
|
|
insta::assert_json_snapshot!(one_vec.into_array_of_vectors(), @r###"
|
|
|
|
[
|
|
|
|
[
|
|
|
|
0.1,
|
|
|
|
0.2
|
|
|
|
]
|
|
|
|
]
|
|
|
|
"###);
|
|
|
|
insta::assert_json_snapshot!(two_vecs.into_array_of_vectors(), @r###"
|
|
|
|
[
|
|
|
|
[
|
|
|
|
0.1,
|
|
|
|
0.2
|
|
|
|
],
|
|
|
|
[
|
|
|
|
0.3,
|
|
|
|
0.4
|
|
|
|
]
|
|
|
|
]
|
|
|
|
"###);
|
|
|
|
}
|
|
|
|
}
|