From b8e4709dfa6377ab4c84540e2e08069750be82e6 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 22:06:39 +0100 Subject: [PATCH] Remove prompt strategy and fallback --- milli/src/prompt/mod.rs | 74 ++++++++---------------------------- milli/src/update/settings.rs | 7 +--- milli/src/vector/settings.rs | 26 ++----------- 3 files changed, 21 insertions(+), 86 deletions(-) diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs index 67ef8b4f6..97ccbfb61 100644 --- a/milli/src/prompt/mod.rs +++ b/milli/src/prompt/mod.rs @@ -16,20 +16,16 @@ use crate::FieldsIdsMap; pub struct Prompt { template: liquid::Template, template_text: String, - strategy: PromptFallbackStrategy, - fallback: String, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PromptData { pub template: String, - pub strategy: PromptFallbackStrategy, - pub fallback: String, } impl From for PromptData { fn from(value: Prompt) -> Self { - Self { template: value.template_text, strategy: value.strategy, fallback: value.fallback } + Self { template: value.template_text } } } @@ -37,19 +33,14 @@ impl TryFrom for Prompt { type Error = NewPromptError; fn try_from(value: PromptData) -> Result { - Prompt::new(value.template, Some(value.strategy), Some(value.fallback)) + Prompt::new(value.template) } } impl Clone for Prompt { fn clone(&self) -> Self { let template_text = self.template_text.clone(); - Self { - template: new_template(&template_text).unwrap(), - template_text, - strategy: self.strategy, - fallback: self.fallback.clone(), - } + Self { template: new_template(&template_text).unwrap(), template_text } } } @@ -67,37 +58,20 @@ fn default_template_text() -> &'static str { {% endfor %}" } -fn default_fallback() -> &'static str { - "" -} - impl Default for Prompt { fn default() -> Self { - Self { - template: default_template(), - template_text: default_template_text().into(), - strategy: Default::default(), - fallback: default_fallback().into(), - } + Self { template: default_template(), template_text: default_template_text().into() } } } impl Default for PromptData { fn default() -> Self { - Self { - template: default_template_text().into(), - strategy: Default::default(), - fallback: default_fallback().into(), - } + Self { template: default_template_text().into() } } } impl Prompt { - pub fn new( - template: String, - strategy: Option, - fallback: Option, - ) -> Result { + pub fn new(template: String) -> Result { let this = Self { template: liquid::ParserBuilder::with_stdlib() .build() @@ -105,8 +79,6 @@ impl Prompt { .parse(&template) .map_err(NewPromptError::cannot_parse_template)?, template_text: template, - strategy: strategy.unwrap_or_default(), - fallback: fallback.unwrap_or_default(), }; // render template with special object that's OK with `doc.*` and `fields.*` @@ -130,18 +102,6 @@ impl Prompt { } } -#[derive( - Debug, Default, Clone, PartialEq, Eq, Copy, serde::Serialize, serde::Deserialize, deserr::Deserr, -)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub enum PromptFallbackStrategy { - Fallback, - Skip, - #[default] - Error, -} - #[cfg(test)] mod test { use super::Prompt; @@ -156,18 +116,18 @@ mod test { #[test] fn empty_template() { - Prompt::new("".into(), None, None).unwrap(); + Prompt::new("".into()).unwrap(); } #[test] fn template_ok() { - Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None, None).unwrap(); + Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap(); } #[test] fn template_syntax() { assert!(matches!( - Prompt::new("{{doc.title: {{doc.overview}}".into(), None, None), + Prompt::new("{{doc.title: {{doc.overview}}".into()), Err(NewPromptError { kind: NewPromptErrorKind::CannotParseTemplate(_), fault: FaultSource::User @@ -178,7 +138,7 @@ mod test { #[test] fn template_missing_doc() { assert!(matches!( - Prompt::new("{{title}}: {{overview}}".into(), None, None), + Prompt::new("{{title}}: {{overview}}".into()), Err(NewPromptError { kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), fault: FaultSource::User @@ -188,29 +148,25 @@ mod test { #[test] fn template_nested_doc() { - Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None, None).unwrap(); + Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap(); } #[test] fn template_fields() { - Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None, None).unwrap(); + Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap(); } #[test] fn template_fields_ok() { - Prompt::new( - "{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(), - None, - None, - ) - .unwrap(); + Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into()) + .unwrap(); } #[test] fn template_fields_invalid() { assert!(matches!( // intentionally garbled field - Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None, None), + Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()), Err(NewPromptError { kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), fault: FaultSource::User diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index e9f345e42..d406c121c 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1073,11 +1073,10 @@ fn validate_prompt( match new { Setting::Set(EmbeddingSettings { embedder_options, - document_template: - Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), + document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), }) => { // validate - let template = crate::prompt::Prompt::new(template, None, None) + let template = crate::prompt::Prompt::new(template) .map(|prompt| crate::prompt::PromptData::from(prompt).template) .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; @@ -1085,8 +1084,6 @@ fn validate_prompt( embedder_options, document_template: Setting::Set(PromptSettings { template: Setting::Set(template), - strategy, - fallback, }), })) } diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index f90c3cc71..bd385e3f3 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,7 +1,7 @@ use deserr::Deserr; use serde::{Deserialize, Serialize}; -use crate::prompt::{PromptData, PromptFallbackStrategy}; +use crate::prompt::PromptData; use crate::update::Setting; use crate::vector::hf::WeightSource; use crate::vector::EmbeddingConfig; @@ -56,46 +56,28 @@ pub struct PromptSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub template: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub strategy: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub fallback: Setting, } impl PromptSettings { pub fn apply(&mut self, new: Self) { - let PromptSettings { template, strategy, fallback } = new; + let PromptSettings { template } = new; self.template.apply(template); - self.strategy.apply(strategy); - self.fallback.apply(fallback); } } impl From for PromptSettings { fn from(value: PromptData) -> Self { - Self { - template: Setting::Set(value.template), - strategy: Setting::Set(value.strategy), - fallback: Setting::Set(value.fallback), - } + Self { template: Setting::Set(value.template) } } } impl From for PromptData { fn from(value: PromptSettings) -> Self { let mut this = PromptData::default(); - let PromptSettings { template, strategy, fallback } = value; + let PromptSettings { template } = value; if let Some(template) = template.set() { this.template = template; } - if let Some(strategy) = strategy.set() { - this.strategy = strategy; - } - if let Some(fallback) = fallback.set() { - this.fallback = fallback; - } this } }