Remove some settings

This commit is contained in:
Louis Dureuil 2023-12-13 23:09:50 +01:00
parent 3c1a14f1cd
commit 5b51cb04af
No known key found for this signature in database
2 changed files with 15 additions and 53 deletions

View File

@ -23,7 +23,7 @@ use super::{Embedding, Embeddings};
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum WeightSource {
enum WeightSource {
#[default]
Safetensors,
Pytorch,
@ -33,20 +33,13 @@ pub enum WeightSource {
pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
pub weight_source: WeightSource,
pub normalize_embeddings: bool,
}
impl EmbedderOptions {
pub fn new() -> Self {
Self {
//model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
model: "BAAI/bge-base-en-v1.5".to_string(),
//revision: Some("refs/pr/21".to_string()),
revision: None,
//weight_source: Default::default(),
weight_source: WeightSource::Pytorch,
normalize_embeddings: true,
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
}
}
}
@ -82,20 +75,21 @@ impl Embedder {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),
};
let (config_filename, tokenizer_filename, weights_filename) = {
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let weights = match options.weight_source {
WeightSource::Pytorch => {
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
}
WeightSource::Safetensors => {
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
}
let (weights, source) = {
api.get("pytorch_model.bin")
.map(|filename| (filename, WeightSource::Pytorch))
.or_else(|_| {
api.get("model.safetensors")
.map(|filename| (filename, WeightSource::Safetensors))
})
.map_err(NewEmbedderError::api_get)?
};
(config, tokenizer, weights)
(config, tokenizer, weights, source)
};
let config = std::fs::read_to_string(&config_filename)
@ -106,7 +100,7 @@ impl Embedder {
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let vb = match options.weight_source {
let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?,
WeightSource::Safetensors => unsafe {
@ -168,12 +162,6 @@ impl Embedder {
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embeddings: Tensor = if self.options.normalize_embeddings {
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
} else {
embeddings
};
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
@ -197,7 +185,3 @@ impl Embedder {
self.dimensions
}
}
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::hf::WeightSource;
use crate::vector::EmbeddingConfig;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
@ -204,26 +203,13 @@ pub struct HfEmbedderSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub weight_source: Setting<WeightSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub normalize_embeddings: Setting<bool>,
}
impl HfEmbedderSettings {
pub fn apply(&mut self, new: Self) {
let HfEmbedderSettings {
model,
revision,
weight_source,
normalize_embeddings: normalize_embedding,
} = new;
let HfEmbedderSettings { model, revision } = new;
self.model.apply(model);
self.revision.apply(revision);
self.weight_source.apply(weight_source);
self.normalize_embeddings.apply(normalize_embedding);
}
}
@ -232,15 +218,13 @@ impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
Self {
model: Setting::Set(value.model),
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
weight_source: Setting::Set(value.weight_source),
normalize_embeddings: Setting::Set(value.normalize_embeddings),
}
}
}
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
fn from(value: HfEmbedderSettings) -> Self {
let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value;
let HfEmbedderSettings { model, revision } = value;
let mut this = Self::default();
if let Some(model) = model.set() {
this.model = model;
@ -248,12 +232,6 @@ impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
if let Some(revision) = revision.set() {
this.revision = Some(revision);
}
if let Some(weight_source) = weight_source.set() {
this.weight_source = weight_source;
}
if let Some(normalize_embeddings) = normalize_embeddings.set() {
this.normalize_embeddings = normalize_embeddings;
}
this
}
}