mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-22 18:17:39 +08:00
Expose REST embedder to the API
This commit is contained in:
parent
f87747f4d3
commit
a1db342f01
@ -2646,6 +2646,12 @@ mod tests {
|
|||||||
api_key: Setting::NotSet,
|
api_key: Setting::NotSet,
|
||||||
dimensions: Setting::Set(3),
|
dimensions: Setting::Set(3),
|
||||||
document_template: Setting::NotSet,
|
document_template: Setting::NotSet,
|
||||||
|
url: Setting::NotSet,
|
||||||
|
query: Setting::NotSet,
|
||||||
|
input_field: Setting::NotSet,
|
||||||
|
path_to_embeddings: Setting::NotSet,
|
||||||
|
embedding_object: Setting::NotSet,
|
||||||
|
input_type: Setting::NotSet,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
settings.set_embedder_settings(embedders);
|
settings.set_embedder_settings(embedders);
|
||||||
|
@ -1140,6 +1140,12 @@ fn validate_prompt(
|
|||||||
api_key,
|
api_key,
|
||||||
dimensions,
|
dimensions,
|
||||||
document_template: Setting::Set(template),
|
document_template: Setting::Set(template),
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
}) => {
|
}) => {
|
||||||
// validate
|
// validate
|
||||||
let template = crate::prompt::Prompt::new(template)
|
let template = crate::prompt::Prompt::new(template)
|
||||||
@ -1153,6 +1159,12 @@ fn validate_prompt(
|
|||||||
api_key,
|
api_key,
|
||||||
dimensions,
|
dimensions,
|
||||||
document_template: Setting::Set(template),
|
document_template: Setting::Set(template),
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
new => Ok(new),
|
new => Ok(new),
|
||||||
@ -1165,8 +1177,20 @@ pub fn validate_embedding_settings(
|
|||||||
) -> Result<Setting<EmbeddingSettings>> {
|
) -> Result<Setting<EmbeddingSettings>> {
|
||||||
let settings = validate_prompt(name, settings)?;
|
let settings = validate_prompt(name, settings)?;
|
||||||
let Setting::Set(settings) = settings else { return Ok(settings) };
|
let Setting::Set(settings) = settings else { return Ok(settings) };
|
||||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
let EmbeddingSettings {
|
||||||
settings;
|
source,
|
||||||
|
model,
|
||||||
|
revision,
|
||||||
|
api_key,
|
||||||
|
dimensions,
|
||||||
|
document_template,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
|
} = settings;
|
||||||
|
|
||||||
if let Some(0) = dimensions.set() {
|
if let Some(0) = dimensions.set() {
|
||||||
return Err(crate::error::UserError::InvalidSettingsDimensions {
|
return Err(crate::error::UserError::InvalidSettingsDimensions {
|
||||||
@ -1183,11 +1207,25 @@ pub fn validate_embedding_settings(
|
|||||||
api_key,
|
api_key,
|
||||||
dimensions,
|
dimensions,
|
||||||
document_template,
|
document_template,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
}));
|
}));
|
||||||
};
|
};
|
||||||
match inferred_source {
|
match inferred_source {
|
||||||
EmbedderSource::OpenAi => {
|
EmbedderSource::OpenAi => {
|
||||||
check_unset(&revision, "revision", inferred_source, name)?;
|
check_unset(&revision, "revision", inferred_source, name)?;
|
||||||
|
|
||||||
|
check_unset(&url, "url", inferred_source, name)?;
|
||||||
|
check_unset(&query, "query", inferred_source, name)?;
|
||||||
|
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||||
|
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||||
|
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||||
|
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||||
|
|
||||||
if let Setting::Set(model) = &model {
|
if let Setting::Set(model) = &model {
|
||||||
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
|
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
|
||||||
.ok_or(crate::error::UserError::InvalidOpenAiModel {
|
.ok_or(crate::error::UserError::InvalidOpenAiModel {
|
||||||
@ -1224,10 +1262,24 @@ pub fn validate_embedding_settings(
|
|||||||
check_set(&model, "model", inferred_source, name)?;
|
check_set(&model, "model", inferred_source, name)?;
|
||||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||||
check_unset(&revision, "revision", inferred_source, name)?;
|
check_unset(&revision, "revision", inferred_source, name)?;
|
||||||
|
|
||||||
|
check_unset(&url, "url", inferred_source, name)?;
|
||||||
|
check_unset(&query, "query", inferred_source, name)?;
|
||||||
|
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||||
|
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||||
|
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||||
|
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||||
}
|
}
|
||||||
EmbedderSource::HuggingFace => {
|
EmbedderSource::HuggingFace => {
|
||||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||||
|
|
||||||
|
check_unset(&url, "url", inferred_source, name)?;
|
||||||
|
check_unset(&query, "query", inferred_source, name)?;
|
||||||
|
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||||
|
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||||
|
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||||
|
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||||
}
|
}
|
||||||
EmbedderSource::UserProvided => {
|
EmbedderSource::UserProvided => {
|
||||||
check_unset(&model, "model", inferred_source, name)?;
|
check_unset(&model, "model", inferred_source, name)?;
|
||||||
@ -1235,6 +1287,18 @@ pub fn validate_embedding_settings(
|
|||||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||||
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
|
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
|
||||||
check_set(&dimensions, "dimensions", inferred_source, name)?;
|
check_set(&dimensions, "dimensions", inferred_source, name)?;
|
||||||
|
|
||||||
|
check_unset(&url, "url", inferred_source, name)?;
|
||||||
|
check_unset(&query, "query", inferred_source, name)?;
|
||||||
|
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||||
|
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||||
|
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||||
|
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||||
|
}
|
||||||
|
EmbedderSource::Rest => {
|
||||||
|
check_unset(&model, "model", inferred_source, name)?;
|
||||||
|
check_unset(&revision, "revision", inferred_source, name)?;
|
||||||
|
check_set(&url, "url", inferred_source, name)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Setting::Set(EmbeddingSettings {
|
Ok(Setting::Set(EmbeddingSettings {
|
||||||
@ -1244,6 +1308,12 @@ pub fn validate_embedding_settings(
|
|||||||
api_key,
|
api_key,
|
||||||
dimensions,
|
dimensions,
|
||||||
document_template,
|
document_template,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,7 +194,10 @@ impl Embedder {
|
|||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
if self.options.model == "BAAI/bge-base-en-v1.5" {
|
if self.options.model == "BAAI/bge-base-en-v1.5" {
|
||||||
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 })
|
Some(DistributionShift {
|
||||||
|
current_mean: ordered_float::OrderedFloat(0.85),
|
||||||
|
current_sigma: ordered_float::OrderedFloat(0.1),
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use self::error::{EmbedError, NewEmbedderError};
|
use self::error::{EmbedError, NewEmbedderError};
|
||||||
use crate::prompt::{Prompt, PromptData};
|
use crate::prompt::{Prompt, PromptData};
|
||||||
|
|
||||||
@ -104,7 +107,10 @@ pub enum Embedder {
|
|||||||
OpenAi(openai::Embedder),
|
OpenAi(openai::Embedder),
|
||||||
/// An embedder based on the user providing the embeddings in the documents and queries.
|
/// An embedder based on the user providing the embeddings in the documents and queries.
|
||||||
UserProvided(manual::Embedder),
|
UserProvided(manual::Embedder),
|
||||||
|
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
|
||||||
Ollama(ollama::Embedder),
|
Ollama(ollama::Embedder),
|
||||||
|
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
|
||||||
|
Rest(rest::Embedder),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Configuration for an embedder.
|
/// Configuration for an embedder.
|
||||||
@ -175,6 +181,7 @@ pub enum EmbedderOptions {
|
|||||||
OpenAi(openai::EmbedderOptions),
|
OpenAi(openai::EmbedderOptions),
|
||||||
Ollama(ollama::EmbedderOptions),
|
Ollama(ollama::EmbedderOptions),
|
||||||
UserProvided(manual::EmbedderOptions),
|
UserProvided(manual::EmbedderOptions),
|
||||||
|
Rest(rest::EmbedderOptions),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for EmbedderOptions {
|
impl Default for EmbedderOptions {
|
||||||
@ -209,6 +216,7 @@ impl Embedder {
|
|||||||
EmbedderOptions::UserProvided(options) => {
|
EmbedderOptions::UserProvided(options) => {
|
||||||
Self::UserProvided(manual::Embedder::new(options))
|
Self::UserProvided(manual::Embedder::new(options))
|
||||||
}
|
}
|
||||||
|
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,6 +232,7 @@ impl Embedder {
|
|||||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
||||||
Embedder::Ollama(embedder) => embedder.embed(texts),
|
Embedder::Ollama(embedder) => embedder.embed(texts),
|
||||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||||
|
Embedder::Rest(embedder) => embedder.embed(texts),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,6 +249,7 @@ impl Embedder {
|
|||||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||||
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
|
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
|
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,6 +260,7 @@ impl Embedder {
|
|||||||
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
||||||
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
|
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
|
||||||
Embedder::UserProvided(_) => 1,
|
Embedder::UserProvided(_) => 1,
|
||||||
|
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,6 +271,7 @@ impl Embedder {
|
|||||||
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||||
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
|
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||||
Embedder::UserProvided(_) => 1,
|
Embedder::UserProvided(_) => 1,
|
||||||
|
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,6 +282,7 @@ impl Embedder {
|
|||||||
Embedder::OpenAi(embedder) => embedder.dimensions(),
|
Embedder::OpenAi(embedder) => embedder.dimensions(),
|
||||||
Embedder::Ollama(embedder) => embedder.dimensions(),
|
Embedder::Ollama(embedder) => embedder.dimensions(),
|
||||||
Embedder::UserProvided(embedder) => embedder.dimensions(),
|
Embedder::UserProvided(embedder) => embedder.dimensions(),
|
||||||
|
Embedder::Rest(embedder) => embedder.dimensions(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,6 +293,7 @@ impl Embedder {
|
|||||||
Embedder::OpenAi(embedder) => embedder.distribution(),
|
Embedder::OpenAi(embedder) => embedder.distribution(),
|
||||||
Embedder::Ollama(embedder) => embedder.distribution(),
|
Embedder::Ollama(embedder) => embedder.distribution(),
|
||||||
Embedder::UserProvided(_embedder) => None,
|
Embedder::UserProvided(_embedder) => None,
|
||||||
|
Embedder::Rest(embedder) => embedder.distribution(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -288,17 +302,47 @@ impl Embedder {
|
|||||||
///
|
///
|
||||||
/// The intended use is to make the similarity score more comparable to the regular ranking score.
|
/// The intended use is to make the similarity score more comparable to the regular ranking score.
|
||||||
/// This allows to correct effects where results are too "packed" around a certain value.
|
/// This allows to correct effects where results are too "packed" around a certain value.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
|
||||||
|
#[serde(from = "DistributionShiftSerializable")]
|
||||||
|
#[serde(into = "DistributionShiftSerializable")]
|
||||||
pub struct DistributionShift {
|
pub struct DistributionShift {
|
||||||
/// Value where the results are "packed".
|
/// Value where the results are "packed".
|
||||||
///
|
///
|
||||||
/// Similarity scores are translated so that they are packed around 0.5 instead
|
/// Similarity scores are translated so that they are packed around 0.5 instead
|
||||||
pub current_mean: f32,
|
pub current_mean: OrderedFloat<f32>,
|
||||||
|
|
||||||
/// standard deviation of a similarity score.
|
/// standard deviation of a similarity score.
|
||||||
///
|
///
|
||||||
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
|
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
|
||||||
pub current_sigma: f32,
|
pub current_sigma: OrderedFloat<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct DistributionShiftSerializable {
|
||||||
|
current_mean: f32,
|
||||||
|
current_sigma: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DistributionShift> for DistributionShiftSerializable {
|
||||||
|
fn from(
|
||||||
|
DistributionShift {
|
||||||
|
current_mean: OrderedFloat(current_mean),
|
||||||
|
current_sigma: OrderedFloat(current_sigma),
|
||||||
|
}: DistributionShift,
|
||||||
|
) -> Self {
|
||||||
|
Self { current_mean, current_sigma }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DistributionShiftSerializable> for DistributionShift {
|
||||||
|
fn from(
|
||||||
|
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
current_mean: OrderedFloat(current_mean),
|
||||||
|
current_sigma: OrderedFloat(current_sigma),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DistributionShift {
|
impl DistributionShift {
|
||||||
@ -307,11 +351,13 @@ impl DistributionShift {
|
|||||||
if sigma <= 0.0 {
|
if sigma <= 0.0 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(Self { current_mean: mean, current_sigma: sigma })
|
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shift(&self, score: f32) -> f32 {
|
pub fn shift(&self, score: f32) -> f32 {
|
||||||
|
let current_mean = self.current_mean.0;
|
||||||
|
let current_sigma = self.current_sigma.0;
|
||||||
// <https://math.stackexchange.com/a/2894689>
|
// <https://math.stackexchange.com/a/2894689>
|
||||||
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
||||||
// The parameters we're given is the mean and sigma of the native result distribution.
|
// The parameters we're given is the mean and sigma of the native result distribution.
|
||||||
@ -321,9 +367,9 @@ impl DistributionShift {
|
|||||||
let target_sigma = 0.4;
|
let target_sigma = 0.4;
|
||||||
|
|
||||||
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
||||||
let factor = target_sigma / self.current_sigma;
|
let factor = target_sigma / current_sigma;
|
||||||
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
||||||
let offset = target_mean - (factor * self.current_mean);
|
let offset = target_mean - (factor * current_mean);
|
||||||
|
|
||||||
let mut score = factor * score + offset;
|
let mut score = factor * score + offset;
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use ordered_float::OrderedFloat;
|
||||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||||
|
|
||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
@ -110,15 +111,18 @@ impl EmbeddingModel {
|
|||||||
|
|
||||||
fn distribution(&self) -> Option<DistributionShift> {
|
fn distribution(&self) -> Option<DistributionShift> {
|
||||||
match self {
|
match self {
|
||||||
EmbeddingModel::TextEmbeddingAda002 => {
|
EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
|
||||||
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
|
current_mean: OrderedFloat(0.90),
|
||||||
}
|
current_sigma: OrderedFloat(0.08),
|
||||||
EmbeddingModel::TextEmbedding3Large => {
|
}),
|
||||||
Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 })
|
EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
|
||||||
}
|
current_mean: OrderedFloat(0.70),
|
||||||
EmbeddingModel::TextEmbedding3Small => {
|
current_sigma: OrderedFloat(0.1),
|
||||||
Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 })
|
}),
|
||||||
}
|
EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
|
||||||
|
current_mean: OrderedFloat(0.75),
|
||||||
|
current_sigma: OrderedFloat(0.1),
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
use deserr::Deserr;
|
||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
use serde::Serialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||||
@ -64,7 +65,7 @@ pub struct Embedder {
|
|||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
pub distribution: Option<DistributionShift>,
|
pub distribution: Option<DistributionShift>,
|
||||||
@ -79,7 +80,41 @@ pub struct EmbedderOptions {
|
|||||||
pub input_type: InputType,
|
pub input_type: InputType,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
impl Default for EmbedderOptions {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
url: Default::default(),
|
||||||
|
query: Default::default(),
|
||||||
|
input_field: vec!["input".into()],
|
||||||
|
path_to_embeddings: vec!["data".into()],
|
||||||
|
embedding_object: vec!["embedding".into()],
|
||||||
|
input_type: InputType::Text,
|
||||||
|
api_key: None,
|
||||||
|
distribution: None,
|
||||||
|
dimensions: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::hash::Hash for EmbedderOptions {
|
||||||
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
|
self.api_key.hash(state);
|
||||||
|
self.distribution.hash(state);
|
||||||
|
self.dimensions.hash(state);
|
||||||
|
self.url.hash(state);
|
||||||
|
// skip hashing the query
|
||||||
|
// collisions in regular usage should be minimal,
|
||||||
|
// and the list is limited to 256 values anyway
|
||||||
|
self.input_field.hash(state);
|
||||||
|
self.path_to_embeddings.hash(state);
|
||||||
|
self.embedding_object.hash(state);
|
||||||
|
self.input_type.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||||
pub enum InputType {
|
pub enum InputType {
|
||||||
Text,
|
Text,
|
||||||
TextArray,
|
TextArray,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::rest::InputType;
|
||||||
use super::{ollama, openai};
|
use super::{ollama, openai};
|
||||||
use crate::prompt::PromptData;
|
use crate::prompt::PromptData;
|
||||||
use crate::update::Setting;
|
use crate::update::Setting;
|
||||||
@ -29,6 +30,24 @@ pub struct EmbeddingSettings {
|
|||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
#[deserr(default)]
|
#[deserr(default)]
|
||||||
pub document_template: Setting<String>,
|
pub document_template: Setting<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
pub url: Setting<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
pub query: Setting<serde_json::Value>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
pub input_field: Setting<Vec<String>>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
pub path_to_embeddings: Setting<Vec<String>>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
pub embedding_object: Setting<Vec<String>>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
|
pub input_type: Setting<InputType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_unset<T>(
|
pub fn check_unset<T>(
|
||||||
@ -75,20 +94,42 @@ impl EmbeddingSettings {
|
|||||||
pub const DIMENSIONS: &'static str = "dimensions";
|
pub const DIMENSIONS: &'static str = "dimensions";
|
||||||
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
||||||
|
|
||||||
|
pub const URL: &'static str = "url";
|
||||||
|
pub const QUERY: &'static str = "query";
|
||||||
|
pub const INPUT_FIELD: &'static str = "inputField";
|
||||||
|
pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings";
|
||||||
|
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
|
||||||
|
pub const INPUT_TYPE: &'static str = "inputType";
|
||||||
|
|
||||||
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
|
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
|
||||||
match field {
|
match field {
|
||||||
Self::SOURCE => {
|
Self::SOURCE => &[
|
||||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
|
EmbedderSource::HuggingFace,
|
||||||
}
|
EmbedderSource::OpenAi,
|
||||||
|
EmbedderSource::UserProvided,
|
||||||
|
EmbedderSource::Rest,
|
||||||
|
EmbedderSource::Ollama,
|
||||||
|
],
|
||||||
Self::MODEL => {
|
Self::MODEL => {
|
||||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
||||||
}
|
}
|
||||||
Self::REVISION => &[EmbedderSource::HuggingFace],
|
Self::REVISION => &[EmbedderSource::HuggingFace],
|
||||||
Self::API_KEY => &[EmbedderSource::OpenAi],
|
Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest],
|
||||||
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
|
Self::DIMENSIONS => {
|
||||||
Self::DOCUMENT_TEMPLATE => {
|
&[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest]
|
||||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
|
||||||
}
|
}
|
||||||
|
Self::DOCUMENT_TEMPLATE => &[
|
||||||
|
EmbedderSource::HuggingFace,
|
||||||
|
EmbedderSource::OpenAi,
|
||||||
|
EmbedderSource::Ollama,
|
||||||
|
EmbedderSource::Rest,
|
||||||
|
],
|
||||||
|
Self::URL => &[EmbedderSource::Rest],
|
||||||
|
Self::QUERY => &[EmbedderSource::Rest],
|
||||||
|
Self::INPUT_FIELD => &[EmbedderSource::Rest],
|
||||||
|
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
|
||||||
|
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
|
||||||
|
Self::INPUT_TYPE => &[EmbedderSource::Rest],
|
||||||
_other => unreachable!("unknown field"),
|
_other => unreachable!("unknown field"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,6 +148,18 @@ impl EmbeddingSettings {
|
|||||||
}
|
}
|
||||||
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE],
|
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE],
|
||||||
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
|
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
|
||||||
|
EmbedderSource::Rest => &[
|
||||||
|
Self::SOURCE,
|
||||||
|
Self::API_KEY,
|
||||||
|
Self::DIMENSIONS,
|
||||||
|
Self::DOCUMENT_TEMPLATE,
|
||||||
|
Self::URL,
|
||||||
|
Self::QUERY,
|
||||||
|
Self::INPUT_FIELD,
|
||||||
|
Self::PATH_TO_EMBEDDINGS,
|
||||||
|
Self::EMBEDDING_OBJECT,
|
||||||
|
Self::INPUT_TYPE,
|
||||||
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,6 +194,7 @@ pub enum EmbedderSource {
|
|||||||
HuggingFace,
|
HuggingFace,
|
||||||
Ollama,
|
Ollama,
|
||||||
UserProvided,
|
UserProvided,
|
||||||
|
Rest,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for EmbedderSource {
|
impl std::fmt::Display for EmbedderSource {
|
||||||
@ -150,6 +204,7 @@ impl std::fmt::Display for EmbedderSource {
|
|||||||
EmbedderSource::HuggingFace => "huggingFace",
|
EmbedderSource::HuggingFace => "huggingFace",
|
||||||
EmbedderSource::UserProvided => "userProvided",
|
EmbedderSource::UserProvided => "userProvided",
|
||||||
EmbedderSource::Ollama => "ollama",
|
EmbedderSource::Ollama => "ollama",
|
||||||
|
EmbedderSource::Rest => "rest",
|
||||||
};
|
};
|
||||||
f.write_str(s)
|
f.write_str(s)
|
||||||
}
|
}
|
||||||
@ -157,8 +212,20 @@ impl std::fmt::Display for EmbedderSource {
|
|||||||
|
|
||||||
impl EmbeddingSettings {
|
impl EmbeddingSettings {
|
||||||
pub fn apply(&mut self, new: Self) {
|
pub fn apply(&mut self, new: Self) {
|
||||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
let EmbeddingSettings {
|
||||||
new;
|
source,
|
||||||
|
model,
|
||||||
|
revision,
|
||||||
|
api_key,
|
||||||
|
dimensions,
|
||||||
|
document_template,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
|
} = new;
|
||||||
let old_source = self.source;
|
let old_source = self.source;
|
||||||
self.source.apply(source);
|
self.source.apply(source);
|
||||||
// Reinitialize the whole setting object on a source change
|
// Reinitialize the whole setting object on a source change
|
||||||
@ -170,6 +237,12 @@ impl EmbeddingSettings {
|
|||||||
api_key,
|
api_key,
|
||||||
dimensions,
|
dimensions,
|
||||||
document_template,
|
document_template,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
};
|
};
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -179,6 +252,13 @@ impl EmbeddingSettings {
|
|||||||
self.api_key.apply(api_key);
|
self.api_key.apply(api_key);
|
||||||
self.dimensions.apply(dimensions);
|
self.dimensions.apply(dimensions);
|
||||||
self.document_template.apply(document_template);
|
self.document_template.apply(document_template);
|
||||||
|
|
||||||
|
self.url.apply(url);
|
||||||
|
self.query.apply(query);
|
||||||
|
self.input_field.apply(input_field);
|
||||||
|
self.path_to_embeddings.apply(path_to_embeddings);
|
||||||
|
self.embedding_object.apply(embedding_object);
|
||||||
|
self.input_type.apply(input_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,6 +273,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
api_key: Setting::NotSet,
|
api_key: Setting::NotSet,
|
||||||
dimensions: Setting::NotSet,
|
dimensions: Setting::NotSet,
|
||||||
document_template: Setting::Set(prompt.template),
|
document_template: Setting::Set(prompt.template),
|
||||||
|
url: Setting::NotSet,
|
||||||
|
query: Setting::NotSet,
|
||||||
|
input_field: Setting::NotSet,
|
||||||
|
path_to_embeddings: Setting::NotSet,
|
||||||
|
embedding_object: Setting::NotSet,
|
||||||
|
input_type: Setting::NotSet,
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::OpenAi(options) => Self {
|
super::EmbedderOptions::OpenAi(options) => Self {
|
||||||
source: Setting::Set(EmbedderSource::OpenAi),
|
source: Setting::Set(EmbedderSource::OpenAi),
|
||||||
@ -201,6 +287,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
|
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
|
||||||
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
|
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
|
||||||
document_template: Setting::Set(prompt.template),
|
document_template: Setting::Set(prompt.template),
|
||||||
|
url: Setting::NotSet,
|
||||||
|
query: Setting::NotSet,
|
||||||
|
input_field: Setting::NotSet,
|
||||||
|
path_to_embeddings: Setting::NotSet,
|
||||||
|
embedding_object: Setting::NotSet,
|
||||||
|
input_type: Setting::NotSet,
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::Ollama(options) => Self {
|
super::EmbedderOptions::Ollama(options) => Self {
|
||||||
source: Setting::Set(EmbedderSource::Ollama),
|
source: Setting::Set(EmbedderSource::Ollama),
|
||||||
@ -209,6 +301,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
api_key: Setting::NotSet,
|
api_key: Setting::NotSet,
|
||||||
dimensions: Setting::NotSet,
|
dimensions: Setting::NotSet,
|
||||||
document_template: Setting::Set(prompt.template),
|
document_template: Setting::Set(prompt.template),
|
||||||
|
url: Setting::NotSet,
|
||||||
|
query: Setting::NotSet,
|
||||||
|
input_field: Setting::NotSet,
|
||||||
|
path_to_embeddings: Setting::NotSet,
|
||||||
|
embedding_object: Setting::NotSet,
|
||||||
|
input_type: Setting::NotSet,
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::UserProvided(options) => Self {
|
super::EmbedderOptions::UserProvided(options) => Self {
|
||||||
source: Setting::Set(EmbedderSource::UserProvided),
|
source: Setting::Set(EmbedderSource::UserProvided),
|
||||||
@ -217,6 +315,37 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
api_key: Setting::NotSet,
|
api_key: Setting::NotSet,
|
||||||
dimensions: Setting::Set(options.dimensions),
|
dimensions: Setting::Set(options.dimensions),
|
||||||
document_template: Setting::NotSet,
|
document_template: Setting::NotSet,
|
||||||
|
url: Setting::NotSet,
|
||||||
|
query: Setting::NotSet,
|
||||||
|
input_field: Setting::NotSet,
|
||||||
|
path_to_embeddings: Setting::NotSet,
|
||||||
|
embedding_object: Setting::NotSet,
|
||||||
|
input_type: Setting::NotSet,
|
||||||
|
},
|
||||||
|
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||||
|
api_key,
|
||||||
|
// TODO: support distribution
|
||||||
|
distribution: _,
|
||||||
|
dimensions,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
|
}) => Self {
|
||||||
|
source: Setting::Set(EmbedderSource::Rest),
|
||||||
|
model: Setting::NotSet,
|
||||||
|
revision: Setting::NotSet,
|
||||||
|
api_key: api_key.map(Setting::Set).unwrap_or_default(),
|
||||||
|
dimensions: dimensions.map(Setting::Set).unwrap_or_default(),
|
||||||
|
document_template: Setting::Set(prompt.template),
|
||||||
|
url: Setting::Set(url),
|
||||||
|
query: Setting::Set(query),
|
||||||
|
input_field: Setting::Set(input_field),
|
||||||
|
path_to_embeddings: Setting::Set(path_to_embeddings),
|
||||||
|
embedding_object: Setting::Set(embedding_object),
|
||||||
|
input_type: Setting::Set(input_type),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -225,8 +354,20 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
impl From<EmbeddingSettings> for EmbeddingConfig {
|
impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||||
fn from(value: EmbeddingSettings) -> Self {
|
fn from(value: EmbeddingSettings) -> Self {
|
||||||
let mut this = Self::default();
|
let mut this = Self::default();
|
||||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
let EmbeddingSettings {
|
||||||
value;
|
source,
|
||||||
|
model,
|
||||||
|
revision,
|
||||||
|
api_key,
|
||||||
|
dimensions,
|
||||||
|
document_template,
|
||||||
|
url,
|
||||||
|
query,
|
||||||
|
input_field,
|
||||||
|
path_to_embeddings,
|
||||||
|
embedding_object,
|
||||||
|
input_type,
|
||||||
|
} = value;
|
||||||
if let Some(source) = source.set() {
|
if let Some(source) = source.set() {
|
||||||
match source {
|
match source {
|
||||||
EmbedderSource::OpenAi => {
|
EmbedderSource::OpenAi => {
|
||||||
@ -274,6 +415,26 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
|||||||
dimensions: dimensions.set().unwrap(),
|
dimensions: dimensions.set().unwrap(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
EmbedderSource::Rest => {
|
||||||
|
let embedder_options = super::rest::EmbedderOptions::default();
|
||||||
|
|
||||||
|
this.embedder_options =
|
||||||
|
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||||
|
api_key: api_key.set(),
|
||||||
|
distribution: None,
|
||||||
|
dimensions: dimensions.set(),
|
||||||
|
url: url.set().unwrap(),
|
||||||
|
query: query.set().unwrap_or(embedder_options.query),
|
||||||
|
input_field: input_field.set().unwrap_or(embedder_options.input_field),
|
||||||
|
path_to_embeddings: path_to_embeddings
|
||||||
|
.set()
|
||||||
|
.unwrap_or(embedder_options.path_to_embeddings),
|
||||||
|
embedding_object: embedding_object
|
||||||
|
.set()
|
||||||
|
.unwrap_or(embedder_options.embedding_object),
|
||||||
|
input_type: input_type.set().unwrap_or(embedder_options.input_type),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user