4509: Rest embedder r=ManyTheFish a=dureuill

Fixes #4531 

See [Usage page](https://meilisearch.notion.site/v1-8-AI-search-API-usage-135552d6e85a4a52bc7109be82aeca42?pvs=25#e6f58c3b742c4effb4ddc625ce12ee16)

### Implementation changes

- Remove tokio, futures, reqwests
- Add a new `milli::vector::rest::Embedder` embedder
- Update OpenAI and Ollama embedders to use the REST embedder internally
- Make Embedder::embed a sync method
- Add the new embedder source as described in the usage


Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2024-03-27 09:27:46 +00:00 committed by GitHub
commit 34dfea72cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1048 additions and 764 deletions

5
Cargo.lock generated
View File

@ -3338,7 +3338,6 @@ dependencies = [
"filter-parser", "filter-parser",
"flatten-serde-json", "flatten-serde-json",
"fst", "fst",
"futures",
"fxhash", "fxhash",
"geoutils", "geoutils",
"grenad", "grenad",
@ -3362,7 +3361,6 @@ dependencies = [
"rand", "rand",
"rand_pcg", "rand_pcg",
"rayon", "rayon",
"reqwest",
"roaring", "roaring",
"rstar", "rstar",
"serde", "serde",
@ -3376,8 +3374,9 @@ dependencies = [
"tiktoken-rs", "tiktoken-rs",
"time", "time",
"tokenizers", "tokenizers",
"tokio",
"tracing", "tracing",
"ureq",
"url",
"uuid", "uuid",
] ]

View File

@ -353,6 +353,7 @@ impl ErrorCode for milli::Error {
| UserError::InvalidOpenAiModelDimensions { .. } | UserError::InvalidOpenAiModelDimensions { .. }
| UserError::InvalidOpenAiModelDimensionsMax { .. } | UserError::InvalidOpenAiModelDimensionsMax { .. }
| UserError::InvalidSettingsDimensions { .. } | UserError::InvalidSettingsDimensions { .. }
| UserError::InvalidUrl { .. }
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, | UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,

View File

@ -202,7 +202,7 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
@ -241,7 +241,7 @@ pub async fn search_with_post(
let features = index_scheduler.features(); let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
@ -260,7 +260,7 @@ pub async fn search_with_post(
Ok(HttpResponse::Ok().json(search_result)) Ok(HttpResponse::Ok().json(search_result))
} }
pub async fn embed( pub fn embed(
query: &mut SearchQuery, query: &mut SearchQuery,
index_scheduler: &IndexScheduler, index_scheduler: &IndexScheduler,
index: &milli::Index, index: &milli::Index,
@ -287,7 +287,6 @@ pub async fn embed(
let embeddings = embedder let embeddings = embedder
.embed(vec![q.to_owned()]) .embed(vec![q.to_owned()])
.await
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
.pop() .pop()

View File

@ -605,6 +605,7 @@ fn embedder_analytics(
EmbedderSource::HuggingFace => sources.insert("huggingFace"), EmbedderSource::HuggingFace => sources.insert("huggingFace"),
EmbedderSource::UserProvided => sources.insert("userProvided"), EmbedderSource::UserProvided => sources.insert("userProvided"),
EmbedderSource::Ollama => sources.insert("ollama"), EmbedderSource::Ollama => sources.insert("ollama"),
EmbedderSource::Rest => sources.insert("rest"),
}; };
} }
}; };

View File

@ -75,9 +75,8 @@ pub async fn multi_search_with_post(
}) })
.with_index(query_index)?; .with_index(query_index)?;
let distribution = embed(&mut query, index_scheduler.get_ref(), &index) let distribution =
.await embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?;
.with_index(query_index)?;
let search_result = tokio::task::spawn_blocking(move || { let search_result = tokio::task::spawn_blocking(move || {
perform_search(&index, query, features, distribution) perform_search(&index, query, features, distribution)

View File

@ -80,17 +80,13 @@ tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
"online", "online",
] } ] }
tokio = { version = "1.35.1", features = ["rt"] }
futures = "0.3.30"
reqwest = { version = "0.11.23", features = [
"rustls-tls",
"json",
], default-features = false }
tiktoken-rs = "0.5.8" tiktoken-rs = "0.5.8"
liquid = "0.26.4" liquid = "0.26.4"
arroy = "0.2.0" arroy = "0.2.0"
rand = "0.8.5" rand = "0.8.5"
tracing = "0.1.40" tracing = "0.1.40"
ureq = { version = "2.9.6", features = ["json"] }
url = "2.5.0"
[dev-dependencies] [dev-dependencies]
mimalloc = { version = "0.1.39", default-features = false } mimalloc = { version = "0.1.39", default-features = false }

View File

