Remove some settings

This commit is contained in:
Louis Dureuil 2023-12-13 23:09:50 +01:00
parent 3c1a14f1cd
commit 5b51cb04af
No known key found for this signature in database
2 changed files with 15 additions and 53 deletions

View File

@ -23,7 +23,7 @@ use super::{Embedding, Embeddings};
)] )]
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)] #[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum WeightSource { enum WeightSource {
#[default] #[default]
Safetensors, Safetensors,
Pytorch, Pytorch,
@ -33,20 +33,13 @@ pub enum WeightSource {
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub model: String, pub model: String,
pub revision: Option<String>, pub revision: Option<String>,
pub weight_source: WeightSource,
pub normalize_embeddings: bool,
} }
impl EmbedderOptions { impl EmbedderOptions {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
//model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
model: "BAAI/bge-base-en-v1.5".to_string(), model: "BAAI/bge-base-en-v1.5".to_string(),
//revision: Some("refs/pr/21".to_string()), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
revision: None,
//weight_source: Default::default(),
weight_source: WeightSource::Pytorch,
normalize_embeddings: true,
} }
} }
} }
@ -82,20 +75,21 @@ impl Embedder {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()), None => Repo::model(options.model.clone()),
}; };
let (config_filename, tokenizer_filename, weights_filename) = { let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo); let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let weights = match options.weight_source { let (weights, source) = {
WeightSource::Pytorch => { api.get("pytorch_model.bin")
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? .map(|filename| (filename, WeightSource::Pytorch))
} .or_else(|_| {
WeightSource::Safetensors => { api.get("model.safetensors")
api.get("model.safetensors").map_err(NewEmbedderError::api_get)? .map(|filename| (filename, WeightSource::Safetensors))
} })
.map_err(NewEmbedderError::api_get)?
}; };
(config, tokenizer, weights) (config, tokenizer, weights, source)
}; };
let config = std::fs::read_to_string(&config_filename) let config = std::fs::read_to_string(&config_filename)
@ -106,7 +100,7 @@ impl Embedder {
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let vb = match options.weight_source { let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?, .map_err(NewEmbedderError::pytorch_weight)?,
WeightSource::Safetensors => unsafe { WeightSource::Safetensors => unsafe {
@ -168,12 +162,6 @@ impl Embedder {
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?; .map_err(EmbedError::tensor_shape)?;
let embeddings: Tensor = if self.options.normalize_embeddings {
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
} else {
embeddings
};
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
} }
@ -197,7 +185,3 @@ impl Embedder {
self.dimensions self.dimensions
} }
} }
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize};
use crate::prompt::PromptData; use crate::prompt::PromptData;
use crate::update::Setting; use crate::update::Setting;
use crate::vector::hf::WeightSource;
use crate::vector::EmbeddingConfig; use crate::vector::EmbeddingConfig;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
@ -204,26 +203,13 @@ pub struct HfEmbedderSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub revision: Setting<String>, pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub weight_source: Setting<WeightSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub normalize_embeddings: Setting<bool>,
} }
impl HfEmbedderSettings { impl HfEmbedderSettings {
pub fn apply(&mut self, new: Self) { pub fn apply(&mut self, new: Self) {
let HfEmbedderSettings { let HfEmbedderSettings { model, revision } = new;
model,
revision,
weight_source,
normalize_embeddings: normalize_embedding,
} = new;
self.model.apply(model); self.model.apply(model);
self.revision.apply(revision); self.revision.apply(revision);
self.weight_source.apply(weight_source);
self.normalize_embeddings.apply(normalize_embedding);
} }
} }
@ -232,15 +218,13 @@ impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
Self { Self {
model: Setting::Set(value.model), model: Setting::Set(value.model),
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
weight_source: Setting::Set(value.weight_source),
normalize_embeddings: Setting::Set(value.normalize_embeddings),
} }
} }
} }
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions { impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
fn from(value: HfEmbedderSettings) -> Self { fn from(value: HfEmbedderSettings) -> Self {
let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; let HfEmbedderSettings { model, revision } = value;
let mut this = Self::default(); let mut this = Self::default();
if let Some(model) = model.set() { if let Some(model) = model.set() {
this.model = model; this.model = model;
@ -248,12 +232,6 @@ impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
if let Some(revision) = revision.set() { if let Some(revision) = revision.set() {
this.revision = Some(revision); this.revision = Some(revision);
} }
if let Some(weight_source) = weight_source.set() {
this.weight_source = weight_source;
}
if let Some(normalize_embeddings) = normalize_embeddings.set() {
this.normalize_embeddings = normalize_embeddings;
}
this this
} }
} }