Fix hf embedder

This commit is contained in:
Louis Dureuil 2024-11-08 13:10:17 +01:00
parent e32677999f
commit bef8fc6cf1
No known key found for this signature in database

View File

@ -183,14 +183,17 @@ impl Embedder {
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
let token_ids =
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings =
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?;
let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
Ok(embedding)
}