@ -243,6 +243,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
}, },
#[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")] #[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")]
InvalidSettingsDimensions { embedder_name: String }, InvalidSettingsDimensions { embedder_name: String },
#[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")]
InvalidUrl { embedder_name: String, inner_error: url::ParseError, url: String },
} }
impl From<crate::vector::Error> for Error { impl From<crate::vector::Error> for Error {

View File

@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>, prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters, indexer: GrenadParameters,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
request_threads: &rayon::ThreadPool,
) -> Result<grenad::Reader<BufReader<File>>> { ) -> Result<grenad::Reader<BufReader<File>>> {
puffin::profile_function!(); puffin::profile_function!();
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if chunks.len() == chunks.capacity() { if chunks.len() == chunks.capacity() {
let chunked_embeds = embedder let chunked_embeds = embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) .embed_chunks(
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
request_threads,
)
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
// send last chunk // send last chunk
if !chunks.is_empty() { if !chunks.is_empty() {
let chunked_embeds = embedder let chunked_embeds = embedder
.embed_chunks(std::mem::take(&mut chunks)) .embed_chunks(std::mem::take(&mut chunks), request_threads)
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids for (docid, embeddings) in chunks_ids
@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
if !current_chunk.is_empty() { if !current_chunk.is_empty() {
let embeds = embedder let embeds = embedder
.embed_chunks(vec![std::mem::take(&mut current_chunk)]) .embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;

View File

@ -238,6 +238,12 @@ fn send_original_documents_data(
let documents_chunk_cloned = original_documents_chunk.clone(); let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
let request_threads = rayon::ThreadPoolBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}"))
.build()?;
rayon::spawn(move || { rayon::spawn(move || {
for (name, (embedder, prompt)) in embedders { for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points( let result = extract_vector_points(
@ -249,7 +255,12 @@ fn send_original_documents_data(
); );
match result { match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { let embeddings = match extract_embeddings(
prompts,
indexer,
embedder.clone(),
&request_threads,
) {
Ok(results) => Some(results), Ok(results) => Some(results),
Err(error) => { Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error)); let _ = lmdb_writer_sx_cloned.send(Err(error));

View File

@ -2646,6 +2646,12 @@ mod tests {
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::Set(3), dimensions: Setting::Set(3),
document_template: Setting::NotSet, document_template: Setting::NotSet,
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
}), }),
); );
settings.set_embedder_settings(embedders); settings.set_embedder_settings(embedders);

View File

@ -1140,6 +1140,12 @@ fn validate_prompt(
api_key, api_key,
dimensions, dimensions,
document_template: Setting::Set(template), document_template: Setting::Set(template),
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}) => { }) => {
// validate // validate
let template = crate::prompt::Prompt::new(template) let template = crate::prompt::Prompt::new(template)
@ -1153,6 +1159,12 @@ fn validate_prompt(
api_key, api_key,
dimensions, dimensions,
document_template: Setting::Set(template), document_template: Setting::Set(template),
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
})) }))
} }
new => Ok(new), new => Ok(new),
@ -1165,8 +1177,20 @@ pub fn validate_embedding_settings(
) -> Result<Setting<EmbeddingSettings>> { ) -> Result<Setting<EmbeddingSettings>> {
let settings = validate_prompt(name, settings)?; let settings = validate_prompt(name, settings)?;
let Setting::Set(settings) = settings else { return Ok(settings) }; let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = let EmbeddingSettings {
settings; source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = settings;
if let Some(0) = dimensions.set() { if let Some(0) = dimensions.set() {
return Err(crate::error::UserError::InvalidSettingsDimensions { return Err(crate::error::UserError::InvalidSettingsDimensions {
@ -1175,6 +1199,14 @@ pub fn validate_embedding_settings(
.into()); .into());
} }
if let Some(url) = url.as_ref().set() {
url::Url::parse(url).map_err(|error| crate::error::UserError::InvalidUrl {
embedder_name: name.to_owned(),
inner_error: error,
url: url.to_owned(),
})?;
}
let Some(inferred_source) = source.set() else { let Some(inferred_source) = source.set() else {
return Ok(Setting::Set(EmbeddingSettings { return Ok(Setting::Set(EmbeddingSettings {
source, source,
@ -1183,11 +1215,25 @@ pub fn validate_embedding_settings(
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
})); }));
}; };
match inferred_source { match inferred_source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
check_unset(&revision, "revision", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
if let Setting::Set(model) = &model { if let Setting::Set(model) = &model {
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str()) let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
.ok_or(crate::error::UserError::InvalidOpenAiModel { .ok_or(crate::error::UserError::InvalidOpenAiModel {
@ -1224,10 +1270,24 @@ pub fn validate_embedding_settings(
check_set(&model, "model", inferred_source, name)?; check_set(&model, "model", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?; check_unset(&dimensions, "dimensions", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
} }
EmbedderSource::UserProvided => { EmbedderSource::UserProvided => {
check_unset(&model, "model", inferred_source, name)?; check_unset(&model, "model", inferred_source, name)?;
@ -1235,6 +1295,18 @@ pub fn validate_embedding_settings(
check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&document_template, "documentTemplate", inferred_source, name)?; check_unset(&document_template, "documentTemplate", inferred_source, name)?;
check_set(&dimensions, "dimensions", inferred_source, name)?; check_set(&dimensions, "dimensions", inferred_source, name)?;
check_unset(&url, "url", inferred_source, name)?;
check_unset(&query, "query", inferred_source, name)?;
check_unset(&input_field, "inputField", inferred_source, name)?;
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
check_unset(&input_type, "inputType", inferred_source, name)?;
}
EmbedderSource::Rest => {
check_unset(&model, "model", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_set(&url, "url", inferred_source, name)?;
} }
} }
Ok(Setting::Set(EmbeddingSettings { Ok(Setting::Set(EmbeddingSettings {
@ -1244,6 +1316,12 @@ pub fn validate_embedding_settings(
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
})) }))
} }

View File

@ -2,9 +2,7 @@ use std::path::PathBuf;
use hf_hub::api::sync::ApiError; use hf_hub::api::sync::ApiError;
use super::ollama::OllamaError;
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::vector::openai::OpenAiError;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("Error while generating embeddings: {inner}")] #[error("Error while generating embeddings: {inner}")]
@ -52,37 +50,34 @@ pub enum EmbedErrorKind {
TensorValue(candle_core::Error), TensorValue(candle_core::Error),
#[error("could not run model: {0}")] #[error("could not run model: {0}")]
ModelForward(candle_core::Error), ModelForward(candle_core::Error),
#[error("could not reach OpenAI: {0}")]
OpenAiNetwork(reqwest::Error),
#[error("unexpected response from OpenAI: {0}")]
OpenAiUnexpected(reqwest::Error),
#[error("could not authenticate against OpenAI: {0}")]
OpenAiAuth(OpenAiError),
#[error("sent too many requests to OpenAI: {0}")]
OpenAiTooManyRequests(OpenAiError),
#[error("received internal error from OpenAI: {0:?}")]
OpenAiInternalServerError(Option<OpenAiError>),
#[error("sent too many tokens in a request to OpenAI: {0}")]
OpenAiTooManyTokens(OpenAiError),
#[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String), ManualEmbed(String),
#[error("could not initialize asynchronous runtime: {0}")] #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")]
OpenAiRuntimeInit(std::io::Error), OllamaModelNotFoundError(Option<String>),
#[error("initializing web client for sending embedding requests failed: {0}")] #[error("error deserialization the response body as JSON: {0}")]
InitWebClient(reqwest::Error), RestResponseDeserialization(std::io::Error),
// Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends. #[error("component `{0}` not found in path `{1}` in response: `{2}`")]
#[error("unexpected response from Ollama: {0}")] RestResponseMissingEmbeddings(String, String, String),
OllamaUnexpected(reqwest::Error), #[error("expected a response parseable as a vector or an array of vectors: {0}")]
#[error("sent too many requests to Ollama: {0}")] RestResponseFormat(serde_json::Error),
OllamaTooManyRequests(OllamaError), #[error("expected a response containing {0} embeddings, got only {1}")]
#[error("received internal error from Ollama: {0}")] RestResponseEmbeddingCount(usize, usize),
OllamaInternalServerError(OllamaError), #[error("could not authenticate against embedding server: {0:?}")]
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] RestUnauthorized(Option<String>),
OllamaModelNotFoundError(OllamaError), #[error("sent too many requests to embedding server: {0:?}")]
#[error("received unhandled HTTP status code {0} from Ollama")] RestTooManyRequests(Option<String>),
OllamaUnhandledStatusCode(u16), #[error("sent a bad request to embedding server: {0:?}")]
RestBadRequest(Option<String>),
#[error("received internal error from embedding server: {0:?}")]
RestInternalServerError(u16, Option<String>),
#[error("received HTTP {0} from embedding server: {0:?}")]
RestOtherStatusCode(u16, Option<String>),
#[error("could not reach embedding server: {0}")]
RestNetwork(ureq::Transport),
#[error("was expected '{}' to be an object in query '{0}'", .1.join("."))]
RestNotAnObject(serde_json::Value, Vec<String>),
#[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")]
OpenAiUnexpectedDimension(usize, usize),
} }
impl EmbedError { impl EmbedError {
@ -102,64 +97,98 @@ impl EmbedError {
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
} }
pub fn openai_network(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
}
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
}
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_internal_server_error(inner: Option<OpenAiError>) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
}
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
} }
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug }
}
pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
} }
pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } Self {
kind: EmbedErrorKind::RestResponseDeserialization(error),
fault: FaultSource::Runtime,
}
} }
pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError { pub(crate) fn rest_response_missing_embeddings<S: AsRef<str>>(
Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime } response: serde_json::Value,
component: &str,
response_field: &[S],
) -> EmbedError {
let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect();
let response_field = response_field.join(".");
Self {
kind: EmbedErrorKind::RestResponseMissingEmbeddings(
component.to_owned(),
response_field,
serde_json::to_string_pretty(&response).unwrap_or_default(),
),
fault: FaultSource::Undecided,
}
} }
pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided }
}
pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_unauthorized(error_response: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User }
}
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestTooManyRequests(error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_bad_request(error_response: Option<String>) -> EmbedError {
Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User }
}
pub(crate) fn rest_internal_server_error(
code: u16,
error_response: Option<String>,
) -> EmbedError {
Self {
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
fault: FaultSource::Runtime,
}
}
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
Self {
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
fault: FaultSource::Undecided,
}
}
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
}
pub(crate) fn rest_not_an_object(
query: serde_json::Value,
input_path: Vec<String>,
) -> EmbedError {
Self { kind: EmbedErrorKind::RestNotAnObject(query, input_path), fault: FaultSource::User }
}
pub(crate) fn openai_unexpected_dimension(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::OpenAiUnexpectedDimension(expected, got),
fault: FaultSource::Runtime,
}
} }
} }
@ -220,23 +249,12 @@ impl NewEmbedderError {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
} }
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self { Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime, fault: FaultSource::Runtime,
} }
} }
pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::User,
}
}
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -283,7 +301,4 @@ pub enum NewEmbedderErrorKind {
CouldNotDetermineDimension(EmbedError), CouldNotDetermineDimension(EmbedError),
#[error("loading model failed: {0}")] #[error("loading model failed: {0}")]
LoadModel(candle_core::Error), LoadModel(candle_core::Error),
// openai
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
} }

View File

@ -131,7 +131,7 @@ impl Embedder {
let embeddings = this let embeddings = this
.embed(vec!["test".into()]) .embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension(); this.dimensions = embeddings.first().unwrap().dimension();
Ok(this) Ok(this)
@ -194,7 +194,10 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" { if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) Some(DistributionShift {
current_mean: ordered_float::OrderedFloat(0.85),
current_sigma: ordered_float::OrderedFloat(0.1),
})
} else { } else {
None None
} }

View File

@ -1,6 +1,9 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError}; use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData}; use crate::prompt::{Prompt, PromptData};
@ -11,51 +14,70 @@ pub mod openai;
pub mod settings; pub mod settings;
pub mod ollama; pub mod ollama;
pub mod rest;
pub use self::error::Error; pub use self::error::Error;
pub type Embedding = Vec<f32>; pub type Embedding = Vec<f32>;
pub const REQUEST_PARALLELISM: usize = 40;
/// One or multiple embeddings stored consecutively in a flat vector.
pub struct Embeddings<F> { pub struct Embeddings<F> {
data: Vec<F>, data: Vec<F>,
dimension: usize, dimension: usize,
} }
impl<F> Embeddings<F> { impl<F> Embeddings<F> {
/// Declares an empty vector of embeddings of the specified dimensions.
pub fn new(dimension: usize) -> Self { pub fn new(dimension: usize) -> Self {
Self { data: Default::default(), dimension } Self { data: Default::default(), dimension }
} }
/// Declares a vector of embeddings containing a single element.
///
/// The dimension is inferred from the length of the passed embedding.
pub fn from_single_embedding(embedding: Vec<F>) -> Self { pub fn from_single_embedding(embedding: Vec<F>) -> Self {
Self { dimension: embedding.len(), data: embedding } Self { dimension: embedding.len(), data: embedding }
} }
/// Declares a vector of embeddings from its components.
///
/// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> { pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
let mut this = Self::new(dimension); let mut this = Self::new(dimension);
this.append(data)?; this.append(data)?;
Ok(this) Ok(this)
} }
/// Returns the number of embeddings in this vector of embeddings.
pub fn embedding_count(&self) -> usize { pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension self.data.len() / self.dimension
} }
/// Dimension of a single embedding.
pub fn dimension(&self) -> usize { pub fn dimension(&self) -> usize {
self.dimension self.dimension
} }
/// Deconstructs self into the inner flat vector.
pub fn into_inner(self) -> Vec<F> { pub fn into_inner(self) -> Vec<F> {
self.data self.data
} }
/// A reference to the inner flat vector.
pub fn as_inner(&self) -> &[F] { pub fn as_inner(&self) -> &[F] {
&self.data &self.data
} }
/// Iterates over the embeddings contained in the flat vector.
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ { pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
self.data.as_slice().chunks_exact(self.dimension) self.data.as_slice().chunks_exact(self.dimension)
} }
/// Push an embedding at the end of the embeddings.
///
/// If `embedding.len() != self.dimension`, then the push operation fails.
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> { pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
if embedding.len() != self.dimension { if embedding.len() != self.dimension {
return Err(embedding); return Err(embedding);
@ -64,6 +86,9 @@ impl<F> Embeddings<F> {
Ok(()) Ok(())
} }
/// Append a flat vector of embeddings a the end of the embeddings.
///
/// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> { pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
if embeddings.len() % self.dimension != 0 { if embeddings.len() % self.dimension != 0 {
return Err(embeddings); return Err(embeddings);
@ -73,37 +98,60 @@ impl<F> Embeddings<F> {
} }
} }
/// An embedder can be used to transform text into embeddings.
#[derive(Debug)] #[derive(Debug)]
pub enum Embedder { pub enum Embedder {
/// An embedder based on running local models, fetched from the Hugging Face Hub.
HuggingFace(hf::Embedder), HuggingFace(hf::Embedder),
/// An embedder based on making embedding queries against the OpenAI API.
OpenAi(openai::Embedder), OpenAi(openai::Embedder),
/// An embedder based on the user providing the embeddings in the documents and queries.
UserProvided(manual::Embedder), UserProvided(manual::Embedder),
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
Ollama(ollama::Embedder), Ollama(ollama::Embedder),
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
Rest(rest::Embedder),
} }
/// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig { pub struct EmbeddingConfig {
/// Options of the embedder, specific to each kind of embedder
pub embedder_options: EmbedderOptions, pub embedder_options: EmbedderOptions,
/// Document template
pub prompt: PromptData, pub prompt: PromptData,
// TODO: add metrics and anything needed // TODO: add metrics and anything needed
} }
/// Map of embedder configurations.
///
/// Each configuration is mapped to a name.
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>); pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
impl EmbeddingConfigs { impl EmbeddingConfigs {
/// Create the map from its internal component.s
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self { pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
Self(data) Self(data)
} }
/// Get an embedder configuration and template from its name.
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> { pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.0.get(name).cloned() self.0.get(name).cloned()
} }
/// Get the default embedder configuration, if any.
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> { pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.get_default_embedder_name().and_then(|default| self.get(&default)) self.get_default_embedder_name().and_then(|default| self.get(&default))
} }
/// Get the name of the default embedder configuration.
///
/// The default embedder is determined as follows:
///
/// - If there is only one embedder, it is always the default.
/// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder.
/// - In all other cases, there is no default embedder.
pub fn get_default_embedder_name(&self) -> Option<String> { pub fn get_default_embedder_name(&self) -> Option<String> {
let mut it = self.0.keys(); let mut it = self.0.keys();
let first_name = it.next(); let first_name = it.next();
@ -126,12 +174,14 @@ impl IntoIterator for EmbeddingConfigs {
} }
} }
/// Options of an embedder, specific to each kind of embedder.
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions { pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions), HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions), OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions), Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions), UserProvided(manual::EmbedderOptions),
Rest(rest::EmbedderOptions),
} }
impl Default for EmbedderOptions { impl Default for EmbedderOptions {
@ -141,10 +191,12 @@ impl Default for EmbedderOptions {
} }
impl EmbedderOptions { impl EmbedderOptions {
/// Default options for the Hugging Face embedder
pub fn huggingface() -> Self { pub fn huggingface() -> Self {
Self::HuggingFace(hf::EmbedderOptions::new()) Self::HuggingFace(hf::EmbedderOptions::new())
} }
/// Default options for the OpenAI embedder
pub fn openai(api_key: Option<String>) -> Self { pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
} }
@ -155,6 +207,7 @@ impl EmbedderOptions {
} }
impl Embedder { impl Embedder {
/// Spawns a new embedder built from its options.
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options { Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
@ -163,83 +216,133 @@ impl Embedder {
EmbedderOptions::UserProvided(options) => { EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options)) Self::UserProvided(manual::Embedder::new(options))
} }
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?),
}) })
} }
pub async fn embed( /// Embed one or multiple texts.
///
/// Each text can be embedded as one or multiple embeddings.
pub fn embed(
&self, &self,
texts: Vec<String>, texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => { Embedder::OpenAi(embedder) => embedder.embed(texts),
let client = embedder.new_client()?; Embedder::Ollama(embedder) => embedder.embed(texts),
embedder.embed(texts, &client).await
}
Embedder::Ollama(embedder) => {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::UserProvided(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts),
Embedder::Rest(embedder) => embedder.embed(texts),
} }
} }
/// # Panics /// Embed multiple chunks of texts.
/// ///
/// - if called from an asynchronous context /// Each chunk is composed of one or multiple texts.
pub fn embed_chunks( pub fn embed_chunks(
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks), Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
} }
} }
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(), Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1, Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
} }
} }
/// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
pub fn prompt_count_in_chunk_hint(&self) -> usize { pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1, Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
} }
} }
/// Indicates the dimensions of a single embedding produced by the embedder.
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(), Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(), Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(), Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(), Embedder::UserProvided(embedder) => embedder.dimensions(),
Embedder::Rest(embedder) => embedder.dimensions(),
} }
} }
/// An optional distribution used to apply an affine transformation to the similarity score of a document.
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.distribution(), Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(), Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(), Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None, Embedder::UserProvided(_embedder) => None,
Embedder::Rest(embedder) => embedder.distribution(),
} }
} }
} }
#[derive(Debug, Clone, Copy)] /// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.
/// This allows to correct effects where results are too "packed" around a certain value.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(from = "DistributionShiftSerializable")]
#[serde(into = "DistributionShiftSerializable")]
pub struct DistributionShift { pub struct DistributionShift {
pub current_mean: f32, /// Value where the results are "packed".
pub current_sigma: f32, ///
/// Similarity scores are translated so that they are packed around 0.5 instead
pub current_mean: OrderedFloat<f32>,
/// standard deviation of a similarity score.
///
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
pub current_sigma: OrderedFloat<f32>,
}
#[derive(Serialize, Deserialize)]
struct DistributionShiftSerializable {
current_mean: f32,
current_sigma: f32,
}
impl From<DistributionShift> for DistributionShiftSerializable {
fn from(
DistributionShift {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
Self { current_mean, current_sigma }
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
fn from(
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
) -> Self {
Self {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}
}
} }
impl DistributionShift { impl DistributionShift {
@ -248,11 +351,13 @@ impl DistributionShift {
if sigma <= 0.0 { if sigma <= 0.0 {
None None
} else { } else {
Some(Self { current_mean: mean, current_sigma: sigma }) Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
} }
} }
pub fn shift(&self, score: f32) -> f32 { pub fn shift(&self, score: f32) -> f32 {
let current_mean = self.current_mean.0;
let current_sigma = self.current_sigma.0;
// <https://math.stackexchange.com/a/2894689> // <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian. // We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution. // The parameters we're given is the mean and sigma of the native result distribution.
@ -262,9 +367,9 @@ impl DistributionShift {
let target_sigma = 0.4; let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma; let factor = target_sigma / current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1 // a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean); let offset = target_mean - (factor * current_mean);
let mut score = factor * score + offset; let mut score = factor * score + offset;
@ -280,6 +385,7 @@ impl DistributionShift {
} }
} }
/// Whether CUDA is supported in this version of Meilisearch.
pub const fn is_cuda_enabled() -> bool { pub const fn is_cuda_enabled() -> bool {
cfg!(feature = "cuda") cfg!(feature = "cuda")
} }

View File

@ -1,293 +1,94 @@
// Copied from "openai.rs" with the sections I actually understand changed for Ollama. use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
// The common components of the Ollama and OpenAI interfaces might need to be extracted.
use std::fmt::Display; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use reqwest::StatusCode; use super::{DistributionShift, Embeddings};
use super::error::{EmbedError, NewEmbedderError};
use super::openai::Retry;
use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)] #[derive(Debug)]
pub struct Embedder { pub struct Embedder {
headers: reqwest::header::HeaderMap, rest_embedder: RestEmbedder,
options: EmbedderOptions,
} }
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub embedding_model: EmbeddingModel, pub embedding_model: String,
}
#[derive(
Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr,
)]
#[deserr(deny_unknown_fields)]
pub struct EmbeddingModel {
name: String,
dimensions: usize,
}
#[derive(Debug, serde::Serialize)]
struct OllamaRequest<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Debug, serde::Deserialize)]
struct OllamaResponse {
embedding: Embedding,
}
#[derive(Debug, serde::Deserialize)]
pub struct OllamaError {
error: String,
}
impl EmbeddingModel {
pub fn max_token(&self) -> usize {
// this might not be the same for all models
8192
}
pub fn default_dimensions(&self) -> usize {
// Dimensions for nomic-embed-text
768
}
pub fn name(&self) -> String {
self.name.clone()
}
pub fn from_name(name: &str) -> Self {
Self { name: name.to_string(), dimensions: 0 }
}
pub fn supports_overriding_dimensions(&self) -> bool {
false
}
}
impl Default for EmbeddingModel {
fn default() -> Self {
Self { name: "nomic-embed-text".to_string(), dimensions: 0 }
}
} }
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model() -> Self { pub fn with_default_model() -> Self {
Self { embedding_model: Default::default() } Self { embedding_model: "nomic-embed-text".into() }
} }
pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { pub fn with_embedding_model(embedding_model: String) -> Self {
Self { embedding_model } Self { embedding_model }
} }
} }
impl Embedder { impl Embedder {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
reqwest::ClientBuilder::new()
.default_headers(self.headers.clone())
.build()
.map_err(EmbedError::openai_initialize_web_client)
}
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new(); let model = options.embedding_model.as_str();
headers.insert( let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
reqwest::header::CONTENT_TYPE, api_key: None,
reqwest::header::HeaderValue::from_static("application/json"), distribution: None,
); dimensions: None,
url: get_ollama_path(),
let mut embedder = Self { options, headers }; query: serde_json::json!({
"model": model,
let rt = tokio::runtime::Builder::new_current_thread() }),
.enable_io() input_field: vec!["prompt".to_owned()],
.enable_time() path_to_embeddings: Default::default(),
.build() embedding_object: vec!["embedding".to_owned()],
.map_err(EmbedError::openai_runtime_init) input_type: super::rest::InputType::Text,
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; }) {
Ok(embedder) => embedder,
// Get dimensions from Ollama Err(NewEmbedderError {
let request = kind:
OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
// TODO: Refactor into shared error type kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
let client = embedder fault: _,
.new_client() }),
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; fault: _,
}) => {
rt.block_on(async move { return Err(NewEmbedderError::could_not_determine_dimension(
let response = client EmbedError::ollama_model_not_found(error),
.post(get_ollama_path()) ))
.json(&request)
.send()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
// Process error in case model not found
let response = Self::check_response(response).await.map_err(|_err| {
let e = EmbedError::ollama_model_not_found(OllamaError {
error: format!("model: {}", embedder.options.embedding_model.name()),
});
NewEmbedderError::ollama_could_not_determine_dimension(e)
})?;
let response: OllamaResponse = response
.json()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
let embedding = Embeddings::from_single_embedding(response.embedding);
embedder.options.embedding_model.dimensions = embedding.dimension();
tracing::info!(
"ollama model {} with dimensionality {} added",
embedder.options.embedding_model.name(),
embedding.dimension()
);
Ok(embedder)
})
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
// Not the same number of possible error cases covered as with OpenAI.
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OllamaError = response
.json()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests(
OllamaError { error: error_response.error },
)));
}
StatusCode::SERVICE_UNAVAILABLE => {
let error_response: OllamaError = response
.json()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::retry_later(EmbedError::ollama_internal_server_error(
OllamaError { error: error_response.error },
)));
}
StatusCode::NOT_FOUND => {
let error_response: OllamaError = response
.json()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(Retry::give_up)?;
return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError {
error: error_response.error,
})));
}
code => {
return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code(
code.as_u16(),
)));
}
} }
} Err(error) => return Err(error),
Ok(response) };
Ok(Self { rest_embedder })
} }
pub async fn embed( pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
&self, match self.rest_embedder.embed(texts) {
texts: Vec<String>, Ok(embeddings) => Ok(embeddings),
client: &reqwest::Client, Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
) -> Result<Vec<Embeddings<f32>>, EmbedError> { Err(EmbedError::ollama_model_not_found(error))
// Ollama only embedds one document at a time.
let mut results = Vec::with_capacity(texts.len());
// The retry loop is inside the texts loop, might have to switch that around
for text in texts {
// Retries copied from openai.rs
for attempt in 0..7 {
let retry_duration = match self.try_embed(&text, client).await {
Ok(result) => {
results.push(result);
break;
}
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt)
}
}?;
tracing::warn!(
"Attempt #{}, retrying after {}ms.",
attempt,
retry_duration.as_millis()
);
tokio::time::sleep(retry_duration).await;
} }
Err(error) => Err(error),
} }
Ok(results)
}
async fn try_embed(
&self,
text: &str,
client: &reqwest::Client,
) -> Result<Embeddings<f32>, Retry> {
let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text };
let response = client
.post(get_ollama_path())
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OllamaResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::trace!("response: {:?}", response.embedding);
let embedding = Embeddings::from_single_embedding(response.embedding);
Ok(embedding)
} }
pub fn embed_chunks( pub fn embed_chunks(
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
let rt = tokio::runtime::Builder::new_current_thread() threads.install(move || {
.enable_io() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
.enable_time() })
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
} }
// Defaults copied from openai.rs
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {
10 self.rest_embedder.chunk_count_hint()
} }
pub fn prompt_count_in_chunk_hint(&self) -> usize { pub fn prompt_count_in_chunk_hint(&self) -> usize {
10 self.rest_embedder.prompt_count_in_chunk_hint()
} }
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions self.rest_embedder.dimensions()
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
@ -295,12 +96,6 @@ impl Embedder {
} }
} }
impl Display for OllamaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)
}
}
fn get_ollama_path() -> String { fn get_ollama_path() -> String {
// Important: Hostname not enough, has to be entire path to embeddings endpoint // Important: Hostname not enough, has to be entire path to embeddings endpoint
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())

View File

@ -1,17 +1,10 @@
use std::fmt::Display; use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use super::error::{EmbedError, NewEmbedderError}; use super::error::{EmbedError, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
#[derive(Debug)] use crate::vector::error::EmbedErrorKind;
pub struct Embedder {
headers: reqwest::header::HeaderMap,
tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
@ -20,6 +13,32 @@ pub struct EmbedderOptions {
pub dimensions: Option<usize>, pub dimensions: Option<usize>,
} }
impl EmbedderOptions {
pub fn dimensions(&self) -> usize {
if self.embedding_model.supports_overriding_dimensions() {
self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
} else {
self.embedding_model.default_dimensions()
}
}
pub fn query(&self) -> serde_json::Value {
let model = self.embedding_model.name();
let mut query = serde_json::json!({
"model": model,
});
if self.embedding_model.supports_overriding_dimensions() {
if let Some(dimensions) = self.dimensions {
query["dimensions"] = dimensions.into();
}
}
query
}
}
#[derive( #[derive(
Debug, Debug,
Clone, Clone,
@ -92,15 +111,18 @@ impl EmbeddingModel {
fn distribution(&self) -> Option<DistributionShift> { fn distribution(&self) -> Option<DistributionShift> {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => { EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) current_mean: OrderedFloat(0.90),
} current_sigma: OrderedFloat(0.08),
EmbeddingModel::TextEmbedding3Large => { }),
Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 }) EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
} current_mean: OrderedFloat(0.70),
EmbeddingModel::TextEmbedding3Small => { current_sigma: OrderedFloat(0.1),
Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 }) }),
} EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
current_mean: OrderedFloat(0.75),
current_sigma: OrderedFloat(0.1),
}),
} }
} }
@ -125,178 +147,57 @@ impl EmbedderOptions {
} }
} }
impl Embedder { fn infer_api_key() -> String {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> { std::env::var("MEILI_OPENAI_API_KEY")
reqwest::ClientBuilder::new() .or_else(|_| std::env::var("OPENAI_API_KEY"))
.default_headers(self.headers.clone()) .unwrap_or_default()
.build() }
.map_err(EmbedError::openai_initialize_web_client)
}
#[derive(Debug)]
pub struct Embedder {
tokenizer: tiktoken_rs::CoreBPE,
rest_embedder: RestEmbedder,
options: EmbedderOptions,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default(); let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| { let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key(); inferred_api_key = infer_api_key();
&inferred_api_key &inferred_api_key
}); });
headers.insert(
reqwest::header::AUTHORIZATION, let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) api_key: Some(api_key.clone()),
.map_err(NewEmbedderError::openai_invalid_api_key_format)?, distribution: options.embedding_model.distribution(),
); dimensions: Some(options.dimensions()),
headers.insert( url: OPENAI_EMBEDDINGS_URL.to_owned(),
reqwest::header::CONTENT_TYPE, query: options.query(),
reqwest::header::HeaderValue::from_static("application/json"), input_field: vec!["input".to_owned()],
); input_type: crate::vector::rest::InputType::TextArray,
path_to_embeddings: vec!["data".to_owned()],
embedding_object: vec!["embedding".to_owned()],
})?;
// looking at the code it is very unclear that this can actually fail. // looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap(); let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, headers, tokenizer }) Ok(Self { options, rest_embedder, tokenizer })
} }
pub async fn embed( pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
&self, match self.rest_embedder.embed_ref(&texts) {
texts: Vec<String>, Ok(embeddings) => Ok(embeddings),
client: &reqwest::Client, Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => {
) -> Result<Vec<Embeddings<f32>>, EmbedError> { tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
let mut tokenized = false; self.try_embed_tokenized(&texts)
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts, client).await
};
let retry_duration = match result {
Ok(embeddings) => return Ok(embeddings),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
tokenized |= retry.must_tokenize();
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
tracing::warn!(
"Attempt #{}, retrying after {}ms.",
attempt,
retry_duration.as_millis()
);
tokio::time::sleep(retry_duration).await;
}
let result = if tokenized {
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts, client).await
};
result.map_err(Retry::into_error)
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
match response.status() {
StatusCode::UNAUTHORIZED => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::give_up(EmbedError::openai_auth_error(
error_response.error,
)));
}
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
error_response.error,
)));
}
StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE => {
let error_response: Result<OpenAiErrorResponse, _> = response.json().await;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.ok().map(|error_response| error_response.error),
)));
}
StatusCode::BAD_REQUEST => {
// Most probably, one text contained too many tokens
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
error_response.error,
)));
}
code => {
return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code(
code.as_u16(),
)));
}
} }
Err(error) => Err(error),
} }
Ok(response)
} }
async fn try_embed<S: AsRef<str> + serde::Serialize>( fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
&self,
texts: &[S],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
tracing::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::trace!("response: {:?}", response.data);
Ok(response
.data
.into_iter()
.map(|data| Embeddings::from_single_embedding(data.embedding))
.collect())
}
async fn try_embed_tokenized(
&self,
text: &[String],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200; pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len()); let mut all_embeddings = Vec::with_capacity(text.len());
for text in text { for text in text {
@ -304,7 +205,7 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text.as_str()); let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len(); let len = encoded.len();
if len < max_token_count { if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text], client).await?); all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
continue; continue;
} }
@ -312,215 +213,49 @@ impl Embedder {
let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
while tokens.len() > max_token_count { while tokens.len() > max_token_count {
let window = &tokens[..max_token_count]; let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); let embedding = self.rest_embedder.embed_tokens(window)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
})?;
tokens = &tokens[max_token_count - OVERLAP_SIZE..]; tokens = &tokens[max_token_count - OVERLAP_SIZE..];
} }
// end of text // end of text
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap(); let embedding = self.rest_embedder.embed_tokens(tokens)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
})?;
all_embeddings.push(embeddings_for_prompt); all_embeddings.push(embeddings_for_prompt);
} }
Ok(all_embeddings) Ok(all_embeddings)
} }
async fn embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens, client).await {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
.map_err(Retry::retry_later)?;
tokio::time::sleep(duration).await;
}
self.try_embed_tokens(tokens, client)
.await
.map_err(|retry| Retry::give_up(retry.into_error()))
}
async fn try_embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
let request = OpenAiTokensRequest {
model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let mut response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub fn embed_chunks( pub fn embed_chunks(
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
let rt = tokio::runtime::Builder::new_current_thread() threads.install(move || {
.enable_io() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
.enable_time() })
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
} }
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {
10 self.rest_embedder.chunk_count_hint()
} }
pub fn prompt_count_in_chunk_hint(&self) -> usize { pub fn prompt_count_in_chunk_hint(&self) -> usize {
10 self.rest_embedder.prompt_count_in_chunk_hint()
} }
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
if self.options.embedding_model.supports_overriding_dimensions() { self.options.dimensions()
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution() self.options.embedding_model.distribution()
} }
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
}
// retrying in case of failure
pub struct Retry {
pub error: EmbedError,
strategy: RetryStrategy,
}
pub enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
pub fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
pub fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
pub fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
pub fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
pub fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
pub fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
pub fn into_error(self) -> EmbedError {
self.error
}
}
// openai api structs
#[derive(Debug, Serialize)]
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str,
input: &'a [S],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> {
model: &'a str,
input: &'a [usize],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct OpenAiResponse {
data: Vec<OpenAiEmbedding>,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorResponse {
error: OpenAiError,
}
#[derive(Debug, Deserialize)]
pub struct OpenAiError {
message: String,
// type: String,
code: Option<String>,
}
impl Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.code {
Some(code) => write!(f, "{} ({})", self.message, code),
None => write!(f, "{}", self.message),
}
}
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedding {
embedding: Embedding,
// object: String,
// index: usize,
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
} }

