mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-26 12:05:05 +08:00
Remove some settings
This commit is contained in:
parent
3c1a14f1cd
commit
5b51cb04af
@ -23,7 +23,7 @@ use super::{Embedding, Embeddings};
|
|||||||
)]
|
)]
|
||||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||||
pub enum WeightSource {
|
enum WeightSource {
|
||||||
#[default]
|
#[default]
|
||||||
Safetensors,
|
Safetensors,
|
||||||
Pytorch,
|
Pytorch,
|
||||||
@ -33,20 +33,13 @@ pub enum WeightSource {
|
|||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub revision: Option<String>,
|
pub revision: Option<String>,
|
||||||
pub weight_source: WeightSource,
|
|
||||||
pub normalize_embeddings: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
//model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
|
||||||
model: "BAAI/bge-base-en-v1.5".to_string(),
|
model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||||
//revision: Some("refs/pr/21".to_string()),
|
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
|
||||||
revision: None,
|
|
||||||
//weight_source: Default::default(),
|
|
||||||
weight_source: WeightSource::Pytorch,
|
|
||||||
normalize_embeddings: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -82,20 +75,21 @@ impl Embedder {
|
|||||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||||
None => Repo::model(options.model.clone()),
|
None => Repo::model(options.model.clone()),
|
||||||
};
|
};
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
|
||||||
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||||
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
||||||
let weights = match options.weight_source {
|
let (weights, source) = {
|
||||||
WeightSource::Pytorch => {
|
api.get("pytorch_model.bin")
|
||||||
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
|
.map(|filename| (filename, WeightSource::Pytorch))
|
||||||
}
|
.or_else(|_| {
|
||||||
WeightSource::Safetensors => {
|
api.get("model.safetensors")
|
||||||
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
|
.map(|filename| (filename, WeightSource::Safetensors))
|
||||||
}
|
})
|
||||||
|
.map_err(NewEmbedderError::api_get)?
|
||||||
};
|
};
|
||||||
(config, tokenizer, weights)
|
(config, tokenizer, weights, source)
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = std::fs::read_to_string(&config_filename)
|
let config = std::fs::read_to_string(&config_filename)
|
||||||
@ -106,7 +100,7 @@ impl Embedder {
|
|||||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||||
|
|
||||||
let vb = match options.weight_source {
|
let vb = match weight_source {
|
||||||
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
||||||
.map_err(NewEmbedderError::pytorch_weight)?,
|
.map_err(NewEmbedderError::pytorch_weight)?,
|
||||||
WeightSource::Safetensors => unsafe {
|
WeightSource::Safetensors => unsafe {
|
||||||
@ -168,12 +162,6 @@ impl Embedder {
|
|||||||
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||||
.map_err(EmbedError::tensor_shape)?;
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
|
||||||
let embeddings: Tensor = if self.options.normalize_embeddings {
|
|
||||||
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
|
|
||||||
} else {
|
|
||||||
embeddings
|
|
||||||
};
|
|
||||||
|
|
||||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||||
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
||||||
}
|
}
|
||||||
@ -197,7 +185,3 @@ impl Embedder {
|
|||||||
self.dimensions
|
self.dimensions
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
|
||||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
|
||||||
}
|
|
||||||
|
@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::prompt::PromptData;
|
use crate::prompt::PromptData;
|
||||||
use crate::update::Setting;
|
use crate::update::Setting;
|
||||||
use crate::vector::hf::WeightSource;
|
|
||||||
use crate::vector::EmbeddingConfig;
|
use crate::vector::EmbeddingConfig;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||||
@ -204,26 +203,13 @@ pub struct HfEmbedderSettings {
|
|||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
#[deserr(default)]
|
#[deserr(default)]
|
||||||
pub revision: Setting<String>,
|
pub revision: Setting<String>,
|
||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
||||||
#[deserr(default)]
|
|
||||||
pub weight_source: Setting<WeightSource>,
|
|
||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
||||||
#[deserr(default)]
|
|
||||||
pub normalize_embeddings: Setting<bool>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HfEmbedderSettings {
|
impl HfEmbedderSettings {
|
||||||
pub fn apply(&mut self, new: Self) {
|
pub fn apply(&mut self, new: Self) {
|
||||||
let HfEmbedderSettings {
|
let HfEmbedderSettings { model, revision } = new;
|
||||||
model,
|
|
||||||
revision,
|
|
||||||
weight_source,
|
|
||||||
normalize_embeddings: normalize_embedding,
|
|
||||||
} = new;
|
|
||||||
self.model.apply(model);
|
self.model.apply(model);
|
||||||
self.revision.apply(revision);
|
self.revision.apply(revision);
|
||||||
self.weight_source.apply(weight_source);
|
|
||||||
self.normalize_embeddings.apply(normalize_embedding);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,15 +218,13 @@ impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
|
|||||||
Self {
|
Self {
|
||||||
model: Setting::Set(value.model),
|
model: Setting::Set(value.model),
|
||||||
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
|
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
|
||||||
weight_source: Setting::Set(value.weight_source),
|
|
||||||
normalize_embeddings: Setting::Set(value.normalize_embeddings),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
||||||
fn from(value: HfEmbedderSettings) -> Self {
|
fn from(value: HfEmbedderSettings) -> Self {
|
||||||
let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value;
|
let HfEmbedderSettings { model, revision } = value;
|
||||||
let mut this = Self::default();
|
let mut this = Self::default();
|
||||||
if let Some(model) = model.set() {
|
if let Some(model) = model.set() {
|
||||||
this.model = model;
|
this.model = model;
|
||||||
@ -248,12 +232,6 @@ impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
|||||||
if let Some(revision) = revision.set() {
|
if let Some(revision) = revision.set() {
|
||||||
this.revision = Some(revision);
|
this.revision = Some(revision);
|
||||||
}
|
}
|
||||||
if let Some(weight_source) = weight_source.set() {
|
|
||||||
this.weight_source = weight_source;
|
|
||||||
}
|
|
||||||
if let Some(normalize_embeddings) = normalize_embeddings.set() {
|
|
||||||
this.normalize_embeddings = normalize_embeddings;
|
|
||||||
}
|
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user