2023-11-15 15:46:37 +01:00
use std ::fmt ::Display ;
use reqwest ::StatusCode ;
use serde ::{ Deserialize , Serialize } ;
use super ::error ::{ EmbedError , NewEmbedderError } ;
2023-12-14 16:01:35 +01:00
use super ::{ DistributionShift , Embedding , Embeddings } ;
2023-11-15 15:46:37 +01:00
#[ derive(Debug) ]
pub struct Embedder {
2024-01-29 11:23:18 +01:00
headers : reqwest ::header ::HeaderMap ,
2023-11-15 15:46:37 +01:00
tokenizer : tiktoken_rs ::CoreBPE ,
options : EmbedderOptions ,
}
#[ derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize) ]
pub struct EmbedderOptions {
2023-12-12 21:19:48 +01:00
pub api_key : Option < String > ,
2023-11-15 15:46:37 +01:00
pub embedding_model : EmbeddingModel ,
2024-01-30 16:32:57 +01:00
pub dimensions : Option < usize > ,
2023-11-15 15:46:37 +01:00
}
#[ derive(
Debug ,
Clone ,
Copy ,
Default ,
Hash ,
PartialEq ,
Eq ,
serde ::Serialize ,
serde ::Deserialize ,
deserr ::Deserr ,
) ]
#[ serde(deny_unknown_fields, rename_all = " camelCase " ) ]
#[ deserr(rename_all = camelCase, deny_unknown_fields) ]
pub enum EmbeddingModel {
2023-12-20 17:08:32 +01:00
// # WARNING
//
// If ever adding a model, make sure to add it to the list of supported models below.
2023-11-15 15:46:37 +01:00
#[ default ]
2023-12-13 23:25:38 +01:00
#[ serde(rename = " text-embedding-ada-002 " ) ]
#[ deserr(rename = " text-embedding-ada-002 " ) ]
2023-11-15 15:46:37 +01:00
TextEmbeddingAda002 ,
2024-01-30 16:32:57 +01:00
#[ serde(rename = " text-embedding-3-small " ) ]
#[ deserr(rename = " text-embedding-3-small " ) ]
TextEmbedding3Small ,
#[ serde(rename = " text-embedding-3-large " ) ]
#[ deserr(rename = " text-embedding-3-large " ) ]
TextEmbedding3Large ,
2023-11-15 15:46:37 +01:00
}
impl EmbeddingModel {
2023-12-20 17:08:32 +01:00
pub fn supported_models ( ) -> & 'static [ & 'static str ] {
2024-01-30 16:32:57 +01:00
& [ " text-embedding-ada-002 " , " text-embedding-3-small " , " text-embedding-3-large " ]
2023-12-20 17:08:32 +01:00
}
2023-11-15 15:46:37 +01:00
pub fn max_token ( & self ) -> usize {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > 8191 ,
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > 8191 ,
EmbeddingModel ::TextEmbedding3Small = > 8191 ,
2023-11-15 15:46:37 +01:00
}
}
2024-02-07 10:36:30 +01:00
pub fn default_dimensions ( & self ) -> usize {
2023-11-15 15:46:37 +01:00
match self {
EmbeddingModel ::TextEmbeddingAda002 = > 1536 ,
2024-02-07 11:48:19 +01:00
EmbeddingModel ::TextEmbedding3Large = > 3072 ,
EmbeddingModel ::TextEmbedding3Small = > 1536 ,
2023-11-15 15:46:37 +01:00
}
}
pub fn name ( & self ) -> & 'static str {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > " text-embedding-ada-002 " ,
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > " text-embedding-3-large " ,
EmbeddingModel ::TextEmbedding3Small = > " text-embedding-3-small " ,
2023-11-15 15:46:37 +01:00
}
}
2023-12-20 17:08:32 +01:00
pub fn from_name ( name : & str ) -> Option < Self > {
2023-11-15 15:46:37 +01:00
match name {
" text-embedding-ada-002 " = > Some ( EmbeddingModel ::TextEmbeddingAda002 ) ,
2024-01-30 16:32:57 +01:00
" text-embedding-3-large " = > Some ( EmbeddingModel ::TextEmbedding3Large ) ,
" text-embedding-3-small " = > Some ( EmbeddingModel ::TextEmbedding3Small ) ,
2023-11-15 15:46:37 +01:00
_ = > None ,
}
}
2023-12-14 16:01:35 +01:00
fn distribution ( & self ) -> Option < DistributionShift > {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > {
Some ( DistributionShift { current_mean : 0.90 , current_sigma : 0.08 } )
}
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > {
2024-02-07 14:22:13 +01:00
Some ( DistributionShift { current_mean : 0.70 , current_sigma : 0.1 } )
2024-01-30 16:32:57 +01:00
}
EmbeddingModel ::TextEmbedding3Small = > {
2024-02-07 14:22:13 +01:00
Some ( DistributionShift { current_mean : 0.75 , current_sigma : 0.1 } )
2024-01-30 16:32:57 +01:00
}
}
}
2024-02-07 10:36:30 +01:00
pub fn supports_overriding_dimensions ( & self ) -> bool {
2024-01-30 16:32:57 +01:00
match self {
EmbeddingModel ::TextEmbeddingAda002 = > false ,
EmbeddingModel ::TextEmbedding3Large = > true ,
EmbeddingModel ::TextEmbedding3Small = > true ,
2023-12-14 16:01:35 +01:00
}
}
2023-11-15 15:46:37 +01:00
}
pub const OPENAI_EMBEDDINGS_URL : & str = " https://api.openai.com/v1/embeddings " ;
impl EmbedderOptions {
2023-12-12 21:19:48 +01:00
pub fn with_default_model ( api_key : Option < String > ) -> Self {
2024-01-30 16:32:57 +01:00
Self { api_key , embedding_model : Default ::default ( ) , dimensions : None }
2023-11-15 15:46:37 +01:00
}
2023-12-12 21:19:48 +01:00
pub fn with_embedding_model ( api_key : Option < String > , embedding_model : EmbeddingModel ) -> Self {
2024-01-30 16:32:57 +01:00
Self { api_key , embedding_model , dimensions : None }
2023-11-15 15:46:37 +01:00
}
}
impl Embedder {
2024-01-29 11:23:18 +01:00
pub fn new_client ( & self ) -> Result < reqwest ::Client , EmbedError > {
reqwest ::ClientBuilder ::new ( )
. default_headers ( self . headers . clone ( ) )
. build ( )
. map_err ( EmbedError ::openai_initialize_web_client )
}
2023-11-15 15:46:37 +01:00
pub fn new ( options : EmbedderOptions ) -> Result < Self , NewEmbedderError > {
let mut headers = reqwest ::header ::HeaderMap ::new ( ) ;
2023-12-12 21:19:48 +01:00
let mut inferred_api_key = Default ::default ( ) ;
let api_key = options . api_key . as_ref ( ) . unwrap_or_else ( | | {
inferred_api_key = infer_api_key ( ) ;
& inferred_api_key
} ) ;
2023-11-15 15:46:37 +01:00
headers . insert (
reqwest ::header ::AUTHORIZATION ,
2023-12-12 21:19:48 +01:00
reqwest ::header ::HeaderValue ::from_str ( & format! ( " Bearer {} " , api_key ) )
2023-11-15 15:46:37 +01:00
. map_err ( NewEmbedderError ::openai_invalid_api_key_format ) ? ,
) ;
headers . insert (
reqwest ::header ::CONTENT_TYPE ,
reqwest ::header ::HeaderValue ::from_static ( " application/json " ) ,
) ;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs ::cl100k_base ( ) . unwrap ( ) ;
2024-01-29 11:23:18 +01:00
Ok ( Self { options , headers , tokenizer } )
2023-11-15 15:46:37 +01:00
}
2024-01-29 11:23:18 +01:00
pub async fn embed (
& self ,
texts : Vec < String > ,
client : & reqwest ::Client ,
) -> Result < Vec < Embeddings < f32 > > , EmbedError > {
2023-11-15 15:46:37 +01:00
let mut tokenized = false ;
for attempt in 0 .. 7 {
let result = if tokenized {
2024-01-29 11:23:18 +01:00
self . try_embed_tokenized ( & texts , client ) . await
2023-11-15 15:46:37 +01:00
} else {
2024-01-29 11:23:18 +01:00
self . try_embed ( & texts , client ) . await
2023-11-15 15:46:37 +01:00
} ;
let retry_duration = match result {
Ok ( embeddings ) = > return Ok ( embeddings ) ,
Err ( retry ) = > {
2024-02-06 10:49:23 +01:00
tracing ::warn! ( " Failed: {} " , retry . error ) ;
2023-11-15 15:46:37 +01:00
tokenized | = retry . must_tokenize ( ) ;
retry . into_duration ( attempt )
}
} ? ;
2024-03-05 12:19:25 +01:00
let retry_duration = retry_duration . min ( std ::time ::Duration ::from_secs ( 60 ) ) ; // don't wait more than a minute
2024-02-06 10:49:23 +01:00
tracing ::warn! (
" Attempt #{}, retrying after {}ms. " ,
attempt ,
retry_duration . as_millis ( )
) ;
2023-11-15 15:46:37 +01:00
tokio ::time ::sleep ( retry_duration ) . await ;
}
let result = if tokenized {
2024-01-29 11:23:18 +01:00
self . try_embed_tokenized ( & texts , client ) . await
2023-11-15 15:46:37 +01:00
} else {
2024-01-29 11:23:18 +01:00
self . try_embed ( & texts , client ) . await
2023-11-15 15:46:37 +01:00
} ;
result . map_err ( Retry ::into_error )
}
async fn check_response ( response : reqwest ::Response ) -> Result < reqwest ::Response , Retry > {
if ! response . status ( ) . is_success ( ) {
match response . status ( ) {
StatusCode ::UNAUTHORIZED = > {
let error_response : OpenAiErrorResponse = response
. json ( )
. await
. map_err ( EmbedError ::openai_unexpected )
. map_err ( Retry ::retry_later ) ? ;
return Err ( Retry ::give_up ( EmbedError ::openai_auth_error (
error_response . error ,
) ) ) ;
}
StatusCode ::TOO_MANY_REQUESTS = > {
let error_response : OpenAiErrorResponse = response
. json ( )
. await
. map_err ( EmbedError ::openai_unexpected )
. map_err ( Retry ::retry_later ) ? ;
return Err ( Retry ::rate_limited ( EmbedError ::openai_too_many_requests (
error_response . error ,
) ) ) ;
}
2024-03-05 12:18:54 +01:00
StatusCode ::INTERNAL_SERVER_ERROR
| StatusCode ::BAD_GATEWAY
| StatusCode ::SERVICE_UNAVAILABLE = > {
let error_response : Result < OpenAiErrorResponse , _ > = response . json ( ) . await ;
2023-11-15 15:46:37 +01:00
return Err ( Retry ::retry_later ( EmbedError ::openai_internal_server_error (
2024-03-05 12:18:54 +01:00
error_response . ok ( ) . map ( | error_response | error_response . error ) ,
2023-11-15 15:46:37 +01:00
) ) ) ;
}
StatusCode ::BAD_REQUEST = > {
// Most probably, one text contained too many tokens
let error_response : OpenAiErrorResponse = response
. json ( )
. await
. map_err ( EmbedError ::openai_unexpected )
. map_err ( Retry ::retry_later ) ? ;
2024-03-05 12:18:54 +01:00
tracing ::warn! ( " OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt. " ) ;
2023-11-15 15:46:37 +01:00
return Err ( Retry ::retry_tokenized ( EmbedError ::openai_too_many_tokens (
error_response . error ,
) ) ) ;
}
code = > {
2024-03-05 12:18:54 +01:00
return Err ( Retry ::retry_later ( EmbedError ::openai_unhandled_status_code (
2023-11-15 15:46:37 +01:00
code . as_u16 ( ) ,
) ) ) ;
}
}
}
Ok ( response )
}
async fn try_embed < S : AsRef < str > + serde ::Serialize > (
& self ,
texts : & [ S ] ,
2024-01-29 11:23:18 +01:00
client : & reqwest ::Client ,
2023-11-15 15:46:37 +01:00
) -> Result < Vec < Embeddings < f32 > > , Retry > {
for text in texts {
2024-02-06 10:49:23 +01:00
tracing ::trace! ( " Received prompt: {} " , text . as_ref ( ) )
2023-11-15 15:46:37 +01:00
}
2024-01-30 16:32:57 +01:00
let request = OpenAiRequest {
model : self . options . embedding_model . name ( ) ,
input : texts ,
2024-02-07 11:03:00 +01:00
dimensions : self . overriden_dimensions ( ) ,
2024-01-30 16:32:57 +01:00
} ;
2024-01-29 11:23:18 +01:00
let response = client
2023-11-15 15:46:37 +01:00
. post ( OPENAI_EMBEDDINGS_URL )
. json ( & request )
. send ( )
. await
. map_err ( EmbedError ::openai_network )
. map_err ( Retry ::retry_later ) ? ;
let response = Self ::check_response ( response ) . await ? ;
let response : OpenAiResponse = response
. json ( )
. await
. map_err ( EmbedError ::openai_unexpected )
. map_err ( Retry ::retry_later ) ? ;
2024-02-06 10:49:23 +01:00
tracing ::trace! ( " response: {:?} " , response . data ) ;
2023-11-15 15:46:37 +01:00
Ok ( response
. data
. into_iter ( )
. map ( | data | Embeddings ::from_single_embedding ( data . embedding ) )
. collect ( ) )
}
2024-01-29 11:23:18 +01:00
async fn try_embed_tokenized (
& self ,
text : & [ String ] ,
client : & reqwest ::Client ,
) -> Result < Vec < Embeddings < f32 > > , Retry > {
2023-11-15 15:46:37 +01:00
pub const OVERLAP_SIZE : usize = 200 ;
let mut all_embeddings = Vec ::with_capacity ( text . len ( ) ) ;
for text in text {
let max_token_count = self . options . embedding_model . max_token ( ) ;
let encoded = self . tokenizer . encode_ordinary ( text . as_str ( ) ) ;
let len = encoded . len ( ) ;
if len < max_token_count {
2024-01-29 11:23:18 +01:00
all_embeddings . append ( & mut self . try_embed ( & [ text ] , client ) . await ? ) ;
2023-11-15 15:46:37 +01:00
continue ;
}
let mut tokens = encoded . as_slice ( ) ;
2024-02-07 10:36:30 +01:00
let mut embeddings_for_prompt = Embeddings ::new ( self . dimensions ( ) ) ;
2023-11-15 15:46:37 +01:00
while tokens . len ( ) > max_token_count {
let window = & tokens [ .. max_token_count ] ;
2024-01-29 11:23:18 +01:00
embeddings_for_prompt . push ( self . embed_tokens ( window , client ) . await ? ) . unwrap ( ) ;
2023-11-15 15:46:37 +01:00
tokens = & tokens [ max_token_count - OVERLAP_SIZE .. ] ;
}
// end of text
2024-01-29 11:23:18 +01:00
embeddings_for_prompt . push ( self . embed_tokens ( tokens , client ) . await ? ) . unwrap ( ) ;
2023-11-15 15:46:37 +01:00
all_embeddings . push ( embeddings_for_prompt ) ;
}
Ok ( all_embeddings )
}
2024-01-29 11:23:18 +01:00
async fn embed_tokens (
& self ,
tokens : & [ usize ] ,
client : & reqwest ::Client ,
) -> Result < Embedding , Retry > {
2023-11-15 15:46:37 +01:00
for attempt in 0 .. 9 {
2024-01-29 11:23:18 +01:00
let duration = match self . try_embed_tokens ( tokens , client ) . await {
2023-11-15 15:46:37 +01:00
Ok ( embedding ) = > return Ok ( embedding ) ,
Err ( retry ) = > retry . into_duration ( attempt ) ,
}
. map_err ( Retry ::retry_later ) ? ;
tokio ::time ::sleep ( duration ) . await ;
}
2024-01-29 11:23:18 +01:00
self . try_embed_tokens ( tokens , client )
. await
. map_err ( | retry | Retry ::give_up ( retry . into_error ( ) ) )
2023-11-15 15:46:37 +01:00
}
2024-01-29 11:23:18 +01:00
async fn try_embed_tokens (
& self ,
tokens : & [ usize ] ,
client : & reqwest ::Client ,
) -> Result < Embedding , Retry > {
2024-02-07 11:03:00 +01:00
let request = OpenAiTokensRequest {
model : self . options . embedding_model . name ( ) ,
input : tokens ,
dimensions : self . overriden_dimensions ( ) ,
} ;
2024-01-29 11:23:18 +01:00
let response = client
2023-11-15 15:46:37 +01:00
. post ( OPENAI_EMBEDDINGS_URL )
. json ( & request )
. send ( )
. await
. map_err ( EmbedError ::openai_network )
. map_err ( Retry ::retry_later ) ? ;
let response = Self ::check_response ( response ) . await ? ;
let mut response : OpenAiResponse = response
. json ( )
. await
. map_err ( EmbedError ::openai_unexpected )
. map_err ( Retry ::retry_later ) ? ;
Ok ( response . data . pop ( ) . map ( | data | data . embedding ) . unwrap_or_default ( ) )
}
2024-01-29 11:23:18 +01:00
pub fn embed_chunks (
2023-11-15 15:46:37 +01:00
& self ,
text_chunks : Vec < Vec < String > > ,
) -> Result < Vec < Vec < Embeddings < f32 > > > , EmbedError > {
2024-01-29 11:23:18 +01:00
let rt = tokio ::runtime ::Builder ::new_current_thread ( )
. enable_io ( )
. enable_time ( )
. build ( )
. map_err ( EmbedError ::openai_runtime_init ) ? ;
let client = self . new_client ( ) ? ;
rt . block_on ( futures ::future ::try_join_all (
text_chunks . into_iter ( ) . map ( | prompts | self . embed ( prompts , & client ) ) ,
) )
2023-11-15 15:46:37 +01:00
}
pub fn chunk_count_hint ( & self ) -> usize {
10
}
pub fn prompt_count_in_chunk_hint ( & self ) -> usize {
10
}
2023-12-12 21:19:48 +01:00
pub fn dimensions ( & self ) -> usize {
2024-02-07 10:36:30 +01:00
if self . options . embedding_model . supports_overriding_dimensions ( ) {
self . options . dimensions . unwrap_or ( self . options . embedding_model . default_dimensions ( ) )
} else {
self . options . embedding_model . default_dimensions ( )
}
2023-12-12 21:19:48 +01:00
}
2023-12-14 16:01:35 +01:00
pub fn distribution ( & self ) -> Option < DistributionShift > {
self . options . embedding_model . distribution ( )
}
2024-02-07 11:03:00 +01:00
fn overriden_dimensions ( & self ) -> Option < usize > {
if self . options . embedding_model . supports_overriding_dimensions ( ) {
self . options . dimensions
} else {
None
}
}
2023-11-15 15:46:37 +01:00
}
// retrying in case of failure
2024-03-20 10:08:28 +01:00
pub struct Retry {
pub error : EmbedError ,
2023-11-15 15:46:37 +01:00
strategy : RetryStrategy ,
}
2024-03-20 10:08:28 +01:00
pub enum RetryStrategy {
2023-11-15 15:46:37 +01:00
GiveUp ,
Retry ,
RetryTokenized ,
RetryAfterRateLimit ,
}
impl Retry {
2024-03-20 10:08:28 +01:00
pub fn give_up ( error : EmbedError ) -> Self {
2023-11-15 15:46:37 +01:00
Self { error , strategy : RetryStrategy ::GiveUp }
}
2024-03-20 10:08:28 +01:00
pub fn retry_later ( error : EmbedError ) -> Self {
2023-11-15 15:46:37 +01:00
Self { error , strategy : RetryStrategy ::Retry }
}
2024-03-20 10:08:28 +01:00
pub fn retry_tokenized ( error : EmbedError ) -> Self {
2023-11-15 15:46:37 +01:00
Self { error , strategy : RetryStrategy ::RetryTokenized }
}
2024-03-20 10:08:28 +01:00
pub fn rate_limited ( error : EmbedError ) -> Self {
2023-11-15 15:46:37 +01:00
Self { error , strategy : RetryStrategy ::RetryAfterRateLimit }
}
2024-03-20 10:08:28 +01:00
pub fn into_duration ( self , attempt : u32 ) -> Result < tokio ::time ::Duration , EmbedError > {
2023-11-15 15:46:37 +01:00
match self . strategy {
RetryStrategy ::GiveUp = > Err ( self . error ) ,
RetryStrategy ::Retry = > Ok ( tokio ::time ::Duration ::from_millis ( ( 10 u64 ) . pow ( attempt ) ) ) ,
RetryStrategy ::RetryTokenized = > Ok ( tokio ::time ::Duration ::from_millis ( 1 ) ) ,
RetryStrategy ::RetryAfterRateLimit = > {
Ok ( tokio ::time ::Duration ::from_millis ( 100 + 10 u64 . pow ( attempt ) ) )
}
}
}
2024-03-20 10:08:28 +01:00
pub fn must_tokenize ( & self ) -> bool {
2023-11-15 15:46:37 +01:00
matches! ( self . strategy , RetryStrategy ::RetryTokenized )
}
2024-03-20 10:08:28 +01:00
pub fn into_error ( self ) -> EmbedError {
2023-11-15 15:46:37 +01:00
self . error
}
}
// openai api structs
#[ derive(Debug, Serialize) ]
struct OpenAiRequest < ' a , S : AsRef < str > + serde ::Serialize > {
model : & ' a str ,
input : & ' a [ S ] ,
2024-02-07 11:03:00 +01:00
#[ serde(skip_serializing_if = " Option::is_none " ) ]
dimensions : Option < usize > ,
2023-11-15 15:46:37 +01:00
}
#[ derive(Debug, Serialize) ]
struct OpenAiTokensRequest < ' a > {
model : & ' a str ,
input : & ' a [ usize ] ,
2024-02-07 11:03:00 +01:00
#[ serde(skip_serializing_if = " Option::is_none " ) ]
dimensions : Option < usize > ,
2023-11-15 15:46:37 +01:00
}
#[ derive(Debug, Deserialize) ]
struct OpenAiResponse {
data : Vec < OpenAiEmbedding > ,
}
#[ derive(Debug, Deserialize) ]
struct OpenAiErrorResponse {
error : OpenAiError ,
}
#[ derive(Debug, Deserialize) ]
pub struct OpenAiError {
message : String ,
// type: String,
code : Option < String > ,
}
impl Display for OpenAiError {
fn fmt ( & self , f : & mut std ::fmt ::Formatter < '_ > ) -> std ::fmt ::Result {
match & self . code {
Some ( code ) = > write! ( f , " {} ({}) " , self . message , code ) ,
None = > write! ( f , " {} " , self . message ) ,
}
}
}
#[ derive(Debug, Deserialize) ]
struct OpenAiEmbedding {
embedding : Embedding ,
// object: String,
// index: usize,
}
2023-12-12 21:19:48 +01:00
fn infer_api_key ( ) -> String {
std ::env ::var ( " MEILI_OPENAI_API_KEY " )
. or_else ( | _ | std ::env ::var ( " OPENAI_API_KEY " ) )
. unwrap_or_default ( )
}