From 24fe6cd2059e4674597a744cf80afac8d9ac5dcf Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 24 Feb 2025 16:24:04 +0100 Subject: [PATCH] Fix multiple embeddings in hf --- crates/milli/src/vector/hf.rs | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index 3ec0a5b7c..60e40e367 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -255,34 +255,8 @@ impl Embedder { Ok(this) } - pub fn embed(&self, mut texts: Vec) -> std::result::Result, EmbedError> { - let tokens = match texts.len() { - 1 => vec![self - .tokenizer - .encode(texts.pop().unwrap(), true) - .map_err(EmbedError::tokenize)?], - _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, - }; - let token_ids = tokens - .iter() - .map(|tokens| { - let mut tokens = tokens.get_ids().to_vec(); - tokens.truncate(512); - Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) - }) - .collect::, EmbedError>>()?; - - let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; - let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; - let embeddings = self - .model - .forward(&token_ids, &token_type_ids, None) - .map_err(EmbedError::model_forward)?; - - let embeddings = Self::pooling(embeddings, self.pooling)?; - - let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; - Ok(embeddings) + pub fn embed(&self, texts: Vec) -> std::result::Result, EmbedError> { + texts.into_iter().map(|text| self.embed_one(&text)).collect() } fn pooling(embeddings: Tensor, pooling: Pooling) -> Result {