From cc45e264ca6a1eae09cc6370b54b4dc73a1f6ff7 Mon Sep 17 00:00:00 2001 From: Tamo Date: Wed, 18 Sep 2024 18:13:37 +0200 Subject: [PATCH] implement the binary quantization in meilisearch --- Cargo.lock | 30 ++- index-scheduler/src/lib.rs | 11 +- meilisearch-types/src/error.rs | 5 +- meilisearch/src/routes/indexes/similar.rs | 5 +- meilisearch/src/search/mod.rs | 50 +++-- milli/Cargo.toml | 3 +- milli/src/error.rs | 4 + milli/src/index.rs | 55 ++--- milli/src/search/hybrid.rs | 4 +- milli/src/search/mod.rs | 7 +- milli/src/search/new/mod.rs | 4 + milli/src/search/new/vector_sort.rs | 10 +- milli/src/search/similar.rs | 9 +- .../extract/extract_vector_points.rs | 98 ++++----- milli/src/update/index_documents/mod.rs | 37 +++- milli/src/update/index_documents/transform.rs | 35 +--- .../src/update/index_documents/typed_chunk.rs | 26 ++- milli/src/update/settings.rs | 101 +++++---- milli/src/vector/mod.rs | 192 +++++++++++++++++- milli/src/vector/settings.rs | 96 +++++++-- 20 files changed, 559 insertions(+), 223 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1af89d382..485ab1305 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -384,6 +384,24 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "arroy" +version = "0.4.0" +dependencies = [ + "bytemuck", + "byteorder", + "heed", + "log", + "memmap2", + "nohash", + "ordered-float", + "rand", + "rayon", + "roaring", + "tempfile", + "thiserror", +] + [[package]] name = "arroy" version = "0.4.0" @@ -2555,7 +2573,7 @@ name = "index-scheduler" version = "1.11.0" dependencies = [ "anyhow", - "arroy", + "arroy 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "big_s", "bincode", "crossbeam", @@ -2838,7 +2856,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" dependencies = [ "cfg-if", - "windows-targets 0.48.1", + "windows-targets 0.52.4", ] [[package]] @@ -3545,7 +3563,7 @@ dependencies = [ name = "milli" version = "1.11.0" dependencies = [ - "arroy", + "arroy 0.4.0", "big_s", "bimap", "bincode", @@ -3686,6 +3704,12 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d02c0b00610773bb7fc61d85e13d86c7858cbdf00e1a120bfc41bc055dbaa0e" +[[package]] +name = "nohash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0f889fb66f7acdf83442c35775764b51fed3c606ab9cee51500dbde2cf528ca" + [[package]] name = "nom" version = "7.1.3" diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index 753e8c179..2126b0b94 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -1477,7 +1477,7 @@ impl IndexScheduler { .map( |IndexEmbeddingConfig { name, - config: milli::vector::EmbeddingConfig { embedder_options, prompt }, + config: milli::vector::EmbeddingConfig { embedder_options, prompt, quantized }, .. }| { let prompt = @@ -1486,7 +1486,10 @@ impl IndexScheduler { { let embedders = self.embedders.read().unwrap(); if let Some(embedder) = embedders.get(&embedder_options) { - return Ok((name, (embedder.clone(), prompt))); + return Ok(( + name, + (embedder.clone(), prompt, quantized.unwrap_or_default()), + )); } } @@ -1500,7 +1503,7 @@ impl IndexScheduler { let mut embedders = self.embedders.write().unwrap(); embedders.insert(embedder_options, embedder.clone()); } - Ok((name, (embedder, prompt))) + Ok((name, (embedder, prompt, quantized.unwrap_or_default()))) }, ) .collect(); @@ -5197,7 +5200,7 @@ mod tests { let simple_hf_name = name.clone(); let configs = index_scheduler.embedders(configs).unwrap(); - let (hf_embedder, _) = configs.get(&simple_hf_name).unwrap(); + let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap(); let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap(); let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap(); let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap(); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 535bf2dd6..f755998a1 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -395,7 +395,10 @@ impl ErrorCode for milli::Error { | UserError::InvalidSettingsDimensions { .. } | UserError::InvalidUrl { .. } | UserError::InvalidSettingsDocumentTemplateMaxBytes { .. } - | UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + | UserError::InvalidPrompt(_) + | UserError::InvalidDisableBinaryQuantization { .. } => { + Code::InvalidSettingsEmbedders + } UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs index dd30c793e..210a52b75 100644 --- a/meilisearch/src/routes/indexes/similar.rs +++ b/meilisearch/src/routes/indexes/similar.rs @@ -102,8 +102,8 @@ async fn similar( let index = index_scheduler.index(&index_uid)?; - let (embedder_name, embedder) = - SearchKind::embedder(&index_scheduler, &index, &query.embedder, None)?; + let (embedder_name, embedder, quantized) = + SearchKind::embedder(&index_scheduler, &index, query.embedder.as_deref(), None)?; tokio::task::spawn_blocking(move || { perform_similar( @@ -111,6 +111,7 @@ async fn similar( query, embedder_name, embedder, + quantized, retrieve_vectors, index_scheduler.features(), ) diff --git a/meilisearch/src/search/mod.rs b/meilisearch/src/search/mod.rs index 9abfec3e3..66b6e56de 100644 --- a/meilisearch/src/search/mod.rs +++ b/meilisearch/src/search/mod.rs @@ -274,8 +274,8 @@ pub struct HybridQuery { #[derive(Clone)] pub enum SearchKind { KeywordOnly, - SemanticOnly { embedder_name: String, embedder: Arc }, - Hybrid { embedder_name: String, embedder: Arc, semantic_ratio: f32 }, + SemanticOnly { embedder_name: String, embedder: Arc, quantized: bool }, + Hybrid { embedder_name: String, embedder: Arc, quantized: bool, semantic_ratio: f32 }, } impl SearchKind { @@ -285,9 +285,9 @@ impl SearchKind { embedder_name: &str, vector_len: Option, ) -> Result { - let (embedder_name, embedder) = + let (embedder_name, embedder, quantized) = Self::embedder(index_scheduler, index, embedder_name, vector_len)?; - Ok(Self::SemanticOnly { embedder_name, embedder }) + Ok(Self::SemanticOnly { embedder_name, embedder, quantized }) } pub(crate) fn hybrid( @@ -297,9 +297,9 @@ impl SearchKind { semantic_ratio: f32, vector_len: Option, ) -> Result { - let (embedder_name, embedder) = + let (embedder_name, embedder, quantized) = Self::embedder(index_scheduler, index, embedder_name, vector_len)?; - Ok(Self::Hybrid { embedder_name, embedder, semantic_ratio }) + Ok(Self::Hybrid { embedder_name, embedder, quantized, semantic_ratio }) } pub(crate) fn embedder( @@ -307,16 +307,14 @@ impl SearchKind { index: &Index, embedder_name: &str, vector_len: Option, - ) -> Result<(String, Arc), ResponseError> { + ) -> Result<(String, Arc, bool), ResponseError> { let embedder_configs = index.embedding_configs(&index.read_txn()?)?; let embedders = index_scheduler.embedders(embedder_configs)?; - let embedder = embedders.get(embedder_name); - - let embedder = embedder + let (embedder, _, quantized) = embedders + .get(embedder_name) .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) - .map_err(milli::Error::from)? - .0; + .map_err(milli::Error::from)?; if let Some(vector_len) = vector_len { if vector_len != embedder.dimensions() { @@ -330,7 +328,7 @@ impl SearchKind { } } - Ok((embedder_name.to_owned(), embedder)) + Ok((embedder_name.to_owned(), embedder, quantized)) } } @@ -791,7 +789,7 @@ fn prepare_search<'t>( search.query(q); } } - SearchKind::SemanticOnly { embedder_name, embedder } => { + SearchKind::SemanticOnly { embedder_name, embedder, quantized } => { let vector = match query.vector.clone() { Some(vector) => vector, None => { @@ -805,14 +803,19 @@ fn prepare_search<'t>( } }; - search.semantic(embedder_name.clone(), embedder.clone(), Some(vector)); + search.semantic(embedder_name.clone(), embedder.clone(), *quantized, Some(vector)); } - SearchKind::Hybrid { embedder_name, embedder, semantic_ratio: _ } => { + SearchKind::Hybrid { embedder_name, embedder, quantized, semantic_ratio: _ } => { if let Some(q) = &query.q { search.query(q); } // will be embedded in hybrid search if necessary - search.semantic(embedder_name.clone(), embedder.clone(), query.vector.clone()); + search.semantic( + embedder_name.clone(), + embedder.clone(), + *quantized, + query.vector.clone(), + ); } } @@ -1441,6 +1444,7 @@ pub fn perform_similar( query: SimilarQuery, embedder_name: String, embedder: Arc, + quantized: bool, retrieve_vectors: RetrieveVectors, features: RoFeatures, ) -> Result { @@ -1469,8 +1473,16 @@ pub fn perform_similar( )); }; - let mut similar = - milli::Similar::new(internal_id, offset, limit, index, &rtxn, embedder_name, embedder); + let mut similar = milli::Similar::new( + internal_id, + offset, + limit, + index, + &rtxn, + embedder_name, + embedder, + quantized, + ); if let Some(ref filter) = query.filter { if let Some(facets) = parse_filter(filter, Code::InvalidSimilarFilter, features)? { diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 79b61b4f1..4d82d0a03 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -80,7 +80,8 @@ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", tiktoken-rs = "0.5.9" liquid = "0.26.6" rhai = { version = "1.19.0", features = ["serde", "no_module", "no_custom_syntax", "no_time", "sync"] } -arroy = "0.4.0" +# arroy = "0.4.0" +arroy = { path = "../../arroy" } rand = "0.8.5" tracing = "0.1.40" ureq = { version = "2.10.0", features = ["json"] } diff --git a/milli/src/error.rs b/milli/src/error.rs index f0e92a9ab..f09f48c2e 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -258,6 +258,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco }, #[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")] InvalidSettingsDimensions { embedder_name: String }, + #[error( + "`.embedders.{embedder_name}.binaryQuantized`: Cannot disable the binary quantization" + )] + InvalidDisableBinaryQuantization { embedder_name: String }, #[error("`.embedders.{embedder_name}.documentTemplateMaxBytes`: `documentTemplateMaxBytes` cannot be zero")] InvalidSettingsDocumentTemplateMaxBytes { embedder_name: String }, #[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")] diff --git a/milli/src/index.rs b/milli/src/index.rs index 512e911aa..63da889c4 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -21,7 +21,7 @@ use crate::heed_codec::{BEU16StrCodec, FstSetCodec, StrBEU16Codec, StrRefCodec}; use crate::order_by_map::OrderByMap; use crate::proximity::ProximityPrecision; use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; -use crate::vector::{Embedding, EmbeddingConfig}; +use crate::vector::{ArroyReader, Embedding, EmbeddingConfig}; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdMapMissingEntry, FieldIdWordCountCodec, @@ -162,7 +162,7 @@ pub struct Index { /// Maps an embedder name to its id in the arroy store. pub embedder_category_id: Database, /// Vector store based on arroyâ„¢. - pub vector_arroy: arroy::Database, + pub vector_arroy: arroy::Database, /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, @@ -1612,18 +1612,11 @@ impl Index { pub fn arroy_readers<'a>( &'a self, - rtxn: &'a RoTxn<'a>, embedder_id: u8, - ) -> impl Iterator>> + 'a { - crate::vector::arroy_db_range_for_embedder(embedder_id).map_while(move |k| { - arroy::Reader::open(rtxn, k, self.vector_arroy) - .map(Some) - .or_else(|e| match e { - arroy::Error::MissingMetadata(_) => Ok(None), - e => Err(e.into()), - }) - .transpose() - }) + quantized: bool, + ) -> impl Iterator + 'a { + crate::vector::arroy_db_range_for_embedder(embedder_id) + .map_while(move |k| Some(ArroyReader::new(self.vector_arroy, k, quantized))) } pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> { @@ -1644,32 +1637,28 @@ impl Index { docid: DocumentId, ) -> Result>> { let mut res = BTreeMap::new(); - for row in self.embedder_category_id.iter(rtxn)? { - let (embedder_name, embedder_id) = row?; + let embedding_configs = self.embedding_configs(rtxn)?; + for config in embedding_configs { + // TODO: return internal error instead + let embedder_id = self.embedder_category_id.get(rtxn, &config.name)?.unwrap(); let embedder_id = (embedder_id as u16) << 8; + let mut embeddings = Vec::new(); 'vectors: for i in 0..=u8::MAX { - let reader = arroy::Reader::open(rtxn, embedder_id | (i as u16), self.vector_arroy) - .map(Some) - .or_else(|e| match e { - arroy::Error::MissingMetadata(_) => Ok(None), - e => Err(e), - }) - .transpose(); - - let Some(reader) = reader else { - break 'vectors; + let reader = ArroyReader::new( + self.vector_arroy, + embedder_id | (i as u16), + config.config.quantized(), + ); + match reader.item_vector(rtxn, docid) { + Err(arroy::Error::MissingMetadata(_)) => break 'vectors, + Err(err) => return Err(err.into()), + Ok(None) => break 'vectors, + Ok(Some(embedding)) => embeddings.push(embedding), }; - - let embedding = reader?.item_vector(rtxn, docid)?; - if let Some(embedding) = embedding { - embeddings.push(embedding) - } else { - break 'vectors; - } } - res.insert(embedder_name.to_owned(), embeddings); + res.insert(config.name.to_owned(), embeddings); } Ok(res) } diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index e08111473..8b274804c 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -190,7 +190,7 @@ impl<'a> Search<'a> { return Ok(return_keyword_results(self.limit, self.offset, keyword_results)); }; // no embedder, no semantic search - let Some(SemanticSearch { vector, embedder_name, embedder }) = semantic else { + let Some(SemanticSearch { vector, embedder_name, embedder, quantized }) = semantic else { return Ok(return_keyword_results(self.limit, self.offset, keyword_results)); }; @@ -212,7 +212,7 @@ impl<'a> Search<'a> { }; search.semantic = - Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder }); + Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder, quantized }); // TODO: would be better to have two distinct functions at this point let vector_results = search.execute()?; diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 3057066d2..d5b05f515 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -32,6 +32,7 @@ pub struct SemanticSearch { vector: Option>, embedder_name: String, embedder: Arc, + quantized: bool, } pub struct Search<'a> { @@ -89,9 +90,10 @@ impl<'a> Search<'a> { &mut self, embedder_name: String, embedder: Arc, + quantized: bool, vector: Option>, ) -> &mut Search<'a> { - self.semantic = Some(SemanticSearch { embedder_name, embedder, vector }); + self.semantic = Some(SemanticSearch { embedder_name, embedder, quantized, vector }); self } @@ -206,7 +208,7 @@ impl<'a> Search<'a> { degraded, used_negative_operator, } = match self.semantic.as_ref() { - Some(SemanticSearch { vector: Some(vector), embedder_name, embedder }) => { + Some(SemanticSearch { vector: Some(vector), embedder_name, embedder, quantized }) => { execute_vector_search( &mut ctx, vector, @@ -219,6 +221,7 @@ impl<'a> Search<'a> { self.limit, embedder_name, embedder, + *quantized, self.time_budget.clone(), self.ranking_score_threshold, )? diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index b30306a0b..4babc7acc 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -320,6 +320,7 @@ fn get_ranking_rules_for_vector<'ctx>( target: &[f32], embedder_name: &str, embedder: &Embedder, + quantized: bool, ) -> Result>> { // query graph search @@ -347,6 +348,7 @@ fn get_ranking_rules_for_vector<'ctx>( limit_plus_offset, embedder_name, embedder, + quantized, )?; ranking_rules.push(Box::new(vector_sort)); vector = true; @@ -576,6 +578,7 @@ pub fn execute_vector_search( length: usize, embedder_name: &str, embedder: &Embedder, + quantized: bool, time_budget: TimeBudget, ranking_score_threshold: Option, ) -> Result { @@ -591,6 +594,7 @@ pub fn execute_vector_search( vector, embedder_name, embedder, + quantized, )?; let mut placeholder_search_logger = logger::DefaultSearchLogger; diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index e56f3cbbe..653aae7f1 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -16,6 +16,7 @@ pub struct VectorSort { limit: usize, distribution_shift: Option, embedder_index: u8, + quantized: bool, } impl VectorSort { @@ -26,6 +27,7 @@ impl VectorSort { limit: usize, embedder_name: &str, embedder: &Embedder, + quantized: bool, ) -> Result { let embedder_index = ctx .index @@ -41,6 +43,7 @@ impl VectorSort { limit, distribution_shift: embedder.distribution(), embedder_index, + quantized, }) } @@ -49,16 +52,15 @@ impl VectorSort { ctx: &mut SearchContext<'_>, vector_candidates: &RoaringBitmap, ) -> Result<()> { - let readers: std::result::Result, _> = - ctx.index.arroy_readers(ctx.txn, self.embedder_index).collect(); - let readers = readers?; + let readers: Vec<_> = + ctx.index.arroy_readers(self.embedder_index, self.quantized).collect(); let target = &self.target; let mut results = Vec::new(); for reader in readers.iter() { let nns_by_vector = - reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?; + reader.nns_by_vector(ctx.txn, target, self.limit, Some(vector_candidates))?; results.extend(nns_by_vector.into_iter()); } results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance)); diff --git a/milli/src/search/similar.rs b/milli/src/search/similar.rs index bf5cc323f..de329c9c3 100644 --- a/milli/src/search/similar.rs +++ b/milli/src/search/similar.rs @@ -18,6 +18,7 @@ pub struct Similar<'a> { embedder_name: String, embedder: Arc, ranking_score_threshold: Option, + quantized: bool, } impl<'a> Similar<'a> { @@ -29,6 +30,7 @@ impl<'a> Similar<'a> { rtxn: &'a heed::RoTxn<'a>, embedder_name: String, embedder: Arc, + quantized: bool, ) -> Self { Self { id, @@ -40,6 +42,7 @@ impl<'a> Similar<'a> { embedder_name, embedder, ranking_score_threshold: None, + quantized, } } @@ -67,10 +70,7 @@ impl<'a> Similar<'a> { .get(self.rtxn, &self.embedder_name)? .ok_or_else(|| crate::UserError::InvalidEmbedder(self.embedder_name.to_owned()))?; - let readers: std::result::Result, _> = - self.index.arroy_readers(self.rtxn, embedder_index).collect(); - - let readers = readers?; + let readers: Vec<_> = self.index.arroy_readers(embedder_index, self.quantized).collect(); let mut results = Vec::new(); @@ -79,7 +79,6 @@ impl<'a> Similar<'a> { self.rtxn, self.id, self.limit + self.offset + 1, - None, Some(&universe), )?; if let Some(mut nns_by_item) = nns_by_item { diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index e9b83b92c..38a4ebe8a 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -20,7 +20,7 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::settings::InnerIndexSettingsDiff; use crate::vector::error::{EmbedErrorKind, PossibleEmbeddingMistakes, UnusedVectorsDistribution}; use crate::vector::parsed_vectors::{ParsedVectorsDiff, VectorState, RESERVED_VECTORS_FIELD_NAME}; -use crate::vector::settings::{EmbedderAction, ReindexAction}; +use crate::vector::settings::ReindexAction; use crate::vector::{Embedder, Embeddings}; use crate::{try_split_array_at, DocumentId, FieldId, Result, ThreadPoolNoAbort}; @@ -208,65 +208,65 @@ pub fn extract_vector_points( if reindex_vectors { for (name, action) in settings_diff.embedding_config_updates.iter() { - match action { - EmbedderAction::WriteBackToDocuments(_) => continue, // already deleted - EmbedderAction::Reindex(action) => { - let Some((embedder_name, (embedder, prompt))) = configs.remove_entry(name) - else { - tracing::error!(embedder = name, "Requested embedder config not found"); - continue; - }; + if let Some(action) = action.reindex() { + let Some((embedder_name, (embedder, prompt, _quantized))) = + configs.remove_entry(name) + else { + tracing::error!(embedder = name, "Requested embedder config not found"); + continue; + }; - // (docid, _index) -> KvWriterDelAdd -> Vector - let manual_vectors_writer = create_writer( - indexer.chunk_compression_type, - indexer.chunk_compression_level, - tempfile::tempfile()?, - ); + // (docid, _index) -> KvWriterDelAdd -> Vector + let manual_vectors_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); - // (docid) -> (prompt) - let prompts_writer = create_writer( - indexer.chunk_compression_type, - indexer.chunk_compression_level, - tempfile::tempfile()?, - ); + // (docid) -> (prompt) + let prompts_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); - // (docid) -> () - let remove_vectors_writer = create_writer( - indexer.chunk_compression_type, - indexer.chunk_compression_level, - tempfile::tempfile()?, - ); + // (docid) -> () + let remove_vectors_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); - let action = match action { - ReindexAction::FullReindex => ExtractionAction::SettingsFullReindex, - ReindexAction::RegeneratePrompts => { - let Some((_, old_prompt)) = old_configs.get(name) else { - tracing::error!(embedder = name, "Old embedder config not found"); - continue; - }; + let action = match action { + ReindexAction::FullReindex => ExtractionAction::SettingsFullReindex, + ReindexAction::RegeneratePrompts => { + let Some((_, old_prompt, _quantized)) = old_configs.get(name) else { + tracing::error!(embedder = name, "Old embedder config not found"); + continue; + }; - ExtractionAction::SettingsRegeneratePrompts { old_prompt } - } - }; + ExtractionAction::SettingsRegeneratePrompts { old_prompt } + } + }; - extractors.push(EmbedderVectorExtractor { - embedder_name, - embedder, - prompt, - prompts_writer, - remove_vectors_writer, - manual_vectors_writer, - add_to_user_provided: RoaringBitmap::new(), - action, - }); - } + extractors.push(EmbedderVectorExtractor { + embedder_name, + embedder, + prompt, + prompts_writer, + remove_vectors_writer, + manual_vectors_writer, + add_to_user_provided: RoaringBitmap::new(), + action, + }); + } else { + continue; } } } else { // document operation - for (embedder_name, (embedder, prompt)) in configs.into_iter() { + for (embedder_name, (embedder, prompt, _quantized)) in configs.into_iter() { // (docid, _index) -> KvWriterDelAdd -> Vector let manual_vectors_writer = create_writer( indexer.chunk_compression_type, diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 6d659a7a2..29530a0bb 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -43,7 +43,7 @@ use crate::update::index_documents::parallel::ImmutableObkvs; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, }; -use crate::vector::EmbeddingConfigs; +use crate::vector::{ArroyReader, EmbeddingConfigs}; use crate::{CboRoaringBitmapCodec, Index, Object, Result}; static MERGED_DATABASE_COUNT: usize = 7; @@ -679,6 +679,24 @@ where let number_of_documents = self.index.number_of_documents(self.wtxn)?; let mut rng = rand::rngs::StdRng::seed_from_u64(42); + // If an embedder wasn't used in the typedchunk but must be binary quantized + // we should insert it in `dimension` + for (name, action) in settings_diff.embedding_config_updates.iter() { + if action.is_being_quantized && !dimension.contains_key(name.as_str()) { + let index = self.index.embedder_category_id.get(self.wtxn, name)?.ok_or( + InternalError::DatabaseMissingEntry { + db_name: "embedder_category_id", + key: None, + }, + )?; + let first_id = crate::vector::arroy_db_range_for_embedder(index).next().unwrap(); + let reader = + ArroyReader::new(self.index.vector_arroy, first_id, action.was_quantized); + let dim = reader.dimensions(self.wtxn)?; + dimension.insert(name.to_string(), dim); + } + } + for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; let vector_arroy = self.index.vector_arroy; @@ -686,13 +704,19 @@ where let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, )?; + let embedder_config = settings_diff.embedding_config_updates.get(&embedder_name); + let was_quantized = embedder_config.map_or(false, |action| action.was_quantized); + let is_quantizing = embedder_config.map_or(false, |action| action.is_being_quantized); pool.install(|| { for k in crate::vector::arroy_db_range_for_embedder(embedder_index) { - let writer = arroy::Writer::new(vector_arroy, k, dimension); - if writer.need_build(wtxn)? { - writer.build(wtxn, &mut rng, None)?; - } else if writer.is_empty(wtxn)? { + let mut writer = ArroyReader::new(vector_arroy, k, was_quantized); + if is_quantizing { + writer.quantize(wtxn, k, dimension)?; + } + if writer.need_build(wtxn, dimension)? { + writer.build(wtxn, &mut rng, dimension)?; + } else if writer.is_empty(wtxn, dimension)? { break; } } @@ -2746,6 +2770,7 @@ mod tests { response: Setting::NotSet, distribution: Setting::NotSet, headers: Setting::NotSet, + binary_quantized: Setting::NotSet, }), ); settings.set_embedder_settings(embedders); @@ -2774,7 +2799,7 @@ mod tests { std::sync::Arc::new(crate::vector::Embedder::new(embedder.embedder_options).unwrap()); let res = index .search(&rtxn) - .semantic(embedder_name, embedder, Some([0.0, 1.0, 2.0].to_vec())) + .semantic(embedder_name, embedder, false, Some([0.0, 1.0, 2.0].to_vec())) .execute() .unwrap(); assert_eq!(res.documents_ids.len(), 3); diff --git a/milli/src/update/index_documents/transform.rs b/milli/src/update/index_documents/transform.rs index 73fa3ca7b..2467c0019 100644 --- a/milli/src/update/index_documents/transform.rs +++ b/milli/src/update/index_documents/transform.rs @@ -28,7 +28,8 @@ use crate::update::index_documents::GrenadParameters; use crate::update::settings::{InnerIndexSettings, InnerIndexSettingsDiff}; use crate::update::{AvailableDocumentsIds, UpdateIndexingStep}; use crate::vector::parsed_vectors::{ExplicitVectors, VectorOrArrayOfVectors}; -use crate::vector::settings::{EmbedderAction, WriteBackToDocuments}; +use crate::vector::settings::WriteBackToDocuments; +use crate::vector::ArroyReader; use crate::{ is_faceted_by, FieldDistribution, FieldId, FieldIdMapMissingEntry, FieldsIdsMap, Index, Result, }; @@ -989,23 +990,16 @@ impl<'a, 'i> Transform<'a, 'i> { None }; - let readers: Result< - BTreeMap<&str, (Vec>, &RoaringBitmap)>, - > = settings_diff + let readers: Result, &RoaringBitmap)>> = settings_diff .embedding_config_updates .iter() .filter_map(|(name, action)| { - if let EmbedderAction::WriteBackToDocuments(WriteBackToDocuments { - embedder_id, - user_provided, - }) = action + if let Some(WriteBackToDocuments { embedder_id, user_provided }) = + action.write_back() { - let readers: Result> = - self.index.arroy_readers(wtxn, *embedder_id).collect(); - match readers { - Ok(readers) => Some(Ok((name.as_str(), (readers, user_provided)))), - Err(error) => Some(Err(error)), - } + let readers: Vec<_> = + self.index.arroy_readers(*embedder_id, action.was_quantized).collect(); + Some(Ok((name.as_str(), (readers, user_provided)))) } else { None } @@ -1104,23 +1098,14 @@ impl<'a, 'i> Transform<'a, 'i> { } } - let mut writers = Vec::new(); - // delete all vectors from the embedders that need removal for (_, (readers, _)) in readers { for reader in readers { - let dimensions = reader.dimensions(); - let arroy_index = reader.index(); - drop(reader); - let writer = arroy::Writer::new(self.index.vector_arroy, arroy_index, dimensions); - writers.push(writer); + let dimensions = reader.dimensions(wtxn)?; + reader.clear(wtxn, dimensions)?; } } - for writer in writers { - writer.clear(wtxn)?; - } - let grenad_params = GrenadParameters { chunk_compression_type: self.indexer_settings.chunk_compression_type, chunk_compression_level: self.indexer_settings.chunk_compression_level, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 9de95778b..b133f7a87 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -27,6 +27,7 @@ use crate::update::index_documents::helpers::{ as_cloneable_grenad, keep_latest_obkv, try_split_array_at, }; use crate::update::settings::InnerIndexSettingsDiff; +use crate::vector::ArroyReader; use crate::{ lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError, U8StrStrCodec, @@ -666,9 +667,13 @@ pub(crate) fn write_typed_chunk_into_index( let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, )?; + let binary_quantized = settings_diff + .embedding_config_updates + .get(&embedder_name) + .map_or(false, |conf| conf.was_quantized); // FIXME: allow customizing distance let writers: Vec<_> = crate::vector::arroy_db_range_for_embedder(embedder_index) - .map(|k| arroy::Writer::new(index.vector_arroy, k, expected_dimension)) + .map(|k| ArroyReader::new(index.vector_arroy, k, binary_quantized)) .collect(); // remove vectors for docids we want them removed @@ -679,7 +684,7 @@ pub(crate) fn write_typed_chunk_into_index( for writer in &writers { // Uses invariant: vectors are packed in the first writers. - if !writer.del_item(wtxn, docid)? { + if !writer.del_item(wtxn, expected_dimension, docid)? { break; } } @@ -711,7 +716,7 @@ pub(crate) fn write_typed_chunk_into_index( ))); } for (embedding, writer) in embeddings.iter().zip(&writers) { - writer.add_item(wtxn, docid, embedding)?; + writer.add_item(wtxn, expected_dimension, docid, embedding)?; } } @@ -734,7 +739,7 @@ pub(crate) fn write_typed_chunk_into_index( break; }; if candidate == vector { - writer.del_item(wtxn, docid)?; + writer.del_item(wtxn, expected_dimension, docid)?; deleted_index = Some(index); } } @@ -751,8 +756,13 @@ pub(crate) fn write_typed_chunk_into_index( if let Some((last_index, vector)) = last_index_with_a_vector { // unwrap: computed the index from the list of writers let writer = writers.get(last_index).unwrap(); - writer.del_item(wtxn, docid)?; - writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?; + writer.del_item(wtxn, expected_dimension, docid)?; + writers.get(deleted_index).unwrap().add_item( + wtxn, + expected_dimension, + docid, + &vector, + )?; } } } @@ -762,8 +772,8 @@ pub(crate) fn write_typed_chunk_into_index( // overflow was detected during vector extraction. for writer in &writers { - if !writer.contains_item(wtxn, docid)? { - writer.add_item(wtxn, docid, &vector)?; + if !writer.contains_item(wtxn, expected_dimension, docid)? { + writer.add_item(wtxn, expected_dimension, docid, &vector)?; break; } } diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 8702e7ea6..40aa22a81 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -425,11 +425,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { FP: Fn(UpdateIndexingStep) + Sync, FA: Fn() -> bool + Sync, { + println!("inside reindex"); // if the settings are set before any document update, we don't need to do anything, and // will set the primary key during the first document addition. if self.index.number_of_documents(self.wtxn)? == 0 { return Ok(()); } + println!("didnt early exit"); let transform = Transform::new( self.wtxn, @@ -954,7 +956,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let old_configs = self.index.embedding_configs(self.wtxn)?; let remove_all: Result> = old_configs .into_iter() - .map(|IndexEmbeddingConfig { name, config: _, user_provided }| -> Result<_> { + .map(|IndexEmbeddingConfig { name, config, user_provided }| -> Result<_> { let embedder_id = self.index.embedder_category_id.get(self.wtxn, &name)?.ok_or( crate::InternalError::DatabaseMissingEntry { @@ -964,10 +966,10 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { )?; Ok(( name, - EmbedderAction::WriteBackToDocuments(WriteBackToDocuments { - embedder_id, - user_provided, - }), + EmbedderAction::with_write_back( + WriteBackToDocuments { embedder_id, user_provided }, + config.quantized(), + ), )) }) .collect(); @@ -1004,7 +1006,8 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { match joined { // updated config EitherOrBoth::Both((name, (old, user_provided)), (_, new)) => { - let settings_diff = SettingsDiff::from_settings(old, new); + let was_quantized = old.binary_quantized.set().unwrap_or_default(); + let settings_diff = SettingsDiff::from_settings(old, new)?; match settings_diff { SettingsDiff::Remove => { tracing::debug!( @@ -1023,25 +1026,29 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { self.index.embedder_category_id.delete(self.wtxn, &name)?; embedder_actions.insert( name, - EmbedderAction::WriteBackToDocuments(WriteBackToDocuments { - embedder_id, - user_provided, - }), + EmbedderAction::with_write_back( + WriteBackToDocuments { embedder_id, user_provided }, + was_quantized, + ), ); } - SettingsDiff::Reindex { action, updated_settings } => { + SettingsDiff::Reindex { action, updated_settings, quantize } => { tracing::debug!( embedder = name, user_provided = user_provided.len(), ?action, "reindex embedder" ); - embedder_actions.insert(name.clone(), EmbedderAction::Reindex(action)); + embedder_actions.insert( + name.clone(), + EmbedderAction::with_reindex(action, was_quantized) + .with_is_being_quantized(quantize), + ); let new = validate_embedding_settings(Setting::Set(updated_settings), &name)?; updated_configs.insert(name, (new, user_provided)); } - SettingsDiff::UpdateWithoutReindex { updated_settings } => { + SettingsDiff::UpdateWithoutReindex { updated_settings, quantize } => { tracing::debug!( embedder = name, user_provided = user_provided.len(), @@ -1049,6 +1056,12 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { ); let new = validate_embedding_settings(Setting::Set(updated_settings), &name)?; + if quantize { + embedder_actions.insert( + name.clone(), + EmbedderAction::default().with_is_being_quantized(true), + ); + } updated_configs.insert(name, (new, user_provided)); } } @@ -1067,8 +1080,10 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { &mut setting, ); let setting = validate_embedding_settings(setting, &name)?; - embedder_actions - .insert(name.clone(), EmbedderAction::Reindex(ReindexAction::FullReindex)); + embedder_actions.insert( + name.clone(), + EmbedderAction::with_reindex(ReindexAction::FullReindex, false), + ); updated_configs.insert(name, (setting, RoaringBitmap::new())); } } @@ -1082,19 +1097,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let mut find_free_index = move || free_indices.find(|(_, free)| **free).map(|(index, _)| index as u8); for (name, action) in embedder_actions.iter() { - match action { - EmbedderAction::Reindex(ReindexAction::RegeneratePrompts) => { - /* cannot be a new embedder, so has to have an id already */ - } - EmbedderAction::Reindex(ReindexAction::FullReindex) => { - if self.index.embedder_category_id.get(self.wtxn, name)?.is_none() { - let id = find_free_index() - .ok_or(UserError::TooManyEmbedders(updated_configs.len()))?; - tracing::debug!(embedder = name, id, "assigning free id to new embedder"); - self.index.embedder_category_id.put(self.wtxn, name, &id)?; - } - } - EmbedderAction::WriteBackToDocuments(_) => { /* already removed */ } + if matches!(action.reindex(), Some(ReindexAction::FullReindex)) + && self.index.embedder_category_id.get(self.wtxn, name)?.is_none() + { + let id = + find_free_index().ok_or(UserError::TooManyEmbedders(updated_configs.len()))?; + tracing::debug!(embedder = name, id, "assigning free id to new embedder"); + self.index.embedder_category_id.put(self.wtxn, name, &id)?; } } let updated_configs: Vec = updated_configs @@ -1277,7 +1286,11 @@ impl InnerIndexSettingsDiff { // if the user-defined searchables changed, then we need to reindex prompts. if cache_user_defined_searchables { - for (embedder_name, (config, _)) in new_settings.embedding_configs.inner_as_ref() { + for (embedder_name, (config, _, _quantized)) in + new_settings.embedding_configs.inner_as_ref() + { + let was_quantized = + old_settings.embedding_configs.get(&embedder_name).map_or(false, |conf| conf.2); // skip embedders that don't use document templates if !config.uses_document_template() { continue; @@ -1287,16 +1300,19 @@ impl InnerIndexSettingsDiff { // this always makes the code clearer by explicitly handling the cases match embedding_config_updates.entry(embedder_name.clone()) { std::collections::btree_map::Entry::Vacant(entry) => { - entry.insert(EmbedderAction::Reindex(ReindexAction::RegeneratePrompts)); + entry.insert(EmbedderAction::with_reindex( + ReindexAction::RegeneratePrompts, + was_quantized, + )); + } + std::collections::btree_map::Entry::Occupied(entry) => { + let EmbedderAction { + was_quantized: _, + is_being_quantized: _, // We are deleting this embedder, so no point in regeneration + write_back: _, // We are already fully reindexing + reindex: _, // We are already regenerating prompts + } = entry.get(); } - std::collections::btree_map::Entry::Occupied(entry) => match entry.get() { - EmbedderAction::WriteBackToDocuments(_) => { /* we are deleting this embedder, so no point in regeneration */ - } - EmbedderAction::Reindex(ReindexAction::FullReindex) => { /* we are already fully reindexing */ - } - EmbedderAction::Reindex(ReindexAction::RegeneratePrompts) => { /* we are already regenerating prompts */ - } - }, }; } } @@ -1546,7 +1562,7 @@ fn embedders(embedding_configs: Vec) -> Result) -> Result { let max_bytes = match document_template_max_bytes.set() { Some(max_bytes) => NonZeroUsize::new(max_bytes).ok_or_else(|| { @@ -1613,6 +1630,7 @@ fn validate_prompt( response, distribution, headers, + binary_quantized: binary_quantize, })) } new => Ok(new), @@ -1638,6 +1656,7 @@ pub fn validate_embedding_settings( response, distribution, headers, + binary_quantized: binary_quantize, } = settings; if let Some(0) = dimensions.set() { @@ -1678,6 +1697,7 @@ pub fn validate_embedding_settings( response, distribution, headers, + binary_quantized: binary_quantize, })); }; match inferred_source { @@ -1779,6 +1799,7 @@ pub fn validate_embedding_settings( response, distribution, headers, + binary_quantized: binary_quantize, })) } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 23417ced2..edda59121 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -1,8 +1,12 @@ use std::collections::HashMap; use std::sync::Arc; +use arroy::distances::{Angular, BinaryQuantizedAngular}; +use arroy::ItemId; use deserr::{DeserializeError, Deserr}; +use heed::{RoTxn, RwTxn, Unspecified}; use ordered_float::OrderedFloat; +use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; use self::error::{EmbedError, NewEmbedderError}; @@ -26,6 +30,171 @@ pub type Embedding = Vec; pub const REQUEST_PARALLELISM: usize = 40; +pub struct ArroyReader { + quantized: bool, + index: u16, + database: arroy::Database, +} + +impl ArroyReader { + pub fn new(database: arroy::Database, index: u16, quantized: bool) -> Self { + Self { database, index, quantized } + } + + pub fn index(&self) -> u16 { + self.index + } + + pub fn dimensions(&self, rtxn: &RoTxn) -> Result { + if self.quantized { + Ok(arroy::Reader::open(rtxn, self.index, self.quantized_db())?.dimensions()) + } else { + Ok(arroy::Reader::open(rtxn, self.index, self.angular_db())?.dimensions()) + } + } + + pub fn quantize( + &mut self, + wtxn: &mut RwTxn, + index: u16, + dimension: usize, + ) -> Result<(), arroy::Error> { + if !self.quantized { + let writer = arroy::Writer::new(self.angular_db(), index, dimension); + writer.prepare_changing_distance::(wtxn)?; + self.quantized = true; + } + Ok(()) + } + + pub fn need_build(&self, rtxn: &RoTxn, dimension: usize) -> Result { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension).need_build(rtxn) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension).need_build(rtxn) + } + } + + pub fn build( + &self, + wtxn: &mut RwTxn, + rng: &mut R, + dimension: usize, + ) -> Result<(), arroy::Error> { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension).build(wtxn, rng, None) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension).build(wtxn, rng, None) + } + } + + pub fn add_item( + &self, + wtxn: &mut RwTxn, + dimension: usize, + item_id: arroy::ItemId, + vector: &[f32], + ) -> Result<(), arroy::Error> { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension) + .add_item(wtxn, item_id, vector) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension) + .add_item(wtxn, item_id, vector) + } + } + + pub fn del_item( + &self, + wtxn: &mut RwTxn, + dimension: usize, + item_id: arroy::ItemId, + ) -> Result { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension).del_item(wtxn, item_id) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension).del_item(wtxn, item_id) + } + } + + pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension).clear(wtxn) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension).clear(wtxn) + } + } + + pub fn is_empty(&self, rtxn: &RoTxn, dimension: usize) -> Result { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension).is_empty(rtxn) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension).is_empty(rtxn) + } + } + + pub fn contains_item( + &self, + rtxn: &RoTxn, + dimension: usize, + item: arroy::ItemId, + ) -> Result { + if self.quantized { + arroy::Writer::new(self.quantized_db(), self.index, dimension).contains_item(rtxn, item) + } else { + arroy::Writer::new(self.angular_db(), self.index, dimension).contains_item(rtxn, item) + } + } + + pub fn nns_by_item( + &self, + rtxn: &RoTxn, + item: ItemId, + limit: usize, + filter: Option<&RoaringBitmap>, + ) -> Result>, arroy::Error> { + if self.quantized { + arroy::Reader::open(rtxn, self.index, self.quantized_db())? + .nns_by_item(rtxn, item, limit, None, None, filter) + } else { + arroy::Reader::open(rtxn, self.index, self.angular_db())? + .nns_by_item(rtxn, item, limit, None, None, filter) + } + } + + pub fn nns_by_vector( + &self, + txn: &RoTxn, + item: &[f32], + limit: usize, + filter: Option<&RoaringBitmap>, + ) -> Result, arroy::Error> { + if self.quantized { + arroy::Reader::open(txn, self.index, self.quantized_db())? + .nns_by_vector(txn, item, limit, None, None, filter) + } else { + arroy::Reader::open(txn, self.index, self.angular_db())? + .nns_by_vector(txn, item, limit, None, None, filter) + } + } + + pub fn item_vector(&self, rtxn: &RoTxn, docid: u32) -> Result>, arroy::Error> { + if self.quantized { + arroy::Reader::open(rtxn, self.index, self.quantized_db())?.item_vector(rtxn, docid) + } else { + arroy::Reader::open(rtxn, self.index, self.angular_db())?.item_vector(rtxn, docid) + } + } + + fn angular_db(&self) -> arroy::Database { + self.database.remap_data_type() + } + + fn quantized_db(&self) -> arroy::Database { + self.database.remap_data_type() + } +} + /// One or multiple embeddings stored consecutively in a flat vector. pub struct Embeddings { data: Vec, @@ -124,39 +293,48 @@ pub struct EmbeddingConfig { pub embedder_options: EmbedderOptions, /// Document template pub prompt: PromptData, + /// If this embedder is binary quantized + pub quantized: Option, // TODO: add metrics and anything needed } +impl EmbeddingConfig { + pub fn quantized(&self) -> bool { + self.quantized.unwrap_or_default() + } +} + /// Map of embedder configurations. /// /// Each configuration is mapped to a name. #[derive(Clone, Default)] -pub struct EmbeddingConfigs(HashMap, Arc)>); +pub struct EmbeddingConfigs(HashMap, Arc, bool)>); impl EmbeddingConfigs { /// Create the map from its internal component.s - pub fn new(data: HashMap, Arc)>) -> Self { + pub fn new(data: HashMap, Arc, bool)>) -> Self { Self(data) } /// Get an embedder configuration and template from its name. - pub fn get(&self, name: &str) -> Option<(Arc, Arc)> { + pub fn get(&self, name: &str) -> Option<(Arc, Arc, bool)> { self.0.get(name).cloned() } - pub fn inner_as_ref(&self) -> &HashMap, Arc)> { + pub fn inner_as_ref(&self) -> &HashMap, Arc, bool)> { &self.0 } - pub fn into_inner(self) -> HashMap, Arc)> { + pub fn into_inner(self) -> HashMap, Arc, bool)> { self.0 } } impl IntoIterator for EmbeddingConfigs { - type Item = (String, (Arc, Arc)); + type Item = (String, (Arc, Arc, bool)); - type IntoIter = std::collections::hash_map::IntoIter, Arc)>; + type IntoIter = + std::collections::hash_map::IntoIter, Arc, bool)>; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index b7ae90d89..9b2c1c6e3 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -32,6 +32,9 @@ pub struct EmbeddingSettings { pub dimensions: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] + pub binary_quantized: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] pub document_template: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] @@ -85,23 +88,62 @@ pub enum ReindexAction { pub enum SettingsDiff { Remove, - Reindex { action: ReindexAction, updated_settings: EmbeddingSettings }, - UpdateWithoutReindex { updated_settings: EmbeddingSettings }, + Reindex { action: ReindexAction, updated_settings: EmbeddingSettings, quantize: bool }, + UpdateWithoutReindex { updated_settings: EmbeddingSettings, quantize: bool }, } -pub enum EmbedderAction { - WriteBackToDocuments(WriteBackToDocuments), - Reindex(ReindexAction), +#[derive(Default, Debug)] +pub struct EmbedderAction { + pub was_quantized: bool, + pub is_being_quantized: bool, + pub write_back: Option, + pub reindex: Option, } +impl EmbedderAction { + pub fn is_being_quantized(&self) -> bool { + self.is_being_quantized + } + + pub fn write_back(&self) -> Option<&WriteBackToDocuments> { + self.write_back.as_ref() + } + + pub fn reindex(&self) -> Option<&ReindexAction> { + self.reindex.as_ref() + } + + pub fn with_is_being_quantized(mut self, quantize: bool) -> Self { + self.is_being_quantized = quantize; + self + } + + pub fn with_write_back(write_back: WriteBackToDocuments, was_quantized: bool) -> Self { + Self { + was_quantized, + is_being_quantized: false, + write_back: Some(write_back), + reindex: None, + } + } + + pub fn with_reindex(reindex: ReindexAction, was_quantized: bool) -> Self { + Self { was_quantized, is_being_quantized: false, write_back: None, reindex: Some(reindex) } + } +} + +#[derive(Debug)] pub struct WriteBackToDocuments { pub embedder_id: u8, pub user_provided: RoaringBitmap, } impl SettingsDiff { - pub fn from_settings(old: EmbeddingSettings, new: Setting) -> Self { - match new { + pub fn from_settings( + old: EmbeddingSettings, + new: Setting, + ) -> Result { + let ret = match new { Setting::Set(new) => { let EmbeddingSettings { mut source, @@ -116,6 +158,7 @@ impl SettingsDiff { mut distribution, mut headers, mut document_template_max_bytes, + binary_quantized: mut binary_quantize, } = old; let EmbeddingSettings { @@ -131,8 +174,17 @@ impl SettingsDiff { distribution: new_distribution, headers: new_headers, document_template_max_bytes: new_document_template_max_bytes, + binary_quantized: new_binary_quantize, } = new; + if matches!(binary_quantize, Setting::Set(true)) + && matches!(new_binary_quantize, Setting::Set(false)) + { + return Err(UserError::InvalidDisableBinaryQuantization { + embedder_name: String::from("todo"), + }); + } + let mut reindex_action = None; // **Warning**: do not use short-circuiting || here, we want all these operations applied @@ -172,6 +224,7 @@ impl SettingsDiff { _ => {} } } + let binary_quantize_changed = binary_quantize.apply(new_binary_quantize); if url.apply(new_url) { match source { // do not regenerate on an url change in OpenAI @@ -231,16 +284,27 @@ impl SettingsDiff { distribution, headers, document_template_max_bytes, + binary_quantized: binary_quantize, }; match reindex_action { - Some(action) => Self::Reindex { action, updated_settings }, - None => Self::UpdateWithoutReindex { updated_settings }, + Some(action) => Self::Reindex { + action, + updated_settings, + quantize: binary_quantize_changed, + }, + None => Self::UpdateWithoutReindex { + updated_settings, + quantize: binary_quantize_changed, + }, } } Setting::Reset => Self::Remove, - Setting::NotSet => Self::UpdateWithoutReindex { updated_settings: old }, - } + Setting::NotSet => { + Self::UpdateWithoutReindex { updated_settings: old, quantize: false } + } + }; + Ok(ret) } } @@ -486,7 +550,7 @@ impl std::fmt::Display for EmbedderSource { impl From for EmbeddingSettings { fn from(value: EmbeddingConfig) -> Self { - let EmbeddingConfig { embedder_options, prompt } = value; + let EmbeddingConfig { embedder_options, prompt, quantized } = value; let document_template_max_bytes = Setting::Set(prompt.max_bytes.unwrap_or(default_max_bytes()).get()); match embedder_options { @@ -507,6 +571,7 @@ impl From for EmbeddingSettings { response: Setting::NotSet, headers: Setting::NotSet, distribution: Setting::some_or_not_set(distribution), + binary_quantized: Setting::some_or_not_set(quantized), }, super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions { url, @@ -527,6 +592,7 @@ impl From for EmbeddingSettings { response: Setting::NotSet, headers: Setting::NotSet, distribution: Setting::some_or_not_set(distribution), + binary_quantized: Setting::some_or_not_set(quantized), }, super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions { embedding_model, @@ -547,6 +613,7 @@ impl From for EmbeddingSettings { response: Setting::NotSet, headers: Setting::NotSet, distribution: Setting::some_or_not_set(distribution), + binary_quantized: Setting::some_or_not_set(quantized), }, super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { dimensions, @@ -564,6 +631,7 @@ impl From for EmbeddingSettings { response: Setting::NotSet, headers: Setting::NotSet, distribution: Setting::some_or_not_set(distribution), + binary_quantized: Setting::some_or_not_set(quantized), }, super::EmbedderOptions::Rest(super::rest::EmbedderOptions { api_key, @@ -586,6 +654,7 @@ impl From for EmbeddingSettings { response: Setting::Set(response), distribution: Setting::some_or_not_set(distribution), headers: Setting::Set(headers), + binary_quantized: Setting::some_or_not_set(quantized), }, } } @@ -607,8 +676,11 @@ impl From for EmbeddingConfig { response, distribution, headers, + binary_quantized, } = value; + this.quantized = binary_quantized.set(); + if let Some(source) = source.set() { match source { EmbedderSource::OpenAi => {