Allow url parameter for ollama embedder

This commit is contained in:
Louis Dureuil 2024-03-25 11:13:21 +01:00
parent dfa5e41ea6
commit 58972f35cb
No known key found for this signature in database
4 changed files with 16 additions and 13 deletions

View File

@ -1271,7 +1271,6 @@ pub fn validate_embedding_settings(
check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?; check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?; check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;

View File

@ -201,8 +201,8 @@ impl EmbedderOptions {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
} }
pub fn ollama() -> Self { pub fn ollama(url: Option<String>) -> Self {
Self::Ollama(ollama::EmbedderOptions::with_default_model()) Self::Ollama(ollama::EmbedderOptions::with_default_model(url))
} }
} }

View File

@ -12,15 +12,12 @@ pub struct Embedder {
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub embedding_model: String, pub embedding_model: String,
pub url: Option<String>,
} }
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model() -> Self { pub fn with_default_model(url: Option<String>) -> Self {
Self { embedding_model: "nomic-embed-text".into() } Self { embedding_model: "nomic-embed-text".into(), url }
}
pub fn with_embedding_model(embedding_model: String) -> Self {
Self { embedding_model }
} }
} }
@ -31,7 +28,7 @@ impl Embedder {
api_key: None, api_key: None,
distribution: None, distribution: None,
dimensions: None, dimensions: None,
url: get_ollama_path(), url: options.url.unwrap_or_else(get_ollama_path),
query: serde_json::json!({ query: serde_json::json!({
"model": model, "model": model,
}), }),

View File

@ -124,7 +124,7 @@ impl EmbeddingSettings {
EmbedderSource::Ollama, EmbedderSource::Ollama,
EmbedderSource::Rest, EmbedderSource::Rest,
], ],
Self::URL => &[EmbedderSource::Rest], Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest],
Self::QUERY => &[EmbedderSource::Rest], Self::QUERY => &[EmbedderSource::Rest],
Self::INPUT_FIELD => &[EmbedderSource::Rest], Self::INPUT_FIELD => &[EmbedderSource::Rest],
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
@ -146,7 +146,9 @@ impl EmbeddingSettings {
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
} }
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], EmbedderSource::Ollama => {
&[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE, Self::URL]
}
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
EmbedderSource::Rest => &[ EmbedderSource::Rest => &[
Self::SOURCE, Self::SOURCE,
@ -387,10 +389,15 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
} }
EmbedderSource::Ollama => { EmbedderSource::Ollama => {
let mut options: ollama::EmbedderOptions = let mut options: ollama::EmbedderOptions =
super::ollama::EmbedderOptions::with_default_model(); super::ollama::EmbedderOptions::with_default_model(None);
if let Some(model) = model.set() { if let Some(model) = model.set() {
options.embedding_model = model; options.embedding_model = model;
} }
if let Some(url) = url.set() {
options.url = Some(url)
}
this.embedder_options = super::EmbedderOptions::Ollama(options); this.embedder_options = super::EmbedderOptions::Ollama(options);
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {