diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 87c6bc6db..6d659a7a2 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2740,6 +2740,7 @@ mod tests { api_key: Setting::NotSet, dimensions: Setting::Set(3), document_template: Setting::NotSet, + document_template_max_bytes: Setting::NotSet, url: Setting::NotSet, request: Setting::NotSet, response: Setting::NotSet, diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 29470521e..8702e7ea6 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1,5 +1,6 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::convert::TryInto; +use std::num::NonZeroUsize; use std::result::Result as StdResult; use std::sync::Arc; @@ -19,6 +20,7 @@ use crate::index::{ IndexEmbeddingConfig, DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS, }; use crate::order_by_map::OrderByMap; +use crate::prompt::default_max_bytes; use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; @@ -1573,16 +1575,30 @@ fn validate_prompt( api_key, dimensions, document_template: Setting::Set(template), + document_template_max_bytes, url, request, response, distribution, headers, }) => { + let max_bytes = match document_template_max_bytes.set() { + Some(max_bytes) => NonZeroUsize::new(max_bytes).ok_or_else(|| { + crate::error::UserError::InvalidSettingsDocumentTemplateMaxBytes { + embedder_name: name.to_owned(), + } + })?, + None => default_max_bytes(), + }; + // validate - let template = crate::prompt::Prompt::new(template) - .map(|prompt| crate::prompt::PromptData::from(prompt).template) - .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; + let template = crate::prompt::Prompt::new( + template, + // always specify a max_bytes + Some(max_bytes), + ) + .map(|prompt| crate::prompt::PromptData::from(prompt).template) + .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; Ok(Setting::Set(EmbeddingSettings { source, @@ -1591,6 +1607,7 @@ fn validate_prompt( api_key, dimensions, document_template: Setting::Set(template), + document_template_max_bytes, url, request, response, @@ -1615,6 +1632,7 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + document_template_max_bytes, url, request, response, @@ -1654,6 +1672,7 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + document_template_max_bytes, url, request, response, @@ -1726,6 +1745,12 @@ pub fn validate_embedding_settings( inferred_source, name, )?; + check_unset( + &document_template_max_bytes, + EmbeddingSettings::DOCUMENT_TEMPLATE_MAX_BYTES, + inferred_source, + name, + )?; check_set(&dimensions, EmbeddingSettings::DIMENSIONS, inferred_source, name)?; check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?; @@ -1748,6 +1773,7 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + document_template_max_bytes, url, request, response, diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 3cb90cbdb..14e12da3e 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,11 +1,12 @@ use std::collections::BTreeMap; +use std::num::NonZeroUsize; use deserr::Deserr; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; use super::{ollama, openai, DistributionShift}; -use crate::prompt::PromptData; +use crate::prompt::{default_max_bytes, PromptData}; use crate::update::Setting; use crate::vector::EmbeddingConfig; use crate::UserError; @@ -34,6 +35,9 @@ pub struct EmbeddingSettings { pub document_template: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] + pub document_template_max_bytes: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] pub url: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] @@ -111,6 +115,7 @@ impl SettingsDiff { mut response, mut distribution, mut headers, + mut document_template_max_bytes, } = old; let EmbeddingSettings { @@ -125,6 +130,7 @@ impl SettingsDiff { response: new_response, distribution: new_distribution, headers: new_headers, + document_template_max_bytes: new_document_template_max_bytes, } = new; let mut reindex_action = None; @@ -142,6 +148,7 @@ impl SettingsDiff { &mut request, &mut response, &mut document_template, + &mut document_template_max_bytes, &mut headers, ) } @@ -189,6 +196,12 @@ impl SettingsDiff { ReindexAction::RegeneratePrompts, ); } + if document_template_max_bytes.apply(new_document_template_max_bytes) { + ReindexAction::push_action( + &mut reindex_action, + ReindexAction::RegeneratePrompts, + ) + } distribution.apply(new_distribution); api_key.apply(new_api_key); @@ -206,6 +219,7 @@ impl SettingsDiff { response, distribution, headers, + document_template_max_bytes, }; match reindex_action { @@ -239,6 +253,7 @@ fn apply_default_for_source( request: &mut Setting, response: &mut Setting, document_template: &mut Setting, + document_template_max_bytes: &mut Setting, headers: &mut Setting>, ) { match source { @@ -286,6 +301,7 @@ fn apply_default_for_source( *request = Setting::NotSet; *response = Setting::NotSet; *document_template = Setting::NotSet; + *document_template_max_bytes = Setting::NotSet; *headers = Setting::NotSet; } Setting::NotSet => {} @@ -316,6 +332,7 @@ impl EmbeddingSettings { pub const API_KEY: &'static str = "apiKey"; pub const DIMENSIONS: &'static str = "dimensions"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; + pub const DOCUMENT_TEMPLATE_MAX_BYTES: &'static str = "documentTemplateMaxBytes"; pub const URL: &'static str = "url"; pub const REQUEST: &'static str = "request"; @@ -459,6 +476,8 @@ impl std::fmt::Display for EmbedderSource { impl From for EmbeddingSettings { fn from(value: EmbeddingConfig) -> Self { let EmbeddingConfig { embedder_options, prompt } = value; + let document_template_max_bytes = + Setting::Set(prompt.max_bytes.unwrap_or(default_max_bytes()).get()); match embedder_options { super::EmbedderOptions::HuggingFace(super::hf::EmbedderOptions { model, @@ -471,6 +490,7 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), + document_template_max_bytes, url: Setting::NotSet, request: Setting::NotSet, response: Setting::NotSet, @@ -490,6 +510,7 @@ impl From for EmbeddingSettings { api_key: Setting::some_or_not_set(api_key), dimensions: Setting::some_or_not_set(dimensions), document_template: Setting::Set(prompt.template), + document_template_max_bytes, url: Setting::some_or_not_set(url), request: Setting::NotSet, response: Setting::NotSet, @@ -509,6 +530,7 @@ impl From for EmbeddingSettings { api_key: Setting::some_or_not_set(api_key), dimensions: Setting::some_or_not_set(dimensions), document_template: Setting::Set(prompt.template), + document_template_max_bytes, url: Setting::some_or_not_set(url), request: Setting::NotSet, response: Setting::NotSet, @@ -525,6 +547,7 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::Set(dimensions), document_template: Setting::NotSet, + document_template_max_bytes: Setting::NotSet, url: Setting::NotSet, request: Setting::NotSet, response: Setting::NotSet, @@ -546,6 +569,7 @@ impl From for EmbeddingSettings { api_key: Setting::some_or_not_set(api_key), dimensions: Setting::some_or_not_set(dimensions), document_template: Setting::Set(prompt.template), + document_template_max_bytes, url: Setting::Set(url), request: Setting::Set(request), response: Setting::Set(response), @@ -566,6 +590,7 @@ impl From for EmbeddingConfig { api_key, dimensions, document_template, + document_template_max_bytes, url, request, response, @@ -648,7 +673,12 @@ impl From for EmbeddingConfig { } if let Setting::Set(template) = document_template { - this.prompt = PromptData { template } + let max_bytes = document_template_max_bytes + .set() + .and_then(NonZeroUsize::new) + .unwrap_or(default_max_bytes()); + + this.prompt = PromptData { template, max_bytes: Some(max_bytes) } } this