Fix multiple embeddings in hf

This commit is contained in:
Louis Dureuil 2025-02-24 16:24:04 +01:00
parent e374b095a2
commit 24fe6cd205
No known key found for this signature in database

View File

@ -255,34 +255,8 @@ impl Embedder {
Ok(this)
}
pub fn embed(&self, mut texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
let tokens = match texts.len() {
1 => vec![self
.tokenizer
.encode(texts.pop().unwrap(), true)
.map_err(EmbedError::tokenize)?],
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let mut tokens = tokens.get_ids().to_vec();
tokens.truncate(512);
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
})
.collect::<Result<Vec<_>, EmbedError>>()?;
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings = self
.model
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?;
let embeddings = Self::pooling(embeddings, self.pooling)?;
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings)
pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
texts.into_iter().map(|text| self.embed_one(&text)).collect()
}
fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {