From 5b51cb04afd4a005f269527b7f88f47055835784 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 23:09:50 +0100 Subject: [PATCH] Remove some settings --- milli/src/vector/hf.rs | 42 +++++++++++------------------------- milli/src/vector/settings.rs | 26 ++-------------------- 2 files changed, 15 insertions(+), 53 deletions(-) diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 07185d25c..3162dadec 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -23,7 +23,7 @@ use super::{Embedding, Embeddings}; )] #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] -pub enum WeightSource { +enum WeightSource { #[default] Safetensors, Pytorch, @@ -33,20 +33,13 @@ pub enum WeightSource { pub struct EmbedderOptions { pub model: String, pub revision: Option, - pub weight_source: WeightSource, - pub normalize_embeddings: bool, } impl EmbedderOptions { pub fn new() -> Self { Self { - //model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), model: "BAAI/bge-base-en-v1.5".to_string(), - //revision: Some("refs/pr/21".to_string()), - revision: None, - //weight_source: Default::default(), - weight_source: WeightSource::Pytorch, - normalize_embeddings: true, + revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), } } } @@ -82,20 +75,21 @@ impl Embedder { Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), None => Repo::model(options.model.clone()), }; - let (config_filename, tokenizer_filename, weights_filename) = { + let (config_filename, tokenizer_filename, weights_filename, weight_source) = { let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; let api = api.repo(repo); let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; - let weights = match options.weight_source { - WeightSource::Pytorch => { - api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? - } - WeightSource::Safetensors => { - api.get("model.safetensors").map_err(NewEmbedderError::api_get)? - } + let (weights, source) = { + api.get("pytorch_model.bin") + .map(|filename| (filename, WeightSource::Pytorch)) + .or_else(|_| { + api.get("model.safetensors") + .map(|filename| (filename, WeightSource::Safetensors)) + }) + .map_err(NewEmbedderError::api_get)? }; - (config, tokenizer, weights) + (config, tokenizer, weights, source) }; let config = std::fs::read_to_string(&config_filename) @@ -106,7 +100,7 @@ impl Embedder { let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; - let vb = match options.weight_source { + let vb = match weight_source { WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) .map_err(NewEmbedderError::pytorch_weight)?, WeightSource::Safetensors => unsafe { @@ -168,12 +162,6 @@ impl Embedder { let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) .map_err(EmbedError::tensor_shape)?; - let embeddings: Tensor = if self.options.normalize_embeddings { - normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? - } else { - embeddings - }; - let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) } @@ -197,7 +185,3 @@ impl Embedder { self.dimensions } } - -fn normalize_l2(v: &Tensor) -> Result { - v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) -} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index bd385e3f3..e37b0fde7 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize}; use crate::prompt::PromptData; use crate::update::Setting; -use crate::vector::hf::WeightSource; use crate::vector::EmbeddingConfig; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] @@ -204,26 +203,13 @@ pub struct HfEmbedderSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub revision: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub weight_source: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub normalize_embeddings: Setting, } impl HfEmbedderSettings { pub fn apply(&mut self, new: Self) { - let HfEmbedderSettings { - model, - revision, - weight_source, - normalize_embeddings: normalize_embedding, - } = new; + let HfEmbedderSettings { model, revision } = new; self.model.apply(model); self.revision.apply(revision); - self.weight_source.apply(weight_source); - self.normalize_embeddings.apply(normalize_embedding); } } @@ -232,15 +218,13 @@ impl From for HfEmbedderSettings { Self { model: Setting::Set(value.model), revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), - weight_source: Setting::Set(value.weight_source), - normalize_embeddings: Setting::Set(value.normalize_embeddings), } } } impl From for crate::vector::hf::EmbedderOptions { fn from(value: HfEmbedderSettings) -> Self { - let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; + let HfEmbedderSettings { model, revision } = value; let mut this = Self::default(); if let Some(model) = model.set() { this.model = model; @@ -248,12 +232,6 @@ impl From for crate::vector::hf::EmbedderOptions { if let Some(revision) = revision.set() { this.revision = Some(revision); } - if let Some(weight_source) = weight_source.set() { - this.weight_source = weight_source; - } - if let Some(normalize_embeddings) = normalize_embeddings.set() { - this.normalize_embeddings = normalize_embeddings; - } this } }