Update older embedding

This commit is contained in:
Louis Dureuil 2024-10-28 14:22:45 +01:00
parent 1960003805
commit af9f96e2af
No known key found for this signature in database

View File

@ -21,7 +21,7 @@ use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::error::{EmbedErrorKind, PossibleEmbeddingMistakes, UnusedVectorsDistribution}; use crate::vector::error::{EmbedErrorKind, PossibleEmbeddingMistakes, UnusedVectorsDistribution};
use crate::vector::parsed_vectors::{ParsedVectorsDiff, VectorState, RESERVED_VECTORS_FIELD_NAME}; use crate::vector::parsed_vectors::{ParsedVectorsDiff, VectorState, RESERVED_VECTORS_FIELD_NAME};
use crate::vector::settings::ReindexAction; use crate::vector::settings::ReindexAction;
use crate::vector::{Embedder, Embeddings}; use crate::vector::{Embedder, Embedding};
use crate::{try_split_array_at, DocumentId, FieldId, Result, ThreadPoolNoAbort}; use crate::{try_split_array_at, DocumentId, FieldId, Result, ThreadPoolNoAbort};
/// The length of the elements that are always in the buffer when inserting new values. /// The length of the elements that are always in the buffer when inserting new values.
@ -536,9 +536,11 @@ fn extract_vector_document_diff(
} }
// Don't give up if the old prompt was failing // Don't give up if the old prompt was failing
let old_prompt = Some(&prompt).map(|p| { let old_prompt = Some(&prompt).map(|p| {
p.render(obkv, DelAdd::Deletion, old_fields_ids_map).unwrap_or_default() p.render_kvdeladd(obkv, DelAdd::Deletion, old_fields_ids_map)
.unwrap_or_default()
}); });
let new_prompt = prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; let new_prompt =
prompt.render_kvdeladd(obkv, DelAdd::Addition, new_fields_ids_map)?;
if old_prompt.as_ref() != Some(&new_prompt) { if old_prompt.as_ref() != Some(&new_prompt) {
let old_prompt = old_prompt.unwrap_or_default(); let old_prompt = old_prompt.unwrap_or_default();
tracing::trace!( tracing::trace!(
@ -570,7 +572,7 @@ fn extract_vector_document_diff(
return Ok(VectorStateDelta::NoChange); return Ok(VectorStateDelta::NoChange);
} }
// becomes autogenerated // becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render( VectorStateDelta::NowGenerated(prompt.render_kvdeladd(
obkv, obkv,
DelAdd::Addition, DelAdd::Addition,
new_fields_ids_map, new_fields_ids_map,
@ -613,9 +615,10 @@ fn regenerate_if_prompt_changed(
&FieldsIdsMapWithMetadata, &FieldsIdsMapWithMetadata,
), ),
) -> Result<VectorStateDelta> { ) -> Result<VectorStateDelta> {
let old_prompt = let old_prompt = old_prompt
old_prompt.render(obkv, DelAdd::Deletion, old_fields_ids_map).unwrap_or(Default::default()); .render_kvdeladd(obkv, DelAdd::Deletion, old_fields_ids_map)
let new_prompt = new_prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; .unwrap_or(Default::default());
let new_prompt = new_prompt.render_kvdeladd(obkv, DelAdd::Addition, new_fields_ids_map)?;
if new_prompt == old_prompt { if new_prompt == old_prompt {
return Ok(VectorStateDelta::NoChange); return Ok(VectorStateDelta::NoChange);
@ -628,7 +631,7 @@ fn regenerate_prompt(
prompt: &Prompt, prompt: &Prompt,
new_fields_ids_map: &FieldsIdsMapWithMetadata, new_fields_ids_map: &FieldsIdsMapWithMetadata,
) -> Result<VectorStateDelta> { ) -> Result<VectorStateDelta> {
let prompt = prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; let prompt = prompt.render_kvdeladd(obkv, DelAdd::Addition, new_fields_ids_map)?;
Ok(VectorStateDelta::NowGenerated(prompt)) Ok(VectorStateDelta::NowGenerated(prompt))
} }
@ -738,7 +741,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.flat_map(|docids| docids.iter()) .flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{ {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings))?;
} }
chunks_ids.clear(); chunks_ids.clear();
} }
@ -759,7 +762,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.flat_map(|docids| docids.iter()) .flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{ {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings))?;
} }
} }
@ -775,7 +778,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if let Some(embeds) = embeds.first() { if let Some(embeds) = embeds.first() {
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings))?;
} }
} }
} }
@ -790,7 +793,7 @@ fn embed_chunks(
possible_embedding_mistakes: &PossibleEmbeddingMistakes, possible_embedding_mistakes: &PossibleEmbeddingMistakes,
unused_vectors_distribution: &UnusedVectorsDistribution, unused_vectors_distribution: &UnusedVectorsDistribution,
request_threads: &ThreadPoolNoAbort, request_threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>> { ) -> Result<Vec<Vec<Embedding>>> {
match embedder.embed_chunks(text_chunks, request_threads) { match embedder.embed_chunks(text_chunks, request_threads) {
Ok(chunks) => Ok(chunks), Ok(chunks) => Ok(chunks),
Err(error) => { Err(error) => {