diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index c11e6ddc6..53e8a041b 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -34,6 +34,9 @@ pub struct EmbedderOptions { #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] pub enum EmbeddingModel { + // # WARNING + // + // If ever adding a model, make sure to add it to the list of supported models below. #[default] #[serde(rename = "text-embedding-ada-002")] #[deserr(rename = "text-embedding-ada-002")] @@ -41,6 +44,10 @@ pub enum EmbeddingModel { } impl EmbeddingModel { + pub fn supported_models() -> &'static [&'static str] { + &["text-embedding-ada-002"] + } + pub fn max_token(&self) -> usize { match self { EmbeddingModel::TextEmbeddingAda002 => 8191, @@ -59,7 +66,7 @@ impl EmbeddingModel { } } - pub fn from_name(name: &'static str) -> Option { + pub fn from_name(name: &str) -> Option { match name { "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), _ => None, diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 1826c040d..945fc62c0 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -4,32 +4,189 @@ use serde::{Deserialize, Serialize}; use crate::prompt::PromptData; use crate::update::Setting; use crate::vector::EmbeddingConfig; +use crate::UserError; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] pub struct EmbeddingSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] - #[deserr(default, rename = "source")] - pub embedder_options: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] - pub document_template: Setting, + pub source: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub model: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub revision: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub api_key: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub dimensions: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub document_template: Setting, +} + +pub fn check_unset( + key: &Setting, + field: &'static str, + source: EmbedderSource, + embedder_name: &str, +) -> Result<(), UserError> { + if matches!(key, Setting::NotSet) { + Ok(()) + } else { + Err(UserError::InvalidFieldForSource { + embedder_name: embedder_name.to_owned(), + source_: source, + field, + allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source), + allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field), + }) + } +} + +pub fn check_set( + key: &Setting, + field: &'static str, + source: EmbedderSource, + embedder_name: &str, +) -> Result<(), UserError> { + if matches!(key, Setting::Set(_)) { + Ok(()) + } else { + Err(UserError::MissingFieldForSource { + field, + source_: source, + embedder_name: embedder_name.to_owned(), + }) + } +} + +impl EmbeddingSettings { + pub const SOURCE: &str = "source"; + pub const MODEL: &str = "model"; + pub const REVISION: &str = "revision"; + pub const API_KEY: &str = "apiKey"; + pub const DIMENSIONS: &str = "dimensions"; + pub const DOCUMENT_TEMPLATE: &str = "documentTemplate"; + + pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { + match field { + Self::SOURCE => { + &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] + } + Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], + Self::REVISION => &[EmbedderSource::HuggingFace], + Self::API_KEY => &[EmbedderSource::OpenAi], + Self::DIMENSIONS => &[EmbedderSource::UserProvided], + Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], + _other => unreachable!("unknown field"), + } + } + + pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] { + match source { + EmbedderSource::OpenAi => { + &[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE] + } + EmbedderSource::HuggingFace => { + &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] + } + EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], + } + } + + pub(crate) fn apply_default_source(setting: &mut Setting) { + if let Setting::Set(EmbeddingSettings { + source: source @ (Setting::NotSet | Setting::Reset), + .. + }) = setting + { + *source = Setting::Set(EmbedderSource::default()) + } + } +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum EmbedderSource { + #[default] + OpenAi, + HuggingFace, + UserProvided, +} + +impl std::fmt::Display for EmbedderSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + EmbedderSource::OpenAi => "openAi", + EmbedderSource::HuggingFace => "huggingFace", + EmbedderSource::UserProvided => "userProvided", + }; + f.write_str(s) + } } impl EmbeddingSettings { pub fn apply(&mut self, new: Self) { - let EmbeddingSettings { embedder_options, document_template: prompt } = new; - self.embedder_options.apply(embedder_options); - self.document_template.apply(prompt); + let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = + new; + let old_source = self.source; + self.source.apply(source); + // Reinitialize the whole setting object on a source change + if old_source != self.source { + *self = EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + }; + return; + } + + self.model.apply(model); + self.revision.apply(revision); + self.api_key.apply(api_key); + self.dimensions.apply(dimensions); + self.document_template.apply(document_template); } } impl From for EmbeddingSettings { fn from(value: EmbeddingConfig) -> Self { - Self { - embedder_options: Setting::Set(value.embedder_options.into()), - document_template: Setting::Set(value.prompt.into()), + let EmbeddingConfig { embedder_options, prompt } = value; + match embedder_options { + super::EmbedderOptions::HuggingFace(options) => Self { + source: Setting::Set(EmbedderSource::HuggingFace), + model: Setting::Set(options.model), + revision: options.revision.map(Setting::Set).unwrap_or_default(), + api_key: Setting::NotSet, + dimensions: Setting::NotSet, + document_template: Setting::Set(prompt.template), + }, + super::EmbedderOptions::OpenAi(options) => Self { + source: Setting::Set(EmbedderSource::OpenAi), + model: Setting::Set(options.embedding_model.name().to_owned()), + revision: Setting::NotSet, + api_key: options.api_key.map(Setting::Set).unwrap_or_default(), + dimensions: Setting::NotSet, + document_template: Setting::Set(prompt.template), + }, + super::EmbedderOptions::UserProvided(options) => Self { + source: Setting::Set(EmbedderSource::UserProvided), + model: Setting::NotSet, + revision: Setting::NotSet, + api_key: Setting::NotSet, + dimensions: Setting::Set(options.dimensions), + document_template: Setting::NotSet, + }, } } } @@ -37,262 +194,51 @@ impl From for EmbeddingSettings { impl From for EmbeddingConfig { fn from(value: EmbeddingSettings) -> Self { let mut this = Self::default(); - let EmbeddingSettings { embedder_options, document_template: prompt } = value; - if let Some(embedder_options) = embedder_options.set() { - this.embedder_options = embedder_options.into(); - } - if let Some(prompt) = prompt.set() { - this.prompt = prompt.into(); - } - this - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct PromptSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub template: Setting, -} - -impl PromptSettings { - pub fn apply(&mut self, new: Self) { - let PromptSettings { template } = new; - self.template.apply(template); - } -} - -impl From for PromptSettings { - fn from(value: PromptData) -> Self { - Self { template: Setting::Set(value.template) } - } -} - -impl From for PromptData { - fn from(value: PromptSettings) -> Self { - let mut this = PromptData::default(); - let PromptSettings { template } = value; - if let Some(template) = template.set() { - this.template = template; - } - this - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -pub enum EmbedderSettings { - HuggingFace(Setting), - OpenAi(Setting), - UserProvided(UserProvidedSettings), -} - -impl Deserr for EmbedderSettings -where - E: deserr::DeserializeError, -{ - fn deserialize_from_value( - value: deserr::Value, - location: deserr::ValuePointerRef, - ) -> Result { - match value { - deserr::Value::Map(map) => { - if deserr::Map::len(&map) != 1 { - return Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::Unexpected { - msg: format!( - "Expected a single field, got {} fields", - deserr::Map::len(&map) - ), - }, - location, - ))); + let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = + value; + if let Some(source) = source.set() { + match source { + EmbedderSource::OpenAi => { + let mut options = super::openai::EmbedderOptions::with_default_model(None); + if let Some(model) = model.set() { + if let Some(model) = super::openai::EmbeddingModel::from_name(&model) { + options.embedding_model = model; + } + } + if let Some(api_key) = api_key.set() { + options.api_key = Some(api_key); + } + this.embedder_options = super::EmbedderOptions::OpenAi(options); } - let mut it = deserr::Map::into_iter(map); - let (k, v) = it.next().unwrap(); - - match k.as_str() { - "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( - HfEmbedderSettings::deserialize_from_value( - v.into_value(), - location.push_key(&k), - )?, - ))), - "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( - OpenAiEmbedderSettings::deserialize_from_value( - v.into_value(), - location.push_key(&k), - )?, - ))), - "userProvided" => Ok(EmbedderSettings::UserProvided( - UserProvidedSettings::deserialize_from_value( - v.into_value(), - location.push_key(&k), - )?, - )), - other => Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::UnknownKey { - key: other, - accepted: &["huggingFace", "openAi", "userProvided"], - }, - location, - ))), + EmbedderSource::HuggingFace => { + let mut options = super::hf::EmbedderOptions::default(); + if let Some(model) = model.set() { + options.model = model; + // Reset the revision if we are setting the model. + // This allows the following: + // "huggingFace": {} -> default model with default revision + // "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision + // "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision + options.revision = None; + } + if let Some(revision) = revision.set() { + options.revision = Some(revision); + } + this.embedder_options = super::EmbedderOptions::HuggingFace(options); + } + EmbedderSource::UserProvided => { + this.embedder_options = + super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { + dimensions: dimensions.set().unwrap(), + }); } } - _ => Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::IncorrectValueKind { - actual: value, - accepted: &[deserr::ValueKind::Map], - }, - location, - ))), } - } -} -impl Default for EmbedderSettings { - fn default() -> Self { - Self::OpenAi(Default::default()) - } -} - -impl From for EmbedderSettings { - fn from(value: crate::vector::EmbedderOptions) -> Self { - match value { - crate::vector::EmbedderOptions::HuggingFace(hf) => { - Self::HuggingFace(Setting::Set(hf.into())) - } - crate::vector::EmbedderOptions::OpenAi(openai) => { - Self::OpenAi(Setting::Set(openai.into())) - } - crate::vector::EmbedderOptions::UserProvided(user_provided) => { - Self::UserProvided(user_provided.into()) - } + if let Setting::Set(template) = document_template { + this.prompt = PromptData { template } } - } -} -impl From for crate::vector::EmbedderOptions { - fn from(value: EmbedderSettings) -> Self { - match value { - EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), - EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), - EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), - EmbedderSettings::OpenAi(_setting) => { - Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None)) - } - EmbedderSettings::UserProvided(user_provided) => { - Self::UserProvided(user_provided.into()) - } - } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct HfEmbedderSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub model: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub revision: Setting, -} - -impl HfEmbedderSettings { - pub fn apply(&mut self, new: Self) { - let HfEmbedderSettings { model, revision } = new; - self.model.apply(model); - self.revision.apply(revision); - } -} - -impl From for HfEmbedderSettings { - fn from(value: crate::vector::hf::EmbedderOptions) -> Self { - Self { - model: Setting::Set(value.model), - revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), - } - } -} - -impl From for crate::vector::hf::EmbedderOptions { - fn from(value: HfEmbedderSettings) -> Self { - let HfEmbedderSettings { model, revision } = value; - let mut this = Self::default(); - if let Some(model) = model.set() { - this.model = model; - // Reset the revision if we are setting the model. - // This allows the following: - // "huggingFace": {} -> default model with default revision - // "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision - // "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision - this.revision = None; - } - if let Some(revision) = revision.set() { - this.revision = Some(revision); - } this } } - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct OpenAiEmbedderSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub api_key: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")] - #[deserr(default, rename = "model")] - pub embedding_model: Setting, -} - -impl OpenAiEmbedderSettings { - pub fn apply(&mut self, new: Self) { - let Self { api_key, embedding_model: embedding_mode } = new; - self.api_key.apply(api_key); - self.embedding_model.apply(embedding_mode); - } -} - -impl From for OpenAiEmbedderSettings { - fn from(value: crate::vector::openai::EmbedderOptions) -> Self { - Self { - api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset), - embedding_model: Setting::Set(value.embedding_model), - } - } -} - -impl From for crate::vector::openai::EmbedderOptions { - fn from(value: OpenAiEmbedderSettings) -> Self { - let OpenAiEmbedderSettings { api_key, embedding_model } = value; - Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct UserProvidedSettings { - pub dimensions: usize, -} - -impl From for crate::vector::manual::EmbedderOptions { - fn from(value: UserProvidedSettings) -> Self { - Self { dimensions: value.dimensions } - } -} - -impl From for UserProvidedSettings { - fn from(value: crate::vector::manual::EmbedderOptions) -> Self { - Self { dimensions: value.dimensions } - } -}