From 750c98833386169e29639964179110231bea97c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 9 Oct 2024 10:39:25 +0200 Subject: [PATCH] Also export vectors with the documents --- meilitool/src/main.rs | 63 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/meilitool/src/main.rs b/meilitool/src/main.rs index c94ff19da..4f4df8822 100644 --- a/meilitool/src/main.rs +++ b/meilitool/src/main.rs @@ -9,12 +9,15 @@ use file_store::FileStore; use meilisearch_auth::AuthController; use meilisearch_types::heed::types::{SerdeJson, Str}; use meilisearch_types::heed::{Database, Env, EnvOpenOptions, RoTxn, RwTxn, Unspecified}; -use meilisearch_types::milli::documents::{obkv_to_object, DocumentsBatchReader}; -use meilisearch_types::milli::index::{db_name, main_key}; -use meilisearch_types::milli::{obkv_to_json, BEU32}; +use meilisearch_types::milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; use meilisearch_types::tasks::{Status, Task}; use meilisearch_types::versioning::{create_version_file, get_version, parse_version}; -use meilisearch_types::Index; +use meilisearch_types::{milli, Index}; +use milli::documents::{obkv_to_object, DocumentsBatchReader}; +use milli::index::{db_name, main_key}; +use milli::vector::parsed_vectors::{ExplicitVectors, VectorOrArrayOfVectors}; +use milli::{obkv_to_json, BEU32}; +use serde_json::Value::Object; use time::macros::format_description; use time::OffsetDateTime; use uuid_codec::UuidCodec; @@ -782,11 +785,59 @@ fn export_documents(db_path: PathBuf, index_name: String) -> anyhow::Result<()> let rtxn = index.read_txn()?; let fields_ids_map = index.fields_ids_map(&rtxn)?; let all_fields: Vec<_> = fields_ids_map.iter().map(|(id, _)| id).collect(); + let embedding_configs = index.embedding_configs(&rtxn)?; let mut stdout = BufWriter::new(std::io::stdout()); for ret in index.all_documents(&rtxn)? { - let (_id, doc) = ret?; - let document = obkv_to_json(&all_fields, &fields_ids_map, doc)?; + let (id, doc) = ret?; + let mut document = obkv_to_json(&all_fields, &fields_ids_map, doc)?; + + 'inject_vectors: { + let embeddings = index.embeddings(&rtxn, id)?; + + if embeddings.is_empty() { + break 'inject_vectors; + } + + let vectors = document + .entry(RESERVED_VECTORS_FIELD_NAME) + .or_insert(Object(Default::default())); + + let Object(vectors) = vectors else { + return Err(meilisearch_types::milli::Error::UserError( + meilisearch_types::milli::UserError::InvalidVectorsMapType { + document_id: { + if let Ok(Some(Ok(index))) = index + .external_id_of(&rtxn, std::iter::once(id)) + .map(|it| it.into_iter().next()) + { + index + } else { + format!("internal docid={id}") + } + }, + value: vectors.clone(), + }, + ) + .into()); + }; + + for (embedder_name, embeddings) in embeddings { + let user_provided = embedding_configs + .iter() + .find(|conf| conf.name == embedder_name) + .is_some_and(|conf| conf.user_provided.contains(id)); + + let embeddings = ExplicitVectors { + embeddings: Some(VectorOrArrayOfVectors::from_array_of_vectors( + embeddings, + )), + regenerate: !user_provided, + }; + vectors.insert(embedder_name, serde_json::to_value(embeddings).unwrap()); + } + } + serde_json::to_writer(&mut stdout, &document)?; }