From 8b1fcfd7f8ff918789757a1489915098bb1d843e Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 13 Jan 2025 14:34:11 +0100 Subject: [PATCH] Parse ollama URL to adapt configuration depending on the endpoint --- crates/milli/src/vector/error.rs | 8 ++++- crates/milli/src/vector/ollama.rs | 50 +++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/crates/milli/src/vector/error.rs b/crates/milli/src/vector/error.rs index 5edabed0d..d1b2516f5 100644 --- a/crates/milli/src/vector/error.rs +++ b/crates/milli/src/vector/error.rs @@ -67,7 +67,7 @@ pub enum EmbedErrorKind { #[error("could not authenticate against {embedding} server{server_reply}{hint}", embedding=match *.1 { ConfigurationSource::User => "embedding", ConfigurationSource::OpenAi => "OpenAI", - ConfigurationSource::Ollama => "ollama" + ConfigurationSource::Ollama => "Ollama" }, server_reply=option_info(.0.as_deref(), "server replied with "), hint=match *.1 { @@ -306,6 +306,10 @@ impl NewEmbedderError { fault: FaultSource::User, } } + + pub(crate) fn ollama_unsupported_url(url: String) -> NewEmbedderError { + Self { kind: NewEmbedderErrorKind::OllamaUnsupportedUrl(url), fault: FaultSource::User } + } } #[derive(Debug, thiserror::Error)] @@ -369,6 +373,8 @@ pub enum NewEmbedderErrorKind { LoadModel(candle_core::Error), #[error("{0}")] CouldNotParseTemplate(String), + #[error("unsupported Ollama URL.\n - For `ollama` sources, the URL must end with `/api/embed` or `/api/embeddings`\n - Got `{0}`")] + OllamaUnsupportedUrl(String), } pub struct PossibleEmbeddingMistakes { diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index 7ee775cbf..fb0f3fb82 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -38,26 +38,46 @@ impl EmbedderOptions { dimensions, } } + + fn into_rest_embedder_config(self) -> Result { + let url = self.url.unwrap_or_else(get_ollama_path); + let model = self.embedding_model.as_str(); + + // **warning**: do not swap these two `if`s, as the second one is always true when the first one is. + let (request, response) = if url.ends_with("/api/embeddings") { + ( + serde_json::json!({"model": model, "input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]}), + serde_json::json!({"embeddings": [super::rest::RESPONSE_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]}), + ) + } else if url.ends_with("/api/embed") { + ( + serde_json::json!({ + "model": model, + "prompt": super::rest::REQUEST_PLACEHOLDER, + }), + serde_json::json!({ + "embedding": super::rest::RESPONSE_PLACEHOLDER, + }), + ) + } else { + return Err(NewEmbedderError::ollama_unsupported_url(url)); + }; + Ok(RestEmbedderOptions { + api_key: self.api_key, + dimensions: self.dimensions, + distribution: self.distribution, + url, + request, + response, + headers: Default::default(), + }) + } } impl Embedder { pub fn new(options: EmbedderOptions) -> Result { - let model = options.embedding_model.as_str(); let rest_embedder = match RestEmbedder::new( - RestEmbedderOptions { - api_key: options.api_key, - dimensions: options.dimensions, - distribution: options.distribution, - url: options.url.unwrap_or_else(get_ollama_path), - request: serde_json::json!({ - "model": model, - "prompt": super::rest::REQUEST_PLACEHOLDER, - }), - response: serde_json::json!({ - "embedding": super::rest::RESPONSE_PLACEHOLDER, - }), - headers: Default::default(), - }, + options.into_rest_embedder_config()?, super::rest::ConfigurationSource::Ollama, ) { Ok(embedder) => embedder,