2023-11-15 22:46:37 +08:00
|
|
|
use deserr::Deserr;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
2023-12-14 05:06:39 +08:00
|
|
|
use crate::prompt::PromptData;
|
2023-11-15 22:46:37 +08:00
|
|
|
use crate::update::Setting;
|
|
|
|
use crate::vector::EmbeddingConfig;
|
2023-12-21 00:08:32 +08:00
|
|
|
use crate::UserError;
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
|
|
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
|
|
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
|
|
|
pub struct EmbeddingSettings {
|
|
|
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
|
|
#[deserr(default)]
|
2023-12-21 00:08:32 +08:00
|
|
|
pub source: Setting<EmbedderSource>,
|
2023-11-15 22:46:37 +08:00
|
|
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
|
|
#[deserr(default)]
|
2023-12-21 00:08:32 +08:00
|
|
|
pub model: Setting<String>,
|
|
|
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
|
|
#[deserr(default)]
|
|
|
|
pub revision: Setting<String>,
|
|
|
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
|
|
#[deserr(default)]
|
|
|
|
pub api_key: Setting<String>,
|
|
|
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
|
|
#[deserr(default)]
|
|
|
|
pub dimensions: Setting<usize>,
|
|
|
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
|
|
|
#[deserr(default)]
|
|
|
|
pub document_template: Setting<String>,
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn check_unset<T>(
|
|
|
|
key: &Setting<T>,
|
|
|
|
field: &'static str,
|
|
|
|
source: EmbedderSource,
|
|
|
|
embedder_name: &str,
|
|
|
|
) -> Result<(), UserError> {
|
|
|
|
if matches!(key, Setting::NotSet) {
|
|
|
|
Ok(())
|
|
|
|
} else {
|
|
|
|
Err(UserError::InvalidFieldForSource {
|
|
|
|
embedder_name: embedder_name.to_owned(),
|
|
|
|
source_: source,
|
|
|
|
field,
|
|
|
|
allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source),
|
|
|
|
allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field),
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn check_set<T>(
|
|
|
|
key: &Setting<T>,
|
|
|
|
field: &'static str,
|
|
|
|
source: EmbedderSource,
|
|
|
|
embedder_name: &str,
|
|
|
|
) -> Result<(), UserError> {
|
|
|
|
if matches!(key, Setting::Set(_)) {
|
|
|
|
Ok(())
|
|
|
|
} else {
|
|
|
|
Err(UserError::MissingFieldForSource {
|
|
|
|
field,
|
|
|
|
source_: source,
|
|
|
|
embedder_name: embedder_name.to_owned(),
|
|
|
|
})
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
impl EmbeddingSettings {
|
2023-12-21 00:48:09 +08:00
|
|
|
pub const SOURCE: &'static str = "source";
|
|
|
|
pub const MODEL: &'static str = "model";
|
|
|
|
pub const REVISION: &'static str = "revision";
|
|
|
|
pub const API_KEY: &'static str = "apiKey";
|
|
|
|
pub const DIMENSIONS: &'static str = "dimensions";
|
|
|
|
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
2023-12-21 00:08:32 +08:00
|
|
|
|
|
|
|
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
|
|
|
|
match field {
|
|
|
|
Self::SOURCE => {
|
|
|
|
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
2023-12-21 00:08:32 +08:00
|
|
|
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
|
|
|
Self::REVISION => &[EmbedderSource::HuggingFace],
|
|
|
|
Self::API_KEY => &[EmbedderSource::OpenAi],
|
|
|
|
Self::DIMENSIONS => &[EmbedderSource::UserProvided],
|
|
|
|
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
|
|
|
_other => unreachable!("unknown field"),
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
|
|
|
|
match source {
|
|
|
|
EmbedderSource::OpenAi => {
|
|
|
|
&[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE]
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
2023-12-21 00:08:32 +08:00
|
|
|
EmbedderSource::HuggingFace => {
|
|
|
|
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
|
2023-12-13 04:19:48 +08:00
|
|
|
}
|
2023-12-21 00:08:32 +08:00
|
|
|
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
pub(crate) fn apply_default_source(setting: &mut Setting<EmbeddingSettings>) {
|
|
|
|
if let Setting::Set(EmbeddingSettings {
|
|
|
|
source: source @ (Setting::NotSet | Setting::Reset),
|
|
|
|
..
|
|
|
|
}) = setting
|
|
|
|
{
|
|
|
|
*source = Setting::Set(EmbedderSource::default())
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
2023-11-15 22:46:37 +08:00
|
|
|
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
|
|
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
2023-12-21 00:08:32 +08:00
|
|
|
pub enum EmbedderSource {
|
|
|
|
#[default]
|
|
|
|
OpenAi,
|
|
|
|
HuggingFace,
|
|
|
|
UserProvided,
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
impl std::fmt::Display for EmbedderSource {
|
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
let s = match self {
|
|
|
|
EmbedderSource::OpenAi => "openAi",
|
|
|
|
EmbedderSource::HuggingFace => "huggingFace",
|
|
|
|
EmbedderSource::UserProvided => "userProvided",
|
|
|
|
};
|
|
|
|
f.write_str(s)
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
impl EmbeddingSettings {
|
|
|
|
pub fn apply(&mut self, new: Self) {
|
|
|
|
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
|
|
|
new;
|
|
|
|
let old_source = self.source;
|
|
|
|
self.source.apply(source);
|
|
|
|
// Reinitialize the whole setting object on a source change
|
|
|
|
if old_source != self.source {
|
|
|
|
*self = EmbeddingSettings {
|
|
|
|
source,
|
|
|
|
model,
|
|
|
|
revision,
|
|
|
|
api_key,
|
|
|
|
dimensions,
|
|
|
|
document_template,
|
|
|
|
};
|
|
|
|
return;
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
self.model.apply(model);
|
|
|
|
self.revision.apply(revision);
|
2023-11-15 22:46:37 +08:00
|
|
|
self.api_key.apply(api_key);
|
2023-12-21 00:08:32 +08:00
|
|
|
self.dimensions.apply(dimensions);
|
|
|
|
self.document_template.apply(document_template);
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
impl From<EmbeddingConfig> for EmbeddingSettings {
|
|
|
|
fn from(value: EmbeddingConfig) -> Self {
|
|
|
|
let EmbeddingConfig { embedder_options, prompt } = value;
|
|
|
|
match embedder_options {
|
|
|
|
super::EmbedderOptions::HuggingFace(options) => Self {
|
|
|
|
source: Setting::Set(EmbedderSource::HuggingFace),
|
|
|
|
model: Setting::Set(options.model),
|
|
|
|
revision: options.revision.map(Setting::Set).unwrap_or_default(),
|
|
|
|
api_key: Setting::NotSet,
|
|
|
|
dimensions: Setting::NotSet,
|
|
|
|
document_template: Setting::Set(prompt.template),
|
|
|
|
},
|
|
|
|
super::EmbedderOptions::OpenAi(options) => Self {
|
|
|
|
source: Setting::Set(EmbedderSource::OpenAi),
|
|
|
|
model: Setting::Set(options.embedding_model.name().to_owned()),
|
|
|
|
revision: Setting::NotSet,
|
|
|
|
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
|
2024-02-07 17:37:59 +08:00
|
|
|
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
|
2023-12-21 00:08:32 +08:00
|
|
|
document_template: Setting::Set(prompt.template),
|
|
|
|
},
|
|
|
|
super::EmbedderOptions::UserProvided(options) => Self {
|
|
|
|
source: Setting::Set(EmbedderSource::UserProvided),
|
|
|
|
model: Setting::NotSet,
|
|
|
|
revision: Setting::NotSet,
|
|
|
|
api_key: Setting::NotSet,
|
|
|
|
dimensions: Setting::Set(options.dimensions),
|
|
|
|
document_template: Setting::NotSet,
|
|
|
|
},
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
impl From<EmbeddingSettings> for EmbeddingConfig {
|
|
|
|
fn from(value: EmbeddingSettings) -> Self {
|
|
|
|
let mut this = Self::default();
|
|
|
|
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
|
|
|
value;
|
|
|
|
if let Some(source) = source.set() {
|
|
|
|
match source {
|
|
|
|
EmbedderSource::OpenAi => {
|
|
|
|
let mut options = super::openai::EmbedderOptions::with_default_model(None);
|
|
|
|
if let Some(model) = model.set() {
|
|
|
|
if let Some(model) = super::openai::EmbeddingModel::from_name(&model) {
|
|
|
|
options.embedding_model = model;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if let Some(api_key) = api_key.set() {
|
|
|
|
options.api_key = Some(api_key);
|
|
|
|
}
|
2024-01-30 23:32:57 +08:00
|
|
|
if let Some(dimensions) = dimensions.set() {
|
|
|
|
options.dimensions = Some(dimensions);
|
|
|
|
}
|
2023-12-21 00:08:32 +08:00
|
|
|
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
|
|
|
}
|
|
|
|
EmbedderSource::HuggingFace => {
|
|
|
|
let mut options = super::hf::EmbedderOptions::default();
|
|
|
|
if let Some(model) = model.set() {
|
|
|
|
options.model = model;
|
|
|
|
// Reset the revision if we are setting the model.
|
|
|
|
// This allows the following:
|
|
|
|
// "huggingFace": {} -> default model with default revision
|
|
|
|
// "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision
|
|
|
|
// "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision
|
|
|
|
options.revision = None;
|
|
|
|
}
|
|
|
|
if let Some(revision) = revision.set() {
|
|
|
|
options.revision = Some(revision);
|
|
|
|
}
|
|
|
|
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
|
|
|
|
}
|
|
|
|
EmbedderSource::UserProvided => {
|
|
|
|
this.embedder_options =
|
|
|
|
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
|
|
|
|
dimensions: dimensions.set().unwrap(),
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-12-13 04:19:48 +08:00
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
if let Setting::Set(template) = document_template {
|
|
|
|
this.prompt = PromptData { template }
|
|
|
|
}
|
2023-12-13 04:19:48 +08:00
|
|
|
|
2023-12-21 00:08:32 +08:00
|
|
|
this
|
2023-12-13 04:19:48 +08:00
|
|
|
}
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|