Allow overriding pooling method

This commit is contained in:
Louis Dureuil 2025-02-18 17:12:23 +01:00
parent 11759c4be4
commit 7b4ce468a6
No known key found for this signature in database
4 changed files with 78 additions and 1 deletions

View File

@ -2763,6 +2763,7 @@ mod tests {
source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided), source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided),
model: Setting::NotSet, model: Setting::NotSet,
revision: Setting::NotSet, revision: Setting::NotSet,
pooling: Setting::NotSet,
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::Set(3), dimensions: Setting::Set(3),
document_template: Setting::NotSet, document_template: Setting::NotSet,

View File

@ -1676,6 +1676,7 @@ fn validate_prompt(
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template: Setting::Set(template), document_template: Setting::Set(template),
@ -1709,6 +1710,7 @@ fn validate_prompt(
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template: Setting::Set(template), document_template: Setting::Set(template),
@ -1735,6 +1737,7 @@ pub fn validate_embedding_settings(
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
@ -1776,6 +1779,7 @@ pub fn validate_embedding_settings(
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
@ -1791,6 +1795,7 @@ pub fn validate_embedding_settings(
match inferred_source { match inferred_source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
@ -1829,6 +1834,7 @@ pub fn validate_embedding_settings(
EmbedderSource::Ollama => { EmbedderSource::Ollama => {
check_set(&model, EmbeddingSettings::MODEL, inferred_source, name)?; check_set(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
@ -1846,6 +1852,7 @@ pub fn validate_embedding_settings(
EmbedderSource::UserProvided => { EmbedderSource::UserProvided => {
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?; check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?; check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?;
check_unset( check_unset(
&document_template, &document_template,
@ -1869,6 +1876,7 @@ pub fn validate_embedding_settings(
EmbedderSource::Rest => { EmbedderSource::Rest => {
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?; check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
check_set(&url, EmbeddingSettings::URL, inferred_source, name)?; check_set(&url, EmbeddingSettings::URL, inferred_source, name)?;
check_set(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_set(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
check_set(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; check_set(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
@ -1878,6 +1886,7 @@ pub fn validate_embedding_settings(
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template, document_template,

View File

@ -34,6 +34,30 @@ pub struct EmbedderOptions {
pub model: String, pub model: String,
pub revision: Option<String>, pub revision: Option<String>,
pub distribution: Option<DistributionShift>, pub distribution: Option<DistributionShift>,
#[serde(default)]
pub pooling: OverridePooling,
}
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Deserialize,
serde::Serialize,
utoipa::ToSchema,
deserr::Deserr,
)]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
#[serde(rename_all = "camelCase")]
pub enum OverridePooling {
UseModel,
ForceCls,
#[default]
ForceMean,
} }
impl EmbedderOptions { impl EmbedderOptions {
@ -42,6 +66,7 @@ impl EmbedderOptions {
model: "BAAI/bge-base-en-v1.5".to_string(), model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
distribution: None, distribution: None,
pooling: OverridePooling::UseModel,
} }
} }
} }
@ -95,6 +120,15 @@ pub enum Pooling {
MeanSqrtLen, MeanSqrtLen,
LastToken, LastToken,
} }
impl Pooling {
fn override_with(&mut self, pooling: OverridePooling) {
match pooling {
OverridePooling::UseModel => {}
OverridePooling::ForceCls => *self = Pooling::Cls,
OverridePooling::ForceMean => *self = Pooling::Mean,
}
}
}
impl From<PoolingConfig> for Pooling { impl From<PoolingConfig> for Pooling {
fn from(value: PoolingConfig) -> Self { fn from(value: PoolingConfig) -> Self {
@ -151,7 +185,7 @@ impl Embedder {
} }
Err(error) => return Err(NewEmbedderError::api_get(error)), Err(error) => return Err(NewEmbedderError::api_get(error)),
}; };
let pooling: Pooling = match pooling { let mut pooling: Pooling = match pooling {
Some(pooling_filename) => { Some(pooling_filename) => {
let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| { let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner) NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
@ -170,6 +204,8 @@ impl Embedder {
None => Pooling::default(), None => Pooling::default(),
}; };
pooling.override_with(options.pooling);
(config, tokenizer, weights, source, pooling) (config, tokenizer, weights, source, pooling)
}; };

View File

@ -6,6 +6,7 @@ use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use utoipa::ToSchema; use utoipa::ToSchema;
use super::hf::OverridePooling;
use super::{ollama, openai, DistributionShift}; use super::{ollama, openai, DistributionShift};
use crate::prompt::{default_max_bytes, PromptData}; use crate::prompt::{default_max_bytes, PromptData};
use crate::update::Setting; use crate::update::Setting;
@ -30,6 +31,10 @@ pub struct EmbeddingSettings {
pub revision: Setting<String>, pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
#[schema(value_type = Option<OverridePooling>)]
pub pooling: Setting<OverridePooling>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
#[schema(value_type = Option<String>)] #[schema(value_type = Option<String>)]
pub api_key: Setting<String>, pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
@ -164,6 +169,7 @@ impl SettingsDiff {
mut source, mut source,
mut model, mut model,
mut revision, mut revision,
mut pooling,
mut api_key, mut api_key,
mut dimensions, mut dimensions,
mut document_template, mut document_template,
@ -180,6 +186,7 @@ impl SettingsDiff {
source: new_source, source: new_source,
model: new_model, model: new_model,
revision: new_revision, revision: new_revision,
pooling: new_pooling,
api_key: new_api_key, api_key: new_api_key,
dimensions: new_dimensions, dimensions: new_dimensions,
document_template: new_document_template, document_template: new_document_template,
@ -210,6 +217,7 @@ impl SettingsDiff {
&source, &source,
&mut model, &mut model,
&mut revision, &mut revision,
&mut pooling,
&mut dimensions, &mut dimensions,
&mut url, &mut url,
&mut request, &mut request,
@ -225,6 +233,9 @@ impl SettingsDiff {
if revision.apply(new_revision) { if revision.apply(new_revision) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
} }
if pooling.apply(new_pooling) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
}
if dimensions.apply(new_dimensions) { if dimensions.apply(new_dimensions) {
match source { match source {
// regenerate on dimensions change in OpenAI since truncation is supported // regenerate on dimensions change in OpenAI since truncation is supported
@ -290,6 +301,7 @@ impl SettingsDiff {
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
@ -338,6 +350,7 @@ fn apply_default_for_source(
source: &Setting<EmbedderSource>, source: &Setting<EmbedderSource>,
model: &mut Setting<String>, model: &mut Setting<String>,
revision: &mut Setting<String>, revision: &mut Setting<String>,
pooling: &mut Setting<OverridePooling>,
dimensions: &mut Setting<usize>, dimensions: &mut Setting<usize>,
url: &mut Setting<String>, url: &mut Setting<String>,
request: &mut Setting<serde_json::Value>, request: &mut Setting<serde_json::Value>,
@ -350,6 +363,7 @@ fn apply_default_for_source(
Setting::Set(EmbedderSource::HuggingFace) => { Setting::Set(EmbedderSource::HuggingFace) => {
*model = Setting::Reset; *model = Setting::Reset;
*revision = Setting::Reset; *revision = Setting::Reset;
*pooling = Setting::Reset;
*dimensions = Setting::NotSet; *dimensions = Setting::NotSet;
*url = Setting::NotSet; *url = Setting::NotSet;
*request = Setting::NotSet; *request = Setting::NotSet;
@ -359,6 +373,7 @@ fn apply_default_for_source(
Setting::Set(EmbedderSource::Ollama) => { Setting::Set(EmbedderSource::Ollama) => {
*model = Setting::Reset; *model = Setting::Reset;
*revision = Setting::NotSet; *revision = Setting::NotSet;
*pooling = Setting::NotSet;
*dimensions = Setting::Reset; *dimensions = Setting::Reset;
*url = Setting::NotSet; *url = Setting::NotSet;
*request = Setting::NotSet; *request = Setting::NotSet;
@ -368,6 +383,7 @@ fn apply_default_for_source(
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => { Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
*model = Setting::Reset; *model = Setting::Reset;
*revision = Setting::NotSet; *revision = Setting::NotSet;
*pooling = Setting::NotSet;
*dimensions = Setting::NotSet; *dimensions = Setting::NotSet;
*url = Setting::Reset; *url = Setting::Reset;
*request = Setting::NotSet; *request = Setting::NotSet;
@ -377,6 +393,7 @@ fn apply_default_for_source(
Setting::Set(EmbedderSource::Rest) => { Setting::Set(EmbedderSource::Rest) => {
*model = Setting::NotSet; *model = Setting::NotSet;
*revision = Setting::NotSet; *revision = Setting::NotSet;
*pooling = Setting::NotSet;
*dimensions = Setting::Reset; *dimensions = Setting::Reset;
*url = Setting::Reset; *url = Setting::Reset;
*request = Setting::Reset; *request = Setting::Reset;
@ -386,6 +403,7 @@ fn apply_default_for_source(
Setting::Set(EmbedderSource::UserProvided) => { Setting::Set(EmbedderSource::UserProvided) => {
*model = Setting::NotSet; *model = Setting::NotSet;
*revision = Setting::NotSet; *revision = Setting::NotSet;
*pooling = Setting::NotSet;
*dimensions = Setting::Reset; *dimensions = Setting::Reset;
*url = Setting::NotSet; *url = Setting::NotSet;
*request = Setting::NotSet; *request = Setting::NotSet;
@ -419,6 +437,7 @@ impl EmbeddingSettings {
pub const SOURCE: &'static str = "source"; pub const SOURCE: &'static str = "source";
pub const MODEL: &'static str = "model"; pub const MODEL: &'static str = "model";
pub const REVISION: &'static str = "revision"; pub const REVISION: &'static str = "revision";
pub const POOLING: &'static str = "pooling";
pub const API_KEY: &'static str = "apiKey"; pub const API_KEY: &'static str = "apiKey";
pub const DIMENSIONS: &'static str = "dimensions"; pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
@ -446,6 +465,7 @@ impl EmbeddingSettings {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
} }
Self::REVISION => &[EmbedderSource::HuggingFace], Self::REVISION => &[EmbedderSource::HuggingFace],
Self::POOLING => &[EmbedderSource::HuggingFace],
Self::API_KEY => { Self::API_KEY => {
&[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest] &[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest]
} }
@ -500,6 +520,7 @@ impl EmbeddingSettings {
Self::SOURCE, Self::SOURCE,
Self::MODEL, Self::MODEL,
Self::REVISION, Self::REVISION,
Self::POOLING,
Self::DOCUMENT_TEMPLATE, Self::DOCUMENT_TEMPLATE,
Self::DOCUMENT_TEMPLATE_MAX_BYTES, Self::DOCUMENT_TEMPLATE_MAX_BYTES,
Self::DISTRIBUTION, Self::DISTRIBUTION,
@ -592,10 +613,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
model, model,
revision, revision,
distribution, distribution,
pooling,
}) => Self { }) => Self {
source: Setting::Set(EmbedderSource::HuggingFace), source: Setting::Set(EmbedderSource::HuggingFace),
model: Setting::Set(model), model: Setting::Set(model),
revision: Setting::some_or_not_set(revision), revision: Setting::some_or_not_set(revision),
pooling: Setting::Set(pooling),
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::NotSet, dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
@ -617,6 +640,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
source: Setting::Set(EmbedderSource::OpenAi), source: Setting::Set(EmbedderSource::OpenAi),
model: Setting::Set(embedding_model.name().to_owned()), model: Setting::Set(embedding_model.name().to_owned()),
revision: Setting::NotSet, revision: Setting::NotSet,
pooling: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key), api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions), dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
@ -638,6 +662,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
source: Setting::Set(EmbedderSource::Ollama), source: Setting::Set(EmbedderSource::Ollama),
model: Setting::Set(embedding_model), model: Setting::Set(embedding_model),
revision: Setting::NotSet, revision: Setting::NotSet,
pooling: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key), api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions), dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
@ -656,6 +681,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
source: Setting::Set(EmbedderSource::UserProvided), source: Setting::Set(EmbedderSource::UserProvided),
model: Setting::NotSet, model: Setting::NotSet,
revision: Setting::NotSet, revision: Setting::NotSet,
pooling: Setting::NotSet,
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::Set(dimensions), dimensions: Setting::Set(dimensions),
document_template: Setting::NotSet, document_template: Setting::NotSet,
@ -679,6 +705,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
source: Setting::Set(EmbedderSource::Rest), source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet, model: Setting::NotSet,
revision: Setting::NotSet, revision: Setting::NotSet,
pooling: Setting::NotSet,
api_key: Setting::some_or_not_set(api_key), api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions), dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
@ -701,6 +728,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
source, source,
model, model,
revision, revision,
pooling,
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
@ -764,6 +792,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(revision) = revision.set() { if let Some(revision) = revision.set() {
options.revision = Some(revision); options.revision = Some(revision);
} }
if let Some(pooling) = pooling.set() {
options.pooling = pooling;
}
options.distribution = distribution.set(); options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::HuggingFace(options); this.embedder_options = super::EmbedderOptions::HuggingFace(options);
} }