diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index ea892ca57..3fe28e53a 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -183,14 +183,17 @@ impl Embedder { let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids }; let token_ids = Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; + let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; let embeddings = self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) - let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?; - let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) .map_err(EmbedError::tensor_shape)?; + let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?; let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; Ok(embedding) }