From 0ee4671a9154e51eb2fd15c2aaf8b4ef048754b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 8 Jan 2025 15:59:56 +0100 Subject: [PATCH] Fix after upgrading candle --- crates/milli/src/vector/hf.rs | 12 ++++++++---- crates/milli/src/vector/openai.rs | 12 +++++++++++- crates/milli/src/vector/rest.rs | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index 3fe28e53a..447a88f5d 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -163,8 +163,10 @@ impl Embedder { let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; - let embeddings = - self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + let embeddings = self + .model + .forward(&token_ids, &token_type_ids, None) + .map_err(EmbedError::model_forward)?; // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = @@ -185,8 +187,10 @@ impl Embedder { Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; - let embeddings = - self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + let embeddings = self + .model + .forward(&token_ids, &token_type_ids, None) + .map_err(EmbedError::model_forward)?; // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index 7262bfef8..938c04fe3 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::time::Instant; use ordered_float::OrderedFloat; @@ -168,7 +169,6 @@ fn infer_api_key() -> String { .unwrap_or_default() } -#[derive(Debug)] pub struct Embedder { tokenizer: tiktoken_rs::CoreBPE, rest_embedder: RestEmbedder, @@ -302,3 +302,13 @@ impl Embedder { self.options.distribution() } } + +impl fmt::Debug for Embedder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Embedder") + .field("tokenizer", &"CoreBPE") + .field("rest_embedder", &self.rest_embedder) + .field("options", &self.options) + .finish() + } +} diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index 98be311d4..eb05bac64 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -175,7 +175,7 @@ impl Embedder { pub fn embed_tokens( &self, - tokens: &[usize], + tokens: &[u32], deadline: Option, ) -> Result { let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;