Make sure the overriden dimensions are always used when embedding

This commit is contained in:
Louis Dureuil 2024-02-07 10:36:30 +01:00
parent fb705116a6
commit 7ae4013478
No known key found for this signature in database

View File

@ -65,14 +65,10 @@ impl EmbeddingModel {
} }
} }
pub fn dimensions(&self) -> usize { pub fn default_dimensions(&self) -> usize {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => 1536, EmbeddingModel::TextEmbeddingAda002 => 1536,
//Default value for the model
EmbeddingModel::TextEmbedding3Large => 1536, EmbeddingModel::TextEmbedding3Large => 1536,
//Default value for the model
EmbeddingModel::TextEmbedding3Small => 3072, EmbeddingModel::TextEmbedding3Small => 3072,
} }
} }
@ -108,7 +104,7 @@ impl EmbeddingModel {
} }
} }
pub fn is_optional_dimensions_supported(&self) -> bool { pub fn supports_overriding_dimensions(&self) -> bool {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => false, EmbeddingModel::TextEmbeddingAda002 => false,
EmbeddingModel::TextEmbedding3Large => true, EmbeddingModel::TextEmbedding3Large => true,
@ -275,7 +271,7 @@ impl Embedder {
let request = OpenAiRequest { let request = OpenAiRequest {
model: self.options.embedding_model.name(), model: self.options.embedding_model.name(),
input: texts, input: texts,
dimension: if self.options.embedding_model.is_optional_dimensions_supported() { dimension: if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.as_ref() self.options.dimensions.as_ref()
} else { } else {
None None
@ -323,8 +319,7 @@ impl Embedder {
} }
let mut tokens = encoded.as_slice(); let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt = let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
Embeddings::new(self.options.embedding_model.dimensions());
while tokens.len() > max_token_count { while tokens.len() > max_token_count {
let window = &tokens[..max_token_count]; let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
@ -409,7 +404,11 @@ impl Embedder {
} }
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions()) if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {