From 8708cbef2538d28c65b7511e9706b9c1a093762a Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 14 Mar 2024 14:44:43 +0100 Subject: [PATCH] Add RestEmbedder --- milli/src/vector/error.rs | 109 ++++++++++++++++++++++ milli/src/vector/mod.rs | 1 + milli/src/vector/rest.rs | 185 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 295 insertions(+) create mode 100644 milli/src/vector/rest.rs diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 1def4f7a9..b2eb37e81 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -83,6 +83,32 @@ pub enum EmbedErrorKind { OllamaModelNotFoundError(OllamaError), #[error("received unhandled HTTP status code {0} from Ollama")] OllamaUnhandledStatusCode(u16), + #[error("error serializing template context: {0}")] + RestTemplateContextSerialization(liquid::Error), + #[error( + "error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'" + )] + RestTemplateError(liquid::Error), + #[error("error deserialization the response body as JSON: {0}")] + RestResponseDeserialization(std::io::Error), + #[error("component `{0}` not found in path `{1}` in response: `{2}`")] + RestResponseMissingEmbeddings(String, String, String), + #[error("expected a response parseable as a vector or an array of vectors: {0}")] + RestResponseFormat(serde_json::Error), + #[error("expected a response containing {0} embeddings, got only {1}")] + RestResponseEmbeddingCount(usize, usize), + #[error("could not authenticate against embedding server: {0:?}")] + RestUnauthorized(Option), + #[error("sent too many requests to embedding server: {0:?}")] + RestTooManyRequests(Option), + #[error("sent a bad request to embedding server: {0:?}")] + RestBadRequest(Option), + #[error("received internal error from embedding server: {0:?}")] + RestInternalServerError(u16, Option), + #[error("received HTTP {0} from embedding server: {0:?}")] + RestOtherStatusCode(u16, Option), + #[error("could not reach embedding server: {0}")] + RestNetwork(ureq::Transport), } impl EmbedError { @@ -161,6 +187,89 @@ impl EmbedError { pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } } + + pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError { + Self { + kind: EmbedErrorKind::RestTemplateContextSerialization(error), + fault: FaultSource::Bug, + } + } + + pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError { + Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User } + } + + pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { + Self { + kind: EmbedErrorKind::RestResponseDeserialization(error), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_response_missing_embeddings>( + response: serde_json::Value, + component: &str, + response_field: &[S], + ) -> EmbedError { + let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect(); + let response_field = response_field.join("."); + + Self { + kind: EmbedErrorKind::RestResponseMissingEmbeddings( + component.to_owned(), + response_field, + serde_json::to_string_pretty(&response).unwrap_or_default(), + ), + fault: FaultSource::Undecided, + } + } + + pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError { + Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided } + } + + pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError { + Self { + kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_unauthorized(error_response: Option) -> EmbedError { + Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User } + } + + pub(crate) fn rest_too_many_requests(error_response: Option) -> EmbedError { + Self { + kind: EmbedErrorKind::RestTooManyRequests(error_response), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_bad_request(error_response: Option) -> EmbedError { + Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User } + } + + pub(crate) fn rest_internal_server_error( + code: u16, + error_response: Option, + ) -> EmbedError { + Self { + kind: EmbedErrorKind::RestInternalServerError(code, error_response), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_other_status_code(code: u16, error_response: Option) -> EmbedError { + Self { + kind: EmbedErrorKind::RestOtherStatusCode(code, error_response), + fault: FaultSource::Undecided, + } + } + + pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { + Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime } + } } #[derive(Debug, thiserror::Error)] diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 86dde8ad4..7eef3d442 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -11,6 +11,7 @@ pub mod openai; pub mod settings; pub mod ollama; +pub mod rest; pub use self::error::Error; diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs new file mode 100644 index 000000000..975bd3790 --- /dev/null +++ b/milli/src/vector/rest.rs @@ -0,0 +1,185 @@ +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; + +use super::openai::Retry; +use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError}; +use crate::VectorOrArrayOfVectors; + +pub struct Embedder { + client: ureq::Agent, + options: EmbedderOptions, + bearer: Option, + dimensions: usize, +} + +pub struct EmbedderOptions { + api_key: Option, + distribution: Option, + dimensions: Option, + url: String, + query: liquid::Template, + response_field: Vec, +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}")); + + let client = ureq::agent(); + + let dimensions = if let Some(dimensions) = options.dimensions { + dimensions + } else { + infer_dimensions(&client, &options, bearer.as_deref())? + }; + + Ok(Self { client, dimensions, options, bearer }) + } + + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice()) + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + threads: &rayon::ThreadPool, + ) -> Result>>, EmbedError> { + threads + .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) + .collect() + } + + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn distribution(&self) -> Option { + self.options.distribution + } +} + +fn infer_dimensions( + client: &ureq::Agent, + options: &EmbedderOptions, + bearer: Option<&str>, +) -> Result { + let v = embed(client, options, bearer, ["test"].as_slice()) + .map_err(NewEmbedderError::could_not_determine_dimension)?; + // unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error + Ok(v.first().unwrap().dimension()) +} + +fn embed( + client: &ureq::Agent, + options: &EmbedderOptions, + bearer: Option<&str>, + inputs: &[S], +) -> Result>, EmbedError> +where + S: serde::Serialize, +{ + let request = client.post(&options.url); + let request = + if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; + let request = request.set("Content-Type", "application/json"); + + let body = options + .query + .render( + &liquid::to_object(&serde_json::json!({ + "input": inputs, + })) + .map_err(EmbedError::rest_template_context_serialization)?, + ) + .map_err(EmbedError::rest_template_render)?; + + for attempt in 0..7 { + let response = request.send_string(&body); + let result = check_response(response); + + let retry_duration = match result { + Ok(response) => { + return response_to_embedding(response, &options.response_field, inputs.len()) + } + Err(retry) => { + tracing::warn!("Failed: {}", retry.error); + retry.into_duration(attempt) + } + }?; + + let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute + tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); + std::thread::sleep(retry_duration); + } + + let response = request.send_string(&body); + let result = check_response(response); + result + .map_err(Retry::into_error) + .and_then(|response| response_to_embedding(response, &options.response_field, inputs.len())) +} + +fn check_response(response: Result) -> Result { + match response { + Ok(response) => Ok(response), + Err(ureq::Error::Status(code, response)) => { + let error_response: Option = response.into_string().ok(); + Err(match code { + 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), + 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), + 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), + 500..=599 => { + Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) + } + x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), + }) + } + Err(ureq::Error::Transport(transport)) => { + Err(Retry::retry_later(EmbedError::rest_network(transport))) + } + } +} + +fn response_to_embedding>( + response: ureq::Response, + response_field: &[S], + expected_count: usize, +) -> Result>, EmbedError> { + let response: serde_json::Value = + response.into_json().map_err(EmbedError::rest_response_deserialization)?; + + let mut current_value = &response; + for component in response_field { + let component = component.as_ref(); + let current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings(response, component, response_field) + })?; + } + + let embeddings = current_value.to_owned(); + + let embeddings: VectorOrArrayOfVectors = + serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; + + let embeddings = embeddings.into_array_of_vectors(); + + let embeddings: Vec> = embeddings + .into_iter() + .flatten() + .map(|embedding| Embeddings::from_single_embedding(embedding)) + .collect(); + + if embeddings.len() != expected_count { + return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); + } + + Ok(embeddings) +}