Expose distribution in settings

This commit is contained in:
Louis Dureuil 2024-03-27 11:51:04 +01:00
parent 168ded3b9d
commit a25456120d
No known key found for this signature in database

View File

@ -2,7 +2,7 @@ use deserr::Deserr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::rest::InputType; use super::rest::InputType;
use super::{ollama, openai}; use super::{ollama, openai, DistributionShift};
use crate::prompt::PromptData; use crate::prompt::PromptData;
use crate::update::Setting; use crate::update::Setting;
use crate::vector::EmbeddingConfig; use crate::vector::EmbeddingConfig;
@ -48,6 +48,9 @@ pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub input_type: Setting<InputType>, pub input_type: Setting<InputType>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub distribution: Setting<DistributionShift>,
} }
pub fn check_unset<T>( pub fn check_unset<T>(
@ -101,6 +104,8 @@ impl EmbeddingSettings {
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject"; pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
pub const INPUT_TYPE: &'static str = "inputType"; pub const INPUT_TYPE: &'static str = "inputType";
pub const DISTRIBUTION: &'static str = "distribution";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field { match field {
Self::SOURCE => &[ Self::SOURCE => &[
@ -132,6 +137,13 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest], Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
Self::INPUT_TYPE => &[EmbedderSource::Rest], Self::INPUT_TYPE => &[EmbedderSource::Rest],
Self::DISTRIBUTION => &[
EmbedderSource::HuggingFace,
EmbedderSource::Ollama,
EmbedderSource::OpenAi,
EmbedderSource::Rest,
EmbedderSource::UserProvided,
],
_other => unreachable!("unknown field"), _other => unreachable!("unknown field"),
} }
} }
@ -144,14 +156,24 @@ impl EmbeddingSettings {
Self::API_KEY, Self::API_KEY,
Self::DOCUMENT_TEMPLATE, Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS, Self::DIMENSIONS,
Self::DISTRIBUTION,
], ],
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => &[
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] Self::SOURCE,
} Self::MODEL,
EmbedderSource::Ollama => { Self::REVISION,
&[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE, Self::URL, Self::API_KEY] Self::DOCUMENT_TEMPLATE,
} Self::DISTRIBUTION,
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], ],
EmbedderSource::Ollama => &[
Self::SOURCE,
Self::MODEL,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::API_KEY,
Self::DISTRIBUTION,
],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION],
EmbedderSource::Rest => &[ EmbedderSource::Rest => &[
Self::SOURCE, Self::SOURCE,
Self::API_KEY, Self::API_KEY,
@ -163,6 +185,7 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS, Self::PATH_TO_EMBEDDINGS,
Self::EMBEDDING_OBJECT, Self::EMBEDDING_OBJECT,
Self::INPUT_TYPE, Self::INPUT_TYPE,
Self::DISTRIBUTION,
], ],
} }
} }
@ -283,6 +306,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::OpenAi(options) => Self { super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi), source: Setting::Set(EmbedderSource::OpenAi),
@ -297,6 +321,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::Ollama(options) => Self { super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama), source: Setting::Set(EmbedderSource::Ollama),
@ -311,6 +336,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::UserProvided(options) => Self { super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided), source: Setting::Set(EmbedderSource::UserProvided),
@ -325,11 +351,10 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::Rest(super::rest::EmbedderOptions { super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key, api_key,
// TODO: support distribution
distribution: _,
dimensions, dimensions,
url, url,
query, query,
@ -337,6 +362,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
}) => Self { }) => Self {
source: Setting::Set(EmbedderSource::Rest), source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet, model: Setting::NotSet,
@ -350,6 +376,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::Set(path_to_embeddings), path_to_embeddings: Setting::Set(path_to_embeddings),
embedding_object: Setting::Set(embedding_object), embedding_object: Setting::Set(embedding_object),
input_type: Setting::Set(input_type), input_type: Setting::Set(input_type),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
}, },
} }
} }
@ -371,7 +398,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
} = value; } = value;
if let Some(source) = source.set() { if let Some(source) = source.set() {
match source { match source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
@ -387,6 +416,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(dimensions) = dimensions.set() { if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions); options.dimensions = Some(dimensions);
} }
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::OpenAi(options); this.embedder_options = super::EmbedderOptions::OpenAi(options);
} }
EmbedderSource::Ollama => { EmbedderSource::Ollama => {
@ -399,6 +429,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
options.embedding_model = model; options.embedding_model = model;
} }
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::Ollama(options); this.embedder_options = super::EmbedderOptions::Ollama(options);
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {
@ -415,12 +446,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(revision) = revision.set() { if let Some(revision) = revision.set() {
options.revision = Some(revision); options.revision = Some(revision);
} }
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::HuggingFace(options); this.embedder_options = super::EmbedderOptions::HuggingFace(options);
} }
EmbedderSource::UserProvided => { EmbedderSource::UserProvided => {
this.embedder_options = this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(), dimensions: dimensions.set().unwrap(),
distribution: distribution.set(),
}); });
} }
EmbedderSource::Rest => { EmbedderSource::Rest => {
@ -429,7 +462,6 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
this.embedder_options = this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions { super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(), api_key: api_key.set(),
distribution: None,
dimensions: dimensions.set(), dimensions: dimensions.set(),
url: url.set().unwrap(), url: url.set().unwrap(),
query: query.set().unwrap_or(embedder_options.query), query: query.set().unwrap_or(embedder_options.query),
@ -441,6 +473,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
.set() .set()
.unwrap_or(embedder_options.embedding_object), .unwrap_or(embedder_options.embedding_object),
input_type: input_type.set().unwrap_or(embedder_options.input_type), input_type: input_type.set().unwrap_or(embedder_options.input_type),
distribution: distribution.set(),
}) })
} }
} }