pass dimensions only when defined

This commit is contained in:
Louis Dureuil 2024-02-07 11:03:00 +01:00
parent 517f5332d6
commit 74c180267e
No known key found for this signature in database

View File

@ -271,11 +271,7 @@ impl Embedder {
let request = OpenAiRequest { let request = OpenAiRequest {
model: self.options.embedding_model.name(), model: self.options.embedding_model.name(),
input: texts, input: texts,
dimension: if self.options.embedding_model.supports_overriding_dimensions() { dimensions: self.overriden_dimensions(),
self.options.dimensions.as_ref()
} else {
None
},
}; };
let response = client let response = client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
@ -360,8 +356,11 @@ impl Embedder {
tokens: &[usize], tokens: &[usize],
client: &reqwest::Client, client: &reqwest::Client,
) -> Result<Embedding, Retry> { ) -> Result<Embedding, Retry> {
let request = let request = OpenAiTokensRequest {
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client let response = client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
.json(&request) .json(&request)
@ -414,6 +413,14 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution() self.options.embedding_model.distribution()
} }
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
} }
// retrying in case of failure // retrying in case of failure
@ -473,13 +480,16 @@ impl Retry {
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str, model: &'a str,
input: &'a [S], input: &'a [S],
dimension: Option<&'a usize>, #[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> { struct OpenAiTokensRequest<'a> {
model: &'a str, model: &'a str,
input: &'a [usize], input: &'a [usize],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]