From d731fa661b207754afed735d6e306e88abc5e2cb Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 16 Jul 2024 15:17:49 +0200 Subject: [PATCH] ollama and openai use new EmbedderOptions --- milli/src/vector/ollama.rs | 29 +++++++++++++++------------ milli/src/vector/openai.rs | 41 ++++++++++++++++++++++---------------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 2c29cc816..84baac1ba 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -28,19 +28,22 @@ impl EmbedderOptions { impl Embedder { pub fn new(options: EmbedderOptions) -> Result { let model = options.embedding_model.as_str(); - let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { - api_key: options.api_key, - dimensions: None, - distribution: options.distribution, - url: options.url.unwrap_or_else(get_ollama_path), - query: serde_json::json!({ - "model": model, - }), - input_field: vec!["prompt".to_owned()], - path_to_embeddings: Default::default(), - embedding_object: vec!["embedding".to_owned()], - input_type: super::rest::InputType::Text, - }) { + let rest_embedder = match RestEmbedder::new( + RestEmbedderOptions { + api_key: options.api_key, + dimensions: None, + distribution: options.distribution, + url: options.url.unwrap_or_else(get_ollama_path), + request: serde_json::json!({ + "model": model, + "prompt": super::rest::REQUEST_PLACEHOLDER, + }), + response: serde_json::json!({ + "embedding": super::rest::RESPONSE_PLACEHOLDER, + }), + }, + super::rest::ConfigurationSource::Ollama, + ) { Ok(embedder) => embedder, Err(NewEmbedderError { kind: diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index ade9e51fc..514ad4a3b 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -26,20 +26,21 @@ impl EmbedderOptions { } } - pub fn query(&self) -> serde_json::Value { + pub fn request(&self) -> serde_json::Value { let model = self.embedding_model.name(); - let mut query = serde_json::json!({ + let mut request = serde_json::json!({ "model": model, + "input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER] }); if self.embedding_model.supports_overriding_dimensions() { if let Some(dimensions) = self.dimensions { - query["dimensions"] = dimensions.into(); + request["dimensions"] = dimensions.into(); } } - query + request } pub fn distribution(&self) -> Option { @@ -180,17 +181,23 @@ impl Embedder { let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned(); - let rest_embedder = RestEmbedder::new(RestEmbedderOptions { - api_key: Some(api_key.clone()), - distribution: None, - dimensions: Some(options.dimensions()), - url, - query: options.query(), - input_field: vec!["input".to_owned()], - input_type: crate::vector::rest::InputType::TextArray, - path_to_embeddings: vec!["data".to_owned()], - embedding_object: vec!["embedding".to_owned()], - })?; + let rest_embedder = RestEmbedder::new( + RestEmbedderOptions { + api_key: Some(api_key.clone()), + distribution: None, + dimensions: Some(options.dimensions()), + url, + request: options.request(), + response: serde_json::json!({ + "data": [{ + "embedding": super::rest::RESPONSE_PLACEHOLDER + }, + super::rest::REPEAT_PLACEHOLDER + ] + }), + }, + super::rest::ConfigurationSource::OpenAi, + )?; // looking at the code it is very unclear that this can actually fail. let tokenizer = tiktoken_rs::cl100k_base().unwrap(); @@ -201,7 +208,7 @@ impl Embedder { pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { match self.rest_embedder.embed_ref(&texts) { Ok(embeddings) => Ok(embeddings), - Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => { + Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); self.try_embed_tokenized(&texts) } @@ -225,7 +232,7 @@ impl Embedder { let embedding = self.rest_embedder.embed_tokens(tokens)?; embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { - EmbedError::openai_unexpected_dimension(self.dimensions(), got.len()) + EmbedError::rest_unexpected_dimension(self.dimensions(), got.len()) })?; all_embeddings.push(embeddings_for_prompt);