From bef8fc6cf14ae3078ec5d0af19fb352637cda443 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Fri, 8 Nov 2024 13:10:17 +0100 Subject: [PATCH] Fix hf embedder --- crates/milli/src/vector/hf.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index ea892ca57..3fe28e53a 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -183,14 +183,17 @@ impl Embedder { let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids }; let token_ids = Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; + let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; let embeddings = self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) - let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?; - let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) .map_err(EmbedError::tensor_shape)?; + let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?; let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; Ok(embedding) }