diff --git a/crates/milli/src/update/index_documents/mod.rs b/crates/milli/src/update/index_documents/mod.rs index 56c26ed29..d62128eaa 100644 --- a/crates/milli/src/update/index_documents/mod.rs +++ b/crates/milli/src/update/index_documents/mod.rs @@ -2763,6 +2763,7 @@ mod tests { source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided), model: Setting::NotSet, revision: Setting::NotSet, + pooling: Setting::NotSet, api_key: Setting::NotSet, dimensions: Setting::Set(3), document_template: Setting::NotSet, diff --git a/crates/milli/src/update/settings.rs b/crates/milli/src/update/settings.rs index 85259c2d0..0d0648fc8 100644 --- a/crates/milli/src/update/settings.rs +++ b/crates/milli/src/update/settings.rs @@ -1676,6 +1676,7 @@ fn validate_prompt( source, model, revision, + pooling, api_key, dimensions, document_template: Setting::Set(template), @@ -1709,6 +1710,7 @@ fn validate_prompt( source, model, revision, + pooling, api_key, dimensions, document_template: Setting::Set(template), @@ -1735,6 +1737,7 @@ pub fn validate_embedding_settings( source, model, revision, + pooling, api_key, dimensions, document_template, @@ -1776,6 +1779,7 @@ pub fn validate_embedding_settings( source, model, revision, + pooling, api_key, dimensions, document_template, @@ -1791,6 +1795,7 @@ pub fn validate_embedding_settings( match inferred_source { EmbedderSource::OpenAi => { check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; + check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?; check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; @@ -1829,6 +1834,7 @@ pub fn validate_embedding_settings( EmbedderSource::Ollama => { check_set(&model, EmbeddingSettings::MODEL, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; + check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?; check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; @@ -1846,6 +1852,7 @@ pub fn validate_embedding_settings( EmbedderSource::UserProvided => { check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; + check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?; check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?; check_unset( &document_template, @@ -1869,6 +1876,7 @@ pub fn validate_embedding_settings( EmbedderSource::Rest => { check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; + check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?; check_set(&url, EmbeddingSettings::URL, inferred_source, name)?; check_set(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_set(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; @@ -1878,6 +1886,7 @@ pub fn validate_embedding_settings( source, model, revision, + pooling, api_key, dimensions, document_template, diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index 9ec34daef..b01a66255 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -34,6 +34,30 @@ pub struct EmbedderOptions { pub model: String, pub revision: Option, pub distribution: Option, + #[serde(default)] + pub pooling: OverridePooling, +} + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Deserialize, + serde::Serialize, + utoipa::ToSchema, + deserr::Deserr, +)] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +#[serde(rename_all = "camelCase")] +pub enum OverridePooling { + UseModel, + ForceCls, + #[default] + ForceMean, } impl EmbedderOptions { @@ -42,6 +66,7 @@ impl EmbedderOptions { model: "BAAI/bge-base-en-v1.5".to_string(), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), distribution: None, + pooling: OverridePooling::UseModel, } } } @@ -95,6 +120,15 @@ pub enum Pooling { MeanSqrtLen, LastToken, } +impl Pooling { + fn override_with(&mut self, pooling: OverridePooling) { + match pooling { + OverridePooling::UseModel => {} + OverridePooling::ForceCls => *self = Pooling::Cls, + OverridePooling::ForceMean => *self = Pooling::Mean, + } + } +} impl From for Pooling { fn from(value: PoolingConfig) -> Self { @@ -151,7 +185,7 @@ impl Embedder { } Err(error) => return Err(NewEmbedderError::api_get(error)), }; - let pooling: Pooling = match pooling { + let mut pooling: Pooling = match pooling { Some(pooling_filename) => { let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| { NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner) @@ -170,6 +204,8 @@ impl Embedder { None => Pooling::default(), }; + pooling.override_with(options.pooling); + (config, tokenizer, weights, source, pooling) }; diff --git a/crates/milli/src/vector/settings.rs b/crates/milli/src/vector/settings.rs index 86028c1c4..f10407e42 100644 --- a/crates/milli/src/vector/settings.rs +++ b/crates/milli/src/vector/settings.rs @@ -6,6 +6,7 @@ use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; +use super::hf::OverridePooling; use super::{ollama, openai, DistributionShift}; use crate::prompt::{default_max_bytes, PromptData}; use crate::update::Setting; @@ -30,6 +31,10 @@ pub struct EmbeddingSettings { pub revision: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] + #[schema(value_type = Option)] + pub pooling: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] #[schema(value_type = Option)] pub api_key: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] @@ -164,6 +169,7 @@ impl SettingsDiff { mut source, mut model, mut revision, + mut pooling, mut api_key, mut dimensions, mut document_template, @@ -180,6 +186,7 @@ impl SettingsDiff { source: new_source, model: new_model, revision: new_revision, + pooling: new_pooling, api_key: new_api_key, dimensions: new_dimensions, document_template: new_document_template, @@ -210,6 +217,7 @@ impl SettingsDiff { &source, &mut model, &mut revision, + &mut pooling, &mut dimensions, &mut url, &mut request, @@ -225,6 +233,9 @@ impl SettingsDiff { if revision.apply(new_revision) { ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); } + if pooling.apply(new_pooling) { + ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); + } if dimensions.apply(new_dimensions) { match source { // regenerate on dimensions change in OpenAI since truncation is supported @@ -290,6 +301,7 @@ impl SettingsDiff { source, model, revision, + pooling, api_key, dimensions, document_template, @@ -338,6 +350,7 @@ fn apply_default_for_source( source: &Setting, model: &mut Setting, revision: &mut Setting, + pooling: &mut Setting, dimensions: &mut Setting, url: &mut Setting, request: &mut Setting, @@ -350,6 +363,7 @@ fn apply_default_for_source( Setting::Set(EmbedderSource::HuggingFace) => { *model = Setting::Reset; *revision = Setting::Reset; + *pooling = Setting::Reset; *dimensions = Setting::NotSet; *url = Setting::NotSet; *request = Setting::NotSet; @@ -359,6 +373,7 @@ fn apply_default_for_source( Setting::Set(EmbedderSource::Ollama) => { *model = Setting::Reset; *revision = Setting::NotSet; + *pooling = Setting::NotSet; *dimensions = Setting::Reset; *url = Setting::NotSet; *request = Setting::NotSet; @@ -368,6 +383,7 @@ fn apply_default_for_source( Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => { *model = Setting::Reset; *revision = Setting::NotSet; + *pooling = Setting::NotSet; *dimensions = Setting::NotSet; *url = Setting::Reset; *request = Setting::NotSet; @@ -377,6 +393,7 @@ fn apply_default_for_source( Setting::Set(EmbedderSource::Rest) => { *model = Setting::NotSet; *revision = Setting::NotSet; + *pooling = Setting::NotSet; *dimensions = Setting::Reset; *url = Setting::Reset; *request = Setting::Reset; @@ -386,6 +403,7 @@ fn apply_default_for_source( Setting::Set(EmbedderSource::UserProvided) => { *model = Setting::NotSet; *revision = Setting::NotSet; + *pooling = Setting::NotSet; *dimensions = Setting::Reset; *url = Setting::NotSet; *request = Setting::NotSet; @@ -419,6 +437,7 @@ impl EmbeddingSettings { pub const SOURCE: &'static str = "source"; pub const MODEL: &'static str = "model"; pub const REVISION: &'static str = "revision"; + pub const POOLING: &'static str = "pooling"; pub const API_KEY: &'static str = "apiKey"; pub const DIMENSIONS: &'static str = "dimensions"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; @@ -446,6 +465,7 @@ impl EmbeddingSettings { &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] } Self::REVISION => &[EmbedderSource::HuggingFace], + Self::POOLING => &[EmbedderSource::HuggingFace], Self::API_KEY => { &[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest] } @@ -500,6 +520,7 @@ impl EmbeddingSettings { Self::SOURCE, Self::MODEL, Self::REVISION, + Self::POOLING, Self::DOCUMENT_TEMPLATE, Self::DOCUMENT_TEMPLATE_MAX_BYTES, Self::DISTRIBUTION, @@ -592,10 +613,12 @@ impl From for EmbeddingSettings { model, revision, distribution, + pooling, }) => Self { source: Setting::Set(EmbedderSource::HuggingFace), model: Setting::Set(model), revision: Setting::some_or_not_set(revision), + pooling: Setting::Set(pooling), api_key: Setting::NotSet, dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), @@ -617,6 +640,7 @@ impl From for EmbeddingSettings { source: Setting::Set(EmbedderSource::OpenAi), model: Setting::Set(embedding_model.name().to_owned()), revision: Setting::NotSet, + pooling: Setting::NotSet, api_key: Setting::some_or_not_set(api_key), dimensions: Setting::some_or_not_set(dimensions), document_template: Setting::Set(prompt.template), @@ -638,6 +662,7 @@ impl From for EmbeddingSettings { source: Setting::Set(EmbedderSource::Ollama), model: Setting::Set(embedding_model), revision: Setting::NotSet, + pooling: Setting::NotSet, api_key: Setting::some_or_not_set(api_key), dimensions: Setting::some_or_not_set(dimensions), document_template: Setting::Set(prompt.template), @@ -656,6 +681,7 @@ impl From for EmbeddingSettings { source: Setting::Set(EmbedderSource::UserProvided), model: Setting::NotSet, revision: Setting::NotSet, + pooling: Setting::NotSet, api_key: Setting::NotSet, dimensions: Setting::Set(dimensions), document_template: Setting::NotSet, @@ -679,6 +705,7 @@ impl From for EmbeddingSettings { source: Setting::Set(EmbedderSource::Rest), model: Setting::NotSet, revision: Setting::NotSet, + pooling: Setting::NotSet, api_key: Setting::some_or_not_set(api_key), dimensions: Setting::some_or_not_set(dimensions), document_template: Setting::Set(prompt.template), @@ -701,6 +728,7 @@ impl From for EmbeddingConfig { source, model, revision, + pooling, api_key, dimensions, document_template, @@ -764,6 +792,9 @@ impl From for EmbeddingConfig { if let Some(revision) = revision.set() { options.revision = Some(revision); } + if let Some(pooling) = pooling.set() { + options.pooling = pooling; + } options.distribution = distribution.set(); this.embedder_options = super::EmbedderOptions::HuggingFace(options); }