From 321ec5f3fa01107829b49e11ccefbb2ac2490bc0 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 11:17:20 +0200 Subject: [PATCH] Accept multiple vectors by documents using the _vectors field --- .../extract/extract_vector_points.rs | 32 +++++++++++++------ .../src/update/index_documents/extract/mod.rs | 4 +-- milli/src/update/index_documents/mod.rs | 6 ++-- .../src/update/index_documents/typed_chunk.rs | 5 +-- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 409df5dbd..7e2bd25c5 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,20 +1,22 @@ +use std::convert::TryFrom; use std::fs::File; use std::io; use bytemuck::cast_slice; +use either::Either; use serde_json::from_slice; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::{FieldId, InternalError, Result}; -/// Extracts the embedding vector contained in each document under the `_vector` field. +/// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec #[logging_timer::time] pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - vector_fid: FieldId, + vectors_fid: FieldId, ) -> Result> { let mut writer = create_writer( indexer.chunk_compression_type, @@ -26,14 +28,26 @@ pub fn extract_vector_points( while let Some((docid_bytes, value)) = cursor.move_on_next()? { let obkv = obkv::KvReader::new(value); - // first we get the _vector field - if let Some(vector) = obkv.get(vector_fid) { - // try to extract the vector - let vector: Vec = from_slice(vector).map_err(InternalError::SerdeJson).unwrap(); - let bytes = cast_slice(&vector); - writer.insert(docid_bytes, bytes)?; + // first we retrieve the _vectors field + if let Some(vectors) = obkv.get(vectors_fid) { + // extract the vectors + let vectors: Either>, Vec> = + from_slice(vectors).map_err(InternalError::SerdeJson).unwrap(); + let vectors = vectors.map_right(|v| vec![v]).into_inner(); + + for (i, vector) in vectors.into_iter().enumerate() { + match u16::try_from(i) { + Ok(i) => { + let mut key = docid_bytes.to_vec(); + key.extend_from_slice(&i.to_ne_bytes()); + let bytes = cast_slice(&vector); + writer.insert(key, bytes)?; + } + Err(_) => continue, + } + } } - // else => the _vector object was `null`, there is nothing to do + // else => the `_vectors` object was `null`, there is nothing to do } writer_into_reader(writer) diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index fdc6f5616..325d52279 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -47,7 +47,7 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vector_field_id: Option, + vectors_field_id: Option, stop_words: Option>, max_positions_per_attributes: Option, exact_attributes: HashSet, @@ -72,7 +72,7 @@ pub(crate) fn data_from_obkv_documents( &faceted_fields, primary_key_id, geo_fields_ids, - vector_field_id, + vectors_field_id, &stop_words, max_positions_per_attributes, ) diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index adbab54db..5b6e03637 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -304,8 +304,8 @@ where } None => None, }; - // get the fid of the `_vector` field. - let vector_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vector"); + // get the fid of the `_vectors` field. + let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); let stop_words = self.index.stop_words(self.wtxn)?; let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?; @@ -342,7 +342,7 @@ where faceted_fields, primary_key_id, geo_fields_ids, - vector_field_id, + vectors_field_id, stop_words, max_positions_per_attributes, exact_attributes, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 0e2e85c1c..7d23ef320 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -20,7 +20,7 @@ use super::{ClonableMmap, MergeFn}; use crate::error::UserError; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; -use crate::update::index_documents::helpers::as_cloneable_grenad; +use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; pub(crate) enum TypedChunk { @@ -241,7 +241,8 @@ pub(crate) fn write_typed_chunk_into_index( let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) - let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + let (left, _index) = try_split_array_at(key).unwrap(); + let docid = DocumentId::from_be_bytes(left); // convert the vector back to a Vec let vector: Vec = pod_collect_to_vec(value);