From a25456120d352fd10729ae2751f4774f61840f9d Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 27 Mar 2024 11:51:04 +0100 Subject: [PATCH] Expose distribution in settings --- milli/src/vector/settings.rs | 55 ++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index c277dd0cf..b13b84178 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -2,7 +2,7 @@ use deserr::Deserr; use serde::{Deserialize, Serialize}; use super::rest::InputType; -use super::{ollama, openai}; +use super::{ollama, openai, DistributionShift}; use crate::prompt::PromptData; use crate::update::Setting; use crate::vector::EmbeddingConfig; @@ -48,6 +48,9 @@ pub struct EmbeddingSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub input_type: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub distribution: Setting, } pub fn check_unset( @@ -101,6 +104,8 @@ impl EmbeddingSettings { pub const EMBEDDING_OBJECT: &'static str = "embeddingObject"; pub const INPUT_TYPE: &'static str = "inputType"; + pub const DISTRIBUTION: &'static str = "distribution"; + pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { match field { Self::SOURCE => &[ @@ -132,6 +137,13 @@ impl EmbeddingSettings { Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest], Self::INPUT_TYPE => &[EmbedderSource::Rest], + Self::DISTRIBUTION => &[ + EmbedderSource::HuggingFace, + EmbedderSource::Ollama, + EmbedderSource::OpenAi, + EmbedderSource::Rest, + EmbedderSource::UserProvided, + ], _other => unreachable!("unknown field"), } } @@ -144,14 +156,24 @@ impl EmbeddingSettings { Self::API_KEY, Self::DOCUMENT_TEMPLATE, Self::DIMENSIONS, + Self::DISTRIBUTION, ], - EmbedderSource::HuggingFace => { - &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] - } - EmbedderSource::Ollama => { - &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE, Self::URL, Self::API_KEY] - } - EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], + EmbedderSource::HuggingFace => &[ + Self::SOURCE, + Self::MODEL, + Self::REVISION, + Self::DOCUMENT_TEMPLATE, + Self::DISTRIBUTION, + ], + EmbedderSource::Ollama => &[ + Self::SOURCE, + Self::MODEL, + Self::DOCUMENT_TEMPLATE, + Self::URL, + Self::API_KEY, + Self::DISTRIBUTION, + ], + EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION], EmbedderSource::Rest => &[ Self::SOURCE, Self::API_KEY, @@ -163,6 +185,7 @@ impl EmbeddingSettings { Self::PATH_TO_EMBEDDINGS, Self::EMBEDDING_OBJECT, Self::INPUT_TYPE, + Self::DISTRIBUTION, ], } } @@ -283,6 +306,7 @@ impl From for EmbeddingSettings { path_to_embeddings: Setting::NotSet, embedding_object: Setting::NotSet, input_type: Setting::NotSet, + distribution: options.distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::OpenAi(options) => Self { source: Setting::Set(EmbedderSource::OpenAi), @@ -297,6 +321,7 @@ impl From for EmbeddingSettings { path_to_embeddings: Setting::NotSet, embedding_object: Setting::NotSet, input_type: Setting::NotSet, + distribution: options.distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::Ollama(options) => Self { source: Setting::Set(EmbedderSource::Ollama), @@ -311,6 +336,7 @@ impl From for EmbeddingSettings { path_to_embeddings: Setting::NotSet, embedding_object: Setting::NotSet, input_type: Setting::NotSet, + distribution: options.distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::UserProvided(options) => Self { source: Setting::Set(EmbedderSource::UserProvided), @@ -325,11 +351,10 @@ impl From for EmbeddingSettings { path_to_embeddings: Setting::NotSet, embedding_object: Setting::NotSet, input_type: Setting::NotSet, + distribution: options.distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::Rest(super::rest::EmbedderOptions { api_key, - // TODO: support distribution - distribution: _, dimensions, url, query, @@ -337,6 +362,7 @@ impl From for EmbeddingSettings { path_to_embeddings, embedding_object, input_type, + distribution, }) => Self { source: Setting::Set(EmbedderSource::Rest), model: Setting::NotSet, @@ -350,6 +376,7 @@ impl From for EmbeddingSettings { path_to_embeddings: Setting::Set(path_to_embeddings), embedding_object: Setting::Set(embedding_object), input_type: Setting::Set(input_type), + distribution: distribution.map(Setting::Set).unwrap_or_default(), }, } } @@ -371,7 +398,9 @@ impl From for EmbeddingConfig { path_to_embeddings, embedding_object, input_type, + distribution, } = value; + if let Some(source) = source.set() { match source { EmbedderSource::OpenAi => { @@ -387,6 +416,7 @@ impl From for EmbeddingConfig { if let Some(dimensions) = dimensions.set() { options.dimensions = Some(dimensions); } + options.distribution = distribution.set(); this.embedder_options = super::EmbedderOptions::OpenAi(options); } EmbedderSource::Ollama => { @@ -399,6 +429,7 @@ impl From for EmbeddingConfig { options.embedding_model = model; } + options.distribution = distribution.set(); this.embedder_options = super::EmbedderOptions::Ollama(options); } EmbedderSource::HuggingFace => { @@ -415,12 +446,14 @@ impl From for EmbeddingConfig { if let Some(revision) = revision.set() { options.revision = Some(revision); } + options.distribution = distribution.set(); this.embedder_options = super::EmbedderOptions::HuggingFace(options); } EmbedderSource::UserProvided => { this.embedder_options = super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { dimensions: dimensions.set().unwrap(), + distribution: distribution.set(), }); } EmbedderSource::Rest => { @@ -429,7 +462,6 @@ impl From for EmbeddingConfig { this.embedder_options = super::EmbedderOptions::Rest(super::rest::EmbedderOptions { api_key: api_key.set(), - distribution: None, dimensions: dimensions.set(), url: url.set().unwrap(), query: query.set().unwrap_or(embedder_options.query), @@ -441,6 +473,7 @@ impl From for EmbeddingConfig { .set() .unwrap_or(embedder_options.embedding_object), input_type: input_type.set().unwrap_or(embedder_options.input_type), + distribution: distribution.set(), }) } }