2025-01-08 15:59:56 +01:00
use std ::fmt ;
2024-11-06 09:24:51 +01:00
use std ::time ::Instant ;
2024-03-25 10:05:38 +01:00
use ordered_float ::OrderedFloat ;
2024-03-19 15:41:37 +01:00
use rayon ::iter ::{ IntoParallelIterator , ParallelIterator as _ } ;
2024-10-28 14:08:54 +01:00
use rayon ::slice ::ParallelSlice as _ ;
2023-11-15 15:46:37 +01:00
use super ::error ::{ EmbedError , NewEmbedderError } ;
2024-03-19 15:41:37 +01:00
use super ::rest ::{ Embedder as RestEmbedder , EmbedderOptions as RestEmbedderOptions } ;
2024-10-28 14:08:54 +01:00
use super ::DistributionShift ;
2024-04-24 16:40:12 +02:00
use crate ::error ::FaultSource ;
2024-03-19 15:41:37 +01:00
use crate ::vector ::error ::EmbedErrorKind ;
2024-10-28 14:08:54 +01:00
use crate ::vector ::Embedding ;
2024-04-24 16:40:12 +02:00
use crate ::ThreadPoolNoAbort ;
2023-11-15 15:46:37 +01:00
#[ derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize) ]
pub struct EmbedderOptions {
2024-07-15 16:20:19 +02:00
pub url : Option < String > ,
2023-12-12 21:19:48 +01:00
pub api_key : Option < String > ,
2023-11-15 15:46:37 +01:00
pub embedding_model : EmbeddingModel ,
2024-01-30 16:32:57 +01:00
pub dimensions : Option < usize > ,
2024-03-27 11:50:22 +01:00
pub distribution : Option < DistributionShift > ,
2023-11-15 15:46:37 +01:00
}
2024-03-19 15:41:37 +01:00
impl EmbedderOptions {
pub fn dimensions ( & self ) -> usize {
if self . embedding_model . supports_overriding_dimensions ( ) {
self . dimensions . unwrap_or ( self . embedding_model . default_dimensions ( ) )
} else {
self . embedding_model . default_dimensions ( )
}
}
2024-07-16 15:17:49 +02:00
pub fn request ( & self ) -> serde_json ::Value {
2024-03-19 15:41:37 +01:00
let model = self . embedding_model . name ( ) ;
2024-07-16 15:17:49 +02:00
let mut request = serde_json ::json! ( {
2024-03-19 15:41:37 +01:00
" model " : model ,
2024-07-16 15:17:49 +02:00
" input " : [ super ::rest ::REQUEST_PLACEHOLDER , super ::rest ::REPEAT_PLACEHOLDER ]
2024-03-19 15:41:37 +01:00
} ) ;
if self . embedding_model . supports_overriding_dimensions ( ) {
if let Some ( dimensions ) = self . dimensions {
2024-07-16 15:17:49 +02:00
request [ " dimensions " ] = dimensions . into ( ) ;
2024-03-19 15:41:37 +01:00
}
}
2024-07-16 15:17:49 +02:00
request
2024-03-19 15:41:37 +01:00
}
2024-03-27 11:50:22 +01:00
pub fn distribution ( & self ) -> Option < DistributionShift > {
self . distribution . or ( self . embedding_model . distribution ( ) )
}
2024-03-19 15:41:37 +01:00
}
2023-11-15 15:46:37 +01:00
#[ derive(
Debug ,
Clone ,
Copy ,
Default ,
Hash ,
PartialEq ,
Eq ,
serde ::Serialize ,
serde ::Deserialize ,
deserr ::Deserr ,
) ]
#[ serde(deny_unknown_fields, rename_all = " camelCase " ) ]
#[ deserr(rename_all = camelCase, deny_unknown_fields) ]
pub enum EmbeddingModel {
2023-12-20 17:08:32 +01:00
// # WARNING
//
// If ever adding a model, make sure to add it to the list of supported models below.
2023-12-13 23:25:38 +01:00
#[ serde(rename = " text-embedding-ada-002 " ) ]
#[ deserr(rename = " text-embedding-ada-002 " ) ]
2023-11-15 15:46:37 +01:00
TextEmbeddingAda002 ,
2024-01-30 16:32:57 +01:00
2024-09-09 13:09:35 +02:00
#[ default ]
2024-01-30 16:32:57 +01:00
#[ serde(rename = " text-embedding-3-small " ) ]
#[ deserr(rename = " text-embedding-3-small " ) ]
TextEmbedding3Small ,
#[ serde(rename = " text-embedding-3-large " ) ]
#[ deserr(rename = " text-embedding-3-large " ) ]
TextEmbedding3Large ,
2023-11-15 15:46:37 +01:00
}
impl EmbeddingModel {
2023-12-20 17:08:32 +01:00
pub fn supported_models ( ) -> & 'static [ & 'static str ] {
2024-01-30 16:32:57 +01:00
& [ " text-embedding-ada-002 " , " text-embedding-3-small " , " text-embedding-3-large " ]
2023-12-20 17:08:32 +01:00
}
2023-11-15 15:46:37 +01:00
pub fn max_token ( & self ) -> usize {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > 8191 ,
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > 8191 ,
EmbeddingModel ::TextEmbedding3Small = > 8191 ,
2023-11-15 15:46:37 +01:00
}
}
2024-02-07 10:36:30 +01:00
pub fn default_dimensions ( & self ) -> usize {
2023-11-15 15:46:37 +01:00
match self {
EmbeddingModel ::TextEmbeddingAda002 = > 1536 ,
2024-02-07 11:48:19 +01:00
EmbeddingModel ::TextEmbedding3Large = > 3072 ,
EmbeddingModel ::TextEmbedding3Small = > 1536 ,
2023-11-15 15:46:37 +01:00
}
}
pub fn name ( & self ) -> & 'static str {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > " text-embedding-ada-002 " ,
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > " text-embedding-3-large " ,
EmbeddingModel ::TextEmbedding3Small = > " text-embedding-3-small " ,
2023-11-15 15:46:37 +01:00
}
}
2023-12-20 17:08:32 +01:00
pub fn from_name ( name : & str ) -> Option < Self > {
2023-11-15 15:46:37 +01:00
match name {
" text-embedding-ada-002 " = > Some ( EmbeddingModel ::TextEmbeddingAda002 ) ,
2024-01-30 16:32:57 +01:00
" text-embedding-3-large " = > Some ( EmbeddingModel ::TextEmbedding3Large ) ,
" text-embedding-3-small " = > Some ( EmbeddingModel ::TextEmbedding3Small ) ,
2023-11-15 15:46:37 +01:00
_ = > None ,
}
}
2023-12-14 16:01:35 +01:00
fn distribution ( & self ) -> Option < DistributionShift > {
match self {
2024-03-25 10:05:38 +01:00
EmbeddingModel ::TextEmbeddingAda002 = > Some ( DistributionShift {
current_mean : OrderedFloat ( 0.90 ) ,
current_sigma : OrderedFloat ( 0.08 ) ,
} ) ,
EmbeddingModel ::TextEmbedding3Large = > Some ( DistributionShift {
current_mean : OrderedFloat ( 0.70 ) ,
current_sigma : OrderedFloat ( 0.1 ) ,
} ) ,
EmbeddingModel ::TextEmbedding3Small = > Some ( DistributionShift {
current_mean : OrderedFloat ( 0.75 ) ,
current_sigma : OrderedFloat ( 0.1 ) ,
} ) ,
2024-01-30 16:32:57 +01:00
}
}
2024-02-07 10:36:30 +01:00
pub fn supports_overriding_dimensions ( & self ) -> bool {
2024-01-30 16:32:57 +01:00
match self {
EmbeddingModel ::TextEmbeddingAda002 = > false ,
EmbeddingModel ::TextEmbedding3Large = > true ,
EmbeddingModel ::TextEmbedding3Small = > true ,
2023-12-14 16:01:35 +01:00
}
}
2023-11-15 15:46:37 +01:00
}
pub const OPENAI_EMBEDDINGS_URL : & str = " https://api.openai.com/v1/embeddings " ;
impl EmbedderOptions {
2023-12-12 21:19:48 +01:00
pub fn with_default_model ( api_key : Option < String > ) -> Self {
2024-07-15 16:20:19 +02:00
Self {
api_key ,
embedding_model : Default ::default ( ) ,
dimensions : None ,
distribution : None ,
url : None ,
}
2023-11-15 15:46:37 +01:00
}
}
2023-12-12 21:19:48 +01:00
fn infer_api_key ( ) -> String {
std ::env ::var ( " MEILI_OPENAI_API_KEY " )
. or_else ( | _ | std ::env ::var ( " OPENAI_API_KEY " ) )
. unwrap_or_default ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub struct Embedder {
tokenizer : tiktoken_rs ::CoreBPE ,
rest_embedder : RestEmbedder ,
options : EmbedderOptions ,
}
impl Embedder {
pub fn new ( options : EmbedderOptions ) -> Result < Self , NewEmbedderError > {
let mut inferred_api_key = Default ::default ( ) ;
let api_key = options . api_key . as_ref ( ) . unwrap_or_else ( | | {
inferred_api_key = infer_api_key ( ) ;
& inferred_api_key
} ) ;
2024-07-15 16:20:19 +02:00
let url = options . url . as_deref ( ) . unwrap_or ( OPENAI_EMBEDDINGS_URL ) . to_owned ( ) ;
2024-07-16 15:17:49 +02:00
let rest_embedder = RestEmbedder ::new (
RestEmbedderOptions {
2024-07-30 15:44:19 +02:00
api_key : ( ! api_key . is_empty ( ) ) . then ( | | api_key . clone ( ) ) ,
2024-07-16 15:17:49 +02:00
distribution : None ,
dimensions : Some ( options . dimensions ( ) ) ,
url ,
request : options . request ( ) ,
response : serde_json ::json! ( {
" data " : [ {
" embedding " : super ::rest ::RESPONSE_PLACEHOLDER
} ,
super ::rest ::REPEAT_PLACEHOLDER
]
} ) ,
2024-07-22 12:04:05 +02:00
headers : Default ::default ( ) ,
2024-07-16 15:17:49 +02:00
} ,
super ::rest ::ConfigurationSource ::OpenAi ,
) ? ;
2024-03-19 15:41:37 +01:00
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs ::cl100k_base ( ) . unwrap ( ) ;
Ok ( Self { options , rest_embedder , tokenizer } )
}
2024-10-28 14:08:54 +01:00
pub fn embed < S : AsRef < str > + serde ::Serialize > (
& self ,
texts : & [ S ] ,
2024-11-06 09:24:51 +01:00
deadline : Option < Instant > ,
2024-10-28 14:08:54 +01:00
) -> Result < Vec < Embedding > , EmbedError > {
2024-11-06 09:24:51 +01:00
match self . rest_embedder . embed_ref ( texts , deadline ) {
2024-03-19 15:41:37 +01:00
Ok ( embeddings ) = > Ok ( embeddings ) ,
2024-07-16 15:17:49 +02:00
Err ( EmbedError { kind : EmbedErrorKind ::RestBadRequest ( error , _ ) , fault : _ } ) = > {
2024-03-19 15:41:37 +01:00
tracing ::warn! ( error = ? error , " OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template. " ) ;
2024-11-06 09:24:51 +01:00
self . try_embed_tokenized ( texts , deadline )
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
Err ( error ) = > Err ( error ) ,
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
}
2024-03-14 11:14:31 +01:00
2024-11-06 09:24:51 +01:00
fn try_embed_tokenized < S : AsRef < str > > (
& self ,
text : & [ S ] ,
deadline : Option < Instant > ,
) -> Result < Vec < Embedding > , EmbedError > {
2024-03-19 15:41:37 +01:00
let mut all_embeddings = Vec ::with_capacity ( text . len ( ) ) ;
for text in text {
2024-10-28 14:08:54 +01:00
let text = text . as_ref ( ) ;
2024-03-19 15:41:37 +01:00
let max_token_count = self . options . embedding_model . max_token ( ) ;
2024-10-28 14:08:54 +01:00
let encoded = self . tokenizer . encode_ordinary ( text ) ;
2024-03-19 15:41:37 +01:00
let len = encoded . len ( ) ;
if len < max_token_count {
2024-11-06 09:24:51 +01:00
all_embeddings . append ( & mut self . rest_embedder . embed_ref ( & [ text ] , deadline ) ? ) ;
2024-03-19 15:41:37 +01:00
continue ;
2024-03-14 11:14:31 +01:00
}
2024-07-15 16:27:26 +02:00
let tokens = & encoded . as_slice ( ) [ 0 .. max_token_count ] ;
2024-03-14 11:14:31 +01:00
2024-11-06 09:24:51 +01:00
let embedding = self . rest_embedder . embed_tokens ( tokens , deadline ) ? ;
2024-03-14 11:14:31 +01:00
2024-10-28 14:08:54 +01:00
all_embeddings . push ( embedding ) ;
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
Ok ( all_embeddings )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn embed_chunks (
& self ,
text_chunks : Vec < Vec < String > > ,
2024-04-24 16:40:12 +02:00
threads : & ThreadPoolNoAbort ,
2024-10-28 14:08:54 +01:00
) -> Result < Vec < Vec < Embedding > > , EmbedError > {
2024-04-24 16:40:12 +02:00
threads
. install ( move | | {
2024-11-06 09:24:51 +01:00
text_chunks . into_par_iter ( ) . map ( move | chunk | self . embed ( & chunk , None ) ) . collect ( )
2024-10-28 14:08:54 +01:00
} )
. map_err ( | error | EmbedError {
kind : EmbedErrorKind ::PanicInThreadPool ( error ) ,
fault : FaultSource ::Bug ,
} ) ?
}
pub ( crate ) fn embed_chunks_ref (
& self ,
texts : & [ & str ] ,
threads : & ThreadPoolNoAbort ,
) -> Result < Vec < Vec < f32 > > , EmbedError > {
threads
. install ( move | | {
let embeddings : Result < Vec < Vec < Embedding > > , _ > = texts
2024-11-12 16:31:22 +01:00
. par_chunks ( self . prompt_count_in_chunk_hint ( ) )
2024-11-06 09:24:51 +01:00
. map ( move | chunk | self . embed ( chunk , None ) )
2024-10-28 14:08:54 +01:00
. collect ( ) ;
let embeddings = embeddings ? ;
Ok ( embeddings . into_iter ( ) . flatten ( ) . collect ( ) )
2024-04-24 16:40:12 +02:00
} )
. map_err ( | error | EmbedError {
kind : EmbedErrorKind ::PanicInThreadPool ( error ) ,
fault : FaultSource ::Bug ,
} ) ?
2024-03-19 15:41:37 +01:00
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn chunk_count_hint ( & self ) -> usize {
self . rest_embedder . chunk_count_hint ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn prompt_count_in_chunk_hint ( & self ) -> usize {
self . rest_embedder . prompt_count_in_chunk_hint ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn dimensions ( & self ) -> usize {
self . options . dimensions ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn distribution ( & self ) -> Option < DistributionShift > {
2024-03-27 11:50:22 +01:00
self . options . distribution ( )
2024-03-14 11:14:31 +01:00
}
}
2025-01-08 15:59:56 +01:00
impl fmt ::Debug for Embedder {
fn fmt ( & self , f : & mut fmt ::Formatter < '_ > ) -> fmt ::Result {
f . debug_struct ( " Embedder " )
. field ( " tokenizer " , & " CoreBPE " )
. field ( " rest_embedder " , & self . rest_embedder )
. field ( " options " , & self . options )
. finish ( )
}
}