Accept multiple vectors by documents using the _vectors field

This commit is contained in:
Kerollmops 2023-06-20 11:17:20 +02:00 committed by Clément Renault
parent 1b2923f7c0
commit 321ec5f3fa
No known key found for this signature in database
GPG Key ID: 92ADA4E935E71FA4
4 changed files with 31 additions and 16 deletions

View File

@ -1,20 +1,22 @@
use std::convert::TryFrom;
use std::fs::File; use std::fs::File;
use std::io; use std::io;
use bytemuck::cast_slice; use bytemuck::cast_slice;
use either::Either;
use serde_json::from_slice; use serde_json::from_slice;
use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::{FieldId, InternalError, Result}; use crate::{FieldId, InternalError, Result};
/// Extracts the embedding vector contained in each document under the `_vector` field. /// Extracts the embedding vector contained in each document under the `_vectors` field.
/// ///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32> /// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
#[logging_timer::time] #[logging_timer::time]
pub fn extract_vector_points<R: io::Read + io::Seek>( pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>, obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters, indexer: GrenadParameters,
vector_fid: FieldId, vectors_fid: FieldId,
) -> Result<grenad::Reader<File>> { ) -> Result<grenad::Reader<File>> {
let mut writer = create_writer( let mut writer = create_writer(
indexer.chunk_compression_type, indexer.chunk_compression_type,
@ -26,14 +28,26 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
while let Some((docid_bytes, value)) = cursor.move_on_next()? { while let Some((docid_bytes, value)) = cursor.move_on_next()? {
let obkv = obkv::KvReader::new(value); let obkv = obkv::KvReader::new(value);
// first we get the _vector field // first we retrieve the _vectors field
if let Some(vector) = obkv.get(vector_fid) { if let Some(vectors) = obkv.get(vectors_fid) {
// try to extract the vector // extract the vectors
let vector: Vec<f32> = from_slice(vector).map_err(InternalError::SerdeJson).unwrap(); let vectors: Either<Vec<Vec<f32>>, Vec<f32>> =
let bytes = cast_slice(&vector); from_slice(vectors).map_err(InternalError::SerdeJson).unwrap();
writer.insert(docid_bytes, bytes)?; let vectors = vectors.map_right(|v| vec![v]).into_inner();
for (i, vector) in vectors.into_iter().enumerate() {
match u16::try_from(i) {
Ok(i) => {
let mut key = docid_bytes.to_vec();
key.extend_from_slice(&i.to_ne_bytes());
let bytes = cast_slice(&vector);
writer.insert(key, bytes)?;
}
Err(_) => continue,
}
}
} }
// else => the _vector object was `null`, there is nothing to do // else => the `_vectors` object was `null`, there is nothing to do
} }
writer_into_reader(writer) writer_into_reader(writer)

View File

@ -47,7 +47,7 @@ pub(crate) fn data_from_obkv_documents(
faceted_fields: HashSet<FieldId>, faceted_fields: HashSet<FieldId>,
primary_key_id: FieldId, primary_key_id: FieldId,
geo_fields_ids: Option<(FieldId, FieldId)>, geo_fields_ids: Option<(FieldId, FieldId)>,
vector_field_id: Option<FieldId>, vectors_field_id: Option<FieldId>,
stop_words: Option<fst::Set<&[u8]>>, stop_words: Option<fst::Set<&[u8]>>,
max_positions_per_attributes: Option<u32>, max_positions_per_attributes: Option<u32>,
exact_attributes: HashSet<FieldId>, exact_attributes: HashSet<FieldId>,
@ -72,7 +72,7 @@ pub(crate) fn data_from_obkv_documents(
&faceted_fields, &faceted_fields,
primary_key_id, primary_key_id,
geo_fields_ids, geo_fields_ids,
vector_field_id, vectors_field_id,
&stop_words, &stop_words,
max_positions_per_attributes, max_positions_per_attributes,
) )

View File

@ -304,8 +304,8 @@ where
} }
None => None, None => None,
}; };
// get the fid of the `_vector` field. // get the fid of the `_vectors` field.
let vector_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vector"); let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
let stop_words = self.index.stop_words(self.wtxn)?; let stop_words = self.index.stop_words(self.wtxn)?;
let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?; let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?;
@ -342,7 +342,7 @@ where
faceted_fields, faceted_fields,
primary_key_id, primary_key_id,
geo_fields_ids, geo_fields_ids,
vector_field_id, vectors_field_id,
stop_words, stop_words,
max_positions_per_attributes, max_positions_per_attributes,
exact_attributes, exact_attributes,

View File

@ -20,7 +20,7 @@ use super::{ClonableMmap, MergeFn};
use crate::error::UserError; use crate::error::UserError;
use crate::facet::FacetType; use crate::facet::FacetType;
use crate::update::facet::FacetsUpdate; use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::as_cloneable_grenad; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32};
pub(crate) enum TypedChunk { pub(crate) enum TypedChunk {
@ -241,7 +241,8 @@ pub(crate) fn write_typed_chunk_into_index(
let mut cursor = vector_points.into_cursor()?; let mut cursor = vector_points.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? { while let Some((key, value)) = cursor.move_on_next()? {
// convert the key back to a u32 (4 bytes) // convert the key back to a u32 (4 bytes)
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); let (left, _index) = try_split_array_at(key).unwrap();
let docid = DocumentId::from_be_bytes(left);
// convert the vector back to a Vec<f32> // convert the vector back to a Vec<f32>
let vector: Vec<f32> = pod_collect_to_vec(value); let vector: Vec<f32> = pod_collect_to_vec(value);