From 34349faeae779a4ab8687ab16408615ceccba03b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 11:35:36 +0200 Subject: [PATCH] Create a new _vector extractor --- Cargo.lock | 1 + milli/Cargo.toml | 1 + .../extract/extract_vector_points.rs | 40 +++++++++++++++++++ .../src/update/index_documents/extract/mod.rs | 2 + 4 files changed, 44 insertions(+) create mode 100644 milli/src/update/index_documents/extract/extract_vector_points.rs diff --git a/Cargo.lock b/Cargo.lock index 46218fc34..9d09fef9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2683,6 +2683,7 @@ dependencies = [ "bimap", "bincode", "bstr", + "bytemuck", "byteorder", "charabia", "concat-arrays", diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 138103723..5ff73303a 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -15,6 +15,7 @@ license.workspace = true bimap = { version = "0.6.3", features = ["serde"] } bincode = "1.3.3" bstr = "1.4.0" +bytemuck = "1.13.1" byteorder = "1.4.3" charabia = { version = "0.7.2", default-features = false } concat-arrays = "0.1.2" diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs new file mode 100644 index 000000000..409df5dbd --- /dev/null +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -0,0 +1,40 @@ +use std::fs::File; +use std::io; + +use bytemuck::cast_slice; +use serde_json::from_slice; + +use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; +use crate::{FieldId, InternalError, Result}; + +/// Extracts the embedding vector contained in each document under the `_vector` field. +/// +/// Returns the generated grenad reader containing the docid as key associated to the Vec +#[logging_timer::time] +pub fn extract_vector_points( + obkv_documents: grenad::Reader, + indexer: GrenadParameters, + vector_fid: FieldId, +) -> Result> { + let mut writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut cursor = obkv_documents.into_cursor()?; + while let Some((docid_bytes, value)) = cursor.move_on_next()? { + let obkv = obkv::KvReader::new(value); + + // first we get the _vector field + if let Some(vector) = obkv.get(vector_fid) { + // try to extract the vector + let vector: Vec = from_slice(vector).map_err(InternalError::SerdeJson).unwrap(); + let bytes = cast_slice(&vector); + writer.insert(docid_bytes, bytes)?; + } + // else => the _vector object was `null`, there is nothing to do + } + + writer_into_reader(writer) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 632f568ab..128fc29c0 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -4,6 +4,7 @@ mod extract_facet_string_docids; mod extract_fid_docid_facet_values; mod extract_fid_word_count_docids; mod extract_geo_points; +mod extract_vector_points; mod extract_word_docids; mod extract_word_fid_docids; mod extract_word_pair_proximity_docids; @@ -22,6 +23,7 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; +use self::extract_vector_points::extract_vector_points; use self::extract_word_docids::extract_word_docids; use self::extract_word_fid_docids::extract_word_fid_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;