373
milli/src/vector/rest.rs Normal file
View File

@ -0,0 +1,373 @@
use deserr::Deserr;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use serde::{Deserialize, Serialize};
use super::{
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
};
// retrying in case of failure
pub struct Retry {
pub error: EmbedError,
strategy: RetryStrategy,
}
pub enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
pub fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
pub fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
pub fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
pub fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
pub fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
pub fn into_error(self) -> EmbedError {
self.error
}
}
#[derive(Debug)]
pub struct Embedder {
client: ureq::Agent,
options: EmbedderOptions,
bearer: Option<String>,
dimensions: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>,
pub url: String,
pub query: serde_json::Value,
pub input_field: Vec<String>,
// path to the array of embeddings
pub path_to_embeddings: Vec<String>,
// shape of a single embedding
pub embedding_object: Vec<String>,
pub input_type: InputType,
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self {
url: Default::default(),
query: Default::default(),
input_field: vec!["input".into()],
path_to_embeddings: vec!["data".into()],
embedding_object: vec!["embedding".into()],
input_type: InputType::Text,
api_key: None,
distribution: None,
dimensions: None,
}
}
}
impl std::hash::Hash for EmbedderOptions {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.api_key.hash(state);
self.distribution.hash(state);
self.dimensions.hash(state);
self.url.hash(state);
// skip hashing the query
// collisions in regular usage should be minimal,
// and the list is limited to 256 values anyway
self.input_field.hash(state);
self.path_to_embeddings.hash(state);
self.embedding_object.hash(state);
self.input_type.hash(state);
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
#[serde(rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum InputType {
Text,
TextArray,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
let client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();
let dimensions = if let Some(dimensions) = options.dimensions {
dimensions
} else {
infer_dimensions(&client, &options, bearer.as_deref())?
};
Ok(Self { client, dimensions, options, bearer })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len())
}
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
where
S: AsRef<str> + Serialize,
{
embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len())
}
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
}
pub fn chunk_count_hint(&self) -> usize {
super::REQUEST_PARALLELISM
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self.options.input_type {
InputType::Text => 1,
InputType::TextArray => 10,
}
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution
}
}
fn infer_dimensions(
client: &ureq::Agent,
options: &EmbedderOptions,
bearer: Option<&str>,
) -> Result<usize, NewEmbedderError> {
let v = embed(client, options, bearer, ["test"].as_slice(), 1)
.map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().dimension())
}
fn embed<S>(
client: &ureq::Agent,
options: &EmbedderOptions,
bearer: Option<&str>,
inputs: &[S],
expected_count: usize,
) -> Result<Vec<Embeddings<f32>>, EmbedError>
where
S: Serialize,
{
let request = client.post(&options.url);
let request =
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
let request = request.set("Content-Type", "application/json");
let input_value = match options.input_type {
InputType::Text => serde_json::json!(inputs.first()),
InputType::TextArray => serde_json::json!(inputs),
};
let body = match options.input_field.as_slice() {
[] => {
// inject input in body
input_value
}
[input] => {
let mut body = options.query.clone();
body.as_object_mut()
.ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.insert(input.clone(), input_value);
body
}
[path @ .., input] => {
let mut body = options.query.clone();
let mut current_value = &mut body;
for component in path {
current_value = current_value
.as_object_mut()
.ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.entry(component.clone())
.or_insert(serde_json::json!({}));
}
current_value.as_object_mut().unwrap().insert(input.clone(), input_value);
body
}
};
for attempt in 0..7 {
let response = request.clone().send_json(&body);
let result = check_response(response);
let retry_duration = match result {
Ok(response) => return response_to_embedding(response, options, expected_count),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
std::thread::sleep(retry_duration);
}
let response = request.send_json(&body);
let result = check_response(response);
result
.map_err(Retry::into_error)
.and_then(|response| response_to_embedding(response, options, expected_count))
}
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
match response {
Ok(response) => Ok(response),
Err(ureq::Error::Status(code, response)) => {
let error_response: Option<String> = response.into_string().ok();
Err(match code {
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)),
500..=599 => {
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
}
402..=499 => {
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
}
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
})
}
Err(ureq::Error::Transport(transport)) => {
Err(Retry::retry_later(EmbedError::rest_network(transport)))
}
}
}
fn response_to_embedding(
response: ureq::Response,
options: &EmbedderOptions,
expected_count: usize,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let response: serde_json::Value =
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
let mut current_value = &response;
for component in &options.path_to_embeddings {
let component = component.as_ref();
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.path_to_embeddings,
)
})?;
}
let embeddings = match options.input_type {
InputType::Text => {
for component in &options.embedding_object {
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.embedding_object,
)
})?;
}
let embeddings = current_value.to_owned();
let embeddings: Embedding =
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
vec![Embeddings::from_single_embedding(embeddings)]
}
InputType::TextArray => {
let empty = vec![];
let values = current_value.as_array().unwrap_or(&empty);
let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count);
for value in values {
let mut current_value = value;
for component in &options.embedding_object {
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.embedding_object,
)
})?;
}
let embedding = current_value.to_owned();
let embedding: Embedding =
serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?;
embeddings.push(Embeddings::from_single_embedding(embedding));
}
embeddings
}
};
if embeddings.len() != expected_count {
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
}
Ok(embeddings)
}

View File

@ -1,6 +1,7 @@
use deserr::Deserr; use deserr::Deserr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::rest::InputType;
use super::{ollama, openai}; use super::{ollama, openai};
use crate::prompt::PromptData; use crate::prompt::PromptData;
use crate::update::Setting; use crate::update::Setting;
@ -29,6 +30,24 @@ pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub document_template: Setting<String>, pub document_template: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub url: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub query: Setting<serde_json::Value>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub input_field: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub path_to_embeddings: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub embedding_object: Setting<Vec<String>>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub input_type: Setting<InputType>,
} }
pub fn check_unset<T>( pub fn check_unset<T>(
@ -75,20 +94,42 @@ impl EmbeddingSettings {
pub const DIMENSIONS: &'static str = "dimensions"; pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
pub const URL: &'static str = "url";
pub const QUERY: &'static str = "query";
pub const INPUT_FIELD: &'static str = "inputField";
pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings";
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
pub const INPUT_TYPE: &'static str = "inputType";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field { match field {
Self::SOURCE => { Self::SOURCE => &[
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] EmbedderSource::HuggingFace,
} EmbedderSource::OpenAi,
EmbedderSource::UserProvided,
EmbedderSource::Rest,
EmbedderSource::Ollama,
],
Self::MODEL => { Self::MODEL => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
} }
Self::REVISION => &[EmbedderSource::HuggingFace], Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi], Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest],
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], Self::DIMENSIONS => {
Self::DOCUMENT_TEMPLATE => { &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest]
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
} }
Self::DOCUMENT_TEMPLATE => &[
EmbedderSource::HuggingFace,
EmbedderSource::OpenAi,
EmbedderSource::Ollama,
EmbedderSource::Rest,
],
Self::URL => &[EmbedderSource::Rest],
Self::QUERY => &[EmbedderSource::Rest],
Self::INPUT_FIELD => &[EmbedderSource::Rest],
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
Self::INPUT_TYPE => &[EmbedderSource::Rest],
_other => unreachable!("unknown field"), _other => unreachable!("unknown field"),
} }
} }
@ -107,6 +148,18 @@ impl EmbeddingSettings {
} }
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
EmbedderSource::Rest => &[
Self::SOURCE,
Self::API_KEY,
Self::DIMENSIONS,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::QUERY,
Self::INPUT_FIELD,
Self::PATH_TO_EMBEDDINGS,
Self::EMBEDDING_OBJECT,
Self::INPUT_TYPE,
],
} }
} }
@ -141,6 +194,7 @@ pub enum EmbedderSource {
HuggingFace, HuggingFace,
Ollama, Ollama,
UserProvided, UserProvided,
Rest,
} }
impl std::fmt::Display for EmbedderSource { impl std::fmt::Display for EmbedderSource {
@ -150,6 +204,7 @@ impl std::fmt::Display for EmbedderSource {
EmbedderSource::HuggingFace => "huggingFace", EmbedderSource::HuggingFace => "huggingFace",
EmbedderSource::UserProvided => "userProvided", EmbedderSource::UserProvided => "userProvided",
EmbedderSource::Ollama => "ollama", EmbedderSource::Ollama => "ollama",
EmbedderSource::Rest => "rest",
}; };
f.write_str(s) f.write_str(s)
} }
@ -157,8 +212,20 @@ impl std::fmt::Display for EmbedderSource {
impl EmbeddingSettings { impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) { pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = let EmbeddingSettings {
new; source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = new;
let old_source = self.source; let old_source = self.source;
self.source.apply(source); self.source.apply(source);
// Reinitialize the whole setting object on a source change // Reinitialize the whole setting object on a source change
@ -170,6 +237,12 @@ impl EmbeddingSettings {
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}; };
return; return;
} }
@ -179,6 +252,13 @@ impl EmbeddingSettings {
self.api_key.apply(api_key); self.api_key.apply(api_key);
self.dimensions.apply(dimensions); self.dimensions.apply(dimensions);
self.document_template.apply(document_template); self.document_template.apply(document_template);
self.url.apply(url);
self.query.apply(query);
self.input_field.apply(input_field);
self.path_to_embeddings.apply(path_to_embeddings);
self.embedding_object.apply(embedding_object);
self.input_type.apply(input_type);
} }
} }
@ -193,6 +273,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::NotSet, dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
}, },
super::EmbedderOptions::OpenAi(options) => Self { super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi), source: Setting::Set(EmbedderSource::OpenAi),
@ -201,14 +287,26 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: options.api_key.map(Setting::Set).unwrap_or_default(), api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(), dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
}, },
super::EmbedderOptions::Ollama(options) => Self { super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama), source: Setting::Set(EmbedderSource::Ollama),
model: Setting::Set(options.embedding_model.name().to_owned()), model: Setting::Set(options.embedding_model.to_owned()),
revision: Setting::NotSet, revision: Setting::NotSet,
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::NotSet, dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
}, },
super::EmbedderOptions::UserProvided(options) => Self { super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided), source: Setting::Set(EmbedderSource::UserProvided),
@ -217,6 +315,37 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::Set(options.dimensions), dimensions: Setting::Set(options.dimensions),
document_template: Setting::NotSet, document_template: Setting::NotSet,
url: Setting::NotSet,
query: Setting::NotSet,
input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
},
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key,
// TODO: support distribution
distribution: _,
dimensions,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
}) => Self {
source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: api_key.map(Setting::Set).unwrap_or_default(),
dimensions: dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template),
url: Setting::Set(url),
query: Setting::Set(query),
input_field: Setting::Set(input_field),
path_to_embeddings: Setting::Set(path_to_embeddings),
embedding_object: Setting::Set(embedding_object),
input_type: Setting::Set(input_type),
}, },
} }
} }
@ -225,8 +354,20 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
impl From<EmbeddingSettings> for EmbeddingConfig { impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self { fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default(); let mut this = Self::default();
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = let EmbeddingSettings {
value; source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = value;
if let Some(source) = source.set() { if let Some(source) = source.set() {
match source { match source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
@ -248,7 +389,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
let mut options: ollama::EmbedderOptions = let mut options: ollama::EmbedderOptions =
super::ollama::EmbedderOptions::with_default_model(); super::ollama::EmbedderOptions::with_default_model();
if let Some(model) = model.set() { if let Some(model) = model.set() {
options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); options.embedding_model = model;
} }
this.embedder_options = super::EmbedderOptions::Ollama(options); this.embedder_options = super::EmbedderOptions::Ollama(options);
} }
@ -274,6 +415,26 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
dimensions: dimensions.set().unwrap(), dimensions: dimensions.set().unwrap(),
}); });
} }
EmbedderSource::Rest => {
let embedder_options = super::rest::EmbedderOptions::default();
this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(),
distribution: None,
dimensions: dimensions.set(),
url: url.set().unwrap(),
query: query.set().unwrap_or(embedder_options.query),
input_field: input_field.set().unwrap_or(embedder_options.input_field),
path_to_embeddings: path_to_embeddings
.set()
.unwrap_or(embedder_options.path_to_embeddings),
embedding_object: embedding_object
.set()
.unwrap_or(embedder_options.embedding_object),
input_type: input_type.set().unwrap_or(embedder_options.input_type),
})
}
} }
} }