From 4571e512d2b306469454f7a82467282dc36d3f41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 12:19:06 +0200 Subject: [PATCH] Store the vectors in an HNSW in LMDB --- Cargo.lock | 63 ++++++++++++++++++- milli/Cargo.toml | 5 +- milli/src/dot_product.rs | 16 +++++ milli/src/index.rs | 53 ++++++++++++---- milli/src/lib.rs | 1 + milli/src/update/clear_documents.rs | 3 + milli/src/update/delete_documents.rs | 1 + .../src/update/index_documents/typed_chunk.rs | 36 +++++------ 8 files changed, 142 insertions(+), 36 deletions(-) create mode 100644 milli/src/dot_product.rs diff --git a/Cargo.lock b/Cargo.lock index 9d09fef9d..904d1c225 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1221,6 +1221,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dump" version = "1.2.0" @@ -1725,6 +1731,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.6", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1826,6 +1841,22 @@ dependencies = [ "digest", ] +[[package]] +name = "hnsw" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752" +dependencies = [ + "ahash 0.7.6", + "hashbrown 0.11.2", + "libm", + "num-traits", + "rand_core", + "serde", + "smallvec", + "space", +] + [[package]] name = "http" version = "0.2.9" @@ -1956,7 +1987,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", "serde", ] @@ -2057,7 +2088,7 @@ checksum = "37228e06c75842d1097432d94d02f37fe3ebfca9791c2e8fef6e9db17ed128c1" dependencies = [ "cedarwood", "fxhash", - "hashbrown", + "hashbrown 0.12.3", "lazy_static", "phf", "phf_codegen", @@ -2698,6 +2729,7 @@ dependencies = [ "geoutils", "grenad", "heed", + "hnsw", "insta", "itertools", "json-depth-checker", @@ -2712,6 +2744,7 @@ dependencies = [ "once_cell", "ordered-float", "rand", + "rand_pcg", "rayon", "roaring", "rstar", @@ -2721,6 +2754,7 @@ dependencies = [ "smallstr", "smallvec", "smartstring", + "space", "tempfile", "thiserror", "time", @@ -3273,6 +3307,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", + "serde", +] + [[package]] name = "rayon" version = "1.7.0" @@ -3732,6 +3776,9 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +dependencies = [ + "serde", +] [[package]] name = "smartstring" @@ -3754,6 +3801,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "space" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123" +dependencies = [ + "doc-comment", + "num-traits", +] + [[package]] name = "spin" version = "0.5.2" @@ -4405,7 +4462,7 @@ version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c531a2dc4c462b833788be2c07eef4e621d0e9edbd55bf280cc164c1c1aa043" dependencies = [ - "hashbrown", + "hashbrown 0.12.3", "once_cell", ] diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 5ff73303a..08f0c2645 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -15,7 +15,7 @@ license.workspace = true bimap = { version = "0.6.3", features = ["serde"] } bincode = "1.3.3" bstr = "1.4.0" -bytemuck = "1.13.1" +bytemuck = { version = "1.13.1", features = ["extern_crate_alloc"] } byteorder = "1.4.3" charabia = { version = "0.7.2", default-features = false } concat-arrays = "0.1.2" @@ -33,18 +33,21 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f "lmdb", "sync-read-txn", ] } +hnsw = { version = "0.11.0", features = ["serde1"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.5.10" obkv = "0.2.0" once_cell = "1.17.1" ordered-float = "3.6.0" +rand_pcg = { version = "0.3.1", features = ["serde1"] } rayon = "1.7.0" roaring = "0.10.1" rstar = { version = "0.10.0", features = ["serde"] } serde = { version = "1.0.160", features = ["derive"] } serde_json = { version = "1.0.95", features = ["preserve_order"] } slice-group-by = "0.3.0" +space = "0.17.0" smallstr = { version = "0.3.0", features = ["serde"] } smallvec = "1.10.0" smartstring = "1.0.1" diff --git a/milli/src/dot_product.rs b/milli/src/dot_product.rs new file mode 100644 index 000000000..2f5f1e474 --- /dev/null +++ b/milli/src/dot_product.rs @@ -0,0 +1,16 @@ +use serde::{Deserialize, Serialize}; +use space::Metric; + +#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] +pub struct DotProduct; + +impl Metric> for DotProduct { + type Unit = u32; + + // Following . + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum(); + debug_assert!(!dist.is_nan()); + dist.to_bits() + } +} diff --git a/milli/src/index.rs b/milli/src/index.rs index fad3f665c..4cdfb010c 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -8,10 +8,12 @@ use charabia::{Language, Script}; use heed::flags::Flags; use heed::types::*; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; +use rand_pcg::Pcg32; use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; +use crate::dot_product::DotProduct; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -26,6 +28,9 @@ use crate::{ Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, }; +/// The HNSW data-structure that we serialize, fill and search in. +pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; + pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; @@ -42,6 +47,7 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; + pub const VECTOR_HNSW_KEY: &str = "vector-hnsw"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -86,6 +92,7 @@ pub mod db_name { pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; + pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -149,6 +156,9 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, + /// Maps a vector id to the document id that have it. + pub vector_id_docid: Database, OwnedType>, + /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, ObkvCodec>, } @@ -162,7 +172,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(23); + options.max_dbs(24); unsafe { options.flag(Flags::MdbAlwaysFreePages) }; let env = options.open(path)?; @@ -198,11 +208,11 @@ impl Index { env.create_database(&mut wtxn, Some(FACET_ID_IS_NULL_DOCIDS))?; let facet_id_is_empty_docids = env.create_database(&mut wtxn, Some(FACET_ID_IS_EMPTY_DOCIDS))?; - let field_id_docid_facet_f64s = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; + let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -231,6 +241,7 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, }) } @@ -502,6 +513,26 @@ impl Index { } } + /* vector HNSW */ + + /// Writes the provided `hnsw`. + pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { + self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw) + } + + /// Delete the `hnsw`. + pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { + self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY) + } + + /// Returns the `hnsw`. + pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { + match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HNSW_KEY)? { + Some(hnsw) => Ok(Some(hnsw)), + None => Ok(None), + } + } + /* field distribution */ /// Writes the field distribution which associates every field name with @@ -1466,9 +1497,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - age 1 | - id 2 | - name 2 | + age 1 + id 2 + name 2 "### ); @@ -1486,9 +1517,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - age 1 | - id 2 | - name 2 | + age 1 + id 2 + name 2 "### ); @@ -1502,9 +1533,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - has_dog 1 | - id 2 | - name 2 | + has_dog 1 + id 2 + name 2 "### ); } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index d3ee4f08e..2e62e35ac 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,6 +10,7 @@ pub mod documents; mod asc_desc; mod criterion; +pub mod dot_product; mod error; mod external_documents_ids; pub mod facet; diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 04119c641..f4a2d43fe 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -39,6 +39,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, } = self.index; @@ -57,6 +58,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; + self.index.delete_vector_hnsw(self.wtxn)?; // We clean all the faceted documents ids. for field_id in faceted_fields { @@ -95,6 +97,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; + vector_id_docid.clear(self.wtxn)?; documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index b971768a3..73af66a95 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -240,6 +240,7 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { facet_id_exists_docids, facet_id_is_null_docids, facet_id_is_empty_docids, + vector_id_docid, documents, } = self.index; // Remove from the documents database diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 8b3477948..e2c67044c 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -4,10 +4,12 @@ use std::convert::TryInto; use std::fs::File; use std::io; +use bytemuck::allocation::pod_collect_to_vec; use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; +use hnsw::Searcher; use roaring::RoaringBitmap; use super::helpers::{ @@ -17,7 +19,7 @@ use super::{ClonableMmap, MergeFn}; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::as_cloneable_grenad; -use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result}; +use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -223,27 +225,19 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } TypedChunk::VectorPoints(vector_points) => { - // let mut rtree = index.geo_rtree(wtxn)?.unwrap_or_default(); - // let mut geo_faceted_docids = index.geo_faceted_documents_ids(wtxn)?; + let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); + let mut searcher = Searcher::new(); - // let mut cursor = geo_points.into_cursor()?; - // while let Some((key, value)) = cursor.move_on_next()? { - // // convert the key back to a u32 (4 bytes) - // let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - - // // convert the latitude and longitude back to a f64 (8 bytes) - // let (lat, tail) = helpers::try_split_array_at::(value).unwrap(); - // let (lng, _) = helpers::try_split_array_at::(tail).unwrap(); - // let point = [f64::from_ne_bytes(lat), f64::from_ne_bytes(lng)]; - // let xyz_point = lat_lng_to_xyz(&point); - - // rtree.insert(GeoPoint::new(xyz_point, (docid, point))); - // geo_faceted_docids.insert(docid); - // } - // index.put_geo_rtree(wtxn, &rtree)?; - // index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; - - todo!("index vector points") + let mut cursor = vector_points.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + // convert the key back to a u32 (4 bytes) + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + // convert the vector back to a Vec + let vector: Vec = pod_collect_to_vec(value); + let vector_id = hnsw.insert(vector, &mut searcher) as u32; + index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; + } + index.put_vector_hnsw(wtxn, &hnsw)?; } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new();