OpenAI sync

This commit is contained in:
Louis Dureuil 2024-03-14 11:14:31 +01:00
parent bc58e8a310
commit c3d02f092d
No known key found for this signature in database
6 changed files with 274 additions and 321 deletions

1
Cargo.lock generated
View File

@ -3378,6 +3378,7 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tracing", "tracing",
"ureq",
"uuid", "uuid",
] ]

View File

@ -91,6 +91,7 @@ liquid = "0.26.4"
arroy = "0.2.0" arroy = "0.2.0"
rand = "0.8.5" rand = "0.8.5"
tracing = "0.1.40" tracing = "0.1.40"
ureq = { version = "2.9.6", features = ["json"] }
[dev-dependencies] [dev-dependencies]
mimalloc = { version = "0.1.39", default-features = false } mimalloc = { version = "0.1.39", default-features = false }

View File

@ -53,17 +53,17 @@ pub enum EmbedErrorKind {
#[error("could not run model: {0}")] #[error("could not run model: {0}")]
ModelForward(candle_core::Error), ModelForward(candle_core::Error),
#[error("could not reach OpenAI: {0}")] #[error("could not reach OpenAI: {0}")]
OpenAiNetwork(reqwest::Error), OpenAiNetwork(ureq::Transport),
#[error("unexpected response from OpenAI: {0}")] #[error("unexpected response from OpenAI: {0}")]
OpenAiUnexpected(reqwest::Error), OpenAiUnexpected(ureq::Error),
#[error("could not authenticate against OpenAI: {0}")] #[error("could not authenticate against OpenAI: {0:?}")]
OpenAiAuth(OpenAiError), OpenAiAuth(Option<OpenAiError>),
#[error("sent too many requests to OpenAI: {0}")] #[error("sent too many requests to OpenAI: {0:?}")]
OpenAiTooManyRequests(OpenAiError), OpenAiTooManyRequests(Option<OpenAiError>),
#[error("received internal error from OpenAI: {0:?}")] #[error("received internal error from OpenAI: {0:?}")]
OpenAiInternalServerError(Option<OpenAiError>), OpenAiInternalServerError(Option<OpenAiError>),
#[error("sent too many tokens in a request to OpenAI: {0}")] #[error("sent too many tokens in a request to OpenAI: {0:?}")]
OpenAiTooManyTokens(OpenAiError), OpenAiTooManyTokens(Option<OpenAiError>),
#[error("received unhandled HTTP status code {0} from OpenAI")] #[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16), OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
@ -102,19 +102,19 @@ impl EmbedError {
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
} }
pub fn openai_network(inner: reqwest::Error) -> Self { pub fn openai_network(inner: ureq::Transport) -> Self {
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
} }
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { pub fn openai_unexpected(inner: ureq::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
} }
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { pub(crate) fn openai_auth_error(inner: Option<OpenAiError>) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
} }
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { pub(crate) fn openai_too_many_requests(inner: Option<OpenAiError>) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
} }
@ -122,7 +122,7 @@ impl EmbedError {
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
} }
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { pub(crate) fn openai_too_many_tokens(inner: Option<OpenAiError>) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
} }
@ -220,7 +220,7 @@ impl NewEmbedderError {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
} }
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self { Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime, fault: FaultSource::Runtime,

View File

@ -131,7 +131,7 @@ impl Embedder {
let embeddings = this let embeddings = this
.embed(vec!["test".into()]) .embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension(); this.dimensions = embeddings.first().unwrap().dimension();
Ok(this) Ok(this)

View File

@ -98,7 +98,7 @@ pub enum Embedder {
/// An embedder based on running local models, fetched from the Hugging Face Hub. /// An embedder based on running local models, fetched from the Hugging Face Hub.
HuggingFace(hf::Embedder), HuggingFace(hf::Embedder),
/// An embedder based on making embedding queries against the OpenAI API. /// An embedder based on making embedding queries against the OpenAI API.
OpenAi(openai::Embedder), OpenAi(openai::sync::Embedder),
/// An embedder based on the user providing the embeddings in the documents and queries. /// An embedder based on the user providing the embeddings in the documents and queries.
UserProvided(manual::Embedder), UserProvided(manual::Embedder),
Ollama(ollama::Embedder), Ollama(ollama::Embedder),
@ -201,7 +201,7 @@ impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options { Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?), EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => { EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options)) Self::UserProvided(manual::Embedder::new(options))
@ -218,10 +218,7 @@ impl Embedder {
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => { Embedder::OpenAi(embedder) => embedder.embed(texts),
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::Ollama(embedder) => { Embedder::Ollama(embedder) => {
let client = embedder.new_client()?; let client = embedder.new_client()?;
embedder.embed(texts, &client).await embedder.embed(texts, &client).await

View File

@ -1,18 +1,10 @@
use std::fmt::Display; use std::fmt::Display;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::error::{EmbedError, NewEmbedderError}; use super::error::{EmbedError, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings}; use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)]
pub struct Embedder {
headers: reqwest::header::HeaderMap,
tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub api_key: Option<String>, pub api_key: Option<String>,
@ -125,298 +117,6 @@ impl EmbedderOptions {
} }
} }
impl Embedder {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
reqwest::ClientBuilder::new()
.default_headers(self.headers.clone())
.build()
.map_err(EmbedError::openai_initialize_web_client)
}
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
);
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, headers, tokenizer })
}
pub async fn embed(
&self,
texts: Vec<String>,
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false;
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts, client).await
};
let retry_duration = match result {
Ok(embeddings) => return Ok(embeddings),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
tokenized |= retry.must_tokenize();
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
tracing::warn!(
"Attempt #{}, retrying after {}ms.",
attempt,
retry_duration.as_millis()
);
tokio::time::sleep(retry_duration).await;
}
let result = if tokenized {
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts, client).await
};
result.map_err(Retry::into_error)
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
match response.status() {
StatusCode::UNAUTHORIZED => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::give_up(EmbedError::openai_auth_error(
error_response.error,
)));
}
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
error_response.error,
)));
}
StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE => {
let error_response: Result<OpenAiErrorResponse, _> = response.json().await;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.ok().map(|error_response| error_response.error),
)));
}
StatusCode::BAD_REQUEST => {
// Most probably, one text contained too many tokens
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
error_response.error,
)));
}
code => {
return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code(
code.as_u16(),
)));
}
}
}
Ok(response)
}
async fn try_embed<S: AsRef<str> + serde::Serialize>(
&self,
texts: &[S],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
tracing::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::trace!("response: {:?}", response.data);
Ok(response
.data
.into_iter()
.map(|data| Embeddings::from_single_embedding(data.embedding))
.collect())
}
async fn try_embed_tokenized(
&self,
text: &[String],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
continue;
}
let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
}
// end of text
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
async fn embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens, client).await {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
.map_err(Retry::retry_later)?;
tokio::time::sleep(duration).await;
}
self.try_embed_tokens(tokens, client)
.await
.map_err(|retry| Retry::give_up(retry.into_error()))
}
async fn try_embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
let request = OpenAiTokensRequest {
model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let mut response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
}
pub fn chunk_count_hint(&self) -> usize {
10
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
}
// retrying in case of failure // retrying in case of failure
pub struct Retry { pub struct Retry {
@ -524,3 +224,257 @@ fn infer_api_key() -> String {
.or_else(|_| std::env::var("OPENAI_API_KEY")) .or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default() .unwrap_or_default()
} }
pub mod sync {
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use super::{
EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest,
OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL,
};
use crate::vector::DistributionShift;
const REQUEST_PARALLELISM: usize = 10;
#[derive(Debug)]
pub struct Embedder {
tokenizer: tiktoken_rs::CoreBPE,
options: super::EmbedderOptions,
bearer: String,
threads: rayon::ThreadPool,
}
impl Embedder {
pub fn new(options: super::EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = super::infer_api_key();
&inferred_api_key
});
let bearer = format!("Bearer {api_key}");
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
// FIXME: unwrap
let threads = rayon::ThreadPoolBuilder::new()
.num_threads(REQUEST_PARALLELISM)
.thread_name(|index| format!("embedder-chunk-{index}"))
.build()
.unwrap();
Ok(Self { options, bearer, tokenizer, threads })
}
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false;
let client = ureq::agent();
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts, &client)
} else {
self.try_embed(&texts, &client)
};
let retry_duration = match result {
Ok(embeddings) => return Ok(embeddings),
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
tokenized |= retry.must_tokenize();
retry.into_duration(attempt)
}
}?;
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
tracing::warn!(
"Attempt #{}, retrying after {}ms.",
attempt,
retry_duration.as_millis()
);
std::thread::sleep(retry_duration);
}
let result = if tokenized {
self.try_embed_tokenized(&texts, &client)
} else {
self.try_embed(&texts, &client)
};
result.map_err(Retry::into_error)
}
fn check_response(
response: Result<ureq::Response, ureq::Error>,
) -> Result<ureq::Response, Retry> {
match response {
Ok(response) => Ok(response),
Err(ureq::Error::Status(code, response)) => {
let error_response: Option<OpenAiErrorResponse> = response.into_json().ok();
let error = error_response.map(|response| response.error);
Err(match code {
401 => Retry::give_up(EmbedError::openai_auth_error(error)),
429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)),
400 => {
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error))
}
500..=599 => {
Retry::retry_later(EmbedError::openai_internal_server_error(error))
}
x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)),
})
}
Err(ureq::Error::Transport(transport)) => {
Err(Retry::retry_later(EmbedError::openai_network(transport)))
}
}
}
fn try_embed<S: AsRef<str> + serde::Serialize>(
&self,
texts: &[S],
client: &ureq::Agent,
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
tracing::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.set("Authorization", &self.bearer)
.send_json(&request);
let response = Self::check_response(response)?;
let response: OpenAiResponse = response
.into_json()
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::trace!("response: {:?}", response.data);
Ok(response
.data
.into_iter()
.map(|data| Embeddings::from_single_embedding(data.embedding))
.collect())
}
fn try_embed_tokenized(
&self,
text: &[String],
client: &ureq::Agent,
) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text], client)?);
continue;
}
let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap();
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
}
// end of text
embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap();
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens, client) {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
.map_err(Retry::retry_later)?;
std::thread::sleep(duration);
}
self.try_embed_tokens(tokens, client)
.map_err(|retry| Retry::give_up(retry.into_error()))
}
fn try_embed_tokens(
&self,
tokens: &[usize],
client: &ureq::Agent,
) -> Result<Embedding, Retry> {
let request = OpenAiTokensRequest {
model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.set("Authorization", &self.bearer)
.send_json(&request);
let response = Self::check_response(response)?;
let mut response: OpenAiResponse = response
.into_json()
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
self.threads
.install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk)))
.collect()
}
pub fn chunk_count_hint(&self) -> usize {
10
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
}
}