mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-25 11:35:05 +08:00
Fix hf embedder
This commit is contained in:
parent
e32677999f
commit
bef8fc6cf1
@ -183,14 +183,17 @@ impl Embedder {
|
||||
let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids };
|
||||
let token_ids =
|
||||
Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?;
|
||||
let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings =
|
||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (n_tokens, _hidden_size) = embeddings.dims2().map_err(EmbedError::tensor_shape)?;
|
||||
let embedding = (embeddings.sum(0).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||
let (_n_sentence, n_tokens, _hidden_size) =
|
||||
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||
let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
|
||||
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
|
||||
Ok(embedding)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user