2024-03-19 15:41:37 +01:00
use rayon ::iter ::{ IntoParallelIterator , ParallelIterator as _ } ;
2023-11-15 15:46:37 +01:00
use super ::error ::{ EmbedError , NewEmbedderError } ;
2024-03-19 15:41:37 +01:00
use super ::rest ::{ Embedder as RestEmbedder , EmbedderOptions as RestEmbedderOptions } ;
use super ::{ DistributionShift , Embeddings } ;
use crate ::vector ::error ::EmbedErrorKind ;
2023-11-15 15:46:37 +01:00
#[ derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize) ]
pub struct EmbedderOptions {
2023-12-12 21:19:48 +01:00
pub api_key : Option < String > ,
2023-11-15 15:46:37 +01:00
pub embedding_model : EmbeddingModel ,
2024-01-30 16:32:57 +01:00
pub dimensions : Option < usize > ,
2023-11-15 15:46:37 +01:00
}
2024-03-19 15:41:37 +01:00
impl EmbedderOptions {
pub fn dimensions ( & self ) -> usize {
if self . embedding_model . supports_overriding_dimensions ( ) {
self . dimensions . unwrap_or ( self . embedding_model . default_dimensions ( ) )
} else {
self . embedding_model . default_dimensions ( )
}
}
pub fn query ( & self ) -> serde_json ::Value {
let model = self . embedding_model . name ( ) ;
let mut query = serde_json ::json! ( {
" model " : model ,
} ) ;
if self . embedding_model . supports_overriding_dimensions ( ) {
if let Some ( dimensions ) = self . dimensions {
query [ " dimensions " ] = dimensions . into ( ) ;
}
}
query
}
}
2023-11-15 15:46:37 +01:00
#[ derive(
Debug ,
Clone ,
Copy ,
Default ,
Hash ,
PartialEq ,
Eq ,
serde ::Serialize ,
serde ::Deserialize ,
deserr ::Deserr ,
) ]
#[ serde(deny_unknown_fields, rename_all = " camelCase " ) ]
#[ deserr(rename_all = camelCase, deny_unknown_fields) ]
pub enum EmbeddingModel {
2023-12-20 17:08:32 +01:00
// # WARNING
//
// If ever adding a model, make sure to add it to the list of supported models below.
2023-11-15 15:46:37 +01:00
#[ default ]
2023-12-13 23:25:38 +01:00
#[ serde(rename = " text-embedding-ada-002 " ) ]
#[ deserr(rename = " text-embedding-ada-002 " ) ]
2023-11-15 15:46:37 +01:00
TextEmbeddingAda002 ,
2024-01-30 16:32:57 +01:00
#[ serde(rename = " text-embedding-3-small " ) ]
#[ deserr(rename = " text-embedding-3-small " ) ]
TextEmbedding3Small ,
#[ serde(rename = " text-embedding-3-large " ) ]
#[ deserr(rename = " text-embedding-3-large " ) ]
TextEmbedding3Large ,
2023-11-15 15:46:37 +01:00
}
impl EmbeddingModel {
2023-12-20 17:08:32 +01:00
pub fn supported_models ( ) -> & 'static [ & 'static str ] {
2024-01-30 16:32:57 +01:00
& [ " text-embedding-ada-002 " , " text-embedding-3-small " , " text-embedding-3-large " ]
2023-12-20 17:08:32 +01:00
}
2023-11-15 15:46:37 +01:00
pub fn max_token ( & self ) -> usize {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > 8191 ,
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > 8191 ,
EmbeddingModel ::TextEmbedding3Small = > 8191 ,
2023-11-15 15:46:37 +01:00
}
}
2024-02-07 10:36:30 +01:00
pub fn default_dimensions ( & self ) -> usize {
2023-11-15 15:46:37 +01:00
match self {
EmbeddingModel ::TextEmbeddingAda002 = > 1536 ,
2024-02-07 11:48:19 +01:00
EmbeddingModel ::TextEmbedding3Large = > 3072 ,
EmbeddingModel ::TextEmbedding3Small = > 1536 ,
2023-11-15 15:46:37 +01:00
}
}
pub fn name ( & self ) -> & 'static str {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > " text-embedding-ada-002 " ,
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > " text-embedding-3-large " ,
EmbeddingModel ::TextEmbedding3Small = > " text-embedding-3-small " ,
2023-11-15 15:46:37 +01:00
}
}
2023-12-20 17:08:32 +01:00
pub fn from_name ( name : & str ) -> Option < Self > {
2023-11-15 15:46:37 +01:00
match name {
" text-embedding-ada-002 " = > Some ( EmbeddingModel ::TextEmbeddingAda002 ) ,
2024-01-30 16:32:57 +01:00
" text-embedding-3-large " = > Some ( EmbeddingModel ::TextEmbedding3Large ) ,
" text-embedding-3-small " = > Some ( EmbeddingModel ::TextEmbedding3Small ) ,
2023-11-15 15:46:37 +01:00
_ = > None ,
}
}
2023-12-14 16:01:35 +01:00
fn distribution ( & self ) -> Option < DistributionShift > {
match self {
EmbeddingModel ::TextEmbeddingAda002 = > {
Some ( DistributionShift { current_mean : 0.90 , current_sigma : 0.08 } )
}
2024-01-30 16:32:57 +01:00
EmbeddingModel ::TextEmbedding3Large = > {
2024-02-07 14:22:13 +01:00
Some ( DistributionShift { current_mean : 0.70 , current_sigma : 0.1 } )
2024-01-30 16:32:57 +01:00
}
EmbeddingModel ::TextEmbedding3Small = > {
2024-02-07 14:22:13 +01:00
Some ( DistributionShift { current_mean : 0.75 , current_sigma : 0.1 } )
2024-01-30 16:32:57 +01:00
}
}
}
2024-02-07 10:36:30 +01:00
pub fn supports_overriding_dimensions ( & self ) -> bool {
2024-01-30 16:32:57 +01:00
match self {
EmbeddingModel ::TextEmbeddingAda002 = > false ,
EmbeddingModel ::TextEmbedding3Large = > true ,
EmbeddingModel ::TextEmbedding3Small = > true ,
2023-12-14 16:01:35 +01:00
}
}
2023-11-15 15:46:37 +01:00
}
pub const OPENAI_EMBEDDINGS_URL : & str = " https://api.openai.com/v1/embeddings " ;
impl EmbedderOptions {
2023-12-12 21:19:48 +01:00
pub fn with_default_model ( api_key : Option < String > ) -> Self {
2024-01-30 16:32:57 +01:00
Self { api_key , embedding_model : Default ::default ( ) , dimensions : None }
2023-11-15 15:46:37 +01:00
}
2023-12-12 21:19:48 +01:00
pub fn with_embedding_model ( api_key : Option < String > , embedding_model : EmbeddingModel ) -> Self {
2024-01-30 16:32:57 +01:00
Self { api_key , embedding_model , dimensions : None }
2023-11-15 15:46:37 +01:00
}
}
2023-12-12 21:19:48 +01:00
fn infer_api_key ( ) -> String {
std ::env ::var ( " MEILI_OPENAI_API_KEY " )
. or_else ( | _ | std ::env ::var ( " OPENAI_API_KEY " ) )
. unwrap_or_default ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
#[ derive(Debug) ]
pub struct Embedder {
tokenizer : tiktoken_rs ::CoreBPE ,
rest_embedder : RestEmbedder ,
options : EmbedderOptions ,
}
impl Embedder {
pub fn new ( options : EmbedderOptions ) -> Result < Self , NewEmbedderError > {
let mut inferred_api_key = Default ::default ( ) ;
let api_key = options . api_key . as_ref ( ) . unwrap_or_else ( | | {
inferred_api_key = infer_api_key ( ) ;
& inferred_api_key
} ) ;
let rest_embedder = RestEmbedder ::new ( RestEmbedderOptions {
api_key : Some ( api_key . clone ( ) ) ,
distribution : options . embedding_model . distribution ( ) ,
dimensions : Some ( options . dimensions ( ) ) ,
url : OPENAI_EMBEDDINGS_URL . to_owned ( ) ,
query : options . query ( ) ,
input_field : vec ! [ " input " . to_owned ( ) ] ,
input_type : crate ::vector ::rest ::InputType ::TextArray ,
path_to_embeddings : vec ! [ " data " . to_owned ( ) ] ,
embedding_object : vec ! [ " embedding " . to_owned ( ) ] ,
} ) ? ;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs ::cl100k_base ( ) . unwrap ( ) ;
Ok ( Self { options , rest_embedder , tokenizer } )
}
pub fn embed ( & self , texts : Vec < String > ) -> Result < Vec < Embeddings < f32 > > , EmbedError > {
match self . rest_embedder . embed_ref ( & texts ) {
Ok ( embeddings ) = > Ok ( embeddings ) ,
Err ( EmbedError { kind : EmbedErrorKind ::RestBadRequest ( error ) , fault : _ } ) = > {
tracing ::warn! ( error = ? error , " OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template. " ) ;
self . try_embed_tokenized ( & texts )
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
Err ( error ) = > Err ( error ) ,
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
fn try_embed_tokenized ( & self , text : & [ String ] ) -> Result < Vec < Embeddings < f32 > > , EmbedError > {
pub const OVERLAP_SIZE : usize = 200 ;
let mut all_embeddings = Vec ::with_capacity ( text . len ( ) ) ;
for text in text {
let max_token_count = self . options . embedding_model . max_token ( ) ;
let encoded = self . tokenizer . encode_ordinary ( text . as_str ( ) ) ;
let len = encoded . len ( ) ;
if len < max_token_count {
all_embeddings . append ( & mut self . rest_embedder . embed_ref ( & [ text ] ) ? ) ;
continue ;
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
let mut tokens = encoded . as_slice ( ) ;
let mut embeddings_for_prompt = Embeddings ::new ( self . dimensions ( ) ) ;
while tokens . len ( ) > max_token_count {
let window = & tokens [ .. max_token_count ] ;
let embedding = self . rest_embedder . embed_tokens ( window ) ? ;
2024-03-20 13:25:10 +01:00
embeddings_for_prompt . append ( embedding . into_inner ( ) ) . map_err ( | got | {
EmbedError ::openai_unexpected_dimension ( self . dimensions ( ) , got . len ( ) )
} ) ? ;
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
tokens = & tokens [ max_token_count - OVERLAP_SIZE .. ] ;
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
// end of text
let embedding = self . rest_embedder . embed_tokens ( tokens ) ? ;
2024-03-20 13:25:10 +01:00
embeddings_for_prompt . append ( embedding . into_inner ( ) ) . map_err ( | got | {
EmbedError ::openai_unexpected_dimension ( self . dimensions ( ) , got . len ( ) )
} ) ? ;
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
all_embeddings . push ( embeddings_for_prompt ) ;
2024-03-14 11:14:31 +01:00
}
2024-03-19 15:41:37 +01:00
Ok ( all_embeddings )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn embed_chunks (
& self ,
text_chunks : Vec < Vec < String > > ,
threads : & rayon ::ThreadPool ,
) -> Result < Vec < Vec < Embeddings < f32 > > > , EmbedError > {
threads . install ( move | | {
text_chunks . into_par_iter ( ) . map ( move | chunk | self . embed ( chunk ) ) . collect ( )
} )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn chunk_count_hint ( & self ) -> usize {
self . rest_embedder . chunk_count_hint ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn prompt_count_in_chunk_hint ( & self ) -> usize {
self . rest_embedder . prompt_count_in_chunk_hint ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn dimensions ( & self ) -> usize {
self . options . dimensions ( )
}
2024-03-14 11:14:31 +01:00
2024-03-19 15:41:37 +01:00
pub fn distribution ( & self ) -> Option < DistributionShift > {
self . options . embedding_model . distribution ( )
2024-03-14 11:14:31 +01:00
}
}