mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-03-06 22:02:34 +08:00
Fix multiple embeddings in hf
This commit is contained in:
parent
e374b095a2
commit
24fe6cd205
@ -255,34 +255,8 @@ impl Embedder {
|
|||||||
Ok(this)
|
Ok(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, mut texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
|
pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Embedding>, EmbedError> {
|
||||||
let tokens = match texts.len() {
|
texts.into_iter().map(|text| self.embed_one(&text)).collect()
|
||||||
1 => vec![self
|
|
||||||
.tokenizer
|
|
||||||
.encode(texts.pop().unwrap(), true)
|
|
||||||
.map_err(EmbedError::tokenize)?],
|
|
||||||
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
|
|
||||||
};
|
|
||||||
let token_ids = tokens
|
|
||||||
.iter()
|
|
||||||
.map(|tokens| {
|
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
|
||||||
tokens.truncate(512);
|
|
||||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
|
||||||
|
|
||||||
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
|
||||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
|
||||||
let embeddings = self
|
|
||||||
.model
|
|
||||||
.forward(&token_ids, &token_type_ids, None)
|
|
||||||
.map_err(EmbedError::model_forward)?;
|
|
||||||
|
|
||||||
let embeddings = Self::pooling(embeddings, self.pooling)?;
|
|
||||||
|
|
||||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
|
||||||
Ok(embeddings)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
|
fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user