mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-03-06 22:02:34 +08:00
Allow overriding pooling method
This commit is contained in:
parent
11759c4be4
commit
7b4ce468a6
@ -2763,6 +2763,7 @@ mod tests {
|
||||
source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
pooling: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(3),
|
||||
document_template: Setting::NotSet,
|
||||
|
@ -1676,6 +1676,7 @@ fn validate_prompt(
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template: Setting::Set(template),
|
||||
@ -1709,6 +1710,7 @@ fn validate_prompt(
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template: Setting::Set(template),
|
||||
@ -1735,6 +1737,7 @@ pub fn validate_embedding_settings(
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
@ -1776,6 +1779,7 @@ pub fn validate_embedding_settings(
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
@ -1791,6 +1795,7 @@ pub fn validate_embedding_settings(
|
||||
match inferred_source {
|
||||
EmbedderSource::OpenAi => {
|
||||
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
|
||||
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
|
||||
|
||||
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||
@ -1829,6 +1834,7 @@ pub fn validate_embedding_settings(
|
||||
EmbedderSource::Ollama => {
|
||||
check_set(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
||||
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
|
||||
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
|
||||
|
||||
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||
@ -1846,6 +1852,7 @@ pub fn validate_embedding_settings(
|
||||
EmbedderSource::UserProvided => {
|
||||
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
||||
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
|
||||
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
|
||||
check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?;
|
||||
check_unset(
|
||||
&document_template,
|
||||
@ -1869,6 +1876,7 @@ pub fn validate_embedding_settings(
|
||||
EmbedderSource::Rest => {
|
||||
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
||||
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
|
||||
check_unset(&pooling, EmbeddingSettings::POOLING, inferred_source, name)?;
|
||||
check_set(&url, EmbeddingSettings::URL, inferred_source, name)?;
|
||||
check_set(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||
check_set(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||
@ -1878,6 +1886,7 @@ pub fn validate_embedding_settings(
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
|
@ -34,6 +34,30 @@ pub struct EmbedderOptions {
|
||||
pub model: String,
|
||||
pub revision: Option<String>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
#[serde(default)]
|
||||
pub pooling: OverridePooling,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
Copy,
|
||||
Default,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
serde::Deserialize,
|
||||
serde::Serialize,
|
||||
utoipa::ToSchema,
|
||||
deserr::Deserr,
|
||||
)]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum OverridePooling {
|
||||
UseModel,
|
||||
ForceCls,
|
||||
#[default]
|
||||
ForceMean,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
@ -42,6 +66,7 @@ impl EmbedderOptions {
|
||||
model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
|
||||
distribution: None,
|
||||
pooling: OverridePooling::UseModel,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -95,6 +120,15 @@ pub enum Pooling {
|
||||
MeanSqrtLen,
|
||||
LastToken,
|
||||
}
|
||||
impl Pooling {
|
||||
fn override_with(&mut self, pooling: OverridePooling) {
|
||||
match pooling {
|
||||
OverridePooling::UseModel => {}
|
||||
OverridePooling::ForceCls => *self = Pooling::Cls,
|
||||
OverridePooling::ForceMean => *self = Pooling::Mean,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PoolingConfig> for Pooling {
|
||||
fn from(value: PoolingConfig) -> Self {
|
||||
@ -151,7 +185,7 @@ impl Embedder {
|
||||
}
|
||||
Err(error) => return Err(NewEmbedderError::api_get(error)),
|
||||
};
|
||||
let pooling: Pooling = match pooling {
|
||||
let mut pooling: Pooling = match pooling {
|
||||
Some(pooling_filename) => {
|
||||
let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
|
||||
NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
|
||||
@ -170,6 +204,8 @@ impl Embedder {
|
||||
None => Pooling::default(),
|
||||
};
|
||||
|
||||
pooling.override_with(options.pooling);
|
||||
|
||||
(config, tokenizer, weights, source, pooling)
|
||||
};
|
||||
|
||||
|
@ -6,6 +6,7 @@ use roaring::RoaringBitmap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use super::hf::OverridePooling;
|
||||
use super::{ollama, openai, DistributionShift};
|
||||
use crate::prompt::{default_max_bytes, PromptData};
|
||||
use crate::update::Setting;
|
||||
@ -30,6 +31,10 @@ pub struct EmbeddingSettings {
|
||||
pub revision: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<OverridePooling>)]
|
||||
pub pooling: Setting<OverridePooling>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
#[schema(value_type = Option<String>)]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
@ -164,6 +169,7 @@ impl SettingsDiff {
|
||||
mut source,
|
||||
mut model,
|
||||
mut revision,
|
||||
mut pooling,
|
||||
mut api_key,
|
||||
mut dimensions,
|
||||
mut document_template,
|
||||
@ -180,6 +186,7 @@ impl SettingsDiff {
|
||||
source: new_source,
|
||||
model: new_model,
|
||||
revision: new_revision,
|
||||
pooling: new_pooling,
|
||||
api_key: new_api_key,
|
||||
dimensions: new_dimensions,
|
||||
document_template: new_document_template,
|
||||
@ -210,6 +217,7 @@ impl SettingsDiff {
|
||||
&source,
|
||||
&mut model,
|
||||
&mut revision,
|
||||
&mut pooling,
|
||||
&mut dimensions,
|
||||
&mut url,
|
||||
&mut request,
|
||||
@ -225,6 +233,9 @@ impl SettingsDiff {
|
||||
if revision.apply(new_revision) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
}
|
||||
if pooling.apply(new_pooling) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
}
|
||||
if dimensions.apply(new_dimensions) {
|
||||
match source {
|
||||
// regenerate on dimensions change in OpenAI since truncation is supported
|
||||
@ -290,6 +301,7 @@ impl SettingsDiff {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
@ -338,6 +350,7 @@ fn apply_default_for_source(
|
||||
source: &Setting<EmbedderSource>,
|
||||
model: &mut Setting<String>,
|
||||
revision: &mut Setting<String>,
|
||||
pooling: &mut Setting<OverridePooling>,
|
||||
dimensions: &mut Setting<usize>,
|
||||
url: &mut Setting<String>,
|
||||
request: &mut Setting<serde_json::Value>,
|
||||
@ -350,6 +363,7 @@ fn apply_default_for_source(
|
||||
Setting::Set(EmbedderSource::HuggingFace) => {
|
||||
*model = Setting::Reset;
|
||||
*revision = Setting::Reset;
|
||||
*pooling = Setting::Reset;
|
||||
*dimensions = Setting::NotSet;
|
||||
*url = Setting::NotSet;
|
||||
*request = Setting::NotSet;
|
||||
@ -359,6 +373,7 @@ fn apply_default_for_source(
|
||||
Setting::Set(EmbedderSource::Ollama) => {
|
||||
*model = Setting::Reset;
|
||||
*revision = Setting::NotSet;
|
||||
*pooling = Setting::NotSet;
|
||||
*dimensions = Setting::Reset;
|
||||
*url = Setting::NotSet;
|
||||
*request = Setting::NotSet;
|
||||
@ -368,6 +383,7 @@ fn apply_default_for_source(
|
||||
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
|
||||
*model = Setting::Reset;
|
||||
*revision = Setting::NotSet;
|
||||
*pooling = Setting::NotSet;
|
||||
*dimensions = Setting::NotSet;
|
||||
*url = Setting::Reset;
|
||||
*request = Setting::NotSet;
|
||||
@ -377,6 +393,7 @@ fn apply_default_for_source(
|
||||
Setting::Set(EmbedderSource::Rest) => {
|
||||
*model = Setting::NotSet;
|
||||
*revision = Setting::NotSet;
|
||||
*pooling = Setting::NotSet;
|
||||
*dimensions = Setting::Reset;
|
||||
*url = Setting::Reset;
|
||||
*request = Setting::Reset;
|
||||
@ -386,6 +403,7 @@ fn apply_default_for_source(
|
||||
Setting::Set(EmbedderSource::UserProvided) => {
|
||||
*model = Setting::NotSet;
|
||||
*revision = Setting::NotSet;
|
||||
*pooling = Setting::NotSet;
|
||||
*dimensions = Setting::Reset;
|
||||
*url = Setting::NotSet;
|
||||
*request = Setting::NotSet;
|
||||
@ -419,6 +437,7 @@ impl EmbeddingSettings {
|
||||
pub const SOURCE: &'static str = "source";
|
||||
pub const MODEL: &'static str = "model";
|
||||
pub const REVISION: &'static str = "revision";
|
||||
pub const POOLING: &'static str = "pooling";
|
||||
pub const API_KEY: &'static str = "apiKey";
|
||||
pub const DIMENSIONS: &'static str = "dimensions";
|
||||
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
||||
@ -446,6 +465,7 @@ impl EmbeddingSettings {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
||||
}
|
||||
Self::REVISION => &[EmbedderSource::HuggingFace],
|
||||
Self::POOLING => &[EmbedderSource::HuggingFace],
|
||||
Self::API_KEY => {
|
||||
&[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest]
|
||||
}
|
||||
@ -500,6 +520,7 @@ impl EmbeddingSettings {
|
||||
Self::SOURCE,
|
||||
Self::MODEL,
|
||||
Self::REVISION,
|
||||
Self::POOLING,
|
||||
Self::DOCUMENT_TEMPLATE,
|
||||
Self::DOCUMENT_TEMPLATE_MAX_BYTES,
|
||||
Self::DISTRIBUTION,
|
||||
@ -592,10 +613,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
model,
|
||||
revision,
|
||||
distribution,
|
||||
pooling,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::HuggingFace),
|
||||
model: Setting::Set(model),
|
||||
revision: Setting::some_or_not_set(revision),
|
||||
pooling: Setting::Set(pooling),
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
@ -617,6 +640,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
source: Setting::Set(EmbedderSource::OpenAi),
|
||||
model: Setting::Set(embedding_model.name().to_owned()),
|
||||
revision: Setting::NotSet,
|
||||
pooling: Setting::NotSet,
|
||||
api_key: Setting::some_or_not_set(api_key),
|
||||
dimensions: Setting::some_or_not_set(dimensions),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
@ -638,6 +662,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
source: Setting::Set(EmbedderSource::Ollama),
|
||||
model: Setting::Set(embedding_model),
|
||||
revision: Setting::NotSet,
|
||||
pooling: Setting::NotSet,
|
||||
api_key: Setting::some_or_not_set(api_key),
|
||||
dimensions: Setting::some_or_not_set(dimensions),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
@ -656,6 +681,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
source: Setting::Set(EmbedderSource::UserProvided),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
pooling: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(dimensions),
|
||||
document_template: Setting::NotSet,
|
||||
@ -679,6 +705,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
source: Setting::Set(EmbedderSource::Rest),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
pooling: Setting::NotSet,
|
||||
api_key: Setting::some_or_not_set(api_key),
|
||||
dimensions: Setting::some_or_not_set(dimensions),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
@ -701,6 +728,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
pooling,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
@ -764,6 +792,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
if let Some(revision) = revision.set() {
|
||||
options.revision = Some(revision);
|
||||
}
|
||||
if let Some(pooling) = pooling.set() {
|
||||
options.pooling = pooling;
|
||||
}
|
||||
options.distribution = distribution.set();
|
||||
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user