diff --git a/Cargo.lock b/Cargo.lock index cef8e9c8a..e78372421 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,7 +80,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -216,7 +216,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "anes" @@ -441,7 +441,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -549,7 +549,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -593,6 +593,15 @@ dependencies = [ "serde", ] +[[package]] +name = "bitpacking" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c1d3e2bfd8d06048a179f7b17afc3188effa10385e7b00dc65af6aae732ea92" +dependencies = [ + "crunchy", +] + [[package]] name = "bitvec" version = "1.0.1" @@ -634,7 +643,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", "syn_derive", ] @@ -684,6 +693,10 @@ name = "bumpalo" version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +dependencies = [ + "allocator-api2", + "serde", +] [[package]] name = "byte-unit" @@ -741,7 +754,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -954,8 +967,7 @@ dependencies = [ [[package]] name = "charabia" version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ff52497324e7d168505a16949ae836c14595606fab94687238d2f6c8d4c798" +source = "git+https://github.com/meilisearch/charabia?branch=mutualize-char-normalizer#f8d8308cdb8db80819be7eeed5652cc4a995cc71" dependencies = [ "aho-corasick", "csv", @@ -1052,7 +1064,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1378,7 +1390,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1400,7 +1412,7 @@ checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" dependencies = [ "darling_core 0.20.9", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1454,7 +1466,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1496,7 +1508,7 @@ dependencies = [ "darling 0.20.9", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1516,7 +1528,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core 0.20.0", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1558,7 +1570,7 @@ dependencies = [ "convert_case 0.6.0", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1622,7 +1634,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1780,7 +1792,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1800,7 +1812,7 @@ checksum = "a1ab991c1362ac86c61ab6f556cff143daa22e5a15e4e189df818b2fd19fe65b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -1908,6 +1920,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1985,7 +2003,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -2241,11 +2259,11 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "grenad" version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350d89047298d3b1b40050acd11ab76e487b854a104b760ebc5a7f375093de77" +source = "git+https://github.com/meilisearch/grenad?branch=various-improvements#58ac87d852413571102f44c5e55ca13509a3f1a0" dependencies = [ "bytemuck", "byteorder", + "either", "rayon", "tempfile", ] @@ -2336,6 +2354,18 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashbrown" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", + "serde", +] + [[package]] name = "heapless" version = "0.8.0" @@ -2578,6 +2608,7 @@ dependencies = [ "arroy 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", "big_s", "bincode", + "bumpalo", "crossbeam", "csv", "derive_builder 0.20.0", @@ -2590,7 +2621,9 @@ dependencies = [ "meili-snap", "meilisearch-auth", "meilisearch-types", + "memmap2", "page_size", + "raw-collections", "rayon", "roaring", "serde", @@ -2670,8 +2703,7 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "irg-kvariants" version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef2af7c331f2536964a32b78a7d2e0963d78b42f4a76323b16cc7d94b1ddce26" +source = "git+https://github.com/meilisearch/charabia?branch=mutualize-char-normalizer#f8d8308cdb8db80819be7eeed5652cc4a995cc71" dependencies = [ "csv", "once_cell", @@ -3264,7 +3296,7 @@ checksum = "915f6d0a2963a27cd5205c1902f32ddfe3bc035816afd268cf88c0fc0f8d287e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -3368,7 +3400,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -3507,6 +3539,7 @@ version = "1.11.0" dependencies = [ "actix-web", "anyhow", + "bumpalo", "convert_case 0.6.0", "csv", "deserr", @@ -3519,6 +3552,7 @@ dependencies = [ "meili-snap", "memmap2", "milli", + "raw-collections", "roaring", "serde", "serde-cs", @@ -3567,11 +3601,13 @@ dependencies = [ name = "milli" version = "1.11.0" dependencies = [ + "allocator-api2", "arroy 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", "big_s", "bimap", "bincode", "bstr", + "bumpalo", "bytemuck", "byteorder", "candle-core", @@ -3583,12 +3619,14 @@ dependencies = [ "csv", "deserr", "either", + "enum-iterator", "filter-parser", "flatten-serde-json", "fst", "fxhash", "geoutils", "grenad", + "hashbrown 0.15.1", "heed", "hf-hub", "indexmap", @@ -3607,11 +3645,13 @@ dependencies = [ "once_cell", "ordered-float", "rand", + "raw-collections", "rayon", "rayon-par-bridge", "rhai", "roaring", "rstar", + "rustc-hash 2.0.0", "serde", "serde_json", "slice-group-by", @@ -3620,10 +3660,12 @@ dependencies = [ "smartstring", "tempfile", "thiserror", + "thread_local", "tiktoken-rs", "time", "tokenizers", "tracing", + "uell", "ureq", "url", "uuid", @@ -3699,7 +3741,7 @@ checksum = "371717c0a5543d6a800cac822eac735aa7d2d2fbb41002e9856a4089532dbdce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -3835,7 +3877,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -3864,9 +3906,8 @@ dependencies = [ [[package]] name = "obkv" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2e27bcfe835a379d32352112f6b8dbae2d99d16a5fff42abe6e5ba5386c1e5a" +version = "0.3.0" +source = "git+https://github.com/kerollmops/obkv?branch=unsized-kvreader#ce535874008ecac554f02e0c670e6caf62134d6b" [[package]] name = "once_cell" @@ -4047,7 +4088,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -4101,7 +4142,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -4130,7 +4171,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -4247,9 +4288,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -4434,6 +4475,19 @@ dependencies = [ "rand", ] +[[package]] +name = "raw-collections" +version = "0.1.0" +source = "git+https://github.com/dureuill/raw-collections.git#15e5d7bdebc0c149b2a28b2454f307c717d07f8a" +dependencies = [ + "allocator-api2", + "bitpacking", + "bumpalo", + "hashbrown 0.15.1", + "serde", + "serde_json", +] + [[package]] name = "raw-cpuid" version = "10.7.0" @@ -4631,7 +4685,7 @@ source = "git+https://github.com/rhaiscript/rhai?rev=ef3df63121d27aacd838f366f2b dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -4681,8 +4735,7 @@ dependencies = [ [[package]] name = "roaring" version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +source = "git+https://github.com/RoaringBitmap/roaring-rs?branch=clone-iter-slice#8ff028e484fb6192a0acf5a669eaf18c30cada6e" dependencies = [ "bytemuck", "byteorder", @@ -4873,9 +4926,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.209" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] @@ -4891,23 +4944,24 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.209" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "indexmap", "itoa", + "memchr", "ryu", "serde", ] @@ -5190,7 +5244,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -5212,9 +5266,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.60" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -5230,7 +5284,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -5256,7 +5310,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -5361,14 +5415,14 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] name = "thread_local" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" dependencies = [ "cfg-if", "once_cell", @@ -5513,7 +5567,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -5645,7 +5699,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -5740,6 +5794,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" +[[package]] +name = "uell" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40de5982e28612e20330e77d81f1559b74f66caf3c7fc10b19ada4843f4b4fd7" +dependencies = [ + "bumpalo", +] + [[package]] name = "unescaper" version = "0.1.5" @@ -5991,7 +6054,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", "wasm-bindgen-shared", ] @@ -6025,7 +6088,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -6458,7 +6521,7 @@ checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", "synstructure", ] @@ -6479,7 +6542,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] @@ -6499,7 +6562,7 @@ checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", "synstructure", ] @@ -6520,7 +6583,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.87", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4e65ae83d..68e049f7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,23 +44,5 @@ opt-level = 3 [profile.dev.package.roaring] opt-level = 3 -[profile.dev.package.lindera-ipadic-builder] -opt-level = 3 -[profile.dev.package.encoding] -opt-level = 3 -[profile.dev.package.yada] -opt-level = 3 - -[profile.release.package.lindera-ipadic-builder] -opt-level = 3 -[profile.release.package.encoding] -opt-level = 3 -[profile.release.package.yada] -opt-level = 3 - -[profile.bench.package.lindera-ipadic-builder] -opt-level = 3 -[profile.bench.package.encoding] -opt-level = 3 -[profile.bench.package.yada] -opt-level = 3 +[patch.crates-io] +roaring = { git = "https://github.com/RoaringBitmap/roaring-rs", branch = "clone-iter-slice" } diff --git a/crates/index-scheduler/Cargo.toml b/crates/index-scheduler/Cargo.toml index e80311005..deaded910 100644 --- a/crates/index-scheduler/Cargo.toml +++ b/crates/index-scheduler/Cargo.toml @@ -22,6 +22,7 @@ flate2 = "1.0.30" meilisearch-auth = { path = "../meilisearch-auth" } meilisearch-types = { path = "../meilisearch-types" } page_size = "0.6.0" +raw-collections = { git = "https://github.com/dureuill/raw-collections.git", version = "0.1.0" } rayon = "1.10.0" roaring = { version = "0.10.6", features = ["serde"] } serde = { version = "1.0.204", features = ["derive"] } @@ -29,6 +30,7 @@ serde_json = { version = "1.0.120", features = ["preserve_order"] } synchronoise = "1.0.1" tempfile = "3.10.1" thiserror = "1.0.61" +memmap2 = "0.9.4" time = { version = "0.3.36", features = [ "serde-well-known", "formatting", @@ -38,6 +40,7 @@ time = { version = "0.3.36", features = [ tracing = "0.1.40" ureq = "2.10.0" uuid = { version = "1.10.0", features = ["serde", "v4"] } +bumpalo = "3.16.0" [dev-dependencies] arroy = "0.5.0" diff --git a/crates/index-scheduler/src/batch.rs b/crates/index-scheduler/src/batch.rs index 903ec1217..c06cb6b42 100644 --- a/crates/index-scheduler/src/batch.rs +++ b/crates/index-scheduler/src/batch.rs @@ -22,21 +22,26 @@ use std::ffi::OsStr; use std::fmt; use std::fs::{self, File}; use std::io::BufWriter; +use std::sync::atomic::{self, AtomicU64}; +use std::time::Duration; +use bumpalo::collections::CollectIn; +use bumpalo::Bump; use dump::IndexMetadata; use meilisearch_types::error::Code; use meilisearch_types::heed::{RoTxn, RwTxn}; -use meilisearch_types::milli::documents::{obkv_to_object, DocumentsBatchReader}; +use meilisearch_types::milli::documents::{obkv_to_object, DocumentsBatchReader, PrimaryKey}; use meilisearch_types::milli::heed::CompactionOption; -use meilisearch_types::milli::update::{ - IndexDocumentsConfig, IndexDocumentsMethod, IndexerConfig, Settings as MilliSettings, -}; +use meilisearch_types::milli::update::new::indexer::{self, UpdateByFunction}; +use meilisearch_types::milli::update::{IndexDocumentsMethod, Settings as MilliSettings}; use meilisearch_types::milli::vector::parsed_vectors::{ ExplicitVectors, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME, }; -use meilisearch_types::milli::{self, Filter, Object}; +use meilisearch_types::milli::{self, Filter, ThreadPoolNoAbortBuilder}; use meilisearch_types::settings::{apply_settings_to_builder, Settings, Unchecked}; -use meilisearch_types::tasks::{Details, IndexSwap, Kind, KindWithContent, Status, Task}; +use meilisearch_types::tasks::{ + Details, IndexSwap, Kind, KindWithContent, Status, Task, TaskProgress, +}; use meilisearch_types::{compression, Index, VERSION_FILE_NAME}; use roaring::RoaringBitmap; use time::macros::format_description; @@ -45,7 +50,7 @@ use uuid::Uuid; use crate::autobatcher::{self, BatchKind}; use crate::utils::{self, swap_index_uid_in_task}; -use crate::{Error, IndexScheduler, MustStopProcessing, ProcessingTasks, Result, TaskId}; +use crate::{Error, IndexScheduler, ProcessingTasks, Result, TaskId}; /// Represents a combination of tasks that can all be processed at the same time. /// @@ -526,7 +531,7 @@ impl IndexScheduler { if let Some(task_id) = to_cancel.max() { // We retrieve the tasks that were processing before this tasks cancelation started. // We must *not* reset the processing tasks before calling this method. - let ProcessingTasks { started_at, processing } = + let ProcessingTasks { started_at, processing, progress: _ } = &*self.processing_tasks.read().unwrap(); return Ok(Some(Batch::TaskCancelation { task: self.get_task(rtxn, task_id)?.ok_or(Error::CorruptedTaskQueue)?, @@ -875,10 +880,8 @@ impl IndexScheduler { while let Some(doc) = cursor.next_document().map_err(milli::Error::from)? { - dump_content_file.push_document(&obkv_to_object( - &doc, - &documents_batch_index, - )?)?; + dump_content_file + .push_document(&obkv_to_object(doc, &documents_batch_index)?)?; } dump_content_file.flush()?; } @@ -1218,6 +1221,44 @@ impl IndexScheduler { index: &'i Index, operation: IndexOperation, ) -> Result> { + let indexer_alloc = Bump::new(); + + let started_processing_at = std::time::Instant::now(); + let secs_since_started_processing_at = AtomicU64::new(0); + const PRINT_SECS_DELTA: u64 = 1; + + let processing_tasks = self.processing_tasks.clone(); + let must_stop_processing = self.must_stop_processing.clone(); + let send_progress = |progress| { + let now = std::time::Instant::now(); + let elapsed = secs_since_started_processing_at.load(atomic::Ordering::Relaxed); + let previous = started_processing_at + Duration::from_secs(elapsed); + let elapsed = now - previous; + + if elapsed.as_secs() < PRINT_SECS_DELTA { + return; + } + + secs_since_started_processing_at + .store((now - started_processing_at).as_secs(), atomic::Ordering::Relaxed); + + let TaskProgress { + current_step, + finished_steps, + total_steps, + finished_documents, + total_documents, + } = processing_tasks.write().unwrap().update_progress(progress); + + tracing::info!( + current_step, + finished_steps, + total_steps, + finished_documents, + total_documents + ); + }; + match operation { IndexOperation::DocumentClear { mut tasks, .. } => { let count = milli::update::ClearDocuments::new(index_wtxn, index).execute()?; @@ -1247,155 +1288,154 @@ impl IndexScheduler { operations, mut tasks, } => { - let started_processing_at = std::time::Instant::now(); - let mut primary_key_has_been_set = false; - let must_stop_processing = self.must_stop_processing.clone(); - let indexer_config = self.index_mapper.indexer_config(); - - if let Some(primary_key) = primary_key { - match index.primary_key(index_wtxn)? { - // if a primary key was set AND had already been defined in the index - // but to a different value, we can make the whole batch fail. - Some(pk) => { - if primary_key != pk { - return Err(milli::Error::from( - milli::UserError::PrimaryKeyCannotBeChanged(pk.to_string()), - ) - .into()); - } - } - // if the primary key was set and there was no primary key set for this index - // we set it to the received value before starting the indexing process. - None => { - let mut builder = - milli::update::Settings::new(index_wtxn, index, indexer_config); - builder.set_primary_key(primary_key); - builder.execute( - |indexing_step| tracing::debug!(update = ?indexing_step), - || must_stop_processing.clone().get(), - )?; - primary_key_has_been_set = true; + // TODO: at some point, for better efficiency we might want to reuse the bumpalo for successive batches. + // this is made difficult by the fact we're doing private clones of the index scheduler and sending it + // to a fresh thread. + let mut content_files = Vec::new(); + for operation in &operations { + if let DocumentOperation::Add(content_uuid) = operation { + let content_file = self.file_store.get_update(*content_uuid)?; + let mmap = unsafe { memmap2::Mmap::map(&content_file)? }; + if !mmap.is_empty() { + content_files.push(mmap); } } } - let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; + let rtxn = index.read_txn()?; + let db_fields_ids_map = index.fields_ids_map(&rtxn)?; + let mut new_fields_ids_map = db_fields_ids_map.clone(); - let embedder_configs = index.embedding_configs(index_wtxn)?; - // TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense) - let embedders = self.embedders(embedder_configs)?; - - let mut builder = milli::update::IndexDocuments::new( - index_wtxn, - index, - indexer_config, - config, - |indexing_step| tracing::trace!(?indexing_step, "Update"), - || must_stop_processing.get(), - )?; - - for (operation, task) in operations.into_iter().zip(tasks.iter_mut()) { + let mut content_files_iter = content_files.iter(); + let mut indexer = indexer::DocumentOperation::new(method); + let embedders = index.embedding_configs(index_wtxn)?; + let embedders = self.embedders(embedders)?; + for operation in operations { match operation { - DocumentOperation::Add(content_uuid) => { - let content_file = self.file_store.get_update(content_uuid)?; - let reader = DocumentsBatchReader::from_reader(content_file) - .map_err(milli::Error::from)?; - let (new_builder, user_result) = builder.add_documents(reader)?; - builder = new_builder; - - builder = builder.with_embedders(embedders.clone()); - - let received_documents = - if let Some(Details::DocumentAdditionOrUpdate { - received_documents, - .. - }) = task.details - { - received_documents - } else { - // In the case of a `documentAdditionOrUpdate` the details MUST be set - unreachable!(); - }; - - match user_result { - Ok(count) => { - task.status = Status::Succeeded; - task.details = Some(Details::DocumentAdditionOrUpdate { - received_documents, - indexed_documents: Some(count), - }) - } - Err(e) => { - task.status = Status::Failed; - task.details = Some(Details::DocumentAdditionOrUpdate { - received_documents, - indexed_documents: Some(0), - }); - task.error = Some(milli::Error::from(e).into()); - } - } + DocumentOperation::Add(_content_uuid) => { + let mmap = content_files_iter.next().unwrap(); + indexer.add_documents(mmap)?; + // builder = builder.with_embedders(embedders.clone()); } DocumentOperation::Delete(document_ids) => { - let (new_builder, user_result) = - builder.remove_documents(document_ids)?; - builder = new_builder; - // Uses Invariant: remove documents actually always returns Ok for the inner result - let count = user_result.unwrap(); - let provided_ids = - if let Some(Details::DocumentDeletion { provided_ids, .. }) = - task.details - { - provided_ids - } else { - // In the case of a `documentAdditionOrUpdate` the details MUST be set - unreachable!(); - }; - - task.status = Status::Succeeded; - task.details = Some(Details::DocumentDeletion { - provided_ids, - deleted_documents: Some(count), - }); + let document_ids: bumpalo::collections::vec::Vec<_> = document_ids + .iter() + .map(|s| &*indexer_alloc.alloc_str(s)) + .collect_in(&indexer_alloc); + indexer.delete_documents(document_ids.into_bump_slice()); } } } - if !tasks.iter().all(|res| res.error.is_some()) { - let addition = builder.execute()?; - tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); - } else if primary_key_has_been_set { - // Everything failed but we've set a primary key. - // We need to remove it. - let mut builder = - milli::update::Settings::new(index_wtxn, index, indexer_config); - builder.reset_primary_key(); - builder.execute( - |indexing_step| tracing::trace!(update = ?indexing_step), - || must_stop_processing.clone().get(), - )?; + let local_pool; + let indexer_config = self.index_mapper.indexer_config(); + let pool = match &indexer_config.thread_pool { + Some(pool) => pool, + None => { + local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap(); + &local_pool + } + }; + + let (document_changes, operation_stats, primary_key) = indexer.into_changes( + &indexer_alloc, + index, + &rtxn, + primary_key.as_deref(), + &mut new_fields_ids_map, + )?; + + let mut addition = 0; + for (stats, task) in operation_stats.into_iter().zip(&mut tasks) { + addition += stats.document_count; + match stats.error { + Some(error) => { + task.status = Status::Failed; + task.error = Some(milli::Error::UserError(error).into()); + } + None => task.status = Status::Succeeded, + } + + task.details = match task.details { + Some(Details::DocumentAdditionOrUpdate { received_documents, .. }) => { + Some(Details::DocumentAdditionOrUpdate { + received_documents, + indexed_documents: Some(stats.document_count), + }) + } + Some(Details::DocumentDeletion { provided_ids, .. }) => { + Some(Details::DocumentDeletion { + provided_ids, + deleted_documents: Some(stats.document_count), + }) + } + _ => { + // In the case of a `documentAdditionOrUpdate` or `DocumentDeletion` + // the details MUST be set to either addition or deletion + unreachable!(); + } + } } + if tasks.iter().any(|res| res.error.is_none()) { + pool.install(|| { + indexer::index( + index_wtxn, + index, + indexer_config.grenad_parameters(), + &db_fields_ids_map, + new_fields_ids_map, + primary_key, + &document_changes, + embedders, + &|| must_stop_processing.get(), + &send_progress, + ) + }) + .unwrap()?; + + tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); + } + // else if primary_key_has_been_set { + // // Everything failed but we've set a primary key. + // // We need to remove it. + // let mut builder = + // milli::update::Settings::new(index_wtxn, index, indexer_config); + // builder.reset_primary_key(); + // builder.execute( + // |indexing_step| tracing::trace!(update = ?indexing_step), + // || must_stop_processing.clone().get(), + // )?; + // } + Ok(tasks) } IndexOperation::DocumentEdition { mut task, .. } => { - let (filter, context, function) = - if let KindWithContent::DocumentEdition { - filter_expr, context, function, .. - } = &task.kind - { - (filter_expr, context, function) - } else { - unreachable!() - }; - let result_count = edit_documents_by_function( - index_wtxn, - filter, - context.clone(), + let (filter, code) = if let KindWithContent::DocumentEdition { + filter_expr, + context: _, function, - self.index_mapper.indexer_config(), - self.must_stop_processing.clone(), - index, - ); + .. + } = &task.kind + { + (filter_expr, function) + } else { + unreachable!() + }; + + let candidates = match filter.as_ref().map(Filter::from_json) { + Some(Ok(Some(filter))) => { + filter.evaluate(index_wtxn, index).map_err(|err| match err { + milli::Error::UserError(milli::UserError::InvalidFilter(_)) => { + Error::from(err).with_custom_error_code(Code::InvalidDocumentFilter) + } + e => e.into(), + })? + } + None | Some(Ok(None)) => index.documents_ids(index_wtxn)?, + Some(Err(e)) => return Err(e.into()), + }; + let (original_filter, context, function) = if let Some(Details::DocumentEdition { original_filter, context, @@ -1409,6 +1449,68 @@ impl IndexScheduler { unreachable!(); }; + if candidates.is_empty() { + task.status = Status::Succeeded; + task.details = Some(Details::DocumentEdition { + original_filter, + context, + function, + deleted_documents: Some(0), + edited_documents: Some(0), + }); + + return Ok(vec![task]); + } + + let rtxn = index.read_txn()?; + let db_fields_ids_map = index.fields_ids_map(&rtxn)?; + let mut new_fields_ids_map = db_fields_ids_map.clone(); + // candidates not empty => index not empty => a primary key is set + let primary_key = index.primary_key(&rtxn)?.unwrap(); + + let primary_key = PrimaryKey::new_or_insert(primary_key, &mut new_fields_ids_map) + .map_err(milli::Error::from)?; + + let result_count = Ok((candidates.len(), candidates.len())) as Result<_>; + + if task.error.is_none() { + let local_pool; + let indexer_config = self.index_mapper.indexer_config(); + let pool = match &indexer_config.thread_pool { + Some(pool) => pool, + None => { + local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap(); + &local_pool + } + }; + + pool.install(|| { + let indexer = + UpdateByFunction::new(candidates, context.clone(), code.clone()); + let document_changes = indexer.into_changes(&primary_key)?; + let embedders = index.embedding_configs(index_wtxn)?; + let embedders = self.embedders(embedders)?; + + indexer::index( + index_wtxn, + index, + indexer_config.grenad_parameters(), + &db_fields_ids_map, + new_fields_ids_map, + None, // cannot change primary key in DocumentEdition + &document_changes, + embedders, + &|| must_stop_processing.get(), + &send_progress, + )?; + + Result::Ok(()) + }) + .unwrap()?; + + // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); + } + match result_count { Ok((deleted_documents, edited_documents)) => { task.status = Status::Succeeded; @@ -1509,26 +1611,55 @@ impl IndexScheduler { } } - let config = IndexDocumentsConfig { - update_method: IndexDocumentsMethod::ReplaceDocuments, - ..Default::default() - }; + if to_delete.is_empty() { + return Ok(tasks); + } - let must_stop_processing = self.must_stop_processing.clone(); - let mut builder = milli::update::IndexDocuments::new( - index_wtxn, - index, - self.index_mapper.indexer_config(), - config, - |indexing_step| tracing::debug!(update = ?indexing_step), - || must_stop_processing.get(), - )?; + let rtxn = index.read_txn()?; + let db_fields_ids_map = index.fields_ids_map(&rtxn)?; + let mut new_fields_ids_map = db_fields_ids_map.clone(); - let (new_builder, _count) = - builder.remove_documents_from_db_no_batch(&to_delete)?; - builder = new_builder; + // to_delete not empty => index not empty => primary key set + let primary_key = index.primary_key(&rtxn)?.unwrap(); - let _ = builder.execute()?; + let primary_key = PrimaryKey::new_or_insert(primary_key, &mut new_fields_ids_map) + .map_err(milli::Error::from)?; + + if !tasks.iter().all(|res| res.error.is_some()) { + let local_pool; + let indexer_config = self.index_mapper.indexer_config(); + let pool = match &indexer_config.thread_pool { + Some(pool) => pool, + None => { + local_pool = ThreadPoolNoAbortBuilder::new().build().unwrap(); + &local_pool + } + }; + + let mut indexer = indexer::DocumentDeletion::new(); + indexer.delete_documents_by_docids(to_delete); + let document_changes = indexer.into_changes(&indexer_alloc, primary_key); + let embedders = index.embedding_configs(index_wtxn)?; + let embedders = self.embedders(embedders)?; + + pool.install(|| { + indexer::index( + index_wtxn, + index, + indexer_config.grenad_parameters(), + &db_fields_ids_map, + new_fields_ids_map, + None, // document deletion never changes primary key + &document_changes, + embedders, + &|| must_stop_processing.get(), + &send_progress, + ) + }) + .unwrap()?; + + // tracing::info!(indexing_result = ?addition, processed_in = ?started_processing_at.elapsed(), "document indexing done"); + } Ok(tasks) } @@ -1546,7 +1677,6 @@ impl IndexScheduler { task.status = Status::Succeeded; } - let must_stop_processing = self.must_stop_processing.clone(); builder.execute( |indexing_step| tracing::debug!(update = ?indexing_step), || must_stop_processing.get(), @@ -1733,44 +1863,3 @@ impl IndexScheduler { Ok(content_files_to_delete) } } - -fn edit_documents_by_function<'a>( - wtxn: &mut RwTxn<'a>, - filter: &Option, - context: Option, - code: &str, - indexer_config: &IndexerConfig, - must_stop_processing: MustStopProcessing, - index: &'a Index, -) -> Result<(u64, u64)> { - let candidates = match filter.as_ref().map(Filter::from_json) { - Some(Ok(Some(filter))) => filter.evaluate(wtxn, index).map_err(|err| match err { - milli::Error::UserError(milli::UserError::InvalidFilter(_)) => { - Error::from(err).with_custom_error_code(Code::InvalidDocumentFilter) - } - e => e.into(), - })?, - None | Some(Ok(None)) => index.documents_ids(wtxn)?, - Some(Err(e)) => return Err(e.into()), - }; - - let config = IndexDocumentsConfig { - update_method: IndexDocumentsMethod::ReplaceDocuments, - ..Default::default() - }; - - let mut builder = milli::update::IndexDocuments::new( - wtxn, - index, - indexer_config, - config, - |indexing_step| tracing::debug!(update = ?indexing_step), - || must_stop_processing.get(), - )?; - - let (new_builder, count) = builder.edit_documents(&candidates, context, code)?; - builder = new_builder; - - let _ = builder.execute()?; - Ok(count.unwrap()) -} diff --git a/crates/index-scheduler/src/lib.rs b/crates/index-scheduler/src/lib.rs index 336a43b1b..4eadb8baf 100644 --- a/crates/index-scheduler/src/lib.rs +++ b/crates/index-scheduler/src/lib.rs @@ -55,11 +55,12 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::index::IndexEmbeddingConfig; +use meilisearch_types::milli::update::new::indexer::document_changes::Progress; use meilisearch_types::milli::update::IndexerConfig; use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::task_view::TaskView; -use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; +use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task, TaskProgress}; use rayon::current_num_threads; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use roaring::RoaringBitmap; @@ -164,12 +165,18 @@ struct ProcessingTasks { started_at: OffsetDateTime, /// The list of tasks ids that are currently running. processing: RoaringBitmap, + /// The progress on processing tasks + progress: Option, } impl ProcessingTasks { /// Creates an empty `ProcessingAt` struct. fn new() -> ProcessingTasks { - ProcessingTasks { started_at: OffsetDateTime::now_utc(), processing: RoaringBitmap::new() } + ProcessingTasks { + started_at: OffsetDateTime::now_utc(), + processing: RoaringBitmap::new(), + progress: None, + } } /// Stores the currently processing tasks, and the date time at which it started. @@ -178,8 +185,13 @@ impl ProcessingTasks { self.processing = processing; } + fn update_progress(&mut self, progress: Progress) -> TaskProgress { + self.progress.get_or_insert_with(TaskProgress::default).update(progress) + } + /// Set the processing tasks to an empty list fn stop_processing(&mut self) -> RoaringBitmap { + self.progress = None; std::mem::take(&mut self.processing) } @@ -971,9 +983,11 @@ impl IndexScheduler { let tasks = self.get_existing_tasks(&rtxn, tasks.take(query.limit.unwrap_or(u32::MAX) as usize))?; - let ProcessingTasks { started_at, processing, .. } = + let ProcessingTasks { started_at, processing, progress, .. } = self.processing_tasks.read().map_err(|_| Error::CorruptedTaskQueue)?.clone(); + let _ = progress; + let ret = tasks.into_iter(); if processing.is_empty() { Ok((ret.collect(), total)) @@ -4299,11 +4313,11 @@ mod tests { snapshot!(snapshot_index_scheduler(&index_scheduler), name: "only_first_task_succeed"); // The second batch should fail. - handle.advance_one_failed_batch(); + handle.advance_one_successful_batch(); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_task_fails"); // The second batch should fail. - handle.advance_one_failed_batch(); + handle.advance_one_successful_batch(); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "third_task_fails"); // Is the primary key still what we expect? @@ -4364,7 +4378,7 @@ mod tests { snapshot!(snapshot_index_scheduler(&index_scheduler), name: "only_first_task_succeed"); // The second batch should fail and contains two tasks. - handle.advance_one_failed_batch(); + handle.advance_one_successful_batch(); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_and_third_tasks_fails"); // Is the primary key still what we expect? @@ -4443,7 +4457,8 @@ mod tests { snapshot!(primary_key, @"id"); // We're trying to `bork` again, but now there is already a primary key set for this index. - handle.advance_one_failed_batch(); + // NOTE: it's marked as successful because the batch didn't fails, it's the individual tasks that failed. + handle.advance_one_successful_batch(); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "fourth_task_fails"); // Finally the last task should succeed since its primary key is the same as the valid one. @@ -4603,7 +4618,7 @@ mod tests { snapshot!(primary_key.is_none(), @"false"); // The second batch should contains only one task that fails because it tries to update the primary key to `bork`. - handle.advance_one_failed_batch(); + handle.advance_one_successful_batch(); snapshot!(snapshot_index_scheduler(&index_scheduler), name: "second_task_fails"); // The third batch should succeed and only contains one task. @@ -5216,9 +5231,10 @@ mod tests { let configs = index_scheduler.embedders(configs).unwrap(); let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap(); - let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap(); - let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap(); - let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap(); + let beagle_embed = + hf_embedder.embed_one(S("Intel the beagle best doggo"), None).unwrap(); + let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo"), None).unwrap(); + let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo"), None).unwrap(); (fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed) }; diff --git a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir succeeds.snap b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir succeeds.snap index 41cfcfdab..fed7be6e9 100644 --- a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir succeeds.snap +++ b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir succeeds.snap @@ -1,5 +1,5 @@ --- -source: index-scheduler/src/lib.rs +source: crates/index-scheduler/src/lib.rs --- ### Autobatching Enabled = true ### Processing Tasks: @@ -22,7 +22,7 @@ succeeded [0,1,2,] doggos [0,1,2,] ---------------------------------------------------------------------- ### Index Mapper: -doggos: { number_of_documents: 1, field_distribution: {"_vectors": 1, "breed": 1, "doggo": 1, "id": 1} } +doggos: { number_of_documents: 1, field_distribution: {"breed": 1, "doggo": 1, "id": 1} } ---------------------------------------------------------------------- ### Canceled By: diff --git a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir.snap b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir.snap index e6d0d8232..b8b204935 100644 --- a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir.snap +++ b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/Intel to kefir.snap @@ -1,5 +1,5 @@ --- -source: index-scheduler/src/lib.rs +source: crates/index-scheduler/src/lib.rs --- ### Autobatching Enabled = true ### Processing Tasks: @@ -22,7 +22,7 @@ succeeded [0,1,] doggos [0,1,2,] ---------------------------------------------------------------------- ### Index Mapper: -doggos: { number_of_documents: 1, field_distribution: {"_vectors": 1, "breed": 1, "doggo": 1, "id": 1} } +doggos: { number_of_documents: 1, field_distribution: {"breed": 1, "doggo": 1, "id": 1} } ---------------------------------------------------------------------- ### Canceled By: diff --git a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/adding Intel succeeds.snap b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/adding Intel succeeds.snap index bd4cf0c09..cead3f781 100644 --- a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/adding Intel succeeds.snap +++ b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors/adding Intel succeeds.snap @@ -1,5 +1,5 @@ --- -source: index-scheduler/src/lib.rs +source: crates/index-scheduler/src/lib.rs --- ### Autobatching Enabled = true ### Processing Tasks: @@ -21,7 +21,7 @@ succeeded [0,1,] doggos [0,1,] ---------------------------------------------------------------------- ### Index Mapper: -doggos: { number_of_documents: 1, field_distribution: {"_vectors": 1, "breed": 1, "doggo": 1, "id": 1} } +doggos: { number_of_documents: 1, field_distribution: {"breed": 1, "doggo": 1, "id": 1} } ---------------------------------------------------------------------- ### Canceled By: diff --git a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors_first_and_embedder_later/documents after initial push.snap b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors_first_and_embedder_later/documents after initial push.snap index d2473d00a..e06d09464 100644 --- a/crates/index-scheduler/src/snapshots/lib.rs/import_vectors_first_and_embedder_later/documents after initial push.snap +++ b/crates/index-scheduler/src/snapshots/lib.rs/import_vectors_first_and_embedder_later/documents after initial push.snap @@ -1,4 +1,4 @@ --- -source: index-scheduler/src/lib.rs +source: crates/index-scheduler/src/lib.rs --- -[{"id":0,"doggo":"kefir"},{"id":1,"doggo":"intel","_vectors":{"my_doggo_embedder":[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0],"unknown embedder":[1.0,2.0,3.0]}},{"id":2,"doggo":"max","_vectors":{"my_doggo_embedder":{"embeddings":[2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0],"regenerate":false},"unknown embedder":[4.0,5.0]}},{"id":3,"doggo":"marcel","_vectors":{"my_doggo_embedder":{"embeddings":[3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0],"regenerate":true}}},{"id":4,"doggo":"sora","_vectors":{"my_doggo_embedder":{"embeddings":null,"regenerate":true}}}] +[{"id":0,"doggo":"kefir"},{"id":1,"doggo":"intel","_vectors":{"my_doggo_embedder":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"unknown embedder":[1,2,3]}},{"id":2,"doggo":"max","_vectors":{"my_doggo_embedder":{"regenerate":false,"embeddings":[2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2]},"unknown embedder":[4,5]}},{"id":3,"doggo":"marcel","_vectors":{"my_doggo_embedder":{"regenerate":true,"embeddings":[3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3]}}},{"id":4,"doggo":"sora","_vectors":{"my_doggo_embedder":{"regenerate":true,"embeddings":null}}}] diff --git a/crates/meili-snap/Cargo.toml b/crates/meili-snap/Cargo.toml index e86feabd9..6c68e563c 100644 --- a/crates/meili-snap/Cargo.toml +++ b/crates/meili-snap/Cargo.toml @@ -11,6 +11,6 @@ edition.workspace = true license.workspace = true [dependencies] -insta = { version = "^1.39.0", features = ["json", "redactions"] } +insta = { version = "=1.39.0", features = ["json", "redactions"] } md5 = "0.7.0" once_cell = "1.19" diff --git a/crates/meilisearch-types/Cargo.toml b/crates/meilisearch-types/Cargo.toml index 0dae024f2..3bd368e7c 100644 --- a/crates/meilisearch-types/Cargo.toml +++ b/crates/meilisearch-types/Cargo.toml @@ -13,6 +13,7 @@ license.workspace = true [dependencies] actix-web = { version = "4.8.0", default-features = false } anyhow = "1.0.86" +bumpalo = "3.16.0" convert_case = "0.6.0" csv = "1.3.0" deserr = { version = "0.6.2", features = ["actix-web"] } @@ -23,6 +24,7 @@ flate2 = "1.0.30" fst = "0.4.7" memmap2 = "0.9.4" milli = { path = "../milli" } +raw-collections = { git = "https://github.com/dureuill/raw-collections.git", version = "0.1.0" } roaring = { version = "0.10.6", features = ["serde"] } serde = { version = "1.0.204", features = ["derive"] } serde-cs = "0.2.4" @@ -70,4 +72,3 @@ swedish-recomposition = ["milli/swedish-recomposition"] german = ["milli/german"] # allow turkish normalization turkish = ["milli/turkish"] - diff --git a/crates/meilisearch-types/src/document_formats.rs b/crates/meilisearch-types/src/document_formats.rs index 50dc5bad4..311fcccf4 100644 --- a/crates/meilisearch-types/src/document_formats.rs +++ b/crates/meilisearch-types/src/document_formats.rs @@ -1,20 +1,25 @@ use std::fmt::{self, Debug, Display}; use std::fs::File; -use std::io::{self, BufWriter, Write}; +use std::io::{self, BufWriter}; use std::marker::PhantomData; -use memmap2::MmapOptions; -use milli::documents::{DocumentsBatchBuilder, Error}; +use bumpalo::Bump; +use memmap2::Mmap; +use milli::documents::Error; +use milli::update::new::TopLevelMap; use milli::Object; +use raw_collections::RawMap; use serde::de::{SeqAccess, Visitor}; use serde::{Deserialize, Deserializer}; use serde_json::error::Category; +use serde_json::value::RawValue; +use serde_json::{to_writer, Map, Value}; use crate::error::{Code, ErrorCode}; type Result = std::result::Result; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum PayloadType { Ndjson, Json, @@ -88,6 +93,26 @@ impl From<(PayloadType, Error)> for DocumentFormatError { } } +impl From<(PayloadType, serde_json::Error)> for DocumentFormatError { + fn from((ty, error): (PayloadType, serde_json::Error)) -> Self { + if error.classify() == Category::Data { + Self::Io(error.into()) + } else { + Self::MalformedPayload(Error::Json(error), ty) + } + } +} + +impl From<(PayloadType, csv::Error)> for DocumentFormatError { + fn from((ty, error): (PayloadType, csv::Error)) -> Self { + if error.is_io_error() { + Self::Io(error.into()) + } else { + Self::MalformedPayload(Error::Csv(error), ty) + } + } +} + impl From for DocumentFormatError { fn from(error: io::Error) -> Self { Self::Io(error) @@ -103,67 +128,146 @@ impl ErrorCode for DocumentFormatError { } } -/// Reads CSV from input and write an obkv batch to writer. -pub fn read_csv(file: &File, writer: impl Write, delimiter: u8) -> Result { - let mut builder = DocumentsBatchBuilder::new(BufWriter::new(writer)); - let mmap = unsafe { MmapOptions::new().map(file)? }; - let csv = csv::ReaderBuilder::new().delimiter(delimiter).from_reader(mmap.as_ref()); - builder.append_csv(csv).map_err(|e| (PayloadType::Csv { delimiter }, e))?; - - let count = builder.documents_count(); - let _ = builder.into_inner().map_err(DocumentFormatError::Io)?; - - Ok(count as u64) +// TODO remove that from the place I've borrowed it +#[derive(Debug)] +enum AllowedType { + String, + Boolean, + Number, } -/// Reads JSON from temporary file and write an obkv batch to writer. -pub fn read_json(file: &File, writer: impl Write) -> Result { - let mut builder = DocumentsBatchBuilder::new(BufWriter::new(writer)); - let mmap = unsafe { MmapOptions::new().map(file)? }; - let mut deserializer = serde_json::Deserializer::from_slice(&mmap); +fn parse_csv_header(header: &str) -> (&str, AllowedType) { + // if there are several separators we only split on the last one. + match header.rsplit_once(':') { + Some((field_name, field_type)) => match field_type { + "string" => (field_name, AllowedType::String), + "boolean" => (field_name, AllowedType::Boolean), + "number" => (field_name, AllowedType::Number), + // if the pattern isn't recognized, we keep the whole field. + _otherwise => (header, AllowedType::String), + }, + None => (header, AllowedType::String), + } +} - match array_each(&mut deserializer, |obj| builder.append_json_object(&obj)) { +/// Reads CSV from file and write it in NDJSON in a file checking it along the way. +pub fn read_csv(input: &File, output: impl io::Write, delimiter: u8) -> Result { + let ptype = PayloadType::Csv { delimiter }; + let mut output = BufWriter::new(output); + let mut reader = csv::ReaderBuilder::new().delimiter(delimiter).from_reader(input); + + let headers = reader.headers().map_err(|e| DocumentFormatError::from((ptype, e)))?.clone(); + let typed_fields: Vec<_> = headers.iter().map(parse_csv_header).collect(); + let mut object: Map<_, _> = + typed_fields.iter().map(|(k, _)| (k.to_string(), Value::Null)).collect(); + + let mut line = 0; + let mut record = csv::StringRecord::new(); + while reader.read_record(&mut record).map_err(|e| DocumentFormatError::from((ptype, e)))? { + // We increment here and not at the end of the loop + // to take the header offset into account. + line += 1; + + // Reset the document values + object.iter_mut().for_each(|(_, v)| *v = Value::Null); + + for (i, (name, atype)) in typed_fields.iter().enumerate() { + let value = &record[i]; + let trimmed_value = value.trim(); + let value = match atype { + AllowedType::Number if trimmed_value.is_empty() => Value::Null, + AllowedType::Number => match trimmed_value.parse::() { + Ok(integer) => Value::from(integer), + Err(_) => match trimmed_value.parse::() { + Ok(float) => Value::from(float), + Err(error) => { + return Err(DocumentFormatError::MalformedPayload( + Error::ParseFloat { error, line, value: value.to_string() }, + ptype, + )) + } + }, + }, + AllowedType::Boolean if trimmed_value.is_empty() => Value::Null, + AllowedType::Boolean => match trimmed_value.parse::() { + Ok(bool) => Value::from(bool), + Err(error) => { + return Err(DocumentFormatError::MalformedPayload( + Error::ParseBool { error, line, value: value.to_string() }, + ptype, + )) + } + }, + AllowedType::String if value.is_empty() => Value::Null, + AllowedType::String => Value::from(value), + }; + + *object.get_mut(*name).expect("encountered an unknown field") = value; + } + + to_writer(&mut output, &object).map_err(|e| DocumentFormatError::from((ptype, e)))?; + } + + Ok(line as u64) +} + +/// Reads JSON from file and write it in NDJSON in a file checking it along the way. +pub fn read_json(input: &File, output: impl io::Write) -> Result { + // We memory map to be able to deserailize into a TopLevelMap<'pl> that + // does not allocate when possible and only materialize the first/top level. + let input = unsafe { Mmap::map(input).map_err(DocumentFormatError::Io)? }; + let mut doc_alloc = Bump::with_capacity(1024 * 1024 * 1024); // 1MiB + + let mut out = BufWriter::new(output); + let mut deserializer = serde_json::Deserializer::from_slice(&input); + let res = array_each(&mut deserializer, |obj: &RawValue| { + doc_alloc.reset(); + let map = RawMap::from_raw_value(obj, &doc_alloc)?; + to_writer(&mut out, &map) + }); + let count = match res { // The json data has been deserialized and does not need to be processed again. // The data has been transferred to the writer during the deserialization process. - Ok(Ok(_)) => (), - Ok(Err(e)) => return Err(DocumentFormatError::Io(e)), + Ok(Ok(count)) => count, + Ok(Err(e)) => return Err(DocumentFormatError::from((PayloadType::Json, e))), Err(e) => { // Attempt to deserialize a single json string when the cause of the exception is not Category.data // Other types of deserialisation exceptions are returned directly to the front-end - if e.classify() != serde_json::error::Category::Data { - return Err(DocumentFormatError::MalformedPayload( - Error::Json(e), - PayloadType::Json, - )); + if e.classify() != Category::Data { + return Err(DocumentFormatError::from((PayloadType::Json, e))); } - let content: Object = serde_json::from_slice(&mmap) + let content: Object = serde_json::from_slice(&input) .map_err(Error::Json) .map_err(|e| (PayloadType::Json, e))?; - builder.append_json_object(&content).map_err(DocumentFormatError::Io)?; + to_writer(&mut out, &content) + .map(|_| 1) + .map_err(|e| DocumentFormatError::from((PayloadType::Json, e)))? } + }; + + match out.into_inner() { + Ok(_) => Ok(count), + Err(ie) => Err(DocumentFormatError::Io(ie.into_error())), } - - let count = builder.documents_count(); - let _ = builder.into_inner().map_err(DocumentFormatError::Io)?; - - Ok(count as u64) } -/// Reads JSON from temporary file and write an obkv batch to writer. -pub fn read_ndjson(file: &File, writer: impl Write) -> Result { - let mut builder = DocumentsBatchBuilder::new(BufWriter::new(writer)); - let mmap = unsafe { MmapOptions::new().map(file)? }; +/// Reads NDJSON from file and write it in NDJSON in a file checking it along the way. +pub fn read_ndjson(input: &File, output: impl io::Write) -> Result { + // We memory map to be able to deserailize into a TopLevelMap<'pl> that + // does not allocate when possible and only materialize the first/top level. + let input = unsafe { Mmap::map(input).map_err(DocumentFormatError::Io)? }; + let mut output = BufWriter::new(output); - for result in serde_json::Deserializer::from_slice(&mmap).into_iter() { - let object = result.map_err(Error::Json).map_err(|e| (PayloadType::Ndjson, e))?; - builder.append_json_object(&object).map_err(Into::into).map_err(DocumentFormatError::Io)?; + let mut count = 0; + for result in serde_json::Deserializer::from_slice(&input).into_iter() { + count += 1; + result + .and_then(|map: TopLevelMap| to_writer(&mut output, &map)) + .map_err(|e| DocumentFormatError::from((PayloadType::Ndjson, e)))?; } - let count = builder.documents_count(); - let _ = builder.into_inner().map_err(Into::into).map_err(DocumentFormatError::Io)?; - - Ok(count as u64) + Ok(count) } /// The actual handling of the deserialization process in serde @@ -172,20 +276,23 @@ pub fn read_ndjson(file: &File, writer: impl Write) -> Result { /// ## References /// /// -fn array_each<'de, D, T, F>(deserializer: D, f: F) -> std::result::Result, D::Error> +fn array_each<'de, D, T, F>( + deserializer: D, + f: F, +) -> std::result::Result, D::Error> where D: Deserializer<'de>, T: Deserialize<'de>, - F: FnMut(T) -> io::Result<()>, + F: FnMut(T) -> serde_json::Result<()>, { struct SeqVisitor(F, PhantomData); impl<'de, T, F> Visitor<'de> for SeqVisitor where T: Deserialize<'de>, - F: FnMut(T) -> io::Result<()>, + F: FnMut(T) -> serde_json::Result<()>, { - type Value = io::Result; + type Value = serde_json::Result; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a nonempty sequence") @@ -194,7 +301,7 @@ where fn visit_seq( mut self, mut seq: A, - ) -> std::result::Result, >::Error> + ) -> std::result::Result, >::Error> where A: SeqAccess<'de>, { @@ -203,7 +310,7 @@ where match self.0(value) { Ok(()) => max += 1, Err(e) => return Ok(Err(e)), - }; + } } Ok(Ok(max)) } diff --git a/crates/meilisearch-types/src/tasks.rs b/crates/meilisearch-types/src/tasks.rs index 1dd6d3fbf..7f4431da1 100644 --- a/crates/meilisearch-types/src/tasks.rs +++ b/crates/meilisearch-types/src/tasks.rs @@ -4,6 +4,7 @@ use std::fmt::{Display, Write}; use std::str::FromStr; use enum_iterator::Sequence; +use milli::update::new::indexer::document_changes::Progress; use milli::update::IndexDocumentsMethod; use milli::Object; use roaring::RoaringBitmap; @@ -38,6 +39,62 @@ pub struct Task { pub kind: KindWithContent, } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskProgress { + pub current_step: &'static str, + pub finished_steps: u16, + pub total_steps: u16, + pub finished_documents: Option, + pub total_documents: Option, +} + +impl Default for TaskProgress { + fn default() -> Self { + Self::new() + } +} + +impl TaskProgress { + pub fn new() -> Self { + Self { + current_step: "start", + finished_steps: 0, + total_steps: 1, + finished_documents: None, + total_documents: None, + } + } + + pub fn update(&mut self, progress: Progress) -> TaskProgress { + if self.finished_steps > progress.finished_steps { + return *self; + } + + if self.current_step != progress.step_name { + self.current_step = progress.step_name + } + + self.total_steps = progress.total_steps; + + if self.finished_steps < progress.finished_steps { + self.finished_documents = None; + self.total_documents = None; + } + self.finished_steps = progress.finished_steps; + if let Some((finished_documents, total_documents)) = progress.finished_total_documents { + if let Some(task_finished_documents) = self.finished_documents { + if task_finished_documents > finished_documents { + return *self; + } + } + self.finished_documents = Some(finished_documents); + self.total_documents = Some(total_documents); + } + *self + } +} + impl Task { pub fn index_uid(&self) -> Option<&str> { use KindWithContent::*; diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index 57202f59f..b11d90151 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -57,7 +57,7 @@ meilisearch-types = { path = "../meilisearch-types" } mimalloc = { version = "0.1.43", default-features = false } mime = "0.3.17" num_cpus = "1.16.0" -obkv = "0.2.2" +obkv = { git = "https://github.com/kerollmops/obkv", branch = "unsized-kvreader" } once_cell = "1.19.0" ordered-float = "4.2.1" parking_lot = "0.12.3" diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 241c3ab81..7e185e951 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -796,8 +796,10 @@ fn prepare_search<'t>( let span = tracing::trace_span!(target: "search::vector", "embed_one"); let _entered = span.enter(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + embedder - .embed_one(query.q.clone().unwrap()) + .embed_one(query.q.clone().unwrap(), Some(deadline)) .map_err(milli::vector::Error::from) .map_err(milli::Error::from)? } @@ -1687,7 +1689,7 @@ fn add_non_formatted_ids_to_formatted_options( fn make_document( displayed_attributes: &BTreeSet, field_ids_map: &FieldsIdsMap, - obkv: obkv::KvReaderU16, + obkv: &obkv::KvReaderU16, ) -> Result { let mut document = serde_json::Map::new(); diff --git a/crates/meilisearch/tests/documents/add_documents.rs b/crates/meilisearch/tests/documents/add_documents.rs index c37b3a5e3..17b1d6697 100644 --- a/crates/meilisearch/tests/documents/add_documents.rs +++ b/crates/meilisearch/tests/documents/add_documents.rs @@ -1335,7 +1335,6 @@ async fn error_add_documents_missing_document_id() { } #[actix_rt::test] -#[should_panic] async fn error_document_field_limit_reached_in_one_document() { let server = Server::new().await; let index = server.index("test"); @@ -1352,7 +1351,7 @@ async fn error_document_field_limit_reached_in_one_document() { let documents = json!([big_object]); let (response, code) = index.update_documents(documents, Some("id")).await; - snapshot!(code, @"500 Internal Server Error"); + snapshot!(code, @"202 Accepted"); let response = index.wait_task(response.uid()).await; snapshot!(code, @"202 Accepted"); @@ -1360,16 +1359,21 @@ async fn error_document_field_limit_reached_in_one_document() { snapshot!(response, @r###" { - "uid": 1, + "uid": "[uid]", "indexUid": "test", - "status": "succeeded", + "status": "failed", "type": "documentAdditionOrUpdate", "canceledBy": null, "details": { "receivedDocuments": 1, - "indexedDocuments": 1 + "indexedDocuments": 0 + }, + "error": { + "message": "A document cannot contain more than 65,535 fields.", + "code": "max_fields_limit_exceeded", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#max_fields_limit_exceeded" }, - "error": null, "duration": "[duration]", "enqueuedAt": "[date]", "startedAt": "[date]", @@ -1660,7 +1664,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "The `_geo` field in the document with the id: `11` is not an object. Was expecting an object with the `_geo.lat` and `_geo.lng` fields but instead got `\"foobar\"`.", + "message": "The `_geo` field in the document with the id: `\"11\"` is not an object. Was expecting an object with the `_geo.lat` and `_geo.lng` fields but instead got `\"foobar\"`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1697,7 +1701,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find latitude nor longitude in the document with the id: `11`. Was expecting `_geo.lat` and `_geo.lng` fields.", + "message": "Could not find latitude nor longitude in the document with the id: `\"11\"`. Was expecting `_geo.lat` and `_geo.lng` fields.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1734,7 +1738,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find latitude nor longitude in the document with the id: `11`. Was expecting `_geo.lat` and `_geo.lng` fields.", + "message": "Could not find latitude nor longitude in the document with the id: `\"11\"`. Was expecting `_geo.lat` and `_geo.lng` fields.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1771,7 +1775,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find longitude in the document with the id: `11`. Was expecting a `_geo.lng` field.", + "message": "Could not find longitude in the document with the id: `\"11\"`. Was expecting a `_geo.lng` field.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1808,7 +1812,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find latitude in the document with the id: `11`. Was expecting a `_geo.lat` field.", + "message": "Could not find latitude in the document with the id: `\"11\"`. Was expecting a `_geo.lat` field.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1845,7 +1849,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find longitude in the document with the id: `11`. Was expecting a `_geo.lng` field.", + "message": "Could not find longitude in the document with the id: `\"11\"`. Was expecting a `_geo.lng` field.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1882,7 +1886,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find latitude in the document with the id: `11`. Was expecting a `_geo.lat` field.", + "message": "Could not find latitude in the document with the id: `\"11\"`. Was expecting a `_geo.lat` field.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1919,7 +1923,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not parse latitude nor longitude in the document with the id: `11`. Was expecting finite numbers but instead got `false` and `true`.", + "message": "Could not parse latitude nor longitude in the document with the id: `\"11\"`. Was expecting finite numbers but instead got `false` and `true`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1956,7 +1960,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find longitude in the document with the id: `11`. Was expecting a `_geo.lng` field.", + "message": "Could not find longitude in the document with the id: `\"11\"`. Was expecting a `_geo.lng` field.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -1993,7 +1997,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not find latitude in the document with the id: `11`. Was expecting a `_geo.lat` field.", + "message": "Could not find latitude in the document with the id: `\"11\"`. Was expecting a `_geo.lat` field.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -2030,7 +2034,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not parse latitude nor longitude in the document with the id: `11`. Was expecting finite numbers but instead got `\"doggo\"` and `\"doggo\"`.", + "message": "Could not parse latitude nor longitude in the document with the id: `\"11\"`. Was expecting finite numbers but instead got `\"doggo\"` and `\"doggo\"`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -2067,7 +2071,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "The `_geo` field in the document with the id: `11` contains the following unexpected fields: `{\"doggo\":\"are the best\"}`.", + "message": "The `_geo` field in the document with the id: `\"11\"` contains the following unexpected fields: `{\"doggo\":\"are the best\"}`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -2105,7 +2109,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not parse longitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.", + "message": "Could not parse longitude in the document with the id: `\"12\"`. Was expecting a finite number but instead got `null`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -2141,7 +2145,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not parse latitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.", + "message": "Could not parse latitude in the document with the id: `\"12\"`. Was expecting a finite number but instead got `null`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -2177,7 +2181,7 @@ async fn add_documents_invalid_geo_field() { "indexedDocuments": 0 }, "error": { - "message": "Could not parse latitude nor longitude in the document with the id: `13`. Was expecting finite numbers but instead got `null` and `null`.", + "message": "Could not parse latitude nor longitude in the document with the id: `\"13\"`. Was expecting finite numbers but instead got `null` and `null`.", "code": "invalid_document_geo_field", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" @@ -2197,7 +2201,7 @@ async fn add_invalid_geo_and_then_settings() { let index = server.index("test"); index.create(Some("id")).await; - // _geo is not an object + // _geo is not a correct object let documents = json!([ { "id": "11", @@ -2226,7 +2230,7 @@ async fn add_invalid_geo_and_then_settings() { } "###); - let (ret, code) = index.update_settings(json!({"sortableAttributes": ["_geo"]})).await; + let (ret, code) = index.update_settings(json!({ "sortableAttributes": ["_geo"] })).await; snapshot!(code, @"202 Accepted"); let ret = index.wait_task(ret.uid()).await; snapshot!(ret, @r###" diff --git a/crates/meilisearch/tests/search/geo.rs b/crates/meilisearch/tests/search/geo.rs index 7804f1ad0..e92056191 100644 --- a/crates/meilisearch/tests/search/geo.rs +++ b/crates/meilisearch/tests/search/geo.rs @@ -70,8 +70,8 @@ async fn geo_bounding_box_with_string_and_number() { let documents = DOCUMENTS.clone(); index.update_settings_filterable_attributes(json!(["_geo"])).await; index.update_settings_sortable_attributes(json!(["_geo"])).await; - index.add_documents(documents, None).await; - index.wait_task(2).await; + let (ret, _code) = index.add_documents(documents, None).await; + index.wait_task(ret.uid()).await.succeeded(); index .search( diff --git a/crates/meilisearch/tests/search/mod.rs b/crates/meilisearch/tests/search/mod.rs index d1091d944..afac667bb 100644 --- a/crates/meilisearch/tests/search/mod.rs +++ b/crates/meilisearch/tests/search/mod.rs @@ -750,9 +750,9 @@ async fn test_score_details() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] }, "_rankingScoreDetails": { @@ -1543,9 +1543,9 @@ async fn simple_search_with_strange_synonyms() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] } } @@ -1568,9 +1568,9 @@ async fn simple_search_with_strange_synonyms() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] } } @@ -1593,9 +1593,9 @@ async fn simple_search_with_strange_synonyms() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] } } diff --git a/crates/meilisearch/tests/search/multi.rs b/crates/meilisearch/tests/search/multi.rs index eaa1da15f..8d7340f0d 100644 --- a/crates/meilisearch/tests/search/multi.rs +++ b/crates/meilisearch/tests/search/multi.rs @@ -113,9 +113,9 @@ async fn simple_search_single_index() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] } } @@ -138,9 +138,9 @@ async fn simple_search_single_index() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] } } @@ -182,9 +182,9 @@ async fn federation_single_search_single_index() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -305,9 +305,9 @@ async fn federation_two_search_single_index() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -325,9 +325,9 @@ async fn federation_two_search_single_index() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -480,9 +480,9 @@ async fn simple_search_two_indexes() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] } } @@ -513,9 +513,9 @@ async fn simple_search_two_indexes() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] } }, @@ -535,9 +535,9 @@ async fn simple_search_two_indexes() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] } } @@ -585,9 +585,9 @@ async fn federation_two_search_two_indexes() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -613,9 +613,9 @@ async fn federation_two_search_two_indexes() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -640,9 +640,9 @@ async fn federation_two_search_two_indexes() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -707,9 +707,9 @@ async fn federation_multiple_search_multiple_indexes() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -735,9 +735,9 @@ async fn federation_multiple_search_multiple_indexes() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -773,9 +773,9 @@ async fn federation_multiple_search_multiple_indexes() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -793,9 +793,9 @@ async fn federation_multiple_search_multiple_indexes() { ], "_vectors": { "manual": [ - 10.0, - -23.0, - 32.0 + 10, + -23, + 32 ] }, "_federation": { @@ -824,9 +824,9 @@ async fn federation_multiple_search_multiple_indexes() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -869,9 +869,9 @@ async fn federation_multiple_search_multiple_indexes() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -898,9 +898,9 @@ async fn federation_multiple_search_multiple_indexes() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] }, "_federation": { @@ -1393,9 +1393,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -1414,9 +1414,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -1442,9 +1442,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -1474,9 +1474,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -1522,9 +1522,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -1550,9 +1550,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -1582,9 +1582,9 @@ async fn federation_sort_same_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -1716,9 +1716,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -1748,9 +1748,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -1769,9 +1769,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -1797,9 +1797,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -1845,9 +1845,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -1874,9 +1874,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { "cattos": "pésti", "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -1906,9 +1906,9 @@ async fn federation_sort_same_indexes_different_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - 23.0, - 32.0 + 10, + 23, + 32 ] }, "_federation": { @@ -2103,9 +2103,9 @@ async fn federation_sort_different_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -2124,9 +2124,9 @@ async fn federation_sort_different_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 10.0, - -23.0, - 32.0 + 10, + -23, + 32 ] }, "_federation": { @@ -2145,9 +2145,9 @@ async fn federation_sort_different_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -2166,9 +2166,9 @@ async fn federation_sort_different_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] }, "_federation": { @@ -2187,9 +2187,9 @@ async fn federation_sort_different_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -2228,9 +2228,9 @@ async fn federation_sort_different_indexes_same_criterion_same_direction() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -2415,9 +2415,9 @@ async fn federation_sort_different_ranking_rules() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -2436,9 +2436,9 @@ async fn federation_sort_different_ranking_rules() { ], "_vectors": { "manual": [ - 10.0, - -23.0, - 32.0 + 10, + -23, + 32 ] }, "_federation": { @@ -2457,9 +2457,9 @@ async fn federation_sort_different_ranking_rules() { ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -2478,9 +2478,9 @@ async fn federation_sort_different_ranking_rules() { ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] }, "_federation": { @@ -2499,9 +2499,9 @@ async fn federation_sort_different_ranking_rules() { ], "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -2716,9 +2716,9 @@ async fn federation_sort_different_indexes_different_criterion_same_direction() ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -2757,9 +2757,9 @@ async fn federation_sort_different_indexes_different_criterion_same_direction() ], "_vectors": { "manual": [ - 10.0, - -23.0, - 32.0 + 10, + -23, + 32 ] }, "_federation": { @@ -2778,9 +2778,9 @@ async fn federation_sort_different_indexes_different_criterion_same_direction() ], "_vectors": { "manual": [ - -100.0, - 340.0, - 90.0 + -100, + 340, + 90 ] }, "_federation": { @@ -2799,9 +2799,9 @@ async fn federation_sort_different_indexes_different_criterion_same_direction() ], "_vectors": { "manual": [ - -100.0, - 231.0, - 32.0 + -100, + 231, + 32 ] }, "_federation": { @@ -2820,9 +2820,9 @@ async fn federation_sort_different_indexes_different_criterion_same_direction() ], "_vectors": { "manual": [ - 1.0, - 2.0, - 3.0 + 1, + 2, + 3 ] }, "_federation": { @@ -2881,9 +2881,9 @@ async fn federation_sort_different_indexes_different_criterion_same_direction() ], "_vectors": { "manual": [ - 1.0, - 2.0, - 54.0 + 1, + 2, + 54 ] }, "_federation": { @@ -4346,10 +4346,10 @@ async fn federation_vector_two_indexes() { let (response, code) = server .multi_search(json!({"federation": {}, "queries": [ - {"indexUid" : "vectors-animal", "vector": [1.0, 0.0, 0.5], "hybrid": {"semanticRatio": 1.0, "embedder": "animal"}}, + {"indexUid" : "vectors-animal", "vector": [1.0, 0.0, 0.5], "hybrid": {"semanticRatio": 1.0, "embedder": "animal"}, "retrieveVectors": true}, // joyful and energetic first - {"indexUid": "vectors-sentiment", "vector": [0.8, 0.6], "hybrid": {"semanticRatio": 1.0, "embedder": "sentiment"}}, - {"indexUid": "vectors-sentiment", "q": "dog"}, + {"indexUid": "vectors-sentiment", "vector": [0.8, 0.6], "hybrid": {"semanticRatio": 1.0, "embedder": "sentiment"}, "retrieveVectors": true}, + {"indexUid": "vectors-sentiment", "q": "dog", "retrieveVectors": true}, ]})) .await; snapshot!(code, @"200 OK"); @@ -4364,7 +4364,16 @@ async fn federation_vector_two_indexes() { 0.8, 0.09, 0.8 - ] + ], + "sentiment": { + "embeddings": [ + [ + 0.800000011920929, + 0.30000001192092896 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4379,7 +4388,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ 0.8, 0.3 - ] + ], + "animal": { + "embeddings": [ + [ + 0.800000011920929, + 0.09000000357627869, + 0.800000011920929 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4394,7 +4413,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ -1.0, 0.1 - ] + ], + "animal": { + "embeddings": [ + [ + 0.8500000238418579, + 0.019999999552965164, + 0.10000000149011612 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4410,7 +4439,16 @@ async fn federation_vector_two_indexes() { 0.9, 0.8, 0.05 - ] + ], + "sentiment": { + "embeddings": [ + [ + -0.10000000149011612, + 0.550000011920929 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4426,7 +4464,16 @@ async fn federation_vector_two_indexes() { 0.85, 0.02, 0.1 - ] + ], + "sentiment": { + "embeddings": [ + [ + -1.0, + 0.10000000149011612 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4441,7 +4488,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ -0.2, 0.65 - ] + ], + "animal": { + "embeddings": [ + [ + 0.800000011920929, + 0.8999999761581421, + 0.5 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4456,7 +4513,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ -0.1, 0.55 - ] + ], + "animal": { + "embeddings": [ + [ + 0.8999999761581421, + 0.800000011920929, + 0.05000000074505806 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4472,7 +4539,16 @@ async fn federation_vector_two_indexes() { 0.8, 0.9, 0.5 - ] + ], + "sentiment": { + "embeddings": [ + [ + -0.20000000298023224, + 0.6499999761581421 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4492,8 +4568,8 @@ async fn federation_vector_two_indexes() { // hybrid search, distinct embedder let (response, code) = server .multi_search(json!({"federation": {}, "queries": [ - {"indexUid" : "vectors-animal", "vector": [1.0, 0.0, 0.5], "hybrid": {"semanticRatio": 1.0, "embedder": "animal"}, "showRankingScore": true}, - {"indexUid": "vectors-sentiment", "vector": [-1, 0.6], "q": "beagle", "hybrid": {"semanticRatio": 1.0, "embedder": "sentiment"}, "showRankingScore": true}, + {"indexUid" : "vectors-animal", "vector": [1.0, 0.0, 0.5], "hybrid": {"semanticRatio": 1.0, "embedder": "animal"}, "showRankingScore": true, "retrieveVectors": true}, + {"indexUid": "vectors-sentiment", "vector": [-1, 0.6], "q": "beagle", "hybrid": {"semanticRatio": 1.0, "embedder": "sentiment"}, "showRankingScore": true, "retrieveVectors": true,}, ]})) .await; snapshot!(code, @"200 OK"); @@ -4507,7 +4583,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ 0.8, 0.3 - ] + ], + "animal": { + "embeddings": [ + [ + 0.800000011920929, + 0.09000000357627869, + 0.800000011920929 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4523,7 +4609,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ -1.0, 0.1 - ] + ], + "animal": { + "embeddings": [ + [ + 0.8500000238418579, + 0.019999999552965164, + 0.10000000149011612 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4540,7 +4636,16 @@ async fn federation_vector_two_indexes() { 0.85, 0.02, 0.1 - ] + ], + "sentiment": { + "embeddings": [ + [ + -1.0, + 0.10000000149011612 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4557,7 +4662,16 @@ async fn federation_vector_two_indexes() { 0.8, 0.9, 0.5 - ] + ], + "sentiment": { + "embeddings": [ + [ + -0.20000000298023224, + 0.6499999761581421 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4573,7 +4687,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ -0.2, 0.65 - ] + ], + "animal": { + "embeddings": [ + [ + 0.800000011920929, + 0.8999999761581421, + 0.5 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4589,7 +4713,17 @@ async fn federation_vector_two_indexes() { "sentiment": [ -0.1, 0.55 - ] + ], + "animal": { + "embeddings": [ + [ + 0.8999999761581421, + 0.800000011920929, + 0.05000000074505806 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-animal", @@ -4606,7 +4740,16 @@ async fn federation_vector_two_indexes() { 0.9, 0.8, 0.05 - ] + ], + "sentiment": { + "embeddings": [ + [ + -0.10000000149011612, + 0.550000011920929 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", @@ -4623,7 +4766,16 @@ async fn federation_vector_two_indexes() { 0.8, 0.09, 0.8 - ] + ], + "sentiment": { + "embeddings": [ + [ + 0.800000011920929, + 0.30000001192092896 + ] + ], + "regenerate": false + } }, "_federation": { "indexUid": "vectors-sentiment", diff --git a/crates/meilisearch/tests/vector/mod.rs b/crates/meilisearch/tests/vector/mod.rs index 47d0c1051..8f4e9cc70 100644 --- a/crates/meilisearch/tests/vector/mod.rs +++ b/crates/meilisearch/tests/vector/mod.rs @@ -249,7 +249,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Missing field `regenerate` inside `.manual`", + "message": "Bad embedder configuration in the document with id: `0`. Missing field `._vectors.manual.regenerate`\n - note: `._vectors.manual` must be an array of floats, an array of arrays of floats, or an object with field `regenerate`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -278,7 +278,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Missing field `regenerate` inside `.manual`", + "message": "Bad embedder configuration in the document with id: `0`. Missing field `._vectors.manual.regenerate`\n - note: `._vectors.manual` must be an array of floats, an array of arrays of floats, or an object with field `regenerate`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -308,7 +308,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.regenerate`: expected a boolean, but found a string: `\"yes please\"`", + "message": "Bad embedder configuration in the document with id: `0`. Could not parse `._vectors.manual.regenerate`: invalid type: string \"yes please\", expected a boolean at line 1 column 26", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -320,8 +320,7 @@ async fn user_provided_embeddings_error() { } "###); - let documents = - json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": true }}}); + let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": true, "regenerate": true }}}); let (value, code) = index.add_documents(documents, None).await; snapshot!(code, @"202 Accepted"); let task = index.wait_task(value.uid()).await; @@ -337,7 +336,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings`: expected null or an array, but found a boolean: `true`", + "message": "Bad embedder configuration in the document with id: `0`. Invalid value type at `._vectors.manual.embeddings`: expected null or an array, but found a boolean: `true`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -349,8 +348,7 @@ async fn user_provided_embeddings_error() { } "###); - let documents = - json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [true] }}}); + let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [true], "regenerate": true }}}); let (value, code) = index.add_documents(documents, None).await; snapshot!(code, @"202 Accepted"); let task = index.wait_task(value.uid()).await; @@ -366,7 +364,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[0]`: expected a number or an array, but found a boolean: `true`", + "message": "Bad embedder configuration in the document with id: `0`. Invalid value type at `._vectors.manual.embeddings[0]`: expected a number or an array, but found a boolean: `true`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -378,8 +376,7 @@ async fn user_provided_embeddings_error() { } "###); - let documents = - json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [[true]] }}}); + let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [[true]], "regenerate": false }}}); let (value, code) = index.add_documents(documents, None).await; snapshot!(code, @"202 Accepted"); let task = index.wait_task(value.uid()).await; @@ -395,7 +392,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[0][0]`: expected a number, but found a boolean: `true`", + "message": "Bad embedder configuration in the document with id: `0`. Invalid value type at `._vectors.manual.embeddings[0][0]`: expected a number, but found a boolean: `true`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -436,7 +433,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[1]`: expected a number, but found an array: `[0.2,0.3]`", + "message": "Bad embedder configuration in the document with id: `0`. Invalid value type at `._vectors.manual.embeddings[1]`: expected a number, but found an array: `[0.2,0.3]`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -464,7 +461,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[1]`: expected an array, but found a number: `0.3`", + "message": "Bad embedder configuration in the document with id: `0`. Invalid value type at `._vectors.manual.embeddings[1]`: expected an array, but found a number: `0.3`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -492,7 +489,7 @@ async fn user_provided_embeddings_error() { "indexedDocuments": 0 }, "error": { - "message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[0][1]`: expected a number, but found a boolean: `true`", + "message": "Bad embedder configuration in the document with id: `0`. Invalid value type at `._vectors.manual.embeddings[0][1]`: expected a number, but found a boolean: `true`", "code": "invalid_vectors_type", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_vectors_type" @@ -532,7 +529,7 @@ async fn user_provided_vectors_error() { "indexedDocuments": 0 }, "error": { - "message": "While embedding documents for embedder `manual`: no vectors provided for document \"40\" and at least 4 other document(s)\n- Note: `manual` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.manual`.\n- Hint: opt-out for a document with `_vectors.manual: null`", + "message": "While embedding documents for embedder `manual`: no vectors provided for document `40` and at least 4 other document(s)\n- Note: `manual` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.manual`.\n- Hint: opt-out for a document with `_vectors.manual: null`", "code": "vector_embedding_error", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#vector_embedding_error" @@ -561,7 +558,7 @@ async fn user_provided_vectors_error() { "indexedDocuments": 0 }, "error": { - "message": "While embedding documents for embedder `manual`: no vectors provided for document \"42\"\n- Note: `manual` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.manual`.\n- Hint: try replacing `_vector` by `_vectors` in 1 document(s).", + "message": "While embedding documents for embedder `manual`: no vectors provided for document `42`\n- Note: `manual` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.manual`.\n- Hint: try replacing `_vector` by `_vectors` in 1 document(s).", "code": "vector_embedding_error", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#vector_embedding_error" @@ -590,7 +587,7 @@ async fn user_provided_vectors_error() { "indexedDocuments": 0 }, "error": { - "message": "While embedding documents for embedder `manual`: no vectors provided for document \"42\"\n- Note: `manual` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.manual`.\n- Hint: try replacing `_vectors.manaul` by `_vectors.manual` in 1 document(s).", + "message": "While embedding documents for embedder `manual`: no vectors provided for document `42`\n- Note: `manual` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.manual`.\n- Hint: try replacing `_vectors.manaul` by `_vectors.manual` in 1 document(s).", "code": "vector_embedding_error", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#vector_embedding_error" diff --git a/crates/meilisearch/tests/vector/openai.rs b/crates/meilisearch/tests/vector/openai.rs index 04c068c40..94291ebea 100644 --- a/crates/meilisearch/tests/vector/openai.rs +++ b/crates/meilisearch/tests/vector/openai.rs @@ -137,13 +137,14 @@ fn long_text() -> &'static str { } async fn create_mock_tokenized() -> (MockServer, Value) { - create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false).await + create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false, false).await } async fn create_mock_with_template( document_template: &str, model_dimensions: ModelDimensions, fallible: bool, + slow: bool, ) -> (MockServer, Value) { let mock_server = MockServer::start().await; const API_KEY: &str = "my-api-key"; @@ -154,7 +155,11 @@ async fn create_mock_with_template( Mock::given(method("POST")) .and(path("/")) .respond_with(move |req: &Request| { - // 0. maybe return 500 + // 0. wait for a long time + if slow { + std::thread::sleep(std::time::Duration::from_secs(1)); + } + // 1. maybe return 500 if fallible { let attempt = attempt.fetch_add(1, Ordering::Relaxed); let failed = matches!(attempt % 4, 0 | 1 | 3); @@ -167,7 +172,7 @@ async fn create_mock_with_template( })) } } - // 1. check API key + // 3. check API key match req.headers.get("Authorization") { Some(api_key) if api_key == API_KEY_BEARER => { {} @@ -202,7 +207,7 @@ async fn create_mock_with_template( ) } } - // 2. parse text inputs + // 3. parse text inputs let query: serde_json::Value = match req.body_json() { Ok(query) => query, Err(_error) => return ResponseTemplate::new(400).set_body_json( @@ -223,7 +228,7 @@ async fn create_mock_with_template( panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}") } - // 3. for each text, find embedding in responses + // 4. for each text, find embedding in responses let serde_json::Value::Array(inputs) = &query["input"] else { panic!("Unexpected `input` value") }; @@ -283,7 +288,7 @@ async fn create_mock_with_template( "embedding": embedding, })).collect(); - // 4. produce output from embeddings + // 5. produce output from embeddings ResponseTemplate::new(200).set_body_json(json!({ "object": "list", "data": data, @@ -317,23 +322,27 @@ const DOGGO_TEMPLATE: &str = r#"{%- if doc.gender == "F" -%}Une chienne nommée {%- endif %}, de race {{doc.breed}}."#; async fn create_mock() -> (MockServer, Value) { - create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false).await + create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false, false).await } async fn create_mock_dimensions() -> (MockServer, Value) { - create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false).await + create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false, false).await } async fn create_mock_small_embedding_model() -> (MockServer, Value) { - create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false).await + create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false, false).await } async fn create_mock_legacy_embedding_model() -> (MockServer, Value) { - create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false).await + create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false, false).await } async fn create_fallible_mock() -> (MockServer, Value) { - create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true).await + create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, false).await +} + +async fn create_slow_mock() -> (MockServer, Value) { + create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, true).await } // basic test "it works" @@ -1873,4 +1882,114 @@ async fn it_still_works() { ] "###); } + +// test with a server that responds 500 on 3 out of 4 calls +#[actix_rt::test] +async fn timeout() { + let (_mock, setting) = create_slow_mock().await; + let server = get_server_vector().await; + let index = server.index("doggo"); + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "default": setting, + }, + })) + .await; + snapshot!(code, @"202 Accepted"); + let task = server.wait_task(response.uid()).await; + snapshot!(task["status"], @r###""succeeded""###); + let documents = json!([ + {"id": 0, "name": "kefir", "gender": "M", "birthyear": 2023, "breed": "Patou"}, + ]); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + let task = index.wait_task(value.uid()).await; + snapshot!(task, @r###" + { + "uid": "[uid]", + "indexUid": "doggo", + "status": "succeeded", + "type": "documentAdditionOrUpdate", + "canceledBy": null, + "details": { + "receivedDocuments": 1, + "indexedDocuments": 1 + }, + "error": null, + "duration": "[duration]", + "enqueuedAt": "[date]", + "startedAt": "[date]", + "finishedAt": "[date]" + } + "###); + + let (documents, _code) = index + .get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() }) + .await; + snapshot!(json_string!(documents, {".results.*._vectors.default.embeddings" => "[vector]"}), @r###" + { + "results": [ + { + "id": 0, + "name": "kefir", + "gender": "M", + "birthyear": 2023, + "breed": "Patou", + "_vectors": { + "default": { + "embeddings": "[vector]", + "regenerate": true + } + } + } + ], + "offset": 0, + "limit": 20, + "total": 1 + } + "###); + + let (response, code) = index + .search_post(json!({ + "q": "grand chien de berger des montagnes", + "hybrid": {"semanticRatio": 0.99, "embedder": "default"} + })) + .await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["semanticHitCount"]), @"0"); + snapshot!(json_string!(response["hits"]), @"[]"); + + let (response, code) = index + .search_post(json!({ + "q": "grand chien de berger des montagnes", + "hybrid": {"semanticRatio": 0.99, "embedder": "default"} + })) + .await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["semanticHitCount"]), @"1"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "id": 0, + "name": "kefir", + "gender": "M", + "birthyear": 2023, + "breed": "Patou" + } + ] + "###); + + let (response, code) = index + .search_post(json!({ + "q": "grand chien de berger des montagnes", + "hybrid": {"semanticRatio": 0.99, "embedder": "default"} + })) + .await; + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["semanticHitCount"]), @"0"); + snapshot!(json_string!(response["hits"]), @"[]"); +} + // test with a server that wrongly responds 400 diff --git a/crates/meilisearch/tests/vector/rest.rs b/crates/meilisearch/tests/vector/rest.rs index 2748d0846..09188595c 100644 --- a/crates/meilisearch/tests/vector/rest.rs +++ b/crates/meilisearch/tests/vector/rest.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::sync::atomic::{AtomicUsize, Ordering}; use meili_snap::{json_string, snapshot}; use reqwest::IntoUrl; @@ -13,13 +12,22 @@ use crate::vector::{get_server_vector, GetAllDocumentsOptions}; async fn create_mock() -> (MockServer, Value) { let mock_server = MockServer::start().await; - let counter = AtomicUsize::new(0); + let text_to_embedding: BTreeMap<_, _> = vec![ + // text -> embedding + ("kefir", [0.0, 0.0, 0.0]), + ("intel", [1.0, 1.0, 1.0]), + ] + // turn into btree + .into_iter() + .collect(); Mock::given(method("POST")) .and(path("/")) - .respond_with(move |_req: &Request| { - let counter = counter.fetch_add(1, Ordering::Relaxed); - ResponseTemplate::new(200).set_body_json(json!({ "data": vec![counter; 3] })) + .respond_with(move |req: &Request| { + let text: String = req.body_json().unwrap(); + ResponseTemplate::new(200).set_body_json( + json!({ "data": text_to_embedding.get(text.as_str()).unwrap_or(&[99., 99., 99.]) }), + ) }) .mount(&mock_server) .await; @@ -32,13 +40,14 @@ async fn create_mock() -> (MockServer, Value) { "request": "{{text}}", "response": { "data": "{{embedding}}" - } + }, + "documentTemplate": "{{doc.name}}", }); (mock_server, embedder_settings) } -async fn create_mock_map() -> (MockServer, Value) { +async fn create_mock_default_template() -> (MockServer, Value) { let mock_server = MockServer::start().await; let text_to_embedding: BTreeMap<_, _> = vec![ @@ -97,7 +106,14 @@ struct SingleResponse { async fn create_mock_multiple() -> (MockServer, Value) { let mock_server = MockServer::start().await; - let counter = AtomicUsize::new(0); + let text_to_embedding: BTreeMap<_, _> = vec![ + // text -> embedding + ("kefir", [0.0, 0.0, 0.0]), + ("intel", [1.0, 1.0, 1.0]), + ] + // turn into btree + .into_iter() + .collect(); Mock::given(method("POST")) .and(path("/")) @@ -115,8 +131,11 @@ async fn create_mock_multiple() -> (MockServer, Value) { .input .into_iter() .map(|text| SingleResponse { + embedding: text_to_embedding + .get(text.as_str()) + .unwrap_or(&[99., 99., 99.]) + .to_vec(), text, - embedding: vec![counter.fetch_add(1, Ordering::Relaxed) as f32; 3], }) .collect(); @@ -142,7 +161,8 @@ async fn create_mock_multiple() -> (MockServer, Value) { }, "{{..}}" ] - } + }, + "documentTemplate": "{{doc.name}}" }); (mock_server, embedder_settings) @@ -156,7 +176,14 @@ struct SingleRequest { async fn create_mock_single_response_in_array() -> (MockServer, Value) { let mock_server = MockServer::start().await; - let counter = AtomicUsize::new(0); + let text_to_embedding: BTreeMap<_, _> = vec![ + // text -> embedding + ("kefir", [0.0, 0.0, 0.0]), + ("intel", [1.0, 1.0, 1.0]), + ] + // turn into btree + .into_iter() + .collect(); Mock::given(method("POST")) .and(path("/")) @@ -171,8 +198,11 @@ async fn create_mock_single_response_in_array() -> (MockServer, Value) { }; let output = vec![SingleResponse { + embedding: text_to_embedding + .get(req.input.as_str()) + .unwrap_or(&[99., 99., 99.]) + .to_vec(), text: req.input, - embedding: vec![counter.fetch_add(1, Ordering::Relaxed) as f32; 3], }]; let response = MultipleResponse { output }; @@ -196,7 +226,8 @@ async fn create_mock_single_response_in_array() -> (MockServer, Value) { "embedding": "{{embedding}}" } ] - } + }, + "documentTemplate": "{{doc.name}}" }); (mock_server, embedder_settings) @@ -205,7 +236,14 @@ async fn create_mock_single_response_in_array() -> (MockServer, Value) { async fn create_mock_raw_with_custom_header() -> (MockServer, Value) { let mock_server = MockServer::start().await; - let counter = AtomicUsize::new(0); + let text_to_embedding: BTreeMap<_, _> = vec![ + // text -> embedding + ("kefir", [0.0, 0.0, 0.0]), + ("intel", [1.0, 1.0, 1.0]), + ] + // turn into btree + .into_iter() + .collect(); Mock::given(method("POST")) .and(path("/")) @@ -223,7 +261,7 @@ async fn create_mock_raw_with_custom_header() -> (MockServer, Value) { } } - let _req: String = match req.body_json() { + let req: String = match req.body_json() { Ok(req) => req, Err(error) => { return ResponseTemplate::new(400).set_body_json(json!({ @@ -232,7 +270,7 @@ async fn create_mock_raw_with_custom_header() -> (MockServer, Value) { } }; - let output = vec![counter.fetch_add(1, Ordering::Relaxed) as f32; 3]; + let output = text_to_embedding.get(req.as_str()).unwrap_or(&[99., 99., 99.]).to_vec(); ResponseTemplate::new(200).set_body_json(output) }) @@ -245,7 +283,8 @@ async fn create_mock_raw_with_custom_header() -> (MockServer, Value) { "url": url, "request": "{{text}}", "response": "{{embedding}}", - "headers": {"my-nonstandard-auth": "bearer of the ring"} + "headers": {"my-nonstandard-auth": "bearer of the ring"}, + "documentTemplate": "{{doc.name}}" }); (mock_server, embedder_settings) @@ -254,12 +293,19 @@ async fn create_mock_raw_with_custom_header() -> (MockServer, Value) { async fn create_mock_raw() -> (MockServer, Value) { let mock_server = MockServer::start().await; - let counter = AtomicUsize::new(0); + let text_to_embedding: BTreeMap<_, _> = vec![ + // text -> embedding + ("kefir", [0.0, 0.0, 0.0]), + ("intel", [1.0, 1.0, 1.0]), + ] + // turn into btree + .into_iter() + .collect(); Mock::given(method("POST")) .and(path("/")) .respond_with(move |req: &Request| { - let _req: String = match req.body_json() { + let req: String = match req.body_json() { Ok(req) => req, Err(error) => { return ResponseTemplate::new(400).set_body_json(json!({ @@ -268,7 +314,7 @@ async fn create_mock_raw() -> (MockServer, Value) { } }; - let output = vec![counter.fetch_add(1, Ordering::Relaxed) as f32; 3]; + let output = text_to_embedding.get(req.as_str()).unwrap_or(&[99., 99., 99.]).to_vec(); ResponseTemplate::new(200).set_body_json(output) }) @@ -281,29 +327,30 @@ async fn create_mock_raw() -> (MockServer, Value) { "url": url, "dimensions": 3, "request": "{{text}}", - "response": "{{embedding}}" + "response": "{{embedding}}", + "documentTemplate": "{{doc.name}}" }); (mock_server, embedder_settings) } -pub async fn post(url: T) -> reqwest::Result { - reqwest::Client::builder().build()?.post(url).send().await +pub async fn post(url: T, text: &str) -> reqwest::Result { + reqwest::Client::builder().build()?.post(url).json(&json!(text)).send().await } #[actix_rt::test] async fn dummy_testing_the_mock() { let (mock, _setting) = create_mock().await; - let body = post(&mock.uri()).await.unwrap().text().await.unwrap(); - snapshot!(body, @r###"{"data":[0,0,0]}"###); - let body = post(&mock.uri()).await.unwrap().text().await.unwrap(); - snapshot!(body, @r###"{"data":[1,1,1]}"###); - let body = post(&mock.uri()).await.unwrap().text().await.unwrap(); - snapshot!(body, @r###"{"data":[2,2,2]}"###); - let body = post(&mock.uri()).await.unwrap().text().await.unwrap(); - snapshot!(body, @r###"{"data":[3,3,3]}"###); - let body = post(&mock.uri()).await.unwrap().text().await.unwrap(); - snapshot!(body, @r###"{"data":[4,4,4]}"###); + let body = post(&mock.uri(), "kefir").await.unwrap().text().await.unwrap(); + snapshot!(body, @r###"{"data":[0.0,0.0,0.0]}"###); + let body = post(&mock.uri(), "intel").await.unwrap().text().await.unwrap(); + snapshot!(body, @r###"{"data":[1.0,1.0,1.0]}"###); + let body = post(&mock.uri(), "kefir").await.unwrap().text().await.unwrap(); + snapshot!(body, @r###"{"data":[0.0,0.0,0.0]}"###); + let body = post(&mock.uri(), "kefir").await.unwrap().text().await.unwrap(); + snapshot!(body, @r###"{"data":[0.0,0.0,0.0]}"###); + let body = post(&mock.uri(), "intel").await.unwrap().text().await.unwrap(); + snapshot!(body, @r###"{"data":[1.0,1.0,1.0]}"###); } #[actix_rt::test] @@ -953,7 +1000,7 @@ async fn bad_settings() { let (response, code) = index .update_settings(json!({ "embedders": { - "rest": json!({ "source": "rest", "url": mock.uri(), "request": "{{text}}", "response": { "data": "{{embedding}}" }, "dimensions": 2 }), + "rest": json!({ "source": "rest", "url": mock.uri(), "request": "{{text}}", "response": { "data": "{{embedding}}" }, "dimensions": 2, "documentTemplate": "{{doc.name}}" }), }, })) .await; @@ -1920,6 +1967,7 @@ async fn server_custom_header() { "embedders": { "rest": { "source": "rest", + "documentTemplate": "{{doc.name}}", "url": "[url]", "request": "{{text}}", "response": "{{embedding}}", @@ -1940,7 +1988,7 @@ async fn server_custom_header() { #[actix_rt::test] async fn searchable_reindex() { - let (_mock, setting) = create_mock_map().await; + let (_mock, setting) = create_mock_default_template().await; let server = get_server_vector().await; let index = server.index("doggo"); diff --git a/crates/meilitool/src/main.rs b/crates/meilitool/src/main.rs index 978824356..f84cea98d 100644 --- a/crates/meilitool/src/main.rs +++ b/crates/meilitool/src/main.rs @@ -264,7 +264,7 @@ fn export_a_dump( format!("While iterating on content file {:?}", content_file_uuid) })? { dump_content_file - .push_document(&obkv_to_object(&doc, &documents_batch_index)?)?; + .push_document(&obkv_to_object(doc, &documents_batch_index)?)?; } dump_content_file.flush()?; count += 1; diff --git a/crates/milli/Cargo.toml b/crates/milli/Cargo.toml index 7b43fbf33..07e18ef4d 100644 --- a/crates/milli/Cargo.toml +++ b/crates/milli/Cargo.toml @@ -12,12 +12,14 @@ readme.workspace = true license.workspace = true [dependencies] +big_s = "1.0.2" bimap = { version = "0.6.3", features = ["serde"] } bincode = "1.3.3" bstr = "1.9.1" bytemuck = { version = "1.18.0", features = ["extern_crate_alloc"] } byteorder = "1.5.0" -charabia = { version = "0.9.1", default-features = false } +# charabia = { version = "0.9.0", default-features = false } +charabia = { git = "https://github.com/meilisearch/charabia", branch = "mutualize-char-normalizer", default-features = false } concat-arrays = "0.1.2" crossbeam-channel = "0.5.13" deserr = "0.6.2" @@ -27,9 +29,9 @@ fst = "0.4.7" fxhash = "0.2.1" geoutils = "0.5.1" grenad = { version = "0.4.7", default-features = false, features = [ - "rayon", + "rayon", # TODO Should we keep this feature "tempfile", -] } +], git = "https://github.com/meilisearch/grenad", branch = "various-improvements" } heed = { version = "0.20.3", default-features = false, features = [ "serde-json", "serde-bincode", @@ -40,14 +42,14 @@ json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memchr = "2.5.0" memmap2 = "0.9.4" -obkv = "0.2.2" +obkv = { git = "https://github.com/kerollmops/obkv", branch = "unsized-kvreader" } once_cell = "1.19.0" ordered-float = "4.2.1" rayon = "1.10.0" roaring = { version = "0.10.6", features = ["serde"] } rstar = { version = "0.12.0", features = ["serde"] } serde = { version = "1.0.204", features = ["derive"] } -serde_json = { version = "1.0.120", features = ["preserve_order"] } +serde_json = { version = "1.0.120", features = ["preserve_order", "raw_value"] } slice-group-by = "0.3.1" smallstr = { version = "0.3.0", features = ["serde"] } smallvec = "1.13.2" @@ -79,17 +81,30 @@ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", ] } tiktoken-rs = "0.5.9" liquid = "0.26.6" -rhai = { git = "https://github.com/rhaiscript/rhai", rev = "ef3df63121d27aacd838f366f2b83fd65f20a1e4", features = ["serde", "no_module", "no_custom_syntax", "no_time", "sync"] } +rhai = { git = "https://github.com/rhaiscript/rhai", rev = "ef3df63121d27aacd838f366f2b83fd65f20a1e4", features = [ + "serde", + "no_module", + "no_custom_syntax", + "no_time", + "sync", +] } arroy = "0.5.0" rand = "0.8.5" tracing = "0.1.40" ureq = { version = "2.10.0", features = ["json"] } url = "2.5.2" rayon-par-bridge = "0.1.0" +hashbrown = "0.15.0" +raw-collections = { git = "https://github.com/dureuill/raw-collections.git", version = "0.1.0" } +bumpalo = "3.16.0" +thread_local = "1.1.8" +allocator-api2 = "0.2.18" +rustc-hash = "2.0.0" +uell = "0.1.0" +enum-iterator = "2.1.0" [dev-dependencies] mimalloc = { version = "0.1.43", default-features = false } -big_s = "1.0.2" insta = "1.39.0" maplit = "1.0.2" md5 = "0.7.0" diff --git a/crates/milli/src/documents/builder.rs b/crates/milli/src/documents/builder.rs index ec4d634aa..1cf90447e 100644 --- a/crates/milli/src/documents/builder.rs +++ b/crates/milli/src/documents/builder.rs @@ -292,7 +292,7 @@ mod test { .unwrap() .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -321,7 +321,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -348,7 +348,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -375,7 +375,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -402,7 +402,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -429,7 +429,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -456,7 +456,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -483,7 +483,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -510,7 +510,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -555,7 +555,7 @@ mod test { .into_cursor_and_fields_index(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); + let val = obkv_to_object(doc, &index).map(Value::from).unwrap(); assert_eq!( val, diff --git a/crates/milli/src/documents/enriched.rs b/crates/milli/src/documents/enriched.rs index 609765068..cede4d2f0 100644 --- a/crates/milli/src/documents/enriched.rs +++ b/crates/milli/src/documents/enriched.rs @@ -69,7 +69,7 @@ impl EnrichedDocumentsBatchReader { #[derive(Debug, Clone)] pub struct EnrichedDocument<'a> { - pub document: KvReader<'a, FieldId>, + pub document: &'a KvReader, pub document_id: DocumentId, } diff --git a/crates/milli/src/documents/mod.rs b/crates/milli/src/documents/mod.rs index f4509256d..001e2293a 100644 --- a/crates/milli/src/documents/mod.rs +++ b/crates/milli/src/documents/mod.rs @@ -13,8 +13,8 @@ pub use builder::DocumentsBatchBuilder; pub use enriched::{EnrichedDocument, EnrichedDocumentsBatchCursor, EnrichedDocumentsBatchReader}; use obkv::KvReader; pub use primary_key::{ - validate_document_id_value, DocumentIdExtractionError, FieldIdMapper, PrimaryKey, - DEFAULT_PRIMARY_KEY, + validate_document_id_str, validate_document_id_value, DocumentIdExtractionError, FieldIdMapper, + PrimaryKey, DEFAULT_PRIMARY_KEY, }; pub use reader::{DocumentsBatchCursor, DocumentsBatchCursorError, DocumentsBatchReader}; use serde::{Deserialize, Serialize}; @@ -27,7 +27,7 @@ use crate::{FieldId, Object, Result}; const DOCUMENTS_BATCH_INDEX_KEY: [u8; 8] = u64::MAX.to_be_bytes(); /// Helper function to convert an obkv reader into a JSON object. -pub fn obkv_to_object(obkv: &KvReader<'_, FieldId>, index: &DocumentsBatchIndex) -> Result { +pub fn obkv_to_object(obkv: &KvReader, index: &DocumentsBatchIndex) -> Result { obkv.iter() .map(|(field_id, value)| { let field_name = index @@ -76,7 +76,7 @@ impl DocumentsBatchIndex { self.0.get_by_right(name).cloned() } - pub fn recreate_json(&self, document: &obkv::KvReaderU16<'_>) -> Result { + pub fn recreate_json(&self, document: &obkv::KvReaderU16) -> Result { let mut map = Object::new(); for (k, v) in document.iter() { @@ -96,6 +96,10 @@ impl FieldIdMapper for DocumentsBatchIndex { fn id(&self, name: &str) -> Option { self.id(name) } + + fn name(&self, id: FieldId) -> Option<&str> { + self.name(id) + } } #[derive(Debug, thiserror::Error)] diff --git a/crates/milli/src/documents/primary_key.rs b/crates/milli/src/documents/primary_key.rs index 9ac5ace91..fb8b3d027 100644 --- a/crates/milli/src/documents/primary_key.rs +++ b/crates/milli/src/documents/primary_key.rs @@ -1,8 +1,14 @@ use std::iter; +use std::ops::ControlFlow; use std::result::Result as StdResult; +use bumpalo::Bump; +use serde_json::value::RawValue; use serde_json::Value; +use crate::fields_ids_map::MutFieldIdMapper; +use crate::update::new::indexer::de::{match_component, DeOrBumpStr}; +use crate::update::new::KvReaderFieldId; use crate::{FieldId, InternalError, Object, Result, UserError}; /// The symbol used to define levels in a nested primary key. @@ -17,6 +23,21 @@ pub trait FieldIdMapper { /// /// `None` if the field with this name was not found. fn id(&self, name: &str) -> Option; + + fn name(&self, id: FieldId) -> Option<&str>; +} + +impl FieldIdMapper for &T +where + T: FieldIdMapper, +{ + fn id(&self, name: &str) -> Option { + T::id(self, name) + } + + fn name(&self, id: FieldId) -> Option<&str> { + T::name(self, id) + } } /// A type that represent the type of primary key that has been set @@ -43,7 +64,19 @@ impl<'a> PrimaryKey<'a> { }) } - pub fn name(&self) -> &str { + pub fn new_or_insert( + path: &'a str, + fields: &mut impl MutFieldIdMapper, + ) -> StdResult { + Ok(if path.contains(PRIMARY_KEY_SPLIT_SYMBOL) { + Self::Nested { name: path } + } else { + let field_id = fields.insert(path).ok_or(UserError::AttributeLimitReached)?; + Self::Flat { name: path, field_id } + }) + } + + pub fn name(&self) -> &'a str { match self { PrimaryKey::Flat { name, .. } => name, PrimaryKey::Nested { name } => name, @@ -52,7 +85,7 @@ impl<'a> PrimaryKey<'a> { pub fn document_id( &self, - document: &obkv::KvReader<'_, FieldId>, + document: &obkv::KvReader, fields: &impl FieldIdMapper, ) -> Result> { match self { @@ -100,9 +133,105 @@ impl<'a> PrimaryKey<'a> { } } + pub fn extract_docid_from_db<'pl, 'bump: 'pl, Mapper: FieldIdMapper>( + &self, + document: &'pl KvReaderFieldId, + db_fields_ids_map: &Mapper, + indexer: &'bump Bump, + ) -> Result> { + use serde::Deserializer as _; + + match self { + PrimaryKey::Flat { name: _, field_id } => { + let Some(document_id) = document.get(*field_id) else { + return Err(InternalError::DocumentsError( + crate::documents::Error::InvalidDocumentFormat, + ) + .into()); + }; + + let document_id: &RawValue = + serde_json::from_slice(document_id).map_err(InternalError::SerdeJson)?; + + let document_id = document_id + .deserialize_any(crate::update::new::indexer::de::DocumentIdVisitor(indexer)) + .map_err(InternalError::SerdeJson)?; + + let external_document_id = match document_id { + Ok(document_id) => Ok(document_id), + Err(_) => Err(InternalError::DocumentsError( + crate::documents::Error::InvalidDocumentFormat, + )), + }?; + + Ok(external_document_id) + } + nested @ PrimaryKey::Nested { name: _ } => { + let mut docid = None; + for (first_level, right) in nested.possible_level_names() { + let Some(fid) = db_fields_ids_map.id(first_level) else { continue }; + + let Some(value) = document.get(fid) else { continue }; + let value: &RawValue = + serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + match match_component(first_level, right, value, indexer, &mut docid) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(Ok(_)) => { + return Err(InternalError::DocumentsError( + crate::documents::Error::InvalidDocumentFormat, + ) + .into()) + } + ControlFlow::Break(Err(err)) => { + return Err(InternalError::SerdeJson(err).into()) + } + } + } + Ok(docid.ok_or(InternalError::DocumentsError( + crate::documents::Error::InvalidDocumentFormat, + ))?) + } + } + } + + pub fn extract_fields_and_docid<'pl, 'bump: 'pl, Mapper: MutFieldIdMapper>( + &self, + document: &'pl RawValue, + new_fields_ids_map: &mut Mapper, + indexer: &'bump Bump, + ) -> Result> { + use serde::Deserializer as _; + let res = document + .deserialize_map(crate::update::new::indexer::de::FieldAndDocidExtractor::new( + new_fields_ids_map, + self, + indexer, + )) + .map_err(UserError::SerdeJson)??; + + let external_document_id = match res { + Ok(document_id) => Ok(document_id), + Err(DocumentIdExtractionError::InvalidDocumentId(e)) => Err(e), + Err(DocumentIdExtractionError::MissingDocumentId) => { + Err(UserError::MissingDocumentId { + primary_key: self.name().to_string(), + document: serde_json::from_str(document.get()).unwrap(), + }) + } + Err(DocumentIdExtractionError::TooManyDocumentIds(_)) => { + Err(UserError::TooManyDocumentIds { + primary_key: self.name().to_string(), + document: serde_json::from_str(document.get()).unwrap(), + }) + } + }?; + + Ok(external_document_id) + } + /// Returns an `Iterator` that gives all the possible fields names the primary key /// can have depending of the first level name and depth of the objects. - pub fn possible_level_names(&self) -> impl Iterator + '_ { + pub fn possible_level_names(&self) -> impl Iterator + '_ { let name = self.name(); name.match_indices(PRIMARY_KEY_SPLIT_SYMBOL) .map(move |(i, _)| (&name[..i], &name[i + PRIMARY_KEY_SPLIT_SYMBOL.len_utf8()..])) @@ -149,7 +278,7 @@ fn starts_with(selector: &str, key: &str) -> bool { // FIXME: move to a DocumentId struct -fn validate_document_id(document_id: &str) -> Option<&str> { +pub fn validate_document_id_str(document_id: &str) -> Option<&str> { if document_id.is_empty() || document_id.len() > 512 || !document_id.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') @@ -162,7 +291,7 @@ fn validate_document_id(document_id: &str) -> Option<&str> { pub fn validate_document_id_value(document_id: Value) -> StdResult { match document_id { - Value::String(string) => match validate_document_id(&string) { + Value::String(string) => match validate_document_id_str(&string) { Some(s) if s.len() == string.len() => Ok(string), Some(s) => Ok(s.to_string()), None => Err(UserError::InvalidDocumentId { document_id: Value::String(string) }), diff --git a/crates/milli/src/documents/reader.rs b/crates/milli/src/documents/reader.rs index c7c125c80..20e932805 100644 --- a/crates/milli/src/documents/reader.rs +++ b/crates/milli/src/documents/reader.rs @@ -72,15 +72,24 @@ impl DocumentsBatchCursor { } impl DocumentsBatchCursor { + /// Returns a single document from the database. + pub fn get( + &mut self, + offset: u32, + ) -> Result>, DocumentsBatchCursorError> { + match self.cursor.move_on_key_equal_to(offset.to_be_bytes())? { + Some((key, value)) if key != DOCUMENTS_BATCH_INDEX_KEY => Ok(Some(value.into())), + _otherwise => Ok(None), + } + } + /// Returns the next document, starting from the first one. Subsequent calls to /// `next_document` advance the document reader until all the documents have been read. pub fn next_document( &mut self, - ) -> Result>, DocumentsBatchCursorError> { + ) -> Result>, DocumentsBatchCursorError> { match self.cursor.move_on_next()? { - Some((key, value)) if key != DOCUMENTS_BATCH_INDEX_KEY => { - Ok(Some(KvReader::new(value))) - } + Some((key, value)) if key != DOCUMENTS_BATCH_INDEX_KEY => Ok(Some(value.into())), _otherwise => Ok(None), } } diff --git a/crates/milli/src/error.rs b/crates/milli/src/error.rs index 3b48b50f2..3a9c81e10 100644 --- a/crates/milli/src/error.rs +++ b/crates/milli/src/error.rs @@ -31,23 +31,23 @@ pub enum Error { pub enum InternalError { #[error("{}", HeedError::DatabaseClosing)] DatabaseClosing, - #[error("Missing {} in the {db_name} database.", key.unwrap_or("key"))] + #[error("missing {} in the {db_name} database", key.unwrap_or("key"))] DatabaseMissingEntry { db_name: &'static str, key: Option<&'static str> }, - #[error("Missing {key} in the fieldids weights mapping.")] + #[error("missing {key} in the fieldids weights mapping")] FieldidsWeightsMapMissingEntry { key: FieldId }, #[error(transparent)] FieldIdMapMissingEntry(#[from] FieldIdMapMissingEntry), - #[error("Missing {key} in the field id mapping.")] + #[error("missing {key} in the field id mapping")] FieldIdMappingMissingEntry { key: FieldId }, #[error(transparent)] Fst(#[from] fst::Error), #[error(transparent)] DocumentsError(#[from] documents::Error), - #[error("Invalid compression type have been specified to grenad")] + #[error("invalid compression type have been specified to grenad")] GrenadInvalidCompressionType, - #[error("Invalid grenad file with an invalid version format")] + #[error("invalid grenad file with an invalid version format")] GrenadInvalidFormatVersion, - #[error("Invalid merge while processing {process}")] + #[error("invalid merge while processing {process}")] IndexingMergingKeys { process: &'static str }, #[error(transparent)] RayonThreadPool(#[from] ThreadPoolBuildError), @@ -122,7 +122,7 @@ and can not be more than 512 bytes.", .document_id.to_string() #[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")] InvalidVectorsMapType { document_id: String, value: Value }, #[error("Bad embedder configuration in the document with id: `{document_id}`. {error}")] - InvalidVectorsEmbedderConf { document_id: String, error: deserr::errors::JsonError }, + InvalidVectorsEmbedderConf { document_id: String, error: String }, #[error("{0}")] InvalidFilter(String), #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] diff --git a/crates/milli/src/fields_ids_map.rs b/crates/milli/src/fields_ids_map.rs index f9d7c3704..9a016e7bd 100644 --- a/crates/milli/src/fields_ids_map.rs +++ b/crates/milli/src/fields_ids_map.rs @@ -4,6 +4,10 @@ use serde::{Deserialize, Serialize}; use crate::FieldId; +mod global; +pub mod metadata; +pub use global::GlobalFieldsIdsMap; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FieldsIdsMap { names_ids: BTreeMap, @@ -95,6 +99,20 @@ impl crate::documents::FieldIdMapper for FieldsIdsMap { fn id(&self, name: &str) -> Option { self.id(name) } + + fn name(&self, id: FieldId) -> Option<&str> { + self.name(id) + } +} + +pub trait MutFieldIdMapper { + fn insert(&mut self, name: &str) -> Option; +} + +impl MutFieldIdMapper for FieldsIdsMap { + fn insert(&mut self, name: &str) -> Option { + self.insert(name) + } } #[cfg(test)] diff --git a/crates/milli/src/fields_ids_map/global.rs b/crates/milli/src/fields_ids_map/global.rs new file mode 100644 index 000000000..32aefbfdf --- /dev/null +++ b/crates/milli/src/fields_ids_map/global.rs @@ -0,0 +1,129 @@ +use std::collections::BTreeMap; +use std::sync::RwLock; + +use super::metadata::{FieldIdMapWithMetadata, Metadata}; +use super::MutFieldIdMapper; +use crate::documents::FieldIdMapper; +use crate::FieldId; + +/// A fields ids map that can be globally updated to add fields +#[derive(Debug, Clone)] +pub struct GlobalFieldsIdsMap<'indexing> { + global: &'indexing RwLock, + local: LocalFieldsIdsMap, +} + +#[derive(Debug, Clone)] +pub struct LocalFieldsIdsMap { + names_ids: BTreeMap, + ids_names: BTreeMap, + metadata: BTreeMap, +} + +impl FieldIdMapper for LocalFieldsIdsMap { + fn id(&self, name: &str) -> Option { + self.id(name) + } + + fn name(&self, id: FieldId) -> Option<&str> { + self.name(id) + } +} + +impl LocalFieldsIdsMap { + fn new(global: &RwLock) -> Self { + let global = global.read().unwrap(); + Self { + names_ids: global.as_fields_ids_map().names_ids.clone(), + ids_names: global.as_fields_ids_map().ids_names.clone(), + metadata: global.iter_id_metadata().collect(), + } + } + + fn insert(&mut self, name: &str, field_id: FieldId, metadata: Metadata) { + self.names_ids.insert(name.to_owned(), field_id); + self.ids_names.insert(field_id, name.to_owned()); + self.metadata.insert(field_id, metadata); + } + + fn name(&self, id: FieldId) -> Option<&str> { + self.ids_names.get(&id).map(String::as_str) + } + + fn id(&self, name: &str) -> Option { + self.names_ids.get(name).copied() + } + + fn id_with_metadata(&self, name: &str) -> Option<(FieldId, Metadata)> { + let id = self.id(name)?; + Some((id, self.metadata(id).unwrap())) + } + + fn metadata(&self, id: FieldId) -> Option { + self.metadata.get(&id).copied() + } +} + +impl<'indexing> GlobalFieldsIdsMap<'indexing> { + pub fn new(global: &'indexing RwLock) -> Self { + Self { local: LocalFieldsIdsMap::new(global), global } + } + + /// Returns the field id related to a field name, it will create a new field id if the + /// name is not already known. Returns `None` if the maximum field id as been reached. + pub fn id_or_insert(&mut self, name: &str) -> Option { + self.id_with_metadata_or_insert(name).map(|(fid, _meta)| fid) + } + + pub fn id_with_metadata_or_insert(&mut self, name: &str) -> Option<(FieldId, Metadata)> { + if let Some(entry) = self.local.id_with_metadata(name) { + return Some(entry); + } + + { + // optimistically lookup the global map + let global = self.global.read().unwrap(); + + if let Some((field_id, metadata)) = global.id_with_metadata(name) { + self.local.insert(name, field_id, metadata); + return Some((field_id, metadata)); + } + } + + { + let mut global = self.global.write().unwrap(); + + if let Some((field_id, metadata)) = global.id_with_metadata(name) { + self.local.insert(name, field_id, metadata); + return Some((field_id, metadata)); + } + + let field_id = global.insert(name)?; + let metadata = global.metadata(field_id).unwrap(); + self.local.insert(name, field_id, metadata); + Some((field_id, metadata)) + } + } + + /// Get the name of a field based on its id. + pub fn name(&mut self, id: FieldId) -> Option<&str> { + if self.local.name(id).is_none() { + let global = self.global.read().unwrap(); + + let (name, metadata) = global.name_with_metadata(id)?; + self.local.insert(name, id, metadata); + } + + self.local.name(id) + } + + pub fn local_map(&self) -> &LocalFieldsIdsMap { + &self.local + } +} + +impl<'indexing> MutFieldIdMapper for GlobalFieldsIdsMap<'indexing> { + fn insert(&mut self, name: &str) -> Option { + self.id_or_insert(name) + } +} diff --git a/crates/milli/src/fields_ids_map/metadata.rs b/crates/milli/src/fields_ids_map/metadata.rs new file mode 100644 index 000000000..54fdc7b4b --- /dev/null +++ b/crates/milli/src/fields_ids_map/metadata.rs @@ -0,0 +1,183 @@ +use std::collections::{BTreeMap, HashSet}; +use std::num::NonZeroU16; + +use charabia::Language; +use heed::RoTxn; + +use super::FieldsIdsMap; +use crate::{FieldId, Index, LocalizedAttributesRule, Result}; + +#[derive(Debug, Clone, Copy)] +pub struct Metadata { + pub searchable: bool, + pub filterable: bool, + pub sortable: bool, + localized_attributes_rule_id: Option, +} + +#[derive(Debug, Clone)] +pub struct FieldIdMapWithMetadata { + fields_ids_map: FieldsIdsMap, + builder: MetadataBuilder, + metadata: BTreeMap, +} + +impl FieldIdMapWithMetadata { + pub fn new(existing_fields_ids_map: FieldsIdsMap, builder: MetadataBuilder) -> Self { + let metadata = existing_fields_ids_map + .iter() + .map(|(id, name)| (id, builder.metadata_for_field(name))) + .collect(); + Self { fields_ids_map: existing_fields_ids_map, builder, metadata } + } + + pub fn as_fields_ids_map(&self) -> &FieldsIdsMap { + &self.fields_ids_map + } + + /// Returns the number of fields ids in the map. + pub fn len(&self) -> usize { + self.fields_ids_map.len() + } + + /// Returns `true` if the map is empty. + pub fn is_empty(&self) -> bool { + self.fields_ids_map.is_empty() + } + + /// Returns the field id related to a field name, it will create a new field id if the + /// name is not already known. Returns `None` if the maximum field id as been reached. + pub fn insert(&mut self, name: &str) -> Option { + let id = self.fields_ids_map.insert(name)?; + self.metadata.insert(id, self.builder.metadata_for_field(name)); + Some(id) + } + + /// Get the id of a field based on its name. + pub fn id(&self, name: &str) -> Option { + self.fields_ids_map.id(name) + } + + pub fn id_with_metadata(&self, name: &str) -> Option<(FieldId, Metadata)> { + let id = self.fields_ids_map.id(name)?; + Some((id, self.metadata(id).unwrap())) + } + + /// Get the name of a field based on its id. + pub fn name(&self, id: FieldId) -> Option<&str> { + self.fields_ids_map.name(id) + } + + /// Get the name of a field based on its id. + pub fn name_with_metadata(&self, id: FieldId) -> Option<(&str, Metadata)> { + let name = self.fields_ids_map.name(id)?; + Some((name, self.metadata(id).unwrap())) + } + + pub fn metadata(&self, id: FieldId) -> Option { + self.metadata.get(&id).copied() + } + + /// Iterate over the ids and names in the ids order. + pub fn iter(&self) -> impl Iterator { + self.fields_ids_map.iter().map(|(id, name)| (id, name, self.metadata(id).unwrap())) + } + + pub fn iter_id_metadata(&self) -> impl Iterator + '_ { + self.metadata.iter().map(|(k, v)| (*k, *v)) + } + + pub fn iter_metadata(&self) -> impl Iterator + '_ { + self.metadata.values().copied() + } + + pub fn metadata_builder(&self) -> &MetadataBuilder { + &self.builder + } +} + +impl Metadata { + pub fn locales<'rules>( + &self, + rules: &'rules [LocalizedAttributesRule], + ) -> Option<&'rules [Language]> { + let localized_attributes_rule_id = self.localized_attributes_rule_id?.get(); + let rule = rules.get((localized_attributes_rule_id - 1) as usize).unwrap(); + Some(rule.locales()) + } +} + +#[derive(Debug, Clone)] +pub struct MetadataBuilder { + searchable_attributes: Vec, + filterable_attributes: HashSet, + sortable_attributes: HashSet, + localized_attributes: Option>, +} + +impl MetadataBuilder { + pub fn from_index(index: &Index, rtxn: &RoTxn) -> Result { + let searchable_attributes = + index.searchable_fields(rtxn)?.into_iter().map(|s| s.to_string()).collect(); + let filterable_attributes = index.filterable_fields(rtxn)?; + let sortable_attributes = index.sortable_fields(rtxn)?; + let localized_attributes = index.localized_attributes_rules(rtxn)?; + + Ok(Self { + searchable_attributes, + filterable_attributes, + sortable_attributes, + localized_attributes, + }) + } + + pub fn new( + searchable_attributes: Vec, + filterable_attributes: HashSet, + sortable_attributes: HashSet, + localized_attributes: Option>, + ) -> Self { + Self { + searchable_attributes, + filterable_attributes, + sortable_attributes, + localized_attributes, + } + } + + pub fn metadata_for_field(&self, field: &str) -> Metadata { + let searchable = self + .searchable_attributes + .iter() + .any(|attribute| attribute == "*" || attribute == field); + + let filterable = self.filterable_attributes.contains(field); + + let sortable = self.sortable_attributes.contains(field); + + let localized_attributes_rule_id = self + .localized_attributes + .iter() + .flat_map(|v| v.iter()) + .position(|rule| rule.match_str(field)) + .map(|id| NonZeroU16::new(id.saturating_add(1).try_into().unwrap()).unwrap()); + + Metadata { searchable, filterable, sortable, localized_attributes_rule_id } + } + + pub fn searchable_attributes(&self) -> &[String] { + self.searchable_attributes.as_slice() + } + + pub fn sortable_attributes(&self) -> &HashSet { + &self.sortable_attributes + } + + pub fn filterable_attributes(&self) -> &HashSet { + &self.filterable_attributes + } + + pub fn localized_attributes_rules(&self) -> Option<&[LocalizedAttributesRule]> { + self.localized_attributes.as_deref() + } +} diff --git a/crates/milli/src/heed_codec/facet/ordered_f64_codec.rs b/crates/milli/src/heed_codec/facet/ordered_f64_codec.rs index 4eccdb68b..19ba7a460 100644 --- a/crates/milli/src/heed_codec/facet/ordered_f64_codec.rs +++ b/crates/milli/src/heed_codec/facet/ordered_f64_codec.rs @@ -27,17 +27,34 @@ impl heed::BytesEncode<'_> for OrderedF64Codec { fn bytes_encode(f: &Self::EItem) -> Result, BoxedError> { let mut buffer = [0u8; 16]; - // write the globally ordered float - let bytes = f64_into_bytes(*f).ok_or(InvalidGloballyOrderedFloatError { float: *f })?; - buffer[..8].copy_from_slice(&bytes[..]); - // Then the f64 value just to be able to read it back - let bytes = f.to_be_bytes(); - buffer[8..16].copy_from_slice(&bytes[..]); + encode_f64_into_ordered_bytes(*f, &mut buffer)?; Ok(Cow::Owned(buffer.to_vec())) } } +impl OrderedF64Codec { + pub fn serialize_into( + f: f64, + buffer: &mut [u8; 16], + ) -> Result<(), InvalidGloballyOrderedFloatError> { + encode_f64_into_ordered_bytes(f, buffer) + } +} + +fn encode_f64_into_ordered_bytes( + f: f64, + buffer: &mut [u8; 16], +) -> Result<(), InvalidGloballyOrderedFloatError> { + let bytes = f64_into_bytes(f).ok_or(InvalidGloballyOrderedFloatError { float: f })?; + buffer[..8].copy_from_slice(&bytes[..]); + // Then the f64 value just to be able to read it back + let bytes = f.to_be_bytes(); + buffer[8..16].copy_from_slice(&bytes[..]); + + Ok(()) +} + #[derive(Error, Debug)] #[error("the float {float} cannot be converted to a globally ordered representation")] pub struct InvalidGloballyOrderedFloatError { diff --git a/crates/milli/src/heed_codec/obkv_codec.rs b/crates/milli/src/heed_codec/obkv_codec.rs index 390a57af3..447323571 100644 --- a/crates/milli/src/heed_codec/obkv_codec.rs +++ b/crates/milli/src/heed_codec/obkv_codec.rs @@ -6,10 +6,10 @@ use obkv::{KvReaderU16, KvWriterU16}; pub struct ObkvCodec; impl<'a> heed::BytesDecode<'a> for ObkvCodec { - type DItem = KvReaderU16<'a>; + type DItem = &'a KvReaderU16; fn bytes_decode(bytes: &'a [u8]) -> Result { - Ok(KvReaderU16::new(bytes)) + Ok(KvReaderU16::from_slice(bytes)) } } diff --git a/crates/milli/src/heed_codec/roaring_bitmap/cbo_roaring_bitmap_codec.rs b/crates/milli/src/heed_codec/roaring_bitmap/cbo_roaring_bitmap_codec.rs index fa65d5217..257d5bd0a 100644 --- a/crates/milli/src/heed_codec/roaring_bitmap/cbo_roaring_bitmap_codec.rs +++ b/crates/milli/src/heed_codec/roaring_bitmap/cbo_roaring_bitmap_codec.rs @@ -122,7 +122,7 @@ impl CboRoaringBitmapCodec { /// Merges a DelAdd delta into a CboRoaringBitmap. pub fn merge_deladd_into<'a>( - deladd: KvReaderDelAdd<'_>, + deladd: &KvReaderDelAdd, previous: &[u8], buffer: &'a mut Vec, ) -> io::Result> { diff --git a/crates/milli/src/index.rs b/crates/milli/src/index.rs index 5b7a9c58c..08a8e36f8 100644 --- a/crates/milli/src/index.rs +++ b/crates/milli/src/index.rs @@ -1251,12 +1251,20 @@ impl Index { /* documents */ + /// Returns a document by using the document id. + pub fn document<'t>(&self, rtxn: &'t RoTxn, id: DocumentId) -> Result<&'t obkv::KvReaderU16> { + self.documents + .get(rtxn, &id)? + .ok_or(UserError::UnknownInternalDocumentId { document_id: id }) + .map_err(Into::into) + } + /// Returns an iterator over the requested documents. The next item will be an error if a document is missing. pub fn iter_documents<'a, 't: 'a>( &'a self, rtxn: &'t RoTxn<'t>, ids: impl IntoIterator + 'a, - ) -> Result)>> + 'a> { + ) -> Result> + 'a> { Ok(ids.into_iter().map(move |id| { let kv = self .documents @@ -1271,7 +1279,7 @@ impl Index { &self, rtxn: &'t RoTxn<'t>, ids: impl IntoIterator, - ) -> Result)>> { + ) -> Result> { self.iter_documents(rtxn, ids)?.collect() } @@ -1279,7 +1287,7 @@ impl Index { pub fn all_documents<'a, 't: 'a>( &'a self, rtxn: &'t RoTxn<'t>, - ) -> Result)>> + 'a> { + ) -> Result> + 'a> { self.iter_documents(rtxn, self.documents_ids(rtxn)?) } @@ -1303,7 +1311,7 @@ impl Index { })?; Ok(self.iter_documents(rtxn, ids)?.map(move |entry| -> Result<_> { let (_docid, obkv) = entry?; - match primary_key.document_id(&obkv, &fields)? { + match primary_key.document_id(obkv, &fields)? { Ok(document_id) => Ok(document_id), Err(_) => Err(InternalError::DocumentsError( crate::documents::Error::InvalidDocumentFormat, @@ -1638,6 +1646,14 @@ impl Index { } Ok(res) } + + pub fn prefix_settings(&self, _rtxn: &RoTxn<'_>) -> Result { + Ok(PrefixSettings { + compute_prefixes: true, + max_prefix_length: 4, + prefix_count_threshold: 100, + }) + } } #[derive(Debug, Deserialize, Serialize)] @@ -1647,6 +1663,13 @@ pub struct IndexEmbeddingConfig { pub user_provided: RoaringBitmap, } +#[derive(Debug, Deserialize, Serialize)] +pub struct PrefixSettings { + pub prefix_count_threshold: u64, + pub max_prefix_length: usize, + pub compute_prefixes: bool, +} + #[derive(Serialize, Deserialize)] #[serde(transparent)] struct OffsetDateTime(#[serde(with = "time::serde::rfc3339")] time::OffsetDateTime); diff --git a/crates/milli/src/lib.rs b/crates/milli/src/lib.rs index 8008b7bd1..48b03b6cc 100644 --- a/crates/milli/src/lib.rs +++ b/crates/milli/src/lib.rs @@ -55,7 +55,7 @@ pub use self::error::{ }; pub use self::external_documents_ids::ExternalDocumentsIds; pub use self::fieldids_weights_map::FieldidsWeightsMap; -pub use self::fields_ids_map::FieldsIdsMap; +pub use self::fields_ids_map::{FieldsIdsMap, GlobalFieldsIdsMap}; pub use self::heed_codec::{ BEU16StrCodec, BEU32StrCodec, BoRoaringBitmapCodec, BoRoaringBitmapLenCodec, CboRoaringBitmapCodec, CboRoaringBitmapLenCodec, FieldIdWordCountCodec, ObkvCodec, @@ -88,6 +88,7 @@ pub type Object = serde_json::Map; pub type Position = u32; pub type RelativePosition = u16; pub type SmallString32 = smallstr::SmallString<[u8; 32]>; +pub type Prefix = smallstr::SmallString<[u8; 16]>; pub type SmallVec16 = smallvec::SmallVec<[T; 16]>; pub type SmallVec32 = smallvec::SmallVec<[T; 32]>; pub type SmallVec8 = smallvec::SmallVec<[T; 8]>; @@ -214,7 +215,7 @@ pub fn bucketed_position(relative: u16) -> u16 { pub fn obkv_to_json( displayed_fields: &[FieldId], fields_ids_map: &FieldsIdsMap, - obkv: obkv::KvReaderU16<'_>, + obkv: &obkv::KvReaderU16, ) -> Result { displayed_fields .iter() @@ -232,10 +233,7 @@ pub fn obkv_to_json( } /// Transform every field of a raw obkv store into a JSON Object. -pub fn all_obkv_to_json( - obkv: obkv::KvReaderU16<'_>, - fields_ids_map: &FieldsIdsMap, -) -> Result { +pub fn all_obkv_to_json(obkv: &obkv::KvReaderU16, fields_ids_map: &FieldsIdsMap) -> Result { let all_keys = obkv.iter().map(|(k, _v)| k).collect::>(); obkv_to_json(all_keys.as_slice(), fields_ids_map, obkv) } @@ -434,7 +432,7 @@ mod tests { writer.insert(id1, b"1234").unwrap(); writer.insert(id2, b"4321").unwrap(); let contents = writer.into_inner().unwrap(); - let obkv = obkv::KvReaderU16::new(&contents); + let obkv = obkv::KvReaderU16::from_slice(&contents); let expected = json!({ "field1": 1234, diff --git a/crates/milli/src/prompt/context.rs b/crates/milli/src/prompt/context.rs index 7ab08301a..02258d067 100644 --- a/crates/milli/src/prompt/context.rs +++ b/crates/milli/src/prompt/context.rs @@ -3,23 +3,19 @@ use liquid::model::{ }; use liquid::{ObjectView, ValueView}; -use super::document::Document; -use super::fields::Fields; -use super::FieldsIdsMapWithMetadata; - #[derive(Debug, Clone)] -pub struct Context<'a> { - document: &'a Document<'a>, - fields: Fields<'a>, +pub struct Context<'a, D: ObjectView, F: ArrayView> { + document: &'a D, + fields: &'a F, } -impl<'a> Context<'a> { - pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMapWithMetadata<'a>) -> Self { - Self { document, fields: Fields::new(document, field_id_map) } +impl<'a, D: ObjectView, F: ArrayView> Context<'a, D, F> { + pub fn new(document: &'a D, fields: &'a F) -> Self { + Self { document, fields } } } -impl<'a> ObjectView for Context<'a> { +impl<'a, D: ObjectView, F: ArrayView> ObjectView for Context<'a, D, F> { fn as_value(&self) -> &dyn ValueView { self } @@ -56,7 +52,7 @@ impl<'a> ObjectView for Context<'a> { } } -impl<'a> ValueView for Context<'a> { +impl<'a, D: ObjectView, F: ArrayView> ValueView for Context<'a, D, F> { fn as_debug(&self) -> &dyn std::fmt::Debug { self } diff --git a/crates/milli/src/prompt/document.rs b/crates/milli/src/prompt/document.rs index b5d43b5be..28c0f47af 100644 --- a/crates/milli/src/prompt/document.rs +++ b/crates/milli/src/prompt/document.rs @@ -1,10 +1,15 @@ use std::cell::OnceCell; use std::collections::BTreeMap; +use std::fmt::{self, Debug}; +use bumpalo::Bump; use liquid::model::{ - DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, + ArrayView, DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, ScalarCow, State, + Value as LiquidValue, }; use liquid::{ObjectView, ValueView}; +use raw_collections::{RawMap, RawVec}; +use serde_json::value::RawValue; use crate::update::del_add::{DelAdd, KvReaderDelAdd}; use crate::FieldsIdsMap; @@ -30,13 +35,13 @@ impl ParsedValue { impl<'a> Document<'a> { pub fn new( - data: obkv::KvReaderU16<'a>, + data: &'a obkv::KvReaderU16, side: DelAdd, inverted_field_map: &'a FieldsIdsMap, ) -> Self { let mut out_data = BTreeMap::new(); for (fid, raw) in data { - let obkv = KvReaderDelAdd::new(raw); + let obkv = KvReaderDelAdd::from_slice(raw); let Some(raw) = obkv.get(side) else { continue; }; @@ -93,7 +98,7 @@ impl<'a> ObjectView for Document<'a> { } impl<'a> ValueView for Document<'a> { - fn as_debug(&self) -> &dyn std::fmt::Debug { + fn as_debug(&self) -> &dyn Debug { self } @@ -128,4 +133,515 @@ impl<'a> ValueView for Document<'a> { fn as_object(&self) -> Option<&dyn ObjectView> { Some(self) } + + fn is_object(&self) -> bool { + true + } +} + +/// Implementation for any type that implements the Document trait +use crate::update::new::document::Document as DocumentTrait; + +#[derive(Debug)] +pub struct ParseableDocument<'doc, D> { + document: D, + doc_alloc: &'doc Bump, +} + +impl<'doc, D> ParseableDocument<'doc, D> { + pub fn new(document: D, doc_alloc: &'doc Bump) -> Self { + Self { document, doc_alloc } + } +} + +impl<'doc, D: DocumentTrait<'doc> + Debug> ObjectView for ParseableDocument<'doc, D> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.document.len() as i64 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(self.document.iter_top_level_fields().map(|res| { + let (field, _) = res.unwrap(); + KStringCow::from_ref(field) + })) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.document.iter_top_level_fields().map(|res| { + let (_, value) = res.unwrap(); + ParseableValue::new_bump(value, self.doc_alloc) as _ + })) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.document.iter_top_level_fields().map(|res| { + let (field, value) = res.unwrap(); + (KStringCow::from_ref(field), ParseableValue::new_bump(value, self.doc_alloc) as _) + })) + } + + fn contains_key(&self, index: &str) -> bool { + self.document.top_level_field(index).unwrap().is_some() + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + let s = self.document.top_level_field(index).unwrap()?; + Some(ParseableValue::new_bump(s, self.doc_alloc)) + } +} + +impl<'doc, D: DocumentTrait<'doc> + Debug> ValueView for ParseableDocument<'doc, D> { + fn as_debug(&self) -> &dyn fmt::Debug { + self + } + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.document + .iter_top_level_fields() + .map(|res| { + let (k, v) = res.unwrap(); + (k.to_string().into(), ParseableValue::new(v, self.doc_alloc).to_value()) + }) + .collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } + + fn is_object(&self) -> bool { + true + } +} + +#[derive(Debug)] +struct ParseableValue<'doc> { + value: raw_collections::Value<'doc>, +} + +impl<'doc> ParseableValue<'doc> { + pub fn new(value: &'doc RawValue, doc_alloc: &'doc Bump) -> Self { + let value = raw_collections::Value::from_raw_value(value, doc_alloc).unwrap(); + Self { value } + } + + pub fn new_bump(value: &'doc RawValue, doc_alloc: &'doc Bump) -> &'doc Self { + doc_alloc.alloc(Self::new(value, doc_alloc)) + } +} + +// transparent newtype for implementing ValueView +#[repr(transparent)] +#[derive(Debug)] +struct ParseableMap<'doc>(RawMap<'doc>); + +// transparent newtype for implementing ValueView +#[repr(transparent)] +#[derive(Debug)] +struct ParseableArray<'doc>(RawVec<'doc>); + +impl<'doc> ParseableMap<'doc> { + pub fn as_parseable<'a>(map: &'a RawMap<'doc>) -> &'a ParseableMap<'doc> { + // SAFETY: repr(transparent) + unsafe { &*(map as *const RawMap as *const Self) } + } +} + +impl<'doc> ParseableArray<'doc> { + pub fn as_parseable<'a>(array: &'a RawVec<'doc>) -> &'a ParseableArray<'doc> { + // SAFETY: repr(transparent) + unsafe { &*(array as *const RawVec as *const Self) } + } +} + +impl<'doc> ArrayView for ParseableArray<'doc> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.0.len() as _ + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.0.iter().map(|v| ParseableValue::new_bump(v, self.0.bump()) as _)) + } + + fn contains_key(&self, index: i64) -> bool { + let index = convert_index(index, self.size()); + index < self.size() && index >= 0 + } + + fn get(&self, index: i64) -> Option<&dyn ValueView> { + let index = convert_index(index, self.size()); + if index <= 0 { + return None; + } + let v = self.0.get(index as usize)?; + Some(ParseableValue::new_bump(v, self.0.bump())) + } +} + +impl<'doc> ValueView for ParseableArray<'doc> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DisplayCow::Owned(Box::new(ArrayRender { s: &self.0 })) + } + + fn source(&self) -> DisplayCow<'_> { + DisplayCow::Owned(Box::new(ArraySource { s: &self.0 })) + } + + fn type_name(&self) -> &'static str { + "array" + } + + fn query_state(&self, state: State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.0.is_empty(), + } + } + + fn to_kstr(&self) -> KStringCow<'_> { + let s = ArrayRender { s: &self.0 }.to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Array(self.values().map(|v| v.to_value()).collect()) + } + + fn is_array(&self) -> bool { + true + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self as _) + } +} + +impl<'doc> ObjectView for ParseableMap<'doc> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.0.len() as i64 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(self.0.keys().map(Into::into)) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.0.values().map(|value| { + let doc_alloc = self.0.bump(); + ParseableValue::new_bump(value, doc_alloc) as _ + })) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.0.iter().map(|(k, v)| { + let doc_alloc = self.0.bump(); + (k.into(), ParseableValue::new_bump(v, doc_alloc) as _) + })) + } + + fn contains_key(&self, index: &str) -> bool { + self.0.get(index).is_some() + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + let v = self.0.get(index)?; + let doc_alloc = self.0.bump(); + let value = ParseableValue::new(v, doc_alloc); + Some(doc_alloc.alloc(value) as _) + } +} + +impl<'doc> ValueView for ParseableMap<'doc> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.0.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.0 + .iter() + .map(|(k, v)| { + (k.to_string().into(), ParseableValue::new(v, self.0.bump()).to_value()) + }) + .collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } + + fn is_object(&self) -> bool { + true + } +} + +impl<'doc> ValueView for ParseableValue<'doc> { + fn as_debug(&self) -> &dyn Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + use raw_collections::value::Number; + use raw_collections::Value; + match &self.value { + Value::Null => LiquidValue::Nil.render(), + Value::Bool(v) => v.render(), + Value::Number(number) => match number { + Number::PosInt(x) => DisplayCow::Borrowed(x), + Number::NegInt(x) => x.render(), + Number::Finite(x) => x.render(), + }, + Value::String(s) => s.render(), + Value::Array(raw_vec) => ParseableArray::as_parseable(raw_vec).render(), + Value::Object(raw_map) => ParseableMap::as_parseable(raw_map).render(), + } + } + + fn source(&self) -> DisplayCow<'_> { + use raw_collections::value::Number; + use raw_collections::Value; + match &self.value { + Value::Null => LiquidValue::Nil.source(), + Value::Bool(v) => ValueView::source(v), + Value::Number(number) => match number { + Number::PosInt(x) => DisplayCow::Borrowed(x), + Number::NegInt(x) => x.source(), + Number::Finite(x) => x.source(), + }, + Value::String(s) => s.source(), + Value::Array(raw_vec) => ParseableArray::as_parseable(raw_vec).source(), + Value::Object(raw_map) => ParseableMap::as_parseable(raw_map).source(), + } + } + + fn type_name(&self) -> &'static str { + use raw_collections::value::Number; + use raw_collections::Value; + match &self.value { + Value::Null => LiquidValue::Nil.type_name(), + Value::Bool(v) => v.type_name(), + Value::Number(number) => match number { + Number::PosInt(_x) => "whole positive number", + Number::NegInt(x) => x.type_name(), + Number::Finite(x) => x.type_name(), + }, + Value::String(s) => s.type_name(), + Value::Array(_raw_vec) => "array", + Value::Object(_raw_map) => "object", + } + } + + fn query_state(&self, state: State) -> bool { + use raw_collections::Value; + match &self.value { + Value::Null => ValueView::query_state(&LiquidValue::Nil, state), + Value::Bool(v) => ValueView::query_state(v, state), + Value::Number(_number) => match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + }, + Value::String(s) => ValueView::query_state(s, state), + Value::Array(raw_vec) => ParseableArray::as_parseable(raw_vec).query_state(state), + Value::Object(raw_map) => ParseableMap::as_parseable(raw_map).query_state(state), + } + } + + fn to_kstr(&self) -> KStringCow<'_> { + use raw_collections::Value; + match &self.value { + Value::Null => ValueView::to_kstr(&LiquidValue::Nil), + Value::Bool(v) => ValueView::to_kstr(v), + Value::Number(_number) => self.render().to_string().into(), + Value::String(s) => KStringCow::from_ref(s), + Value::Array(raw_vec) => ParseableArray::as_parseable(raw_vec).to_kstr(), + Value::Object(raw_map) => ParseableMap::as_parseable(raw_map).to_kstr(), + } + } + + fn to_value(&self) -> LiquidValue { + use raw_collections::Value; + match &self.value { + Value::Null => LiquidValue::Nil, + Value::Bool(v) => LiquidValue::Scalar(liquid::model::ScalarCow::new(*v)), + Value::Number(number) => match number { + raw_collections::value::Number::PosInt(number) => { + let number: i64 = match (*number).try_into() { + Ok(number) => number, + Err(_) => { + return LiquidValue::Scalar(ScalarCow::new(self.render().to_string())) + } + }; + LiquidValue::Scalar(ScalarCow::new(number)) + } + raw_collections::value::Number::NegInt(number) => { + LiquidValue::Scalar(ScalarCow::new(*number)) + } + raw_collections::value::Number::Finite(number) => { + LiquidValue::Scalar(ScalarCow::new(*number)) + } + }, + Value::String(s) => LiquidValue::Scalar(liquid::model::ScalarCow::new(s.to_string())), + Value::Array(raw_vec) => ParseableArray::as_parseable(raw_vec).to_value(), + Value::Object(raw_map) => ParseableMap::as_parseable(raw_map).to_value(), + } + } + + fn as_scalar(&self) -> Option> { + use raw_collections::value::Number; + use raw_collections::Value; + match &self.value { + Value::Bool(v) => Some(liquid::model::ScalarCow::new(*v)), + Value::Number(number) => match number { + Number::PosInt(number) => { + let number: i64 = match (*number).try_into() { + Ok(number) => number, + Err(_) => return Some(ScalarCow::new(self.render().to_string())), + }; + Some(ScalarCow::new(number)) + } + Number::NegInt(number) => Some(ScalarCow::new(*number)), + Number::Finite(number) => Some(ScalarCow::new(*number)), + }, + Value::String(s) => Some(ScalarCow::new(*s)), + _ => None, + } + } + + fn is_scalar(&self) -> bool { + use raw_collections::Value; + matches!(&self.value, Value::Bool(_) | Value::Number(_) | Value::String(_)) + } + + fn as_array(&self) -> Option<&dyn liquid::model::ArrayView> { + if let raw_collections::Value::Array(array) = &self.value { + return Some(ParseableArray::as_parseable(array) as _); + } + None + } + + fn is_array(&self) -> bool { + matches!(&self.value, raw_collections::Value::Array(_)) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + if let raw_collections::Value::Object(object) = &self.value { + return Some(ParseableMap::as_parseable(object) as _); + } + None + } + + fn is_object(&self) -> bool { + matches!(&self.value, raw_collections::Value::Object(_)) + } + + fn is_nil(&self) -> bool { + matches!(&self.value, raw_collections::Value::Null) + } +} + +struct ArraySource<'s, 'doc> { + s: &'s RawVec<'doc>, +} + +impl<'s, 'doc> fmt::Display for ArraySource<'s, 'doc> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for item in self.s { + let v = ParseableValue::new(item, self.s.bump()); + write!(f, "{}, ", v.render())?; + } + write!(f, "]")?; + Ok(()) + } +} + +struct ArrayRender<'s, 'doc> { + s: &'s RawVec<'doc>, +} + +impl<'s, 'doc> fmt::Display for ArrayRender<'s, 'doc> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for item in self.s { + let v = ParseableValue::new(item, self.s.bump()); + + write!(f, "{}", v.render())?; + } + Ok(()) + } +} + +fn convert_index(index: i64, max_size: i64) -> i64 { + if 0 <= index { + index + } else { + max_size + index + } } diff --git a/crates/milli/src/prompt/fields.rs b/crates/milli/src/prompt/fields.rs index 81ea88ca6..ab15c31b0 100644 --- a/crates/milli/src/prompt/fields.rs +++ b/crates/milli/src/prompt/fields.rs @@ -1,36 +1,23 @@ +use std::cell::RefCell; +use std::fmt; + +use bumpalo::Bump; use liquid::model::{ ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, }; use liquid::{ObjectView, ValueView}; -use super::document::Document; use super::{FieldMetadata, FieldsIdsMapWithMetadata}; -#[derive(Debug, Clone)] -pub struct Fields<'a>(Vec>); - -impl<'a> Fields<'a> { - pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMapWithMetadata<'a>) -> Self { - Self( - std::iter::repeat(document) - .zip(field_id_map.iter()) - .map(|(document, (fid, name))| FieldValue { - document, - name, - metadata: field_id_map.metadata(fid).unwrap_or_default(), - }) - .collect(), - ) - } -} +use crate::GlobalFieldsIdsMap; #[derive(Debug, Clone, Copy)] -pub struct FieldValue<'a> { +pub struct FieldValue<'a, D: ObjectView> { name: &'a str, - document: &'a Document<'a>, + document: &'a D, metadata: FieldMetadata, } -impl<'a> ValueView for FieldValue<'a> { +impl<'a, D: ObjectView> ValueView for FieldValue<'a, D> { fn as_debug(&self) -> &dyn std::fmt::Debug { self } @@ -70,7 +57,7 @@ impl<'a> ValueView for FieldValue<'a> { } } -impl<'a> FieldValue<'a> { +impl<'a, D: ObjectView> FieldValue<'a, D> { pub fn name(&self) -> &&'a str { &self.name } @@ -88,7 +75,7 @@ impl<'a> FieldValue<'a> { } } -impl<'a> ObjectView for FieldValue<'a> { +impl<'a, D: ObjectView> ObjectView for FieldValue<'a, D> { fn as_value(&self) -> &dyn ValueView { self } @@ -127,7 +114,42 @@ impl<'a> ObjectView for FieldValue<'a> { } } -impl<'a> ArrayView for Fields<'a> { +#[derive(Debug, Clone)] +pub struct OwnedFields<'a, D: ObjectView>(Vec>); + +#[derive(Debug)] +pub struct BorrowedFields<'a, 'map, D: ObjectView> { + document: &'a D, + field_id_map: &'a RefCell>, + doc_alloc: &'a Bump, +} + +impl<'a, D: ObjectView> OwnedFields<'a, D> { + pub fn new(document: &'a D, field_id_map: &'a FieldsIdsMapWithMetadata<'a>) -> Self { + Self( + std::iter::repeat(document) + .zip(field_id_map.iter()) + .map(|(document, (fid, name))| FieldValue { + document, + name, + metadata: field_id_map.metadata(fid).unwrap_or_default(), + }) + .collect(), + ) + } +} + +impl<'a, 'map, D: ObjectView> BorrowedFields<'a, 'map, D> { + pub fn new( + document: &'a D, + field_id_map: &'a RefCell>, + doc_alloc: &'a Bump, + ) -> Self { + Self { document, field_id_map, doc_alloc } + } +} + +impl<'a, D: ObjectView> ArrayView for OwnedFields<'a, D> { fn as_value(&self) -> &dyn ValueView { self.0.as_value() } @@ -149,7 +171,91 @@ impl<'a> ArrayView for Fields<'a> { } } -impl<'a> ValueView for Fields<'a> { +impl<'a, 'map, D: ObjectView> ArrayView for BorrowedFields<'a, 'map, D> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.document.size() + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.document.keys().map(|k| { + let mut field_id_map = self.field_id_map.borrow_mut(); + let (_, metadata) = field_id_map.id_with_metadata_or_insert(&k).unwrap(); + let fv = self.doc_alloc.alloc(FieldValue { + name: self.doc_alloc.alloc_str(&k), + document: self.document, + metadata: FieldMetadata { searchable: metadata.searchable }, + }); + fv as _ + })) + } + + fn contains_key(&self, index: i64) -> bool { + let index = if index >= 0 { index } else { self.size() + index }; + index >= 0 && index < self.size() + } + + fn get(&self, index: i64) -> Option<&dyn ValueView> { + let index = if index >= 0 { index } else { self.size() + index }; + let index: usize = index.try_into().ok()?; + let key = self.document.keys().nth(index)?; + let mut field_id_map = self.field_id_map.borrow_mut(); + let (_, metadata) = field_id_map.id_with_metadata_or_insert(&key)?; + let fv = self.doc_alloc.alloc(FieldValue { + name: self.doc_alloc.alloc_str(&key), + document: self.document, + metadata: FieldMetadata { searchable: metadata.searchable }, + }); + Some(fv as _) + } +} + +impl<'a, 'map, D: ObjectView> ValueView for BorrowedFields<'a, 'map, D> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ArrayRender { s: self })) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ArraySource { s: self })) + } + + fn type_name(&self) -> &'static str { + "array" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.document.size() == 0, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ArrayRender { s: self }.to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Array(self.values().map(|v| v.to_value()).collect()) + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } + + fn is_array(&self) -> bool { + true + } +} + +impl<'a, D: ObjectView> ValueView for OwnedFields<'a, D> { fn as_debug(&self) -> &dyn std::fmt::Debug { self } @@ -182,3 +288,31 @@ impl<'a> ValueView for Fields<'a> { Some(self) } } + +struct ArraySource<'a, 'map, D: ObjectView> { + s: &'a BorrowedFields<'a, 'map, D>, +} + +impl<'a, 'map, D: ObjectView> fmt::Display for ArraySource<'a, 'map, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for item in self.s.values() { + write!(f, "{}, ", item.render())?; + } + write!(f, "]")?; + Ok(()) + } +} + +struct ArrayRender<'a, 'map, D: ObjectView> { + s: &'a BorrowedFields<'a, 'map, D>, +} + +impl<'a, 'map, D: ObjectView> fmt::Display for ArrayRender<'a, 'map, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for item in self.s.values() { + write!(f, "{}", item.render())?; + } + Ok(()) + } +} diff --git a/crates/milli/src/prompt/mod.rs b/crates/milli/src/prompt/mod.rs index 3b32b916f..bbcf054e6 100644 --- a/crates/milli/src/prompt/mod.rs +++ b/crates/milli/src/prompt/mod.rs @@ -4,17 +4,22 @@ pub(crate) mod error; mod fields; mod template_checker; +use std::cell::RefCell; use std::collections::BTreeMap; use std::convert::TryFrom; +use std::fmt::Debug; use std::num::NonZeroUsize; use std::ops::Deref; +use bumpalo::Bump; +use document::ParseableDocument; use error::{NewPromptError, RenderPromptError}; +use fields::{BorrowedFields, OwnedFields}; use self::context::Context; use self::document::Document; use crate::update::del_add::DelAdd; -use crate::{FieldId, FieldsIdsMap}; +use crate::{FieldId, FieldsIdsMap, GlobalFieldsIdsMap}; pub struct Prompt { template: liquid::Template, @@ -109,14 +114,38 @@ impl Prompt { Ok(this) } - pub fn render( + pub fn render_document< + 'a, // lifetime of the borrow of the document + 'doc: 'a, // lifetime of the allocator, will live for an entire chunk of documents + >( &self, - document: obkv::KvReaderU16<'_>, + document: impl crate::update::new::document::Document<'a> + Debug, + field_id_map: &RefCell, + doc_alloc: &'doc Bump, + ) -> Result<&'doc str, RenderPromptError> { + let document = ParseableDocument::new(document, doc_alloc); + let fields = BorrowedFields::new(&document, field_id_map, doc_alloc); + let context = Context::new(&document, &fields); + let mut rendered = bumpalo::collections::Vec::with_capacity_in( + self.max_bytes.unwrap_or_else(default_max_bytes).get(), + doc_alloc, + ); + self.template + .render_to(&mut rendered, &context) + .map_err(RenderPromptError::missing_context)?; + Ok(std::str::from_utf8(rendered.into_bump_slice()) + .expect("render can only write UTF-8 because all inputs and processing preserve utf-8")) + } + + pub fn render_kvdeladd( + &self, + document: &obkv::KvReaderU16, side: DelAdd, field_id_map: &FieldsIdsMapWithMetadata, ) -> Result { let document = Document::new(document, side, field_id_map); - let context = Context::new(&document, field_id_map); + let fields = OwnedFields::new(&document, field_id_map); + let context = Context::new(&document, &fields); let mut rendered = self.template.render(&context).map_err(RenderPromptError::missing_context)?; diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index 8b274804c..5187b572b 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -201,7 +201,9 @@ impl<'a> Search<'a> { let span = tracing::trace_span!(target: "search::hybrid", "embed_one"); let _entered = span.enter(); - match embedder.embed_one(query) { + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3); + + match embedder.embed_one(query, Some(deadline)) { Ok(embedding) => embedding, Err(error) => { tracing::error!(error=%error, "Embedding failed"); diff --git a/crates/milli/src/search/new/db_cache.rs b/crates/milli/src/search/new/db_cache.rs index d1d9d6d9a..243303ba2 100644 --- a/crates/milli/src/search/new/db_cache.rs +++ b/crates/milli/src/search/new/db_cache.rs @@ -3,6 +3,7 @@ use std::collections::hash_map::Entry; use std::hash::Hash; use fxhash::FxHashMap; +use grenad::MergeFunction; use heed::types::Bytes; use heed::{BytesEncode, Database, RoTxn}; use roaring::RoaringBitmap; @@ -11,7 +12,7 @@ use super::interner::Interned; use super::Word; use crate::heed_codec::{BytesDecodeOwned, StrBEU16Codec}; use crate::proximity::ProximityPrecision; -use crate::update::{merge_cbo_roaring_bitmaps, MergeFn}; +use crate::update::MergeCboRoaringBitmaps; use crate::{ CboRoaringBitmapCodec, CboRoaringBitmapLenCodec, Result, SearchContext, U8StrStrCodec, }; @@ -110,19 +111,21 @@ impl<'ctx> DatabaseCache<'ctx> { .map_err(Into::into) } - fn get_value_from_keys<'v, K1, KC>( + fn get_value_from_keys<'v, K1, KC, MF>( txn: &'ctx RoTxn<'_>, cache_key: K1, db_keys: &'v [KC::EItem], cache: &mut FxHashMap>>, db: Database, universe: Option<&RoaringBitmap>, - merger: MergeFn, + merger: MF, ) -> Result> where K1: Copy + Eq + Hash, KC: BytesEncode<'v>, KC::EItem: Sized, + MF: MergeFunction, + crate::Error: From, { if let Entry::Vacant(entry) = cache.entry(cache_key) { let bitmap_ptr: Option> = match db_keys { @@ -138,7 +141,7 @@ impl<'ctx> DatabaseCache<'ctx> { if bitmaps.is_empty() { None } else { - Some(merger(&[], &bitmaps[..])?) + Some(merger.merge(&[], &bitmaps[..])?) } } }; @@ -213,17 +216,17 @@ impl<'ctx> SearchContext<'ctx> { let keys: Vec<_> = restricted_fids.tolerant.iter().map(|(fid, _)| (interned, *fid)).collect(); - DatabaseCache::get_value_from_keys::<_, _>( + DatabaseCache::get_value_from_keys( self.txn, word, &keys[..], &mut self.db_cache.word_docids, self.index.word_fid_docids.remap_data_type::(), universe, - merge_cbo_roaring_bitmaps, + MergeCboRoaringBitmaps, ) } - None => DatabaseCache::get_value::<_, _>( + None => DatabaseCache::get_value( self.txn, word, self.word_interner.get(word).as_str(), @@ -245,17 +248,17 @@ impl<'ctx> SearchContext<'ctx> { let keys: Vec<_> = restricted_fids.exact.iter().map(|(fid, _)| (interned, *fid)).collect(); - DatabaseCache::get_value_from_keys::<_, _>( + DatabaseCache::get_value_from_keys( self.txn, word, &keys[..], &mut self.db_cache.exact_word_docids, self.index.word_fid_docids.remap_data_type::(), universe, - merge_cbo_roaring_bitmaps, + MergeCboRoaringBitmaps, ) } - None => DatabaseCache::get_value::<_, _>( + None => DatabaseCache::get_value( self.txn, word, self.word_interner.get(word).as_str(), @@ -302,17 +305,17 @@ impl<'ctx> SearchContext<'ctx> { let keys: Vec<_> = restricted_fids.tolerant.iter().map(|(fid, _)| (interned, *fid)).collect(); - DatabaseCache::get_value_from_keys::<_, _>( + DatabaseCache::get_value_from_keys( self.txn, prefix, &keys[..], &mut self.db_cache.word_prefix_docids, self.index.word_prefix_fid_docids.remap_data_type::(), universe, - merge_cbo_roaring_bitmaps, + MergeCboRoaringBitmaps, ) } - None => DatabaseCache::get_value::<_, _>( + None => DatabaseCache::get_value( self.txn, prefix, self.word_interner.get(prefix).as_str(), @@ -334,17 +337,17 @@ impl<'ctx> SearchContext<'ctx> { let keys: Vec<_> = restricted_fids.exact.iter().map(|(fid, _)| (interned, *fid)).collect(); - DatabaseCache::get_value_from_keys::<_, _>( + DatabaseCache::get_value_from_keys( self.txn, prefix, &keys[..], &mut self.db_cache.exact_word_prefix_docids, self.index.word_prefix_fid_docids.remap_data_type::(), universe, - merge_cbo_roaring_bitmaps, + MergeCboRoaringBitmaps, ) } - None => DatabaseCache::get_value::<_, _>( + None => DatabaseCache::get_value( self.txn, prefix, self.word_interner.get(prefix).as_str(), @@ -405,7 +408,7 @@ impl<'ctx> SearchContext<'ctx> { Ok(docids) } - ProximityPrecision::ByWord => DatabaseCache::get_value::<_, _>( + ProximityPrecision::ByWord => DatabaseCache::get_value( self.txn, (proximity, word1, word2), &( @@ -538,7 +541,7 @@ impl<'ctx> SearchContext<'ctx> { return Ok(None); } - DatabaseCache::get_value::<_, _>( + DatabaseCache::get_value( self.txn, (word, fid), &(self.word_interner.get(word).as_str(), fid), @@ -559,7 +562,7 @@ impl<'ctx> SearchContext<'ctx> { return Ok(None); } - DatabaseCache::get_value::<_, _>( + DatabaseCache::get_value( self.txn, (word_prefix, fid), &(self.word_interner.get(word_prefix).as_str(), fid), @@ -629,7 +632,7 @@ impl<'ctx> SearchContext<'ctx> { word: Interned, position: u16, ) -> Result> { - DatabaseCache::get_value::<_, _>( + DatabaseCache::get_value( self.txn, (word, position), &(self.word_interner.get(word).as_str(), position), @@ -645,7 +648,7 @@ impl<'ctx> SearchContext<'ctx> { word_prefix: Interned, position: u16, ) -> Result> { - DatabaseCache::get_value::<_, _>( + DatabaseCache::get_value( self.txn, (word_prefix, position), &(self.word_interner.get(word_prefix).as_str(), position), diff --git a/crates/milli/src/search/new/matches/mod.rs b/crates/milli/src/search/new/matches/mod.rs index 7d8d25502..d7bc27c94 100644 --- a/crates/milli/src/search/new/matches/mod.rs +++ b/crates/milli/src/search/new/matches/mod.rs @@ -3,6 +3,9 @@ mod r#match; mod matching_words; mod simple_token_kind; +use std::borrow::Cow; +use std::cmp::{max, min}; + use charabia::{Language, SeparatorKind, Token, Tokenizer}; use either::Either; pub use matching_words::MatchingWords; @@ -10,10 +13,6 @@ use matching_words::{MatchType, PartialMatch}; use r#match::{Match, MatchPosition}; use serde::Serialize; use simple_token_kind::SimpleTokenKind; -use std::{ - borrow::Cow, - cmp::{max, min}, -}; const DEFAULT_CROP_MARKER: &str = "…"; const DEFAULT_HIGHLIGHT_PREFIX: &str = ""; diff --git a/crates/milli/src/update/available_documents_ids.rs b/crates/milli/src/update/available_documents_ids.rs index 3b05c5d6e..e69de29bb 100644 --- a/crates/milli/src/update/available_documents_ids.rs +++ b/crates/milli/src/update/available_documents_ids.rs @@ -1,65 +0,0 @@ -use std::iter::{Chain, FromIterator}; -use std::ops::RangeInclusive; - -use roaring::bitmap::{IntoIter, RoaringBitmap}; - -pub struct AvailableDocumentsIds { - iter: Chain>, -} - -impl AvailableDocumentsIds { - pub fn from_documents_ids(docids: &RoaringBitmap) -> AvailableDocumentsIds { - match docids.max() { - Some(last_id) => { - let mut available = RoaringBitmap::from_iter(0..last_id); - available -= docids; - - let iter = match last_id.checked_add(1) { - Some(id) => id..=u32::MAX, - #[allow(clippy::reversed_empty_ranges)] - None => 1..=0, // empty range iterator - }; - - AvailableDocumentsIds { iter: available.into_iter().chain(iter) } - } - None => { - let empty = RoaringBitmap::new().into_iter(); - AvailableDocumentsIds { iter: empty.chain(0..=u32::MAX) } - } - } - } -} - -impl Iterator for AvailableDocumentsIds { - type Item = u32; - - fn next(&mut self) -> Option { - self.iter.next() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn empty() { - let base = RoaringBitmap::new(); - let left = AvailableDocumentsIds::from_documents_ids(&base); - let right = 0..=u32::MAX; - left.zip(right).take(500).for_each(|(l, r)| assert_eq!(l, r)); - } - - #[test] - fn scattered() { - let mut base = RoaringBitmap::new(); - base.insert(0); - base.insert(10); - base.insert(100); - base.insert(405); - - let left = AvailableDocumentsIds::from_documents_ids(&base); - let right = (0..=u32::MAX).filter(|&n| n != 0 && n != 10 && n != 100 && n != 405); - left.zip(right).take(500).for_each(|(l, r)| assert_eq!(l, r)); - } -} diff --git a/crates/milli/src/update/available_ids.rs b/crates/milli/src/update/available_ids.rs new file mode 100644 index 000000000..68e3dd5a6 --- /dev/null +++ b/crates/milli/src/update/available_ids.rs @@ -0,0 +1,65 @@ +use std::iter::{Chain, FromIterator}; +use std::ops::RangeInclusive; + +use roaring::bitmap::{IntoIter, RoaringBitmap}; + +pub struct AvailableIds { + iter: Chain>, +} + +impl AvailableIds { + pub fn new(docids: &RoaringBitmap) -> AvailableIds { + match docids.max() { + Some(last_id) => { + let mut available = RoaringBitmap::from_iter(0..last_id); + available -= docids; + + let iter = match last_id.checked_add(1) { + Some(id) => id..=u32::MAX, + #[allow(clippy::reversed_empty_ranges)] + None => 1..=0, // empty range iterator + }; + + AvailableIds { iter: available.into_iter().chain(iter) } + } + None => { + let empty = RoaringBitmap::new().into_iter(); + AvailableIds { iter: empty.chain(0..=u32::MAX) } + } + } + } +} + +impl Iterator for AvailableIds { + type Item = u32; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let base = RoaringBitmap::new(); + let left = AvailableIds::new(&base); + let right = 0..=u32::MAX; + left.zip(right).take(500).for_each(|(l, r)| assert_eq!(l, r)); + } + + #[test] + fn scattered() { + let mut base = RoaringBitmap::new(); + base.insert(0); + base.insert(10); + base.insert(100); + base.insert(405); + + let left = AvailableIds::new(&base); + let right = (0..=u32::MAX).filter(|&n| n != 0 && n != 10 && n != 100 && n != 405); + left.zip(right).take(500).for_each(|(l, r)| assert_eq!(l, r)); + } +} diff --git a/crates/milli/src/update/concurrent_available_ids.rs b/crates/milli/src/update/concurrent_available_ids.rs new file mode 100644 index 000000000..f3b15ac45 --- /dev/null +++ b/crates/milli/src/update/concurrent_available_ids.rs @@ -0,0 +1,59 @@ +use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; + +use roaring::RoaringBitmap; + +/// A concurrent ID generate that will never return the same ID twice. +#[derive(Debug)] +pub struct ConcurrentAvailableIds { + /// The current tree node ID we should use if there is no other IDs available. + current: AtomicU32, + /// The total number of tree node IDs used. + used: AtomicU64, + + /// A list of IDs to exhaust before picking IDs from `current`. + available: RoaringBitmap, + /// The current Nth ID to select in the bitmap. + select_in_bitmap: AtomicU32, + /// Tells if you should look in the roaring bitmap or if all the IDs are already exhausted. + look_into_bitmap: AtomicBool, +} + +impl ConcurrentAvailableIds { + /// Creates an ID generator returning unique IDs, avoiding the specified used IDs. + pub fn new(used: RoaringBitmap) -> ConcurrentAvailableIds { + let last_id = used.max().map_or(0, |id| id + 1); + let used_ids = used.len(); + let available = RoaringBitmap::from_sorted_iter(0..last_id).unwrap() - used; + + ConcurrentAvailableIds { + current: AtomicU32::new(last_id), + used: AtomicU64::new(used_ids), + select_in_bitmap: AtomicU32::new(0), + look_into_bitmap: AtomicBool::new(!available.is_empty()), + available, + } + } + + /// Returns a new unique ID and increase the count of IDs used. + pub fn next(&self) -> Option { + if self.used.fetch_add(1, Ordering::Relaxed) > u32::MAX as u64 { + None + } else if self.look_into_bitmap.load(Ordering::Relaxed) { + let current = self.select_in_bitmap.fetch_add(1, Ordering::Relaxed); + match self.available.select(current) { + Some(id) => Some(id), + None => { + self.look_into_bitmap.store(false, Ordering::Relaxed); + Some(self.current.fetch_add(1, Ordering::Relaxed)) + } + } + } else { + Some(self.current.fetch_add(1, Ordering::Relaxed)) + } + } + + /// Returns the number of used ids in total. + pub fn used(&self) -> u64 { + self.used.load(Ordering::Relaxed) + } +} diff --git a/crates/milli/src/update/del_add.rs b/crates/milli/src/update/del_add.rs index 570d292ef..97ff86f2a 100644 --- a/crates/milli/src/update/del_add.rs +++ b/crates/milli/src/update/del_add.rs @@ -1,7 +1,7 @@ use obkv::Key; pub type KvWriterDelAdd = obkv::KvWriter; -pub type KvReaderDelAdd<'a> = obkv::KvReader<'a, DelAdd>; +pub type KvReaderDelAdd = obkv::KvReader; /// DelAdd defines the new value to add in the database and old value to delete from the database. /// @@ -36,7 +36,7 @@ impl Key for DelAdd { /// Addition: put all the values under DelAdd::Addition, /// DeletionAndAddition: put all the values under DelAdd::Deletion and DelAdd::Addition, pub fn into_del_add_obkv( - reader: obkv::KvReader<'_, K>, + reader: &obkv::KvReader, operation: DelAddOperation, buffer: &mut Vec, ) -> Result<(), std::io::Error> { @@ -46,7 +46,7 @@ pub fn into_del_add_obkv( /// Akin to the [into_del_add_obkv] function but lets you /// conditionally define the `DelAdd` variant based on the obkv key. pub fn into_del_add_obkv_conditional_operation( - reader: obkv::KvReader<'_, K>, + reader: &obkv::KvReader, buffer: &mut Vec, operation: F, ) -> std::io::Result<()> @@ -86,8 +86,8 @@ pub enum DelAddOperation { /// putting each deletion obkv's keys under an DelAdd::Deletion /// and putting each addition obkv's keys under an DelAdd::Addition pub fn del_add_from_two_obkvs( - deletion: &obkv::KvReader<'_, K>, - addition: &obkv::KvReader<'_, K>, + deletion: &obkv::KvReader, + addition: &obkv::KvReader, buffer: &mut Vec, ) -> Result<(), std::io::Error> { use itertools::merge_join_by; @@ -121,7 +121,7 @@ pub fn del_add_from_two_obkvs( writer.finish() } -pub fn is_noop_del_add_obkv(del_add: KvReaderDelAdd<'_>) -> bool { +pub fn is_noop_del_add_obkv(del_add: &KvReaderDelAdd) -> bool { del_add.get(DelAdd::Deletion) == del_add.get(DelAdd::Addition) } @@ -136,5 +136,5 @@ pub fn deladd_serialize_add_side<'a>( obkv: &'a [u8], _buffer: &mut Vec, ) -> crate::Result<&'a [u8]> { - Ok(KvReaderDelAdd::new(obkv).get(DelAdd::Addition).unwrap_or_default()) + Ok(KvReaderDelAdd::from_slice(obkv).get(DelAdd::Addition).unwrap_or_default()) } diff --git a/crates/milli/src/update/facet/bulk.rs b/crates/milli/src/update/facet/bulk.rs index a63d59693..19dfc310b 100644 --- a/crates/milli/src/update/facet/bulk.rs +++ b/crates/milli/src/update/facet/bulk.rs @@ -14,7 +14,7 @@ use crate::heed_codec::facet::{ use crate::heed_codec::BytesRefCodec; use crate::update::del_add::{DelAdd, KvReaderDelAdd}; use crate::update::index_documents::{create_writer, valid_lmdb_key, writer_into_reader}; -use crate::update::MergeFn; +use crate::update::MergeDeladdCboRoaringBitmaps; use crate::{CboRoaringBitmapCodec, CboRoaringBitmapLenCodec, FieldId, Index, Result}; /// Algorithm to insert elememts into the `facet_id_(string/f64)_docids` databases @@ -29,7 +29,7 @@ pub struct FacetsUpdateBulk<'i> { facet_type: FacetType, field_ids: Vec, // None if level 0 does not need to be updated - delta_data: Option, MergeFn>>, + delta_data: Option, MergeDeladdCboRoaringBitmaps>>, } impl<'i> FacetsUpdateBulk<'i> { @@ -37,7 +37,7 @@ impl<'i> FacetsUpdateBulk<'i> { index: &'i Index, field_ids: Vec, facet_type: FacetType, - delta_data: Merger, MergeFn>, + delta_data: Merger, MergeDeladdCboRoaringBitmaps>, group_size: u8, min_level_size: u8, ) -> FacetsUpdateBulk<'i> { @@ -90,7 +90,7 @@ impl<'i> FacetsUpdateBulk<'i> { /// Implementation of `FacetsUpdateBulk` that is independent of milli's `Index` type pub(crate) struct FacetsUpdateBulkInner { pub db: heed::Database, FacetGroupValueCodec>, - pub delta_data: Option>, + pub delta_data: Option>, pub group_size: u8, pub min_level_size: u8, } @@ -135,7 +135,7 @@ impl FacetsUpdateBulkInner { if !valid_lmdb_key(key) { continue; } - let value = KvReaderDelAdd::new(value); + let value = KvReaderDelAdd::from_slice(value); // DB is empty, it is safe to ignore Del operations let Some(value) = value.get(DelAdd::Addition) else { @@ -161,7 +161,7 @@ impl FacetsUpdateBulkInner { continue; } - let value = KvReaderDelAdd::new(value); + let value = KvReaderDelAdd::from_slice(value); // the value is a CboRoaringBitmap, but I still need to prepend the // group size for level 0 (= 1) to it diff --git a/crates/milli/src/update/facet/incremental.rs b/crates/milli/src/update/facet/incremental.rs index 0f0937855..a1fa07fe3 100644 --- a/crates/milli/src/update/facet/incremental.rs +++ b/crates/milli/src/update/facet/incremental.rs @@ -15,7 +15,7 @@ use crate::heed_codec::BytesRefCodec; use crate::search::facet::get_highest_level; use crate::update::del_add::DelAdd; use crate::update::index_documents::valid_lmdb_key; -use crate::update::MergeFn; +use crate::update::MergeDeladdCboRoaringBitmaps; use crate::{CboRoaringBitmapCodec, Index, Result}; /// Enum used as a return value for the facet incremental indexing. @@ -57,14 +57,14 @@ enum ModificationResult { /// `facet_id_(string/f64)_docids` databases. pub struct FacetsUpdateIncremental { inner: FacetsUpdateIncrementalInner, - delta_data: Merger, MergeFn>, + delta_data: Merger, MergeDeladdCboRoaringBitmaps>, } impl FacetsUpdateIncremental { pub fn new( index: &Index, facet_type: FacetType, - delta_data: Merger, MergeFn>, + delta_data: Merger, MergeDeladdCboRoaringBitmaps>, group_size: u8, min_level_size: u8, max_group_size: u8, @@ -109,7 +109,7 @@ impl FacetsUpdateIncremental { } current_field_id = Some(key.field_id); - let value = KvReader::new(value); + let value = KvReader::from_slice(value); let docids_to_delete = value .get(DelAdd::Deletion) .map(CboRoaringBitmapCodec::bytes_decode) diff --git a/crates/milli/src/update/facet/mod.rs b/crates/milli/src/update/facet/mod.rs index ad3ddc38f..2e592519b 100644 --- a/crates/milli/src/update/facet/mod.rs +++ b/crates/milli/src/update/facet/mod.rs @@ -86,12 +86,11 @@ use time::OffsetDateTime; use tracing::debug; use self::incremental::FacetsUpdateIncremental; -use super::FacetsUpdateBulk; +use super::{FacetsUpdateBulk, MergeDeladdBtreesetString, MergeDeladdCboRoaringBitmaps}; use crate::facet::FacetType; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec}; use crate::heed_codec::BytesRefCodec; use crate::update::del_add::{DelAdd, KvReaderDelAdd}; -use crate::update::MergeFn; use crate::{try_split_array_at, FieldId, Index, Result}; pub mod bulk; @@ -105,8 +104,8 @@ pub struct FacetsUpdate<'i> { index: &'i Index, database: heed::Database, FacetGroupValueCodec>, facet_type: FacetType, - delta_data: Merger, MergeFn>, - normalized_delta_data: Option, MergeFn>>, + delta_data: Merger, MergeDeladdCboRoaringBitmaps>, + normalized_delta_data: Option, MergeDeladdBtreesetString>>, group_size: u8, max_group_size: u8, min_level_size: u8, @@ -116,8 +115,8 @@ impl<'i> FacetsUpdate<'i> { pub fn new( index: &'i Index, facet_type: FacetType, - delta_data: Merger, MergeFn>, - normalized_delta_data: Option, MergeFn>>, + delta_data: Merger, MergeDeladdCboRoaringBitmaps>, + normalized_delta_data: Option, MergeDeladdBtreesetString>>, data_size: u64, ) -> Self { let database = match facet_type { @@ -182,12 +181,12 @@ impl<'i> FacetsUpdate<'i> { fn index_facet_search( wtxn: &mut heed::RwTxn<'_>, - normalized_delta_data: Merger, MergeFn>, + normalized_delta_data: Merger, MergeDeladdBtreesetString>, index: &Index, ) -> Result<()> { let mut iter = normalized_delta_data.into_stream_merger_iter()?; while let Some((key_bytes, delta_bytes)) = iter.next()? { - let deladd_reader = KvReaderDelAdd::new(delta_bytes); + let deladd_reader = KvReaderDelAdd::from_slice(delta_bytes); let database_set = index .facet_id_normalized_string_strings @@ -298,8 +297,8 @@ pub(crate) mod test_helpers { use crate::search::facet::get_highest_level; use crate::snapshot_tests::display_bitmap; use crate::update::del_add::{DelAdd, KvWriterDelAdd}; - use crate::update::index_documents::merge_deladd_cbo_roaring_bitmaps; - use crate::update::{FacetsUpdateIncrementalInner, MergeFn}; + use crate::update::index_documents::MergeDeladdCboRoaringBitmaps; + use crate::update::FacetsUpdateIncrementalInner; use crate::CboRoaringBitmapCodec; /// Utility function to generate a string whose position in a lexicographically @@ -484,7 +483,7 @@ pub(crate) mod test_helpers { } writer.finish().unwrap(); let reader = grenad::Reader::new(std::io::Cursor::new(new_data)).unwrap(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); builder.push(reader.into_cursor().unwrap()); let merger = builder.build(); diff --git a/crates/milli/src/update/index_documents/enrich.rs b/crates/milli/src/update/index_documents/enrich.rs index 691b2b9d1..85f871830 100644 --- a/crates/milli/src/update/index_documents/enrich.rs +++ b/crates/milli/src/update/index_documents/enrich.rs @@ -47,7 +47,7 @@ pub fn enrich_documents_batch( return match cursor.next_document()? { Some(first_document) => Ok(Err(UserError::MissingDocumentId { primary_key: primary_key.to_string(), - document: obkv_to_object(&first_document, &documents_batch_index)?, + document: obkv_to_object(first_document, &documents_batch_index)?, })), None => unreachable!("Called with reader.is_empty()"), }; @@ -106,7 +106,7 @@ pub fn enrich_documents_batch( let mut count = 0; while let Some(document) = cursor.next_document()? { let document_id = match fetch_or_generate_document_id( - &document, + document, &documents_batch_index, primary_key, autogenerate_docids, @@ -145,7 +145,7 @@ pub fn enrich_documents_batch( #[tracing::instrument(level = "trace", skip(uuid_buffer, documents_batch_index, document) target = "indexing::documents")] fn fetch_or_generate_document_id( - document: &obkv::KvReader<'_, FieldId>, + document: &obkv::KvReader, documents_batch_index: &DocumentsBatchIndex, primary_key: PrimaryKey<'_>, autogenerate_docids: bool, diff --git a/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs b/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs index ba11ceeb3..b1e6f24be 100644 --- a/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs +++ b/crates/milli/src/update/index_documents/extract/extract_docid_word_positions.rs @@ -8,7 +8,7 @@ use obkv::{KvReader, KvWriterU16}; use roaring::RoaringBitmap; use serde_json::Value; -use super::helpers::{create_sorter, keep_latest_obkv, sorter_into_reader, GrenadParameters}; +use super::helpers::{create_sorter, sorter_into_reader, GrenadParameters, KeepLatestObkv}; use crate::error::{InternalError, SerializationError}; use crate::update::del_add::{del_add_from_two_obkvs, DelAdd, KvReaderDelAdd}; use crate::update::settings::{InnerIndexSettings, InnerIndexSettingsDiff}; @@ -35,11 +35,12 @@ pub fn extract_docid_word_positions( let mut documents_ids = RoaringBitmap::new(); let mut docid_word_positions_sorter = create_sorter( grenad::SortAlgorithm::Stable, - keep_latest_obkv, + KeepLatestObkv, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory, + true, ); // initialize buffers. @@ -80,10 +81,10 @@ pub fn extract_docid_word_positions( .try_into() .map(u32::from_be_bytes) .map_err(|_| SerializationError::InvalidNumberSerialization)?; - let obkv = KvReader::::new(value); + let obkv = KvReader::::from_slice(value); // if the searchable fields didn't change, skip the searchable indexing for this document. - if !force_reindexing && !searchable_fields_changed(&obkv, settings_diff) { + if !force_reindexing && !searchable_fields_changed(obkv, settings_diff) { continue; } @@ -98,7 +99,7 @@ pub fn extract_docid_word_positions( || { // deletions tokens_from_document( - &obkv, + obkv, &settings_diff.old, &del_tokenizer, max_positions_per_attributes, @@ -109,7 +110,7 @@ pub fn extract_docid_word_positions( || { // additions tokens_from_document( - &obkv, + obkv, &settings_diff.new, &add_tokenizer, max_positions_per_attributes, @@ -126,13 +127,13 @@ pub fn extract_docid_word_positions( // transforming two KV> into one KV>> value_buffer.clear(); del_add_from_two_obkvs( - &KvReader::::new(del_obkv), - &KvReader::::new(add_obkv), + KvReader::::from_slice(del_obkv), + KvReader::::from_slice(add_obkv), &mut value_buffer, )?; // write each KV> into the sorter, field by field. - let obkv = KvReader::::new(&value_buffer); + let obkv = KvReader::::from_slice(&value_buffer); for (field_id, value) in obkv.iter() { key_buffer.truncate(mem::size_of::()); key_buffer.extend_from_slice(&field_id.to_be_bytes()); @@ -146,13 +147,13 @@ pub fn extract_docid_word_positions( /// Check if any searchable fields of a document changed. fn searchable_fields_changed( - obkv: &KvReader<'_, FieldId>, + obkv: &KvReader, settings_diff: &InnerIndexSettingsDiff, ) -> bool { let searchable_fields = &settings_diff.new.searchable_fields_ids; for (field_id, field_bytes) in obkv.iter() { if searchable_fields.contains(&field_id) { - let del_add = KvReaderDelAdd::new(field_bytes); + let del_add = KvReaderDelAdd::from_slice(field_bytes); match (del_add.get(DelAdd::Deletion), del_add.get(DelAdd::Addition)) { // if both fields are None, check the next field. (None, None) => (), @@ -189,7 +190,7 @@ fn tokenizer_builder<'a>( /// Extract words mapped with their positions of a document. fn tokens_from_document<'a>( - obkv: &KvReader<'a, FieldId>, + obkv: &'a KvReader, settings: &InnerIndexSettings, tokenizer: &Tokenizer<'_>, max_positions_per_attributes: u32, @@ -202,7 +203,7 @@ fn tokens_from_document<'a>( // if field is searchable. if settings.searchable_fields_ids.contains(&field_id) { // extract deletion or addition only. - if let Some(field_bytes) = KvReaderDelAdd::new(field_bytes).get(del_add) { + if let Some(field_bytes) = KvReaderDelAdd::from_slice(field_bytes).get(del_add) { // parse json. let value = serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?; diff --git a/crates/milli/src/update/index_documents/extract/extract_facet_number_docids.rs b/crates/milli/src/update/index_documents/extract/extract_facet_number_docids.rs index bfd769604..34bece989 100644 --- a/crates/milli/src/update/index_documents/extract/extract_facet_number_docids.rs +++ b/crates/milli/src/update/index_documents/extract/extract_facet_number_docids.rs @@ -4,7 +4,7 @@ use std::io::{self, BufReader}; use heed::{BytesDecode, BytesEncode}; use super::helpers::{ - create_sorter, merge_deladd_cbo_roaring_bitmaps, sorter_into_reader, GrenadParameters, + create_sorter, sorter_into_reader, GrenadParameters, MergeDeladdCboRoaringBitmaps, }; use crate::heed_codec::facet::{ FacetGroupKey, FacetGroupKeyCodec, FieldDocIdFacetF64Codec, OrderedF64Codec, @@ -27,11 +27,12 @@ pub fn extract_facet_number_docids( let mut facet_number_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory, + true, ); let mut buffer = Vec::new(); @@ -45,7 +46,7 @@ pub fn extract_facet_number_docids( buffer.clear(); let mut obkv = KvWriterDelAdd::new(&mut buffer); - for (deladd_key, _) in KvReaderDelAdd::new(deladd_obkv_bytes).iter() { + for (deladd_key, _) in KvReaderDelAdd::from_slice(deladd_obkv_bytes).iter() { obkv.insert(deladd_key, document_id.to_ne_bytes())?; } obkv.finish()?; diff --git a/crates/milli/src/update/index_documents/extract/extract_facet_string_docids.rs b/crates/milli/src/update/index_documents/extract/extract_facet_string_docids.rs index 36dd20b15..e0d7e1386 100644 --- a/crates/milli/src/update/index_documents/extract/extract_facet_string_docids.rs +++ b/crates/milli/src/update/index_documents/extract/extract_facet_string_docids.rs @@ -15,7 +15,7 @@ use crate::heed_codec::{BEU16StrCodec, StrRefCodec}; use crate::localized_attributes_rules::LocalizedFieldIds; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::{ - merge_deladd_btreeset_string, merge_deladd_cbo_roaring_bitmaps, + MergeDeladdBtreesetString, MergeDeladdCboRoaringBitmaps, }; use crate::update::settings::InnerIndexSettingsDiff; use crate::{FieldId, Result, MAX_FACET_VALUE_LENGTH}; @@ -56,26 +56,28 @@ fn extract_facet_string_docids_document_update( let mut facet_string_docids_sorter = create_sorter( grenad::SortAlgorithm::Stable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 2), + true, ); let mut normalized_facet_string_docids_sorter = create_sorter( grenad::SortAlgorithm::Stable, - merge_deladd_btreeset_string, + MergeDeladdBtreesetString, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 2), + true, ); let mut buffer = Vec::new(); let mut cursor = docid_fid_facet_string.into_cursor()?; while let Some((key, deladd_original_value_bytes)) = cursor.move_on_next()? { - let deladd_reader = KvReaderDelAdd::new(deladd_original_value_bytes); + let deladd_reader = KvReaderDelAdd::from_slice(deladd_original_value_bytes); let is_same_value = deladd_reader.get(DelAdd::Deletion).is_some() && deladd_reader.get(DelAdd::Addition).is_some(); @@ -144,26 +146,28 @@ fn extract_facet_string_docids_settings( let mut facet_string_docids_sorter = create_sorter( grenad::SortAlgorithm::Stable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 2), + true, ); let mut normalized_facet_string_docids_sorter = create_sorter( grenad::SortAlgorithm::Stable, - merge_deladd_btreeset_string, + MergeDeladdBtreesetString, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 2), + true, ); let mut buffer = Vec::new(); let mut cursor = docid_fid_facet_string.into_cursor()?; while let Some((key, deladd_original_value_bytes)) = cursor.move_on_next()? { - let deladd_reader = KvReaderDelAdd::new(deladd_original_value_bytes); + let deladd_reader = KvReaderDelAdd::from_slice(deladd_original_value_bytes); let is_same_value = deladd_reader.get(DelAdd::Deletion).is_some() && deladd_reader.get(DelAdd::Addition).is_some(); diff --git a/crates/milli/src/update/index_documents/extract/extract_fid_docid_facet_values.rs b/crates/milli/src/update/index_documents/extract/extract_fid_docid_facet_values.rs index 93c6ab408..047669521 100644 --- a/crates/milli/src/update/index_documents/extract/extract_fid_docid_facet_values.rs +++ b/crates/milli/src/update/index_documents/extract/extract_fid_docid_facet_values.rs @@ -1,10 +1,8 @@ -use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet}; use std::convert::TryInto; use std::fs::File; use std::io::{self, BufReader}; use std::mem::size_of; -use std::result::Result as StdResult; use bytemuck::bytes_of; use grenad::Sorter; @@ -15,13 +13,13 @@ use roaring::RoaringBitmap; use serde_json::{from_slice, Value}; use FilterableValues::{Empty, Null, Values}; -use super::helpers::{create_sorter, keep_first, sorter_into_reader, GrenadParameters}; +use super::helpers::{create_sorter, sorter_into_reader, GrenadParameters, KeepFirst}; use crate::error::InternalError; use crate::facet::value_encoding::f64_into_bytes; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::{create_writer, writer_into_reader}; use crate::update::settings::InnerIndexSettingsDiff; -use crate::{CboRoaringBitmapCodec, DocumentId, Error, FieldId, Result, MAX_FACET_VALUE_LENGTH}; +use crate::{CboRoaringBitmapCodec, DocumentId, FieldId, Result, MAX_FACET_VALUE_LENGTH}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::() + size_of::(); @@ -50,20 +48,22 @@ pub fn extract_fid_docid_facet_values( let mut fid_docid_facet_numbers_sorter = create_sorter( grenad::SortAlgorithm::Stable, - keep_first, + KeepFirst, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 2), + true, ); let mut fid_docid_facet_strings_sorter = create_sorter( grenad::SortAlgorithm::Stable, - keep_first, + KeepFirst, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 2), + true, ); // The tuples represents the Del and Add side for a bitmap @@ -83,10 +83,10 @@ pub fn extract_fid_docid_facet_values( if !settings_diff.settings_update_only || old_faceted_fids != new_faceted_fids { let mut cursor = obkv_documents.into_cursor()?; while let Some((docid_bytes, value)) = cursor.move_on_next()? { - let obkv = obkv::KvReader::new(value); + let obkv = obkv::KvReader::from_slice(value); let get_document_json_value = move |field_id, side| { obkv.get(field_id) - .map(KvReaderDelAdd::new) + .map(KvReaderDelAdd::from_slice) .and_then(|kv| kv.get(side)) .map(from_slice) .transpose() @@ -330,15 +330,12 @@ fn truncate_str(s: &str) -> &str { /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. -fn insert_numbers_diff( - fid_docid_facet_numbers_sorter: &mut Sorter, +fn insert_numbers_diff( + fid_docid_facet_numbers_sorter: &mut Sorter, key_buffer: &mut Vec, mut del_numbers: Vec, mut add_numbers: Vec, -) -> Result<()> -where - MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> StdResult, Error>, -{ +) -> Result<()> { // We sort and dedup the float numbers del_numbers.sort_unstable_by_key(|f| OrderedFloat(*f)); add_numbers.sort_unstable_by_key(|f| OrderedFloat(*f)); @@ -390,15 +387,12 @@ where /// Computes the diff between both Del and Add strings and /// only inserts the parts that differ in the sorter. -fn insert_strings_diff( - fid_docid_facet_strings_sorter: &mut Sorter, +fn insert_strings_diff( + fid_docid_facet_strings_sorter: &mut Sorter, key_buffer: &mut Vec, mut del_strings: Vec<(String, String)>, mut add_strings: Vec<(String, String)>, -) -> Result<()> -where - MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> StdResult, Error>, -{ +) -> Result<()> { // We sort and dedup the normalized and original strings del_strings.sort_unstable(); add_strings.sort_unstable(); diff --git a/crates/milli/src/update/index_documents/extract/extract_fid_word_count_docids.rs b/crates/milli/src/update/index_documents/extract/extract_fid_word_count_docids.rs index f252df1cd..5739a5e15 100644 --- a/crates/milli/src/update/index_documents/extract/extract_fid_word_count_docids.rs +++ b/crates/milli/src/update/index_documents/extract/extract_fid_word_count_docids.rs @@ -4,8 +4,8 @@ use std::io::{self, BufReader}; use obkv::KvReaderU16; use super::helpers::{ - create_sorter, merge_deladd_cbo_roaring_bitmaps, sorter_into_reader, try_split_array_at, - GrenadParameters, + create_sorter, sorter_into_reader, try_split_array_at, GrenadParameters, + MergeDeladdCboRoaringBitmaps, }; use crate::error::SerializationError; use crate::index::db_name::DOCID_WORD_POSITIONS; @@ -30,11 +30,12 @@ pub fn extract_fid_word_count_docids( let mut fid_word_count_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory, + true, ); let mut key_buffer = Vec::new(); @@ -45,19 +46,23 @@ pub fn extract_fid_word_count_docids( .ok_or(SerializationError::Decoding { db_name: Some(DOCID_WORD_POSITIONS) })?; let document_id = u32::from_be_bytes(document_id_bytes); - let del_add_reader = KvReaderDelAdd::new(value); + let del_add_reader = KvReaderDelAdd::from_slice(value); let deletion = del_add_reader // get deleted words .get(DelAdd::Deletion) // count deleted words - .map(|deletion| KvReaderU16::new(deletion).iter().take(MAX_COUNTED_WORDS + 1).count()) + .map(|deletion| { + KvReaderU16::from_slice(deletion).iter().take(MAX_COUNTED_WORDS + 1).count() + }) // keep the count if under or equal to MAX_COUNTED_WORDS .filter(|&word_count| word_count <= MAX_COUNTED_WORDS); let addition = del_add_reader // get added words .get(DelAdd::Addition) // count added words - .map(|addition| KvReaderU16::new(addition).iter().take(MAX_COUNTED_WORDS + 1).count()) + .map(|addition| { + KvReaderU16::from_slice(addition).iter().take(MAX_COUNTED_WORDS + 1).count() + }) // keep the count if under or equal to MAX_COUNTED_WORDS .filter(|&word_count| word_count <= MAX_COUNTED_WORDS); diff --git a/crates/milli/src/update/index_documents/extract/extract_geo_points.rs b/crates/milli/src/update/index_documents/extract/extract_geo_points.rs index ac8b7abee..84f5e556b 100644 --- a/crates/milli/src/update/index_documents/extract/extract_geo_points.rs +++ b/crates/milli/src/update/index_documents/extract/extract_geo_points.rs @@ -29,22 +29,20 @@ pub fn extract_geo_points( let mut cursor = obkv_documents.into_cursor()?; while let Some((docid_bytes, value)) = cursor.move_on_next()? { - let obkv = obkv::KvReader::new(value); + let obkv = obkv::KvReader::from_slice(value); // since we only need the primary key when we throw an error // we create this getter to lazily get it when needed let document_id = || -> Value { - let reader = KvReaderDelAdd::new(obkv.get(primary_key_id).unwrap()); + let reader = KvReaderDelAdd::from_slice(obkv.get(primary_key_id).unwrap()); let document_id = reader.get(DelAdd::Deletion).or(reader.get(DelAdd::Addition)).unwrap(); serde_json::from_slice(document_id).unwrap() }; // extract old version - let del_lat_lng = - extract_lat_lng(&obkv, &settings_diff.old, DelAdd::Deletion, document_id)?; + let del_lat_lng = extract_lat_lng(obkv, &settings_diff.old, DelAdd::Deletion, document_id)?; // extract new version - let add_lat_lng = - extract_lat_lng(&obkv, &settings_diff.new, DelAdd::Addition, document_id)?; + let add_lat_lng = extract_lat_lng(obkv, &settings_diff.new, DelAdd::Addition, document_id)?; if del_lat_lng != add_lat_lng { let mut obkv = KvWriterDelAdd::memory(); @@ -68,15 +66,17 @@ pub fn extract_geo_points( /// Extract the finite floats lat and lng from two bytes slices. fn extract_lat_lng( - document: &obkv::KvReader<'_, FieldId>, + document: &obkv::KvReader, settings: &InnerIndexSettings, deladd: DelAdd, document_id: impl Fn() -> Value, ) -> Result> { match settings.geo_fields_ids { Some((lat_fid, lng_fid)) => { - let lat = document.get(lat_fid).map(KvReaderDelAdd::new).and_then(|r| r.get(deladd)); - let lng = document.get(lng_fid).map(KvReaderDelAdd::new).and_then(|r| r.get(deladd)); + let lat = + document.get(lat_fid).map(KvReaderDelAdd::from_slice).and_then(|r| r.get(deladd)); + let lng = + document.get(lng_fid).map(KvReaderDelAdd::from_slice).and_then(|r| r.get(deladd)); let (lat, lng) = match (lat, lng) { (Some(lat), Some(lng)) => (lat, lng), (Some(_), None) => { diff --git a/crates/milli/src/update/index_documents/extract/extract_vector_points.rs b/crates/milli/src/update/index_documents/extract/extract_vector_points.rs index 38a4ebe8a..7b5bf3f40 100644 --- a/crates/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/crates/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -21,7 +21,7 @@ use crate::update::settings::InnerIndexSettingsDiff; use crate::vector::error::{EmbedErrorKind, PossibleEmbeddingMistakes, UnusedVectorsDistribution}; use crate::vector::parsed_vectors::{ParsedVectorsDiff, VectorState, RESERVED_VECTORS_FIELD_NAME}; use crate::vector::settings::ReindexAction; -use crate::vector::{Embedder, Embeddings}; +use crate::vector::{Embedder, Embedding}; use crate::{try_split_array_at, DocumentId, FieldId, Result, ThreadPoolNoAbort}; /// The length of the elements that are always in the buffer when inserting new values. @@ -313,7 +313,7 @@ pub fn extract_vector_points( debug_assert!(from_utf8(external_id_bytes).is_ok()); let docid = DocumentId::from_be_bytes(docid_bytes); - let obkv = obkv::KvReader::new(value); + let obkv = obkv::KvReader::from_slice(value); key_buffer.clear(); key_buffer.extend_from_slice(docid_bytes.as_slice()); @@ -481,7 +481,7 @@ pub fn extract_vector_points( #[allow(clippy::too_many_arguments)] // feel free to find efficient way to factor arguments fn extract_vector_document_diff( docid: DocumentId, - obkv: obkv::KvReader<'_, FieldId>, + obkv: &obkv::KvReader, prompt: &Prompt, (add_to_user_provided, remove_from_user_provided): (&mut RoaringBitmap, &mut RoaringBitmap), (old, new): (VectorState, VectorState), @@ -526,7 +526,7 @@ fn extract_vector_document_diff( // Do we keep this document? let document_is_kept = obkv .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .map(|(_, deladd)| KvReaderDelAdd::from_slice(deladd)) .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { @@ -536,9 +536,11 @@ fn extract_vector_document_diff( } // Don't give up if the old prompt was failing let old_prompt = Some(&prompt).map(|p| { - p.render(obkv, DelAdd::Deletion, old_fields_ids_map).unwrap_or_default() + p.render_kvdeladd(obkv, DelAdd::Deletion, old_fields_ids_map) + .unwrap_or_default() }); - let new_prompt = prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; + let new_prompt = + prompt.render_kvdeladd(obkv, DelAdd::Addition, new_fields_ids_map)?; if old_prompt.as_ref() != Some(&new_prompt) { let old_prompt = old_prompt.unwrap_or_default(); tracing::trace!( @@ -562,7 +564,7 @@ fn extract_vector_document_diff( // Do we keep this document? let document_is_kept = obkv .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .map(|(_, deladd)| KvReaderDelAdd::from_slice(deladd)) .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { if embedder_is_manual { @@ -570,7 +572,7 @@ fn extract_vector_document_diff( return Ok(VectorStateDelta::NoChange); } // becomes autogenerated - VectorStateDelta::NowGenerated(prompt.render( + VectorStateDelta::NowGenerated(prompt.render_kvdeladd( obkv, DelAdd::Addition, new_fields_ids_map, @@ -588,7 +590,7 @@ fn extract_vector_document_diff( // Do we keep this document? let document_is_kept = obkv .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .map(|(_, deladd)| KvReaderDelAdd::from_slice(deladd)) .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { // if the new version of documents has the vectors in the DB, @@ -606,16 +608,17 @@ fn extract_vector_document_diff( } fn regenerate_if_prompt_changed( - obkv: obkv::KvReader<'_, FieldId>, + obkv: &obkv::KvReader, (old_prompt, new_prompt): (&Prompt, &Prompt), (old_fields_ids_map, new_fields_ids_map): ( &FieldsIdsMapWithMetadata, &FieldsIdsMapWithMetadata, ), ) -> Result { - let old_prompt = - old_prompt.render(obkv, DelAdd::Deletion, old_fields_ids_map).unwrap_or(Default::default()); - let new_prompt = new_prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; + let old_prompt = old_prompt + .render_kvdeladd(obkv, DelAdd::Deletion, old_fields_ids_map) + .unwrap_or(Default::default()); + let new_prompt = new_prompt.render_kvdeladd(obkv, DelAdd::Addition, new_fields_ids_map)?; if new_prompt == old_prompt { return Ok(VectorStateDelta::NoChange); @@ -624,11 +627,11 @@ fn regenerate_if_prompt_changed( } fn regenerate_prompt( - obkv: obkv::KvReader<'_, FieldId>, + obkv: &obkv::KvReader, prompt: &Prompt, new_fields_ids_map: &FieldsIdsMapWithMetadata, ) -> Result { - let prompt = prompt.render(obkv, DelAdd::Addition, new_fields_ids_map)?; + let prompt = prompt.render_kvdeladd(obkv, DelAdd::Addition, new_fields_ids_map)?; Ok(VectorStateDelta::NowGenerated(prompt)) } @@ -738,7 +741,7 @@ pub fn extract_embeddings( .flat_map(|docids| docids.iter()) .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) { - state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings))?; } chunks_ids.clear(); } @@ -759,7 +762,7 @@ pub fn extract_embeddings( .flat_map(|docids| docids.iter()) .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) { - state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings))?; } } @@ -775,7 +778,7 @@ pub fn extract_embeddings( if let Some(embeds) = embeds.first() { for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { - state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings))?; } } } @@ -790,7 +793,7 @@ fn embed_chunks( possible_embedding_mistakes: &PossibleEmbeddingMistakes, unused_vectors_distribution: &UnusedVectorsDistribution, request_threads: &ThreadPoolNoAbort, -) -> Result>>> { +) -> Result>> { match embedder.embed_chunks(text_chunks, request_threads) { Ok(chunks) => Ok(chunks), Err(error) => { diff --git a/crates/milli/src/update/index_documents/extract/extract_word_docids.rs b/crates/milli/src/update/index_documents/extract/extract_word_docids.rs index 457d2359e..829da768c 100644 --- a/crates/milli/src/update/index_documents/extract/extract_word_docids.rs +++ b/crates/milli/src/update/index_documents/extract/extract_word_docids.rs @@ -7,8 +7,8 @@ use obkv::KvReaderU16; use roaring::RoaringBitmap; use super::helpers::{ - create_sorter, create_writer, merge_deladd_cbo_roaring_bitmaps, try_split_array_at, - writer_into_reader, GrenadParameters, + create_sorter, create_writer, try_split_array_at, writer_into_reader, GrenadParameters, + MergeDeladdCboRoaringBitmaps, }; use crate::error::SerializationError; use crate::heed_codec::StrBEU16Codec; @@ -16,7 +16,6 @@ use crate::index::db_name::DOCID_WORD_POSITIONS; use crate::update::del_add::{is_noop_del_add_obkv, DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::sorter_into_reader; use crate::update::settings::InnerIndexSettingsDiff; -use crate::update::MergeFn; use crate::{CboRoaringBitmapCodec, DocumentId, FieldId, Result}; /// Extracts the word and the documents ids where this word appear. @@ -40,11 +39,12 @@ pub fn extract_word_docids( let mut word_fid_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 3), + true, ); let mut key_buffer = Vec::new(); let mut del_words = BTreeSet::new(); @@ -58,17 +58,17 @@ pub fn extract_word_docids( let document_id = u32::from_be_bytes(document_id_bytes); let fid = u16::from_be_bytes(fid_bytes); - let del_add_reader = KvReaderDelAdd::new(value); + let del_add_reader = KvReaderDelAdd::from_slice(value); // extract all unique words to remove. if let Some(deletion) = del_add_reader.get(DelAdd::Deletion) { - for (_pos, word) in KvReaderU16::new(deletion).iter() { + for (_pos, word) in KvReaderU16::from_slice(deletion).iter() { del_words.insert(word.to_vec()); } } // extract all unique additional words. if let Some(addition) = del_add_reader.get(DelAdd::Addition) { - for (_pos, word) in KvReaderU16::new(addition).iter() { + for (_pos, word) in KvReaderU16::from_slice(addition).iter() { add_words.insert(word.to_vec()); } } @@ -94,20 +94,22 @@ pub fn extract_word_docids( let mut word_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 3), + true, ); let mut exact_word_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / 3), + true, ); let mut iter = word_fid_docids_sorter.into_stream_merger_iter()?; @@ -115,7 +117,7 @@ pub fn extract_word_docids( // NOTE: replacing sorters by bitmap merging is less efficient, so, use sorters. while let Some((key, value)) = iter.next()? { // only keep the value if their is a change to apply in the DB. - if !is_noop_del_add_obkv(KvReaderDelAdd::new(value)) { + if !is_noop_del_add_obkv(KvReaderDelAdd::from_slice(value)) { word_fid_docids_writer.insert(key, value)?; } @@ -123,7 +125,7 @@ pub fn extract_word_docids( .map_err(|_| SerializationError::Decoding { db_name: Some(DOCID_WORD_POSITIONS) })?; // merge all deletions - let obkv = KvReaderDelAdd::new(value); + let obkv = KvReaderDelAdd::from_slice(value); if let Some(value) = obkv.get(DelAdd::Deletion) { let delete_from_exact = settings_diff.old.exact_attributes.contains(&fid); buffer.clear(); @@ -163,7 +165,7 @@ fn words_into_sorter( key_buffer: &mut Vec, del_words: &BTreeSet>, add_words: &BTreeSet>, - word_fid_docids_sorter: &mut grenad::Sorter, + word_fid_docids_sorter: &mut grenad::Sorter, ) -> Result<()> { use itertools::merge_join_by; use itertools::EitherOrBoth::{Both, Left, Right}; diff --git a/crates/milli/src/update/index_documents/extract/extract_word_pair_proximity_docids.rs b/crates/milli/src/update/index_documents/extract/extract_word_pair_proximity_docids.rs index 5a9363942..6194da23d 100644 --- a/crates/milli/src/update/index_documents/extract/extract_word_pair_proximity_docids.rs +++ b/crates/milli/src/update/index_documents/extract/extract_word_pair_proximity_docids.rs @@ -6,8 +6,8 @@ use std::{cmp, io}; use obkv::KvReaderU16; use super::helpers::{ - create_sorter, create_writer, merge_deladd_cbo_roaring_bitmaps, try_split_array_at, - writer_into_reader, GrenadParameters, MergeFn, + create_sorter, create_writer, try_split_array_at, writer_into_reader, GrenadParameters, + MergeDeladdCboRoaringBitmaps, }; use crate::error::SerializationError; use crate::index::db_name::DOCID_WORD_POSITIONS; @@ -44,11 +44,12 @@ pub fn extract_word_pair_proximity_docids( .map(|_| { create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory.map(|m| m / MAX_DISTANCE as usize), + true, ) }) .collect(); @@ -92,8 +93,8 @@ pub fn extract_word_pair_proximity_docids( } // deletions - if let Some(deletion) = KvReaderDelAdd::new(value).get(DelAdd::Deletion) { - for (position, word) in KvReaderU16::new(deletion).iter() { + if let Some(deletion) = KvReaderDelAdd::from_slice(value).get(DelAdd::Deletion) { + for (position, word) in KvReaderU16::from_slice(deletion).iter() { // drain the proximity window until the head word is considered close to the word we are inserting. while del_word_positions.front().map_or(false, |(_w, p)| { index_proximity(*p as u32, position as u32) >= MAX_DISTANCE @@ -125,8 +126,8 @@ pub fn extract_word_pair_proximity_docids( } // additions - if let Some(addition) = KvReaderDelAdd::new(value).get(DelAdd::Addition) { - for (position, word) in KvReaderU16::new(addition).iter() { + if let Some(addition) = KvReaderDelAdd::from_slice(value).get(DelAdd::Addition) { + for (position, word) in KvReaderU16::from_slice(addition).iter() { // drain the proximity window until the head word is considered close to the word we are inserting. while add_word_positions.front().map_or(false, |(_w, p)| { index_proximity(*p as u32, position as u32) >= MAX_DISTANCE @@ -197,7 +198,7 @@ fn document_word_positions_into_sorter( document_id: DocumentId, del_word_pair_proximity: &BTreeMap<(String, String), u8>, add_word_pair_proximity: &BTreeMap<(String, String), u8>, - word_pair_proximity_docids_sorters: &mut [grenad::Sorter], + word_pair_proximity_docids_sorters: &mut [grenad::Sorter], ) -> Result<()> { use itertools::merge_join_by; use itertools::EitherOrBoth::{Both, Left, Right}; diff --git a/crates/milli/src/update/index_documents/extract/extract_word_position_docids.rs b/crates/milli/src/update/index_documents/extract/extract_word_position_docids.rs index 50b1617f9..f870fbe1b 100644 --- a/crates/milli/src/update/index_documents/extract/extract_word_position_docids.rs +++ b/crates/milli/src/update/index_documents/extract/extract_word_position_docids.rs @@ -5,14 +5,13 @@ use std::io::{self, BufReader}; use obkv::KvReaderU16; use super::helpers::{ - create_sorter, merge_deladd_cbo_roaring_bitmaps, sorter_into_reader, try_split_array_at, - GrenadParameters, + create_sorter, sorter_into_reader, try_split_array_at, GrenadParameters, + MergeDeladdCboRoaringBitmaps, }; use crate::error::SerializationError; use crate::index::db_name::DOCID_WORD_POSITIONS; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::settings::InnerIndexSettingsDiff; -use crate::update::MergeFn; use crate::{bucketed_position, DocumentId, Result}; /// Extracts the word positions and the documents ids where this word appear. @@ -29,11 +28,12 @@ pub fn extract_word_position_docids( let mut word_position_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, indexer.chunk_compression_type, indexer.chunk_compression_level, indexer.max_nb_chunks, max_memory, + true, ); let mut del_word_positions: BTreeSet<(u16, Vec)> = BTreeSet::new(); @@ -60,10 +60,10 @@ pub fn extract_word_position_docids( current_document_id = Some(document_id); - let del_add_reader = KvReaderDelAdd::new(value); + let del_add_reader = KvReaderDelAdd::from_slice(value); // extract all unique words to remove. if let Some(deletion) = del_add_reader.get(DelAdd::Deletion) { - for (position, word_bytes) in KvReaderU16::new(deletion).iter() { + for (position, word_bytes) in KvReaderU16::from_slice(deletion).iter() { let position = bucketed_position(position); del_word_positions.insert((position, word_bytes.to_vec())); } @@ -71,7 +71,7 @@ pub fn extract_word_position_docids( // extract all unique additional words. if let Some(addition) = del_add_reader.get(DelAdd::Addition) { - for (position, word_bytes) in KvReaderU16::new(addition).iter() { + for (position, word_bytes) in KvReaderU16::from_slice(addition).iter() { let position = bucketed_position(position); add_word_positions.insert((position, word_bytes.to_vec())); } @@ -100,7 +100,7 @@ fn words_position_into_sorter( key_buffer: &mut Vec, del_word_positions: &BTreeSet<(u16, Vec)>, add_word_positions: &BTreeSet<(u16, Vec)>, - word_position_docids_sorter: &mut grenad::Sorter, + word_position_docids_sorter: &mut grenad::Sorter, ) -> Result<()> { use itertools::merge_join_by; use itertools::EitherOrBoth::{Both, Left, Right}; diff --git a/crates/milli/src/update/index_documents/helpers/grenad_helpers.rs b/crates/milli/src/update/index_documents/helpers/grenad_helpers.rs index 44009f2fa..62dc40edc 100644 --- a/crates/milli/src/update/index_documents/helpers/grenad_helpers.rs +++ b/crates/milli/src/update/index_documents/helpers/grenad_helpers.rs @@ -1,11 +1,10 @@ -use std::borrow::Cow; use std::fs::File; use std::io::{self, BufReader, BufWriter, Seek}; -use grenad::{CompressionType, Sorter}; +use grenad::{CompressionType, MergeFunction, Sorter}; use heed::types::Bytes; -use super::{ClonableMmap, MergeFn}; +use super::ClonableMmap; use crate::update::index_documents::valid_lmdb_key; use crate::Result; @@ -31,14 +30,15 @@ pub fn create_writer( /// A helper function that creates a grenad sorter /// with the given parameters. The max memory is /// clamped to something reasonable. -pub fn create_sorter( +pub fn create_sorter( sort_algorithm: grenad::SortAlgorithm, - merge: MergeFn, + merge: MF, chunk_compression_type: grenad::CompressionType, chunk_compression_level: Option, max_nb_chunks: Option, max_memory: Option, -) -> grenad::Sorter { + sort_in_parallel: bool, +) -> grenad::Sorter { let mut builder = grenad::Sorter::builder(merge); builder.chunk_compression_type(chunk_compression_type); if let Some(level) = chunk_compression_level { @@ -52,15 +52,19 @@ pub fn create_sorter( builder.allow_realloc(false); } builder.sort_algorithm(sort_algorithm); - builder.sort_in_parallel(true); + builder.sort_in_parallel(sort_in_parallel); builder.build() } #[tracing::instrument(level = "trace", skip_all, target = "indexing::grenad")] -pub fn sorter_into_reader( - sorter: grenad::Sorter, +pub fn sorter_into_reader( + sorter: grenad::Sorter, indexer: GrenadParameters, -) -> Result>> { +) -> Result>> +where + MF: MergeFunction, + crate::Error: From, +{ let mut writer = create_writer( indexer.chunk_compression_type, indexer.chunk_compression_level, @@ -79,6 +83,8 @@ pub fn writer_into_reader( grenad::Reader::new(BufReader::new(file)).map_err(Into::into) } +/// # Safety +/// We use memory mapping inside. So, according to the Rust community, it's unsafe. pub unsafe fn as_cloneable_grenad( reader: &grenad::Reader>, ) -> Result> { @@ -113,12 +119,8 @@ impl GrenadParameters { /// /// This should be called inside of a rayon thread pool, /// otherwise, it will take the global number of threads. - /// - /// The max memory cannot exceed a given reasonable value. pub fn max_memory_by_thread(&self) -> Option { - self.max_memory.map(|max_memory| { - (max_memory / rayon::current_num_threads()).min(MAX_GRENAD_SORTER_USAGE) - }) + self.max_memory.map(|max_memory| (max_memory / rayon::current_num_threads())) } } @@ -169,8 +171,8 @@ pub fn grenad_obkv_into_chunks( /// Write provided sorter in database using serialize_value function. /// merge_values function is used if an entry already exist in the database. #[tracing::instrument(level = "trace", skip_all, target = "indexing::grenad")] -pub fn write_sorter_into_database( - sorter: Sorter, +pub fn write_sorter_into_database( + sorter: Sorter, database: &heed::Database, wtxn: &mut heed::RwTxn<'_>, index_is_empty: bool, @@ -180,6 +182,8 @@ pub fn write_sorter_into_database( where FS: for<'a> Fn(&'a [u8], &'a mut Vec) -> Result<&'a [u8]>, FM: for<'a> Fn(&[u8], &[u8], &'a mut Vec) -> Result>, + MF: MergeFunction, + crate::Error: From, { let mut buffer = Vec::new(); let database = database.remap_types::(); @@ -207,8 +211,3 @@ where Ok(()) } - -/// Used when trying to merge readers, but you don't actually care about the values. -pub fn merge_ignore_values<'a>(_key: &[u8], _values: &[Cow<'a, [u8]>]) -> Result> { - Ok(Cow::Owned(Vec::new())) -} diff --git a/crates/milli/src/update/index_documents/helpers/merge_functions.rs b/crates/milli/src/update/index_documents/helpers/merge_functions.rs index 42784048a..ab8a09a60 100644 --- a/crates/milli/src/update/index_documents/helpers/merge_functions.rs +++ b/crates/milli/src/update/index_documents/helpers/merge_functions.rs @@ -3,6 +3,8 @@ use std::collections::BTreeSet; use std::io; use std::result::Result as StdResult; +use either::Either; +use grenad::MergeFunction; use roaring::RoaringBitmap; use crate::heed_codec::CboRoaringBitmapCodec; @@ -10,7 +12,8 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::transform::Operation; use crate::Result; -pub type MergeFn = for<'a> fn(&[u8], &[Cow<'a, [u8]>]) -> Result>; +pub type EitherObkvMerge = + Either; pub fn serialize_roaring_bitmap(bitmap: &RoaringBitmap, buffer: &mut Vec) -> io::Result<()> { buffer.clear(); @@ -18,35 +21,53 @@ pub fn serialize_roaring_bitmap(bitmap: &RoaringBitmap, buffer: &mut Vec) -> bitmap.serialize_into(buffer) } -pub fn merge_roaring_bitmaps<'a>(_key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { - if values.len() == 1 { - Ok(values[0].clone()) - } else { - let merged = values - .iter() - .map(AsRef::as_ref) - .map(RoaringBitmap::deserialize_from) - .map(StdResult::unwrap) - .reduce(|a, b| a | b) - .unwrap(); - let mut buffer = Vec::new(); - serialize_roaring_bitmap(&merged, &mut buffer)?; - Ok(Cow::Owned(buffer)) +pub struct MergeRoaringBitmaps; + +impl MergeFunction for MergeRoaringBitmaps { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { + if values.len() == 1 { + Ok(values[0].clone()) + } else { + let merged = values + .iter() + .map(AsRef::as_ref) + .map(RoaringBitmap::deserialize_from) + .map(StdResult::unwrap) + .reduce(|a, b| a | b) + .unwrap(); + let mut buffer = Vec::new(); + serialize_roaring_bitmap(&merged, &mut buffer)?; + Ok(Cow::Owned(buffer)) + } } } -pub fn keep_first<'a>(_key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { - Ok(values[0].clone()) +pub struct KeepFirst; + +impl MergeFunction for KeepFirst { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { + Ok(values[0].clone()) + } } /// Only the last value associated with an id is kept. -pub fn keep_latest_obkv<'a>(_key: &[u8], obkvs: &[Cow<'a, [u8]>]) -> Result> { - Ok(obkvs.last().unwrap().clone()) +pub struct KeepLatestObkv; + +impl MergeFunction for KeepLatestObkv { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], obkvs: &[Cow<'a, [u8]>]) -> Result> { + Ok(obkvs.last().unwrap().clone()) + } } pub fn merge_two_del_add_obkvs( - base: obkv::KvReaderU16<'_>, - update: obkv::KvReaderU16<'_>, + base: &obkv::KvReaderU16, + update: &obkv::KvReaderU16, merge_additions: bool, buffer: &mut Vec, ) { @@ -66,7 +87,7 @@ pub fn merge_two_del_add_obkvs( // If merge_additions is false, recreate an obkv keeping the deletions only. value_buffer.clear(); let mut value_writer = KvWriterDelAdd::new(&mut value_buffer); - let base_reader = KvReaderDelAdd::new(v); + let base_reader = KvReaderDelAdd::from_slice(v); if let Some(deletion) = base_reader.get(DelAdd::Deletion) { value_writer.insert(DelAdd::Deletion, deletion).unwrap(); @@ -80,8 +101,8 @@ pub fn merge_two_del_add_obkvs( // merge deletions and additions. value_buffer.clear(); let mut value_writer = KvWriterDelAdd::new(&mut value_buffer); - let base_reader = KvReaderDelAdd::new(base); - let update_reader = KvReaderDelAdd::new(update); + let base_reader = KvReaderDelAdd::from_slice(base); + let update_reader = KvReaderDelAdd::from_slice(update); // keep newest deletion. if let Some(deletion) = update_reader @@ -131,8 +152,8 @@ fn inner_merge_del_add_obkvs<'a>( break; } - let newest = obkv::KvReader::new(&acc); - let oldest = obkv::KvReader::new(¤t[1..]); + let newest = obkv::KvReader::from_slice(&acc); + let oldest = obkv::KvReader::from_slice(¤t[1..]); merge_two_del_add_obkvs(oldest, newest, merge_additions, &mut buffer); // we want the result of the merge into our accumulator. @@ -145,65 +166,79 @@ fn inner_merge_del_add_obkvs<'a>( } /// Merge all the obkvs from the newest to the oldest. -pub fn obkvs_merge_additions_and_deletions<'a>( - _key: &[u8], - obkvs: &[Cow<'a, [u8]>], -) -> Result> { - inner_merge_del_add_obkvs(obkvs, true) +#[derive(Copy, Clone)] +pub struct ObkvsMergeAdditionsAndDeletions; + +impl MergeFunction for ObkvsMergeAdditionsAndDeletions { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], obkvs: &[Cow<'a, [u8]>]) -> Result> { + inner_merge_del_add_obkvs(obkvs, true) + } } /// Merge all the obkvs deletions from the newest to the oldest and keep only the newest additions. -pub fn obkvs_keep_last_addition_merge_deletions<'a>( - _key: &[u8], - obkvs: &[Cow<'a, [u8]>], -) -> Result> { - inner_merge_del_add_obkvs(obkvs, false) +#[derive(Copy, Clone)] +pub struct ObkvsKeepLastAdditionMergeDeletions; + +impl MergeFunction for ObkvsKeepLastAdditionMergeDeletions { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], obkvs: &[Cow<'a, [u8]>]) -> Result> { + inner_merge_del_add_obkvs(obkvs, false) + } } /// Do a union of all the CboRoaringBitmaps in the values. -pub fn merge_cbo_roaring_bitmaps<'a>( - _key: &[u8], - values: &[Cow<'a, [u8]>], -) -> Result> { - if values.len() == 1 { - Ok(values[0].clone()) - } else { - let mut vec = Vec::new(); - CboRoaringBitmapCodec::merge_into(values, &mut vec)?; - Ok(Cow::from(vec)) +pub struct MergeCboRoaringBitmaps; + +impl MergeFunction for MergeCboRoaringBitmaps { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { + if values.len() == 1 { + Ok(values[0].clone()) + } else { + let mut vec = Vec::new(); + CboRoaringBitmapCodec::merge_into(values, &mut vec)?; + Ok(Cow::from(vec)) + } } } /// Do a union of CboRoaringBitmaps on both sides of a DelAdd obkv /// separately and outputs a new DelAdd with both unions. -pub fn merge_deladd_cbo_roaring_bitmaps<'a>( - _key: &[u8], - values: &[Cow<'a, [u8]>], -) -> Result> { - if values.len() == 1 { - Ok(values[0].clone()) - } else { - // Retrieve the bitmaps from both sides - let mut del_bitmaps_bytes = Vec::new(); - let mut add_bitmaps_bytes = Vec::new(); - for value in values { - let obkv = KvReaderDelAdd::new(value); - if let Some(bitmap_bytes) = obkv.get(DelAdd::Deletion) { - del_bitmaps_bytes.push(bitmap_bytes); - } - if let Some(bitmap_bytes) = obkv.get(DelAdd::Addition) { - add_bitmaps_bytes.push(bitmap_bytes); - } - } +pub struct MergeDeladdCboRoaringBitmaps; - let mut output_deladd_obkv = KvWriterDelAdd::memory(); - let mut buffer = Vec::new(); - CboRoaringBitmapCodec::merge_into(del_bitmaps_bytes, &mut buffer)?; - output_deladd_obkv.insert(DelAdd::Deletion, &buffer)?; - buffer.clear(); - CboRoaringBitmapCodec::merge_into(add_bitmaps_bytes, &mut buffer)?; - output_deladd_obkv.insert(DelAdd::Addition, &buffer)?; - output_deladd_obkv.into_inner().map(Cow::from).map_err(Into::into) +impl MergeFunction for MergeDeladdCboRoaringBitmaps { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { + if values.len() == 1 { + Ok(values[0].clone()) + } else { + // Retrieve the bitmaps from both sides + let mut del_bitmaps_bytes = Vec::new(); + let mut add_bitmaps_bytes = Vec::new(); + for value in values { + let obkv = KvReaderDelAdd::from_slice(value); + if let Some(bitmap_bytes) = obkv.get(DelAdd::Deletion) { + del_bitmaps_bytes.push(bitmap_bytes); + } + if let Some(bitmap_bytes) = obkv.get(DelAdd::Addition) { + add_bitmaps_bytes.push(bitmap_bytes); + } + } + + let mut output_deladd_obkv = KvWriterDelAdd::memory(); + let mut buffer = Vec::new(); + CboRoaringBitmapCodec::merge_into(del_bitmaps_bytes, &mut buffer)?; + output_deladd_obkv.insert(DelAdd::Deletion, &buffer)?; + buffer.clear(); + CboRoaringBitmapCodec::merge_into(add_bitmaps_bytes, &mut buffer)?; + output_deladd_obkv.insert(DelAdd::Addition, &buffer)?; + output_deladd_obkv.into_inner().map(Cow::from).map_err(Into::into) + } } } @@ -217,7 +252,7 @@ pub fn merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap<'a>( buffer: &'a mut Vec, ) -> Result> { Ok(CboRoaringBitmapCodec::merge_deladd_into( - KvReaderDelAdd::new(deladd_obkv), + KvReaderDelAdd::from_slice(deladd_obkv), previous, buffer, )?) @@ -225,37 +260,55 @@ pub fn merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap<'a>( /// Do a union of BtreeSet on both sides of a DelAdd obkv /// separately and outputs a new DelAdd with both unions. -pub fn merge_deladd_btreeset_string<'a>( - _key: &[u8], - values: &[Cow<'a, [u8]>], -) -> Result> { - if values.len() == 1 { - Ok(values[0].clone()) - } else { - // Retrieve the bitmaps from both sides - let mut del_set = BTreeSet::new(); - let mut add_set = BTreeSet::new(); - for value in values { - let obkv = KvReaderDelAdd::new(value); - if let Some(bytes) = obkv.get(DelAdd::Deletion) { - let set = serde_json::from_slice::>(bytes).unwrap(); - for value in set { - del_set.insert(value); - } - } - if let Some(bytes) = obkv.get(DelAdd::Addition) { - let set = serde_json::from_slice::>(bytes).unwrap(); - for value in set { - add_set.insert(value); - } - } - } +pub struct MergeDeladdBtreesetString; - let mut output_deladd_obkv = KvWriterDelAdd::memory(); - let del = serde_json::to_vec(&del_set).unwrap(); - output_deladd_obkv.insert(DelAdd::Deletion, &del)?; - let add = serde_json::to_vec(&add_set).unwrap(); - output_deladd_obkv.insert(DelAdd::Addition, &add)?; - output_deladd_obkv.into_inner().map(Cow::from).map_err(Into::into) +impl MergeFunction for MergeDeladdBtreesetString { + type Error = crate::Error; + + fn merge<'a>(&self, _key: &[u8], values: &[Cow<'a, [u8]>]) -> Result> { + if values.len() == 1 { + Ok(values[0].clone()) + } else { + // Retrieve the bitmaps from both sides + let mut del_set = BTreeSet::new(); + let mut add_set = BTreeSet::new(); + for value in values { + let obkv = KvReaderDelAdd::from_slice(value); + if let Some(bytes) = obkv.get(DelAdd::Deletion) { + let set = serde_json::from_slice::>(bytes).unwrap(); + for value in set { + del_set.insert(value); + } + } + if let Some(bytes) = obkv.get(DelAdd::Addition) { + let set = serde_json::from_slice::>(bytes).unwrap(); + for value in set { + add_set.insert(value); + } + } + } + + let mut output_deladd_obkv = KvWriterDelAdd::memory(); + let del = serde_json::to_vec(&del_set).unwrap(); + output_deladd_obkv.insert(DelAdd::Deletion, &del)?; + let add = serde_json::to_vec(&add_set).unwrap(); + output_deladd_obkv.insert(DelAdd::Addition, &add)?; + output_deladd_obkv.into_inner().map(Cow::from).map_err(Into::into) + } + } +} + +/// Used when trying to merge readers, but you don't actually care about the values. +pub struct MergeIgnoreValues; + +impl MergeFunction for MergeIgnoreValues { + type Error = crate::Error; + + fn merge<'a>( + &self, + _key: &[u8], + _values: &[Cow<'a, [u8]>], + ) -> std::result::Result, Self::Error> { + Ok(Cow::Owned(Vec::new())) } } diff --git a/crates/milli/src/update/index_documents/helpers/mod.rs b/crates/milli/src/update/index_documents/helpers/mod.rs index 5d8f16fae..c188e324d 100644 --- a/crates/milli/src/update/index_documents/helpers/mod.rs +++ b/crates/milli/src/update/index_documents/helpers/mod.rs @@ -7,17 +7,8 @@ use std::convert::{TryFrom, TryInto}; pub use clonable_mmap::{ClonableMmap, CursorClonableMmap}; use fst::{IntoStreamer, Streamer}; -pub use grenad_helpers::{ - as_cloneable_grenad, create_sorter, create_writer, grenad_obkv_into_chunks, - merge_ignore_values, sorter_into_reader, write_sorter_into_database, writer_into_reader, - GrenadParameters, -}; -pub use merge_functions::{ - keep_first, keep_latest_obkv, merge_cbo_roaring_bitmaps, merge_deladd_btreeset_string, - merge_deladd_cbo_roaring_bitmaps, merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, - merge_roaring_bitmaps, obkvs_keep_last_addition_merge_deletions, - obkvs_merge_additions_and_deletions, MergeFn, -}; +pub use grenad_helpers::*; +pub use merge_functions::*; use crate::MAX_WORD_LENGTH; diff --git a/crates/milli/src/update/index_documents/mod.rs b/crates/milli/src/update/index_documents/mod.rs index 88d20fff0..befde896d 100644 --- a/crates/milli/src/update/index_documents/mod.rs +++ b/crates/milli/src/update/index_documents/mod.rs @@ -27,13 +27,7 @@ use typed_chunk::{write_typed_chunk_into_index, ChunkAccumulator, TypedChunk}; use self::enrich::enrich_documents_batch; pub use self::enrich::{extract_finite_float_from_value, DocumentId}; -pub use self::helpers::{ - as_cloneable_grenad, create_sorter, create_writer, fst_stream_into_hashset, - fst_stream_into_vec, merge_cbo_roaring_bitmaps, merge_deladd_cbo_roaring_bitmaps, - merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, merge_roaring_bitmaps, - valid_lmdb_key, write_sorter_into_database, writer_into_reader, MergeFn, -}; -use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; +pub use self::helpers::*; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchBuilder, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; @@ -605,7 +599,7 @@ where let cloneable_chunk = unsafe { as_cloneable_grenad(&word_docids_reader)? }; let word_docids = word_docids.get_or_insert_with(|| { - MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn) + MergerBuilder::new(MergeDeladdCboRoaringBitmaps) }); word_docids.push(cloneable_chunk.into_cursor()?); let cloneable_chunk = @@ -613,14 +607,14 @@ where let exact_word_docids = exact_word_docids.get_or_insert_with(|| { MergerBuilder::new( - merge_deladd_cbo_roaring_bitmaps as MergeFn, + MergeDeladdCboRoaringBitmaps, ) }); exact_word_docids.push(cloneable_chunk.into_cursor()?); let cloneable_chunk = unsafe { as_cloneable_grenad(&word_fid_docids_reader)? }; let word_fid_docids = word_fid_docids.get_or_insert_with(|| { - MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn) + MergerBuilder::new(MergeDeladdCboRoaringBitmaps) }); word_fid_docids.push(cloneable_chunk.into_cursor()?); TypedChunk::WordDocids { @@ -634,7 +628,7 @@ where let word_position_docids = word_position_docids.get_or_insert_with(|| { MergerBuilder::new( - merge_deladd_cbo_roaring_bitmaps as MergeFn, + MergeDeladdCboRoaringBitmaps, ) }); word_position_docids.push(cloneable_chunk.into_cursor()?); @@ -738,10 +732,10 @@ where )] pub fn execute_prefix_databases( self, - word_docids: Option>, - exact_word_docids: Option>, - word_position_docids: Option>, - word_fid_docids: Option>, + word_docids: Option>, + exact_word_docids: Option>, + word_position_docids: Option>, + word_fid_docids: Option>, ) -> Result<()> where FP: Fn(UpdateIndexingStep) + Sync, @@ -921,7 +915,7 @@ where )] fn execute_word_prefix_docids( txn: &mut heed::RwTxn<'_>, - merger: Merger, + merger: Merger, word_docids_db: Database, word_prefix_docids_db: Database, indexer_config: &IndexerConfig, diff --git a/crates/milli/src/update/index_documents/parallel.rs b/crates/milli/src/update/index_documents/parallel.rs index 52e72a378..2f6bf9caf 100644 --- a/crates/milli/src/update/index_documents/parallel.rs +++ b/crates/milli/src/update/index_documents/parallel.rs @@ -31,14 +31,14 @@ impl<'t> ImmutableObkvs<'t> { } /// Returns the OBKVs identified by the given ID. - pub fn obkv(&self, docid: DocumentId) -> heed::Result>> { + pub fn obkv(&self, docid: DocumentId) -> heed::Result> { match self .ids .rank(docid) .checked_sub(1) .and_then(|offset| self.slices.get(offset as usize)) { - Some(bytes) => Ok(Some(KvReaderU16::new(bytes))), + Some(&bytes) => Ok(Some(bytes.into())), None => Ok(None), } } diff --git a/crates/milli/src/update/index_documents/transform.rs b/crates/milli/src/update/index_documents/transform.rs index 763f30d0f..7239e8bff 100644 --- a/crates/milli/src/update/index_documents/transform.rs +++ b/crates/milli/src/update/index_documents/transform.rs @@ -5,6 +5,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::fs::File; use std::io::{Read, Seek}; +use either::Either; use fxhash::FxHashMap; use itertools::Itertools; use obkv::{KvReader, KvReaderU16, KvWriter}; @@ -13,10 +14,10 @@ use serde_json::Value; use smartstring::SmartString; use super::helpers::{ - create_sorter, create_writer, keep_first, obkvs_keep_last_addition_merge_deletions, - obkvs_merge_additions_and_deletions, sorter_into_reader, MergeFn, + create_sorter, create_writer, sorter_into_reader, EitherObkvMerge, + ObkvsKeepLastAdditionMergeDeletions, ObkvsMergeAdditionsAndDeletions, }; -use super::{IndexDocumentsMethod, IndexerConfig}; +use super::{IndexDocumentsMethod, IndexerConfig, KeepFirst}; use crate::documents::{DocumentsBatchIndex, EnrichedDocument, EnrichedDocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; use crate::index::{db_name, main_key}; @@ -26,7 +27,7 @@ use crate::update::del_add::{ }; use crate::update::index_documents::GrenadParameters; use crate::update::settings::{InnerIndexSettings, InnerIndexSettingsDiff}; -use crate::update::{AvailableDocumentsIds, UpdateIndexingStep}; +use crate::update::{AvailableIds, UpdateIndexingStep}; use crate::vector::parsed_vectors::{ExplicitVectors, VectorOrArrayOfVectors}; use crate::vector::settings::WriteBackToDocuments; use crate::vector::ArroyWrapper; @@ -55,13 +56,13 @@ pub struct Transform<'a, 'i> { indexer_settings: &'a IndexerConfig, pub index_documents_method: IndexDocumentsMethod, - available_documents_ids: AvailableDocumentsIds, + available_documents_ids: AvailableIds, // Both grenad follows the same format: // key | value // u32 | 1 byte for the Operation byte, the rest is the obkv of the document stored - original_sorter: grenad::Sorter, - flattened_sorter: grenad::Sorter, + original_sorter: grenad::Sorter, + flattened_sorter: grenad::Sorter, replaced_documents_ids: RoaringBitmap, new_documents_ids: RoaringBitmap, @@ -109,11 +110,13 @@ impl<'a, 'i> Transform<'a, 'i> { index_documents_method: IndexDocumentsMethod, _autogenerate_docids: bool, ) -> Result { + use IndexDocumentsMethod::{ReplaceDocuments, UpdateDocuments}; + // We must choose the appropriate merge function for when two or more documents // with the same user id must be merged or fully replaced in the same batch. let merge_function = match index_documents_method { - IndexDocumentsMethod::ReplaceDocuments => obkvs_keep_last_addition_merge_deletions, - IndexDocumentsMethod::UpdateDocuments => obkvs_merge_additions_and_deletions, + ReplaceDocuments => Either::Left(ObkvsKeepLastAdditionMergeDeletions), + UpdateDocuments => Either::Right(ObkvsMergeAdditionsAndDeletions), }; // We initialize the sorter with the user indexing settings. @@ -124,6 +127,7 @@ impl<'a, 'i> Transform<'a, 'i> { indexer_settings.chunk_compression_level, indexer_settings.max_nb_chunks, indexer_settings.max_memory.map(|mem| mem / 2), + true, ); // We initialize the sorter with the user indexing settings. @@ -134,6 +138,7 @@ impl<'a, 'i> Transform<'a, 'i> { indexer_settings.chunk_compression_level, indexer_settings.max_nb_chunks, indexer_settings.max_memory.map(|mem| mem / 2), + true, ); let documents_ids = index.documents_ids(wtxn)?; @@ -141,7 +146,7 @@ impl<'a, 'i> Transform<'a, 'i> { index, fields_ids_map: index.fields_ids_map(wtxn)?, indexer_settings, - available_documents_ids: AvailableDocumentsIds::from_documents_ids(&documents_ids), + available_documents_ids: AvailableIds::new(&documents_ids), original_sorter, flattened_sorter, index_documents_method, @@ -279,21 +284,21 @@ impl<'a, 'i> Transform<'a, 'i> { document_sorter_value_buffer.clear(); document_sorter_value_buffer.push(Operation::Addition as u8); into_del_add_obkv( - KvReaderU16::new(base_obkv), + KvReaderU16::from_slice(base_obkv), deladd_operation, &mut document_sorter_value_buffer, )?; self.original_sorter .insert(&document_sorter_key_buffer, &document_sorter_value_buffer)?; - let base_obkv = KvReader::new(base_obkv); + let base_obkv = KvReader::from_slice(base_obkv); if let Some(flattened_obkv) = - Self::flatten_from_fields_ids_map(&base_obkv, &mut self.fields_ids_map)? + Self::flatten_from_fields_ids_map(base_obkv, &mut self.fields_ids_map)? { // we recreate our buffer with the flattened documents document_sorter_value_buffer.clear(); document_sorter_value_buffer.push(Operation::Addition as u8); into_del_add_obkv( - KvReaderU16::new(&flattened_obkv), + KvReaderU16::from_slice(&flattened_obkv), deladd_operation, &mut document_sorter_value_buffer, )?; @@ -312,7 +317,7 @@ impl<'a, 'i> Transform<'a, 'i> { document_sorter_value_buffer.clear(); document_sorter_value_buffer.push(Operation::Addition as u8); into_del_add_obkv( - KvReaderU16::new(&obkv_buffer), + KvReaderU16::from_slice(&obkv_buffer), DelAddOperation::Addition, &mut document_sorter_value_buffer, )?; @@ -320,14 +325,14 @@ impl<'a, 'i> Transform<'a, 'i> { self.original_sorter .insert(&document_sorter_key_buffer, &document_sorter_value_buffer)?; - let flattened_obkv = KvReader::new(&obkv_buffer); + let flattened_obkv = KvReader::from_slice(&obkv_buffer); if let Some(obkv) = - Self::flatten_from_fields_ids_map(&flattened_obkv, &mut self.fields_ids_map)? + Self::flatten_from_fields_ids_map(flattened_obkv, &mut self.fields_ids_map)? { document_sorter_value_buffer.clear(); document_sorter_value_buffer.push(Operation::Addition as u8); into_del_add_obkv( - KvReaderU16::new(&obkv), + KvReaderU16::from_slice(&obkv), DelAddOperation::Addition, &mut document_sorter_value_buffer, )? @@ -520,22 +525,22 @@ impl<'a, 'i> Transform<'a, 'i> { document_sorter_value_buffer.clear(); document_sorter_value_buffer.push(Operation::Deletion as u8); into_del_add_obkv( - KvReaderU16::new(base_obkv), + KvReaderU16::from_slice(base_obkv), DelAddOperation::Deletion, document_sorter_value_buffer, )?; self.original_sorter.insert(&document_sorter_key_buffer, &document_sorter_value_buffer)?; // flatten it and push it as to delete in the flattened_sorter - let flattened_obkv = KvReader::new(base_obkv); + let flattened_obkv = KvReader::from_slice(base_obkv); if let Some(obkv) = - Self::flatten_from_fields_ids_map(&flattened_obkv, &mut self.fields_ids_map)? + Self::flatten_from_fields_ids_map(flattened_obkv, &mut self.fields_ids_map)? { // we recreate our buffer with the flattened documents document_sorter_value_buffer.clear(); document_sorter_value_buffer.push(Operation::Deletion as u8); into_del_add_obkv( - KvReaderU16::new(&obkv), + KvReaderU16::from_slice(&obkv), DelAddOperation::Deletion, document_sorter_value_buffer, )?; @@ -553,7 +558,7 @@ impl<'a, 'i> Transform<'a, 'i> { target = "indexing::transform" )] fn flatten_from_fields_ids_map( - obkv: &KvReader<'_, FieldId>, + obkv: &KvReader, fields_ids_map: &mut FieldsIdsMap, ) -> Result>> { if obkv @@ -721,10 +726,10 @@ impl<'a, 'i> Transform<'a, 'i> { total_documents: self.documents_count, }); - for (key, value) in KvReader::new(val) { - let reader = KvReaderDelAdd::new(value); + for (key, value) in KvReader::from_slice(val) { + let reader = KvReaderDelAdd::from_slice(value); match (reader.get(DelAdd::Deletion), reader.get(DelAdd::Addition)) { - (None, None) => {} + (None, None) => (), (None, Some(_)) => { // New field let name = self.fields_ids_map.name(key).ok_or( @@ -838,7 +843,7 @@ impl<'a, 'i> Transform<'a, 'i> { /// then fill the provided buffers with delta documents using KvWritterDelAdd. #[allow(clippy::too_many_arguments)] // need the vectors + fid, feel free to create a struct xo xo fn rebind_existing_document( - old_obkv: KvReader<'_, FieldId>, + old_obkv: &KvReader, settings_diff: &InnerIndexSettingsDiff, modified_faceted_fields: &HashSet, mut injected_vectors: serde_json::Map, @@ -926,7 +931,7 @@ impl<'a, 'i> Transform<'a, 'i> { } let data = obkv_writer.into_inner()?; - let obkv = KvReader::::new(&data); + let obkv = KvReader::::from_slice(&data); if let Some(original_obkv_buffer) = original_obkv_buffer { original_obkv_buffer.clear(); @@ -936,8 +941,8 @@ impl<'a, 'i> Transform<'a, 'i> { if let Some(flattened_obkv_buffer) = flattened_obkv_buffer { // take the non-flattened version if flatten_from_fields_ids_map returns None. let mut fields_ids_map = settings_diff.new.fields_ids_map.clone(); - let flattened = Self::flatten_from_fields_ids_map(&obkv, &mut fields_ids_map)?; - let flattened = flattened.as_deref().map_or(obkv, KvReader::new); + let flattened = Self::flatten_from_fields_ids_map(obkv, &mut fields_ids_map)?; + let flattened = flattened.as_deref().map_or(obkv, KvReader::from_slice); flattened_obkv_buffer.clear(); into_del_add_obkv_conditional_operation(flattened, flattened_obkv_buffer, |id| { @@ -980,11 +985,12 @@ impl<'a, 'i> Transform<'a, 'i> { let mut original_sorter = if settings_diff.reindex_vectors() { Some(create_sorter( grenad::SortAlgorithm::Stable, - keep_first, + KeepFirst, self.indexer_settings.chunk_compression_type, self.indexer_settings.chunk_compression_level, self.indexer_settings.max_nb_chunks, self.indexer_settings.max_memory.map(|mem| mem / 2), + true, )) } else { None @@ -1019,11 +1025,12 @@ impl<'a, 'i> Transform<'a, 'i> { if settings_diff.reindex_searchable() || settings_diff.reindex_facets() { Some(create_sorter( grenad::SortAlgorithm::Stable, - keep_first, + KeepFirst, self.indexer_settings.chunk_compression_type, self.indexer_settings.chunk_compression_level, self.indexer_settings.max_nb_chunks, self.indexer_settings.max_memory.map(|mem| mem / 2), + true, )) } else { None @@ -1137,6 +1144,8 @@ fn drop_and_reuse(mut vec: Vec) -> Vec { #[cfg(test)] mod test { + use grenad::MergeFunction; + use super::*; #[test] @@ -1148,21 +1157,21 @@ mod test { kv_writer.insert(0_u8, [0]).unwrap(); let buffer = kv_writer.into_inner().unwrap(); into_del_add_obkv( - KvReaderU16::new(&buffer), + KvReaderU16::from_slice(&buffer), DelAddOperation::Addition, &mut additive_doc_0, ) .unwrap(); additive_doc_0.insert(0, Operation::Addition as u8); into_del_add_obkv( - KvReaderU16::new(&buffer), + KvReaderU16::from_slice(&buffer), DelAddOperation::Deletion, &mut deletive_doc_0, ) .unwrap(); deletive_doc_0.insert(0, Operation::Deletion as u8); into_del_add_obkv( - KvReaderU16::new(&buffer), + KvReaderU16::from_slice(&buffer), DelAddOperation::DeletionAndAddition, &mut del_add_doc_0, ) @@ -1174,7 +1183,7 @@ mod test { kv_writer.insert(1_u8, [1]).unwrap(); let buffer = kv_writer.into_inner().unwrap(); into_del_add_obkv( - KvReaderU16::new(&buffer), + KvReaderU16::from_slice(&buffer), DelAddOperation::Addition, &mut additive_doc_1, ) @@ -1187,32 +1196,39 @@ mod test { kv_writer.insert(1_u8, [1]).unwrap(); let buffer = kv_writer.into_inner().unwrap(); into_del_add_obkv( - KvReaderU16::new(&buffer), + KvReaderU16::from_slice(&buffer), DelAddOperation::Addition, &mut additive_doc_0_1, ) .unwrap(); additive_doc_0_1.insert(0, Operation::Addition as u8); - let ret = obkvs_merge_additions_and_deletions(&[], &[Cow::from(additive_doc_0.as_slice())]) - .unwrap(); + let ret = MergeFunction::merge( + &ObkvsMergeAdditionsAndDeletions, + &[], + &[Cow::from(additive_doc_0.as_slice())], + ) + .unwrap(); assert_eq!(*ret, additive_doc_0); - let ret = obkvs_merge_additions_and_deletions( + let ret = MergeFunction::merge( + &ObkvsMergeAdditionsAndDeletions, &[], &[Cow::from(deletive_doc_0.as_slice()), Cow::from(additive_doc_0.as_slice())], ) .unwrap(); assert_eq!(*ret, del_add_doc_0); - let ret = obkvs_merge_additions_and_deletions( + let ret = MergeFunction::merge( + &ObkvsMergeAdditionsAndDeletions, &[], &[Cow::from(additive_doc_0.as_slice()), Cow::from(deletive_doc_0.as_slice())], ) .unwrap(); assert_eq!(*ret, deletive_doc_0); - let ret = obkvs_merge_additions_and_deletions( + let ret = MergeFunction::merge( + &ObkvsMergeAdditionsAndDeletions, &[], &[ Cow::from(additive_doc_1.as_slice()), @@ -1223,21 +1239,24 @@ mod test { .unwrap(); assert_eq!(*ret, del_add_doc_0); - let ret = obkvs_merge_additions_and_deletions( + let ret = MergeFunction::merge( + &ObkvsMergeAdditionsAndDeletions, &[], &[Cow::from(additive_doc_1.as_slice()), Cow::from(additive_doc_0.as_slice())], ) .unwrap(); assert_eq!(*ret, additive_doc_0_1); - let ret = obkvs_keep_last_addition_merge_deletions( + let ret = MergeFunction::merge( + &ObkvsKeepLastAdditionMergeDeletions, &[], &[Cow::from(additive_doc_1.as_slice()), Cow::from(additive_doc_0.as_slice())], ) .unwrap(); assert_eq!(*ret, additive_doc_0); - let ret = obkvs_keep_last_addition_merge_deletions( + let ret = MergeFunction::merge( + &ObkvsKeepLastAdditionMergeDeletions, &[], &[ Cow::from(deletive_doc_0.as_slice()), diff --git a/crates/milli/src/update/index_documents/typed_chunk.rs b/crates/milli/src/update/index_documents/typed_chunk.rs index 20e70b2a6..a97569800 100644 --- a/crates/milli/src/update/index_documents/typed_chunk.rs +++ b/crates/milli/src/update/index_documents/typed_chunk.rs @@ -4,18 +4,17 @@ use std::fs::File; use std::io::{self, BufReader}; use bytemuck::allocation::pod_collect_to_vec; -use grenad::{Merger, MergerBuilder}; +use grenad::{MergeFunction, Merger, MergerBuilder}; use heed::types::Bytes; use heed::{BytesDecode, RwTxn}; use obkv::{KvReader, KvWriter}; use roaring::RoaringBitmap; use super::helpers::{ - self, keep_first, merge_deladd_btreeset_string, merge_deladd_cbo_roaring_bitmaps, - merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, merge_ignore_values, valid_lmdb_key, - CursorClonableMmap, + self, merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, valid_lmdb_key, + CursorClonableMmap, KeepFirst, MergeDeladdBtreesetString, MergeDeladdCboRoaringBitmaps, + MergeIgnoreValues, }; -use super::MergeFn; use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind}; use crate::facet::FacetType; use crate::index::db_name::DOCUMENTS; @@ -24,7 +23,7 @@ use crate::proximity::MAX_DISTANCE; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{ - as_cloneable_grenad, keep_latest_obkv, try_split_array_at, + as_cloneable_grenad, try_split_array_at, KeepLatestObkv, }; use crate::update::settings::InnerIndexSettingsDiff; use crate::vector::ArroyWrapper; @@ -141,7 +140,7 @@ pub(crate) fn write_typed_chunk_into_index( let vectors_fid = fields_ids_map.id(crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); - let mut builder = MergerBuilder::new(keep_latest_obkv as MergeFn); + let mut builder = MergerBuilder::new(KeepLatestObkv); for typed_chunk in typed_chunks { let TypedChunk::Documents(chunk) = typed_chunk else { unreachable!(); @@ -163,7 +162,7 @@ pub(crate) fn write_typed_chunk_into_index( let mut vectors_buffer = Vec::new(); while let Some((key, reader)) = iter.next()? { let mut writer: KvWriter<_, FieldId> = KvWriter::memory(); - let reader: KvReader<'_, FieldId> = KvReader::new(reader); + let reader: &KvReader = reader.into(); let (document_id_bytes, external_id_bytes) = try_split_array_at(key) .ok_or(SerializationError::Decoding { db_name: Some(DOCUMENTS) })?; @@ -171,7 +170,7 @@ pub(crate) fn write_typed_chunk_into_index( let external_id = std::str::from_utf8(external_id_bytes)?; for (field_id, value) in reader.iter() { - let del_add_reader = KvReaderDelAdd::new(value); + let del_add_reader = KvReaderDelAdd::from_slice(value); if let Some(addition) = del_add_reader.get(DelAdd::Addition) { let addition = if vectors_fid == Some(field_id) { @@ -235,7 +234,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "field_id_word_count_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); for typed_chunk in typed_chunks { let TypedChunk::FieldIdWordCountDocids(chunk) = typed_chunk else { unreachable!(); @@ -258,13 +257,10 @@ pub(crate) fn write_typed_chunk_into_index( let span = tracing::trace_span!(target: "indexing::write_db", "word_docids"); let _entered = span.enter(); - let mut word_docids_builder = - MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); - let mut exact_word_docids_builder = - MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); - let mut word_fid_docids_builder = - MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); - let mut fst_merger_builder = MergerBuilder::new(merge_ignore_values as MergeFn); + let mut word_docids_builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); + let mut exact_word_docids_builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); + let mut word_fid_docids_builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); + let mut fst_merger_builder = MergerBuilder::new(MergeIgnoreValues); for typed_chunk in typed_chunks { let TypedChunk::WordDocids { word_docids_reader, @@ -329,7 +325,7 @@ pub(crate) fn write_typed_chunk_into_index( let span = tracing::trace_span!(target: "indexing::write_db", "word_position_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); for typed_chunk in typed_chunks { let TypedChunk::WordPositionDocids(chunk) = typed_chunk else { unreachable!(); @@ -353,7 +349,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db","field_id_facet_number_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); let mut data_size = 0; for typed_chunk in typed_chunks { let TypedChunk::FieldIdFacetNumberDocids(facet_id_number_docids) = typed_chunk @@ -375,10 +371,9 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "field_id_facet_string_docids"); let _entered = span.enter(); - let mut facet_id_string_builder = - MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut facet_id_string_builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); let mut normalized_facet_id_string_builder = - MergerBuilder::new(merge_deladd_btreeset_string as MergeFn); + MergerBuilder::new(MergeDeladdBtreesetString); let mut data_size = 0; for typed_chunk in typed_chunks { let TypedChunk::FieldIdFacetStringDocids(( @@ -412,7 +407,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "field_id_facet_exists_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); for typed_chunk in typed_chunks { let TypedChunk::FieldIdFacetExistsDocids(chunk) = typed_chunk else { unreachable!(); @@ -436,7 +431,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "field_id_facet_is_null_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); for typed_chunk in typed_chunks { let TypedChunk::FieldIdFacetIsNullDocids(chunk) = typed_chunk else { unreachable!(); @@ -459,7 +454,7 @@ pub(crate) fn write_typed_chunk_into_index( let span = tracing::trace_span!(target: "indexing::write_db", "field_id_facet_is_empty_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); for typed_chunk in typed_chunks { let TypedChunk::FieldIdFacetIsEmptyDocids(chunk) = typed_chunk else { unreachable!(); @@ -483,7 +478,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "word_pair_proximity_docids"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(merge_deladd_cbo_roaring_bitmaps as MergeFn); + let mut builder = MergerBuilder::new(MergeDeladdCboRoaringBitmaps); for typed_chunk in typed_chunks { let TypedChunk::WordPairProximityDocids(chunk) = typed_chunk else { unreachable!(); @@ -516,7 +511,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "field_id_docid_facet_numbers"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(keep_first as MergeFn); + let mut builder = MergerBuilder::new(KeepFirst); for typed_chunk in typed_chunks { let TypedChunk::FieldIdDocidFacetNumbers(chunk) = typed_chunk else { unreachable!(); @@ -530,7 +525,7 @@ pub(crate) fn write_typed_chunk_into_index( index.field_id_docid_facet_f64s.remap_types::(); let mut iter = merger.into_stream_merger_iter()?; while let Some((key, value)) = iter.next()? { - let reader = KvReaderDelAdd::new(value); + let reader = KvReaderDelAdd::from_slice(value); if valid_lmdb_key(key) { match (reader.get(DelAdd::Deletion), reader.get(DelAdd::Addition)) { (None, None) => {} @@ -550,7 +545,7 @@ pub(crate) fn write_typed_chunk_into_index( tracing::trace_span!(target: "indexing::write_db", "field_id_docid_facet_strings"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(keep_first as MergeFn); + let mut builder = MergerBuilder::new(KeepFirst); for typed_chunk in typed_chunks { let TypedChunk::FieldIdDocidFacetStrings(chunk) = typed_chunk else { unreachable!(); @@ -564,7 +559,7 @@ pub(crate) fn write_typed_chunk_into_index( index.field_id_docid_facet_strings.remap_types::(); let mut iter = merger.into_stream_merger_iter()?; while let Some((key, value)) = iter.next()? { - let reader = KvReaderDelAdd::new(value); + let reader = KvReaderDelAdd::from_slice(value); if valid_lmdb_key(key) { match (reader.get(DelAdd::Deletion), reader.get(DelAdd::Addition)) { (None, None) => {} @@ -583,7 +578,7 @@ pub(crate) fn write_typed_chunk_into_index( let span = tracing::trace_span!(target: "indexing::write_db", "geo_points"); let _entered = span.enter(); - let mut builder = MergerBuilder::new(keep_first as MergeFn); + let mut builder = MergerBuilder::new(KeepFirst); for typed_chunk in typed_chunks { let TypedChunk::GeoPoints(chunk) = typed_chunk else { unreachable!(); @@ -601,7 +596,7 @@ pub(crate) fn write_typed_chunk_into_index( // convert the key back to a u32 (4 bytes) let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - let deladd_obkv = KvReaderDelAdd::new(value); + let deladd_obkv = KvReaderDelAdd::from_slice(value); if let Some(value) = deladd_obkv.get(DelAdd::Deletion) { let geopoint = extract_geo_point(value, docid); rtree.remove(&geopoint); @@ -620,9 +615,9 @@ pub(crate) fn write_typed_chunk_into_index( let span = tracing::trace_span!(target: "indexing::write_db", "vector_points"); let _entered = span.enter(); - let mut remove_vectors_builder = MergerBuilder::new(keep_first as MergeFn); - let mut manual_vectors_builder = MergerBuilder::new(keep_first as MergeFn); - let mut embeddings_builder = MergerBuilder::new(keep_first as MergeFn); + let mut remove_vectors_builder = MergerBuilder::new(KeepFirst); + let mut manual_vectors_builder = MergerBuilder::new(KeepFirst); + let mut embeddings_builder = MergerBuilder::new(KeepFirst); let mut add_to_user_provided = RoaringBitmap::new(); let mut remove_from_user_provided = RoaringBitmap::new(); let mut params = None; @@ -719,7 +714,7 @@ pub(crate) fn write_typed_chunk_into_index( let (left, _index) = try_split_array_at(key).unwrap(); let docid = DocumentId::from_be_bytes(left); - let vector_deladd_obkv = KvReaderDelAdd::new(value); + let vector_deladd_obkv = KvReaderDelAdd::from_slice(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { let vector: Vec = pod_collect_to_vec(value); @@ -742,7 +737,7 @@ pub(crate) fn write_typed_chunk_into_index( } /// Converts the latitude and longitude back to an xyz GeoPoint. -fn extract_geo_point(value: &[u8], docid: DocumentId) -> GeoPoint { +pub fn extract_geo_point(value: &[u8], docid: DocumentId) -> GeoPoint { let (lat, tail) = helpers::try_split_array_at::(value).unwrap(); let (lng, _) = helpers::try_split_array_at::(tail).unwrap(); let point = [f64::from_ne_bytes(lat), f64::from_ne_bytes(lng)]; @@ -750,9 +745,13 @@ fn extract_geo_point(value: &[u8], docid: DocumentId) -> GeoPoint { GeoPoint::new(xyz_point, (docid, point)) } -fn merge_word_docids_reader_into_fst( - merger: Merger, -) -> Result>> { +fn merge_word_docids_reader_into_fst( + merger: Merger, +) -> Result>> +where + MF: MergeFunction, + crate::Error: From, +{ let mut iter = merger.into_stream_merger_iter()?; let mut builder = fst::SetBuilder::memory(); @@ -766,8 +765,8 @@ fn merge_word_docids_reader_into_fst( /// Write provided entries in database using serialize_value function. /// merge_values function is used if an entry already exist in the database. #[tracing::instrument(level = "trace", skip_all, target = "indexing::write_db")] -fn write_entries_into_database( - merger: Merger, +fn write_entries_into_database( + merger: Merger, database: &heed::Database, wtxn: &mut RwTxn<'_>, serialize_value: FS, @@ -777,6 +776,8 @@ where R: io::Read + io::Seek, FS: for<'a> Fn(&'a [u8], &'a mut Vec) -> Result<&'a [u8]>, FM: for<'a> Fn(&[u8], &[u8], &'a mut Vec) -> Result>, + MF: MergeFunction, + crate::Error: From, { let mut buffer = Vec::new(); let database = database.remap_types::(); @@ -803,20 +804,22 @@ where /// Akin to the `write_entries_into_database` function but specialized /// for the case when we only index additional searchable fields only. #[tracing::instrument(level = "trace", skip_all, target = "indexing::write_db")] -fn write_proximity_entries_into_database_additional_searchables( - merger: Merger, +fn write_proximity_entries_into_database_additional_searchables( + merger: Merger, database: &heed::Database, wtxn: &mut RwTxn<'_>, ) -> Result<()> where R: io::Read + io::Seek, + MF: MergeFunction, + crate::Error: From, { let mut iter = merger.into_stream_merger_iter()?; while let Some((key, value)) = iter.next()? { if valid_lmdb_key(key) { let (proximity_to_insert, word1, word2) = U8StrStrCodec::bytes_decode(key).map_err(heed::Error::Decoding)?; - let data_to_insert = match KvReaderDelAdd::new(value).get(DelAdd::Addition) { + let data_to_insert = match KvReaderDelAdd::from_slice(value).get(DelAdd::Addition) { Some(value) => { CboRoaringBitmapCodec::bytes_decode(value).map_err(heed::Error::Decoding)? } diff --git a/crates/milli/src/update/indexer_config.rs b/crates/milli/src/update/indexer_config.rs index 115059a1d..6fb33ad78 100644 --- a/crates/milli/src/update/indexer_config.rs +++ b/crates/milli/src/update/indexer_config.rs @@ -1,5 +1,6 @@ use grenad::CompressionType; +use super::GrenadParameters; use crate::thread_pool_no_abort::ThreadPoolNoAbort; #[derive(Debug)] @@ -15,6 +16,17 @@ pub struct IndexerConfig { pub skip_index_budget: bool, } +impl IndexerConfig { + pub fn grenad_parameters(&self) -> GrenadParameters { + GrenadParameters { + chunk_compression_type: self.chunk_compression_type, + chunk_compression_level: self.chunk_compression_level, + max_memory: self.max_memory, + max_nb_chunks: self.max_nb_chunks, + } + } +} + impl Default for IndexerConfig { fn default() -> Self { Self { diff --git a/crates/milli/src/update/mod.rs b/crates/milli/src/update/mod.rs index 195b95d1e..772a73236 100644 --- a/crates/milli/src/update/mod.rs +++ b/crates/milli/src/update/mod.rs @@ -1,11 +1,9 @@ -pub use self::available_documents_ids::AvailableDocumentsIds; +pub use self::available_ids::AvailableIds; pub use self::clear_documents::ClearDocuments; +pub use self::concurrent_available_ids::ConcurrentAvailableIds; pub use self::facet::bulk::FacetsUpdateBulk; pub use self::facet::incremental::FacetsUpdateIncrementalInner; -pub use self::index_documents::{ - merge_cbo_roaring_bitmaps, merge_roaring_bitmaps, DocumentAdditionResult, DocumentId, - IndexDocuments, IndexDocumentsConfig, IndexDocumentsMethod, MergeFn, -}; +pub use self::index_documents::*; pub use self::indexer_config::IndexerConfig; pub use self::settings::{validate_embedding_settings, Setting, Settings}; pub use self::update_step::UpdateIndexingStep; @@ -13,12 +11,14 @@ pub use self::word_prefix_docids::WordPrefixDocids; pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids; pub use self::words_prefixes_fst::WordsPrefixesFst; -mod available_documents_ids; +mod available_ids; mod clear_documents; +mod concurrent_available_ids; pub(crate) mod del_add; pub(crate) mod facet; mod index_documents; mod indexer_config; +pub mod new; mod settings; mod update_step; mod word_prefix_docids; diff --git a/crates/milli/src/update/new/channel.rs b/crates/milli/src/update/new/channel.rs new file mode 100644 index 000000000..9e8039ffd --- /dev/null +++ b/crates/milli/src/update/new/channel.rs @@ -0,0 +1,516 @@ +use std::marker::PhantomData; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crossbeam_channel::{IntoIter, Receiver, SendError, Sender}; +use heed::types::Bytes; +use heed::BytesDecode; +use memmap2::Mmap; +use roaring::RoaringBitmap; + +use super::extract::FacetKind; +use super::StdResult; +use crate::heed_codec::facet::{FieldDocIdFacetF64Codec, FieldDocIdFacetStringCodec}; +use crate::index::main_key::{GEO_FACETED_DOCUMENTS_IDS_KEY, GEO_RTREE_KEY}; +use crate::index::IndexEmbeddingConfig; +use crate::update::new::KvReaderFieldId; +use crate::vector::Embedding; +use crate::{DocumentId, Index}; + +/// The capacity of the channel is currently in number of messages. +pub fn extractor_writer_channel(cap: usize) -> (ExtractorSender, WriterReceiver) { + let (sender, receiver) = crossbeam_channel::bounded(cap); + ( + ExtractorSender { + sender, + send_count: Default::default(), + writer_contentious_count: Default::default(), + extractor_contentious_count: Default::default(), + }, + WriterReceiver(receiver), + ) +} + +pub enum KeyValueEntry { + Small { key_length: usize, data: Box<[u8]> }, + Large { key_entry: KeyEntry, data: Mmap }, +} + +impl KeyValueEntry { + pub fn from_small_key_value(key: &[u8], value: &[u8]) -> Self { + let mut data = Vec::with_capacity(key.len() + value.len()); + data.extend_from_slice(key); + data.extend_from_slice(value); + KeyValueEntry::Small { key_length: key.len(), data: data.into_boxed_slice() } + } + + fn from_large_key_value(key: &[u8], value: Mmap) -> Self { + KeyValueEntry::Large { key_entry: KeyEntry::from_key(key), data: value } + } + + pub fn key(&self) -> &[u8] { + match self { + KeyValueEntry::Small { key_length, data } => &data[..*key_length], + KeyValueEntry::Large { key_entry, data: _ } => key_entry.entry(), + } + } + + pub fn value(&self) -> &[u8] { + match self { + KeyValueEntry::Small { key_length, data } => &data[*key_length..], + KeyValueEntry::Large { key_entry: _, data } => &data[..], + } + } +} + +pub struct KeyEntry { + data: Box<[u8]>, +} + +impl KeyEntry { + pub fn from_key(key: &[u8]) -> Self { + KeyEntry { data: key.to_vec().into_boxed_slice() } + } + + pub fn entry(&self) -> &[u8] { + self.data.as_ref() + } +} + +pub enum EntryOperation { + Delete(KeyEntry), + Write(KeyValueEntry), +} + +pub enum WriterOperation { + DbOperation(DbOperation), + ArroyOperation(ArroyOperation), +} + +pub enum ArroyOperation { + /// TODO: call when deleting regular documents + DeleteVectors { + docid: DocumentId, + }, + SetVectors { + docid: DocumentId, + embedder_id: u8, + embeddings: Vec, + }, + SetVector { + docid: DocumentId, + embedder_id: u8, + embedding: Embedding, + }, + Finish { + configs: Vec, + }, +} + +pub struct DbOperation { + database: Database, + entry: EntryOperation, +} + +#[derive(Debug)] +pub enum Database { + Main, + Documents, + ExternalDocumentsIds, + ExactWordDocids, + FidWordCountDocids, + WordDocids, + WordFidDocids, + WordPairProximityDocids, + WordPositionDocids, + FacetIdIsNullDocids, + FacetIdIsEmptyDocids, + FacetIdExistsDocids, + FacetIdF64NumberDocids, + FacetIdStringDocids, + FieldIdDocidFacetStrings, + FieldIdDocidFacetF64s, +} + +impl Database { + pub fn database(&self, index: &Index) -> heed::Database { + match self { + Database::Main => index.main.remap_types(), + Database::Documents => index.documents.remap_types(), + Database::ExternalDocumentsIds => index.external_documents_ids.remap_types(), + Database::ExactWordDocids => index.exact_word_docids.remap_types(), + Database::WordDocids => index.word_docids.remap_types(), + Database::WordFidDocids => index.word_fid_docids.remap_types(), + Database::WordPositionDocids => index.word_position_docids.remap_types(), + Database::FidWordCountDocids => index.field_id_word_count_docids.remap_types(), + Database::WordPairProximityDocids => index.word_pair_proximity_docids.remap_types(), + Database::FacetIdIsNullDocids => index.facet_id_is_null_docids.remap_types(), + Database::FacetIdIsEmptyDocids => index.facet_id_is_empty_docids.remap_types(), + Database::FacetIdExistsDocids => index.facet_id_exists_docids.remap_types(), + Database::FacetIdF64NumberDocids => index.facet_id_f64_docids.remap_types(), + Database::FacetIdStringDocids => index.facet_id_string_docids.remap_types(), + Database::FieldIdDocidFacetStrings => index.field_id_docid_facet_strings.remap_types(), + Database::FieldIdDocidFacetF64s => index.field_id_docid_facet_f64s.remap_types(), + } + } +} + +impl From for Database { + fn from(value: FacetKind) -> Self { + match value { + FacetKind::Number => Database::FacetIdF64NumberDocids, + FacetKind::String => Database::FacetIdStringDocids, + FacetKind::Null => Database::FacetIdIsNullDocids, + FacetKind::Empty => Database::FacetIdIsEmptyDocids, + FacetKind::Exists => Database::FacetIdExistsDocids, + } + } +} + +impl DbOperation { + pub fn database(&self, index: &Index) -> heed::Database { + self.database.database(index) + } + + pub fn entry(self) -> EntryOperation { + self.entry + } +} + +pub struct WriterReceiver(Receiver); + +impl IntoIterator for WriterReceiver { + type Item = WriterOperation; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +pub struct ExtractorSender { + sender: Sender, + /// The number of message we sent in total in the channel. + send_count: AtomicUsize, + /// The number of times we sent something in a channel that was full. + writer_contentious_count: AtomicUsize, + /// The number of times we sent something in a channel that was empty. + extractor_contentious_count: AtomicUsize, +} + +impl Drop for ExtractorSender { + fn drop(&mut self) { + let send_count = *self.send_count.get_mut(); + let writer_contentious_count = *self.writer_contentious_count.get_mut(); + let extractor_contentious_count = *self.extractor_contentious_count.get_mut(); + eprintln!( + "Extractor channel stats: {send_count} sends, \ + {writer_contentious_count} writer contentions ({}%), \ + {extractor_contentious_count} extractor contentions ({}%)", + (writer_contentious_count as f32 / send_count as f32) * 100.0, + (extractor_contentious_count as f32 / send_count as f32) * 100.0 + ) + } +} + +impl ExtractorSender { + pub fn docids(&self) -> WordDocidsSender<'_, D> { + WordDocidsSender { sender: self, _marker: PhantomData } + } + + pub fn facet_docids(&self) -> FacetDocidsSender<'_> { + FacetDocidsSender { sender: self } + } + + pub fn field_id_docid_facet_sender(&self) -> FieldIdDocidFacetSender<'_> { + FieldIdDocidFacetSender(self) + } + + pub fn documents(&self) -> DocumentsSender<'_> { + DocumentsSender(self) + } + + pub fn embeddings(&self) -> EmbeddingSender<'_> { + EmbeddingSender(&self.sender) + } + + pub fn geo(&self) -> GeoSender<'_> { + GeoSender(&self.sender) + } + + fn send_delete_vector(&self, docid: DocumentId) -> StdResult<(), SendError<()>> { + match self + .sender + .send(WriterOperation::ArroyOperation(ArroyOperation::DeleteVectors { docid })) + { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } + + fn send_db_operation(&self, op: DbOperation) -> StdResult<(), SendError<()>> { + if self.sender.is_full() { + self.writer_contentious_count.fetch_add(1, Ordering::SeqCst); + } + if self.sender.is_empty() { + self.extractor_contentious_count.fetch_add(1, Ordering::SeqCst); + } + + self.send_count.fetch_add(1, Ordering::SeqCst); + match self.sender.send(WriterOperation::DbOperation(op)) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } +} + +pub enum ExactWordDocids {} +pub enum FidWordCountDocids {} +pub enum WordDocids {} +pub enum WordFidDocids {} +pub enum WordPairProximityDocids {} +pub enum WordPositionDocids {} + +pub trait DatabaseType { + const DATABASE: Database; +} + +impl DatabaseType for ExactWordDocids { + const DATABASE: Database = Database::ExactWordDocids; +} + +impl DatabaseType for FidWordCountDocids { + const DATABASE: Database = Database::FidWordCountDocids; +} + +impl DatabaseType for WordDocids { + const DATABASE: Database = Database::WordDocids; +} + +impl DatabaseType for WordFidDocids { + const DATABASE: Database = Database::WordFidDocids; +} + +impl DatabaseType for WordPairProximityDocids { + const DATABASE: Database = Database::WordPairProximityDocids; +} + +impl DatabaseType for WordPositionDocids { + const DATABASE: Database = Database::WordPositionDocids; +} + +pub trait DocidsSender { + fn write(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>>; + fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>>; +} + +pub struct WordDocidsSender<'a, D> { + sender: &'a ExtractorSender, + _marker: PhantomData, +} + +impl DocidsSender for WordDocidsSender<'_, D> { + fn write(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>> { + let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)); + match self.sender.send_db_operation(DbOperation { database: D::DATABASE, entry }) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } + + fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> { + let entry = EntryOperation::Delete(KeyEntry::from_key(key)); + match self.sender.send_db_operation(DbOperation { database: D::DATABASE, entry }) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } +} + +pub struct FacetDocidsSender<'a> { + sender: &'a ExtractorSender, +} + +impl DocidsSender for FacetDocidsSender<'_> { + fn write(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>> { + let (facet_kind, key) = FacetKind::extract_from_key(key); + let database = Database::from(facet_kind); + // let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)); + let entry = match facet_kind { + // skip level group size + FacetKind::String | FacetKind::Number => { + // add facet group size + let value = [&[1], value].concat(); + EntryOperation::Write(KeyValueEntry::from_small_key_value(key, &value)) + } + _ => EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)), + }; + match self.sender.send_db_operation(DbOperation { database, entry }) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } + + fn delete(&self, key: &[u8]) -> StdResult<(), SendError<()>> { + let (facet_kind, key) = FacetKind::extract_from_key(key); + let database = Database::from(facet_kind); + let entry = EntryOperation::Delete(KeyEntry::from_key(key)); + match self.sender.send_db_operation(DbOperation { database, entry }) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } +} + +pub struct FieldIdDocidFacetSender<'a>(&'a ExtractorSender); + +impl FieldIdDocidFacetSender<'_> { + pub fn write_facet_string(&self, key: &[u8], value: &[u8]) -> StdResult<(), SendError<()>> { + debug_assert!(FieldDocIdFacetStringCodec::bytes_decode(key).is_ok()); + let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, value)); + self.0 + .send_db_operation(DbOperation { database: Database::FieldIdDocidFacetStrings, entry }) + } + + pub fn write_facet_f64(&self, key: &[u8]) -> StdResult<(), SendError<()>> { + debug_assert!(FieldDocIdFacetF64Codec::bytes_decode(key).is_ok()); + let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value(key, &[])); + self.0.send_db_operation(DbOperation { database: Database::FieldIdDocidFacetF64s, entry }) + } + + pub fn delete_facet_string(&self, key: &[u8]) -> StdResult<(), SendError<()>> { + debug_assert!(FieldDocIdFacetStringCodec::bytes_decode(key).is_ok()); + let entry = EntryOperation::Delete(KeyEntry::from_key(key)); + self.0 + .send_db_operation(DbOperation { database: Database::FieldIdDocidFacetStrings, entry }) + } + + pub fn delete_facet_f64(&self, key: &[u8]) -> StdResult<(), SendError<()>> { + debug_assert!(FieldDocIdFacetF64Codec::bytes_decode(key).is_ok()); + let entry = EntryOperation::Delete(KeyEntry::from_key(key)); + self.0.send_db_operation(DbOperation { database: Database::FieldIdDocidFacetF64s, entry }) + } +} + +pub struct DocumentsSender<'a>(&'a ExtractorSender); + +impl DocumentsSender<'_> { + /// TODO do that efficiently + pub fn uncompressed( + &self, + docid: DocumentId, + external_id: String, + document: &KvReaderFieldId, + ) -> StdResult<(), SendError<()>> { + let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value( + &docid.to_be_bytes(), + document.as_bytes(), + )); + match self.0.send_db_operation(DbOperation { database: Database::Documents, entry }) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + }?; + + let entry = EntryOperation::Write(KeyValueEntry::from_small_key_value( + external_id.as_bytes(), + &docid.to_be_bytes(), + )); + match self + .0 + .send_db_operation(DbOperation { database: Database::ExternalDocumentsIds, entry }) + { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } + + pub fn delete(&self, docid: DocumentId, external_id: String) -> StdResult<(), SendError<()>> { + let entry = EntryOperation::Delete(KeyEntry::from_key(&docid.to_be_bytes())); + match self.0.send_db_operation(DbOperation { database: Database::Documents, entry }) { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + }?; + + self.0.send_delete_vector(docid)?; + + let entry = EntryOperation::Delete(KeyEntry::from_key(external_id.as_bytes())); + match self + .0 + .send_db_operation(DbOperation { database: Database::ExternalDocumentsIds, entry }) + { + Ok(()) => Ok(()), + Err(SendError(_)) => Err(SendError(())), + } + } +} + +pub struct EmbeddingSender<'a>(&'a Sender); + +impl EmbeddingSender<'_> { + pub fn set_vectors( + &self, + docid: DocumentId, + embedder_id: u8, + embeddings: Vec, + ) -> StdResult<(), SendError<()>> { + self.0 + .send(WriterOperation::ArroyOperation(ArroyOperation::SetVectors { + docid, + embedder_id, + embeddings, + })) + .map_err(|_| SendError(())) + } + + pub fn set_vector( + &self, + docid: DocumentId, + embedder_id: u8, + embedding: Embedding, + ) -> StdResult<(), SendError<()>> { + self.0 + .send(WriterOperation::ArroyOperation(ArroyOperation::SetVector { + docid, + embedder_id, + embedding, + })) + .map_err(|_| SendError(())) + } + + /// Marks all embedders as "to be built" + pub fn finish(self, configs: Vec) -> StdResult<(), SendError<()>> { + self.0 + .send(WriterOperation::ArroyOperation(ArroyOperation::Finish { configs })) + .map_err(|_| SendError(())) + } +} + +pub struct GeoSender<'a>(&'a Sender); + +impl GeoSender<'_> { + pub fn set_rtree(&self, value: Mmap) -> StdResult<(), SendError<()>> { + self.0 + .send(WriterOperation::DbOperation(DbOperation { + database: Database::Main, + entry: EntryOperation::Write(KeyValueEntry::from_large_key_value( + GEO_RTREE_KEY.as_bytes(), + value, + )), + })) + .map_err(|_| SendError(())) + } + + pub fn set_geo_faceted(&self, bitmap: &RoaringBitmap) -> StdResult<(), SendError<()>> { + let mut buffer = Vec::new(); + bitmap.serialize_into(&mut buffer).unwrap(); + + self.0 + .send(WriterOperation::DbOperation(DbOperation { + database: Database::Main, + entry: EntryOperation::Write(KeyValueEntry::from_small_key_value( + GEO_FACETED_DOCUMENTS_IDS_KEY.as_bytes(), + &buffer, + )), + })) + .map_err(|_| SendError(())) + } +} diff --git a/crates/milli/src/update/new/document.rs b/crates/milli/src/update/new/document.rs new file mode 100644 index 000000000..ae9aa9de9 --- /dev/null +++ b/crates/milli/src/update/new/document.rs @@ -0,0 +1,425 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use heed::RoTxn; +use raw_collections::RawMap; +use serde_json::value::RawValue; + +use super::vector_document::VectorDocument; +use super::{KvReaderFieldId, KvWriterFieldId}; +use crate::documents::FieldIdMapper; +use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; +use crate::{DocumentId, GlobalFieldsIdsMap, Index, InternalError, Result, UserError}; + +/// A view into a document that can represent either the current version from the DB, +/// the update data from payload or other means, or the merged updated version. +/// +/// The 'doc lifetime is meant to live sufficiently for the document to be handled by the extractors. +pub trait Document<'doc> { + /// Iterate over all **top-level** fields of the document, returning their name and raw JSON value. + /// + /// - The returned values *may* contain nested fields. + /// - The `_vectors` and `_geo` fields are **ignored** by this method, meaning they are **not returned** by this method. + fn iter_top_level_fields(&self) -> impl Iterator>; + + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get the **top-level** with the specified name, if exists. + /// + /// - The `_vectors` and `_geo` fields are **ignored** by this method, meaning e.g. `top_level_field("_vectors")` will return `Ok(None)` + fn top_level_field(&self, k: &str) -> Result>; + + /// Returns the unparsed value of the `_vectors` field from the document data. + /// + /// This field alone is insufficient to retrieve vectors, as they may be stored in a dedicated location in the database. + /// Use a [`super::vector_document::VectorDocument`] to access the vector. + /// + /// This method is meant as a convenience for implementors of [`super::vector_document::VectorDocument`]. + fn vectors_field(&self) -> Result>; + + /// Returns the unparsed value of the `_geo` field from the document data. + /// + /// This field alone is insufficient to retrieve geo data, as they may be stored in a dedicated location in the database. + /// Use a [`super::geo_document::GeoDocument`] to access the vector. + /// + /// This method is meant as a convenience for implementors of [`super::geo_document::GeoDocument`]. + fn geo_field(&self) -> Result>; +} + +#[derive(Debug)] +pub struct DocumentFromDb<'t, Mapper: FieldIdMapper> +where + Mapper: FieldIdMapper, +{ + fields_ids_map: &'t Mapper, + content: &'t KvReaderFieldId, +} + +impl<'t, Mapper: FieldIdMapper> Clone for DocumentFromDb<'t, Mapper> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl<'t, Mapper: FieldIdMapper> Copy for DocumentFromDb<'t, Mapper> {} + +impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> { + fn iter_top_level_fields(&self) -> impl Iterator> { + let mut it = self.content.iter(); + + std::iter::from_fn(move || loop { + let (fid, value) = it.next()?; + let name = match self.fields_ids_map.name(fid).ok_or( + InternalError::FieldIdMapMissingEntry(crate::FieldIdMapMissingEntry::FieldId { + field_id: fid, + process: "getting current document", + }), + ) { + Ok(name) => name, + Err(error) => return Some(Err(error.into())), + }; + + if name == RESERVED_VECTORS_FIELD_NAME || name == "_geo" { + continue; + } + + let res = (|| { + let value = + serde_json::from_slice(value).map_err(crate::InternalError::SerdeJson)?; + + Ok((name, value)) + })(); + + return Some(res); + }) + } + + fn vectors_field(&self) -> Result> { + self.field(RESERVED_VECTORS_FIELD_NAME) + } + + fn geo_field(&self) -> Result> { + self.field("_geo") + } + + fn len(&self) -> usize { + self.content.iter().count() + } + + fn top_level_field(&self, k: &str) -> Result> { + if k == RESERVED_VECTORS_FIELD_NAME || k == "_geo" { + return Ok(None); + } + self.field(k) + } +} + +impl<'t, Mapper: FieldIdMapper> DocumentFromDb<'t, Mapper> { + pub fn new( + docid: DocumentId, + rtxn: &'t RoTxn, + index: &'t Index, + db_fields_ids_map: &'t Mapper, + ) -> Result> { + index.documents.get(rtxn, &docid).map_err(crate::Error::from).map(|reader| { + reader.map(|reader| Self { fields_ids_map: db_fields_ids_map, content: reader }) + }) + } + + pub fn field(&self, name: &str) -> Result> { + let Some(fid) = self.fields_ids_map.id(name) else { + return Ok(None); + }; + let Some(value) = self.content.get(fid) else { return Ok(None) }; + Ok(Some(serde_json::from_slice(value).map_err(InternalError::SerdeJson)?)) + } +} + +#[derive(Debug)] +pub struct DocumentFromVersions<'a, 'doc> { + versions: &'a Versions<'doc>, +} + +impl<'a, 'doc> DocumentFromVersions<'a, 'doc> { + pub fn new(versions: &'a Versions<'doc>) -> Self { + Self { versions } + } +} + +impl<'a, 'doc> Document<'doc> for DocumentFromVersions<'a, 'doc> { + fn iter_top_level_fields(&self) -> impl Iterator> { + self.versions.iter_top_level_fields().map(Ok) + } + + fn vectors_field(&self) -> Result> { + Ok(self.versions.vectors_field()) + } + + fn geo_field(&self) -> Result> { + Ok(self.versions.geo_field()) + } + + fn len(&self) -> usize { + self.versions.len() + } + + fn top_level_field(&self, k: &str) -> Result> { + Ok(self.versions.top_level_field(k)) + } +} + +#[derive(Debug)] +pub struct MergedDocument<'a, 'doc, 't, Mapper: FieldIdMapper> { + new_doc: DocumentFromVersions<'a, 'doc>, + db: Option>, +} + +impl<'a, 'doc, 't, Mapper: FieldIdMapper> MergedDocument<'a, 'doc, 't, Mapper> { + pub fn with_db( + docid: DocumentId, + rtxn: &'t RoTxn, + index: &'t Index, + db_fields_ids_map: &'t Mapper, + new_doc: DocumentFromVersions<'a, 'doc>, + ) -> Result { + let db = DocumentFromDb::new(docid, rtxn, index, db_fields_ids_map)?; + Ok(Self { new_doc, db }) + } + + pub fn without_db(new_doc: DocumentFromVersions<'a, 'doc>) -> Self { + Self { new_doc, db: None } + } +} + +impl<'d, 'doc: 'd, 't: 'd, Mapper: FieldIdMapper> Document<'d> + for MergedDocument<'d, 'doc, 't, Mapper> +{ + fn iter_top_level_fields(&self) -> impl Iterator> { + let mut new_doc_it = self.new_doc.iter_top_level_fields(); + let mut db_it = self.db.iter().flat_map(|db| db.iter_top_level_fields()); + let mut seen_fields = BTreeSet::new(); + + std::iter::from_fn(move || { + if let Some(next) = new_doc_it.next() { + if let Ok((name, _)) = next { + seen_fields.insert(name); + } + return Some(next); + } + loop { + match db_it.next()? { + Ok((name, value)) => { + if seen_fields.contains(name) { + continue; + } + return Some(Ok((name, value))); + } + Err(err) => return Some(Err(err)), + } + } + }) + } + + fn vectors_field(&self) -> Result> { + if let Some(vectors) = self.new_doc.vectors_field()? { + return Ok(Some(vectors)); + } + + let Some(db) = self.db else { return Ok(None) }; + + db.vectors_field() + } + + fn geo_field(&self) -> Result> { + if let Some(geo) = self.new_doc.geo_field()? { + return Ok(Some(geo)); + } + + let Some(db) = self.db else { return Ok(None) }; + + db.geo_field() + } + + fn len(&self) -> usize { + self.iter_top_level_fields().count() + } + + fn top_level_field(&self, k: &str) -> Result> { + if let Some(f) = self.new_doc.top_level_field(k)? { + return Ok(Some(f)); + } + if let Some(db) = self.db { + return db.field(k); + } + Ok(None) + } +} + +impl<'doc, D> Document<'doc> for &D +where + D: Document<'doc>, +{ + fn iter_top_level_fields(&self) -> impl Iterator> { + D::iter_top_level_fields(self) + } + + fn vectors_field(&self) -> Result> { + D::vectors_field(self) + } + + fn geo_field(&self) -> Result> { + D::geo_field(self) + } + + fn len(&self) -> usize { + D::len(self) + } + + fn top_level_field(&self, k: &str) -> Result> { + D::top_level_field(self, k) + } +} + +/// Turn this document into an obkv, whose fields are indexed by the provided `FieldIdMapper`. +/// +/// The produced obkv is suitable for storing into the documents DB, meaning: +/// +/// - It contains the contains of `_vectors` that are not configured as an embedder +/// - It contains all the top-level fields of the document, with their raw JSON value as value. +/// +/// # Panics +/// +/// - If the document contains a top-level field that is not present in `fields_ids_map`. +/// +pub fn write_to_obkv<'s, 'a, 'map, 'buffer>( + document: &'s impl Document<'s>, + vector_document: Option<&'s impl VectorDocument<'s>>, + fields_ids_map: &'a mut GlobalFieldsIdsMap<'map>, + mut document_buffer: &'a mut bumpalo::collections::Vec<'buffer, u8>, +) -> Result<&'a KvReaderFieldId> +where + 's: 'a, +{ + // will be used in 'inject_vectors + let vectors_value: Box; + + document_buffer.clear(); + let mut unordered_field_buffer = Vec::new(); + unordered_field_buffer.clear(); + + let mut writer = KvWriterFieldId::new(&mut document_buffer); + + for res in document.iter_top_level_fields() { + let (field_name, value) = res?; + let field_id = + fields_ids_map.id_or_insert(field_name).ok_or(UserError::AttributeLimitReached)?; + unordered_field_buffer.push((field_id, value)); + } + + 'inject_vectors: { + let Some(vector_document) = vector_document else { break 'inject_vectors }; + + let vectors_fid = fields_ids_map + .id_or_insert(RESERVED_VECTORS_FIELD_NAME) + .ok_or(UserError::AttributeLimitReached)?; + + let mut vectors = BTreeMap::new(); + for res in vector_document.iter_vectors() { + let (name, entry) = res?; + if entry.has_configured_embedder { + continue; // we don't write vectors with configured embedder in documents + } + vectors.insert( + name, + if entry.implicit { + serde_json::json!(entry.embeddings) + } else { + serde_json::json!({ + "regenerate": entry.regenerate, + // TODO: consider optimizing the shape of embedders here to store an array of f32 rather than a JSON object + "embeddings": entry.embeddings, + }) + }, + ); + } + + if vectors.is_empty() { + break 'inject_vectors; + } + + vectors_value = serde_json::value::to_raw_value(&vectors).unwrap(); + unordered_field_buffer.push((vectors_fid, &vectors_value)); + } + + if let Some(geo_value) = document.geo_field()? { + let fid = fields_ids_map.id_or_insert("_geo").ok_or(UserError::AttributeLimitReached)?; + fields_ids_map.id_or_insert("_geo.lat").ok_or(UserError::AttributeLimitReached)?; + fields_ids_map.id_or_insert("_geo.lng").ok_or(UserError::AttributeLimitReached)?; + unordered_field_buffer.push((fid, geo_value)); + } + + unordered_field_buffer.sort_by_key(|(fid, _)| *fid); + for (fid, value) in unordered_field_buffer.iter() { + writer.insert(*fid, value.get().as_bytes()).unwrap(); + } + + writer.finish().unwrap(); + Ok(KvReaderFieldId::from_slice(document_buffer)) +} + +pub type Entry<'doc> = (&'doc str, &'doc RawValue); + +#[derive(Debug)] +pub struct Versions<'doc> { + data: RawMap<'doc>, +} + +impl<'doc> Versions<'doc> { + pub fn multiple( + mut versions: impl Iterator>>, + ) -> Result> { + let Some(data) = versions.next() else { return Ok(None) }; + let mut data = data?; + for future_version in versions { + let future_version = future_version?; + for (field, value) in future_version { + data.insert(field, value); + } + } + Ok(Some(Self::single(data))) + } + + pub fn single(version: RawMap<'doc>) -> Self { + Self { data: version } + } + + pub fn iter_top_level_fields(&self) -> impl Iterator + '_ { + self.data.iter().filter(|(k, _)| *k != RESERVED_VECTORS_FIELD_NAME && *k != "_geo") + } + + pub fn vectors_field(&self) -> Option<&'doc RawValue> { + self.data.get(RESERVED_VECTORS_FIELD_NAME) + } + + pub fn geo_field(&self) -> Option<&'doc RawValue> { + self.data.get("_geo") + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + pub fn top_level_field(&self, k: &str) -> Option<&'doc RawValue> { + if k == RESERVED_VECTORS_FIELD_NAME || k == "_geo" { + return None; + } + self.data.get(k) + } +} diff --git a/crates/milli/src/update/new/document_change.rs b/crates/milli/src/update/new/document_change.rs new file mode 100644 index 000000000..899655db1 --- /dev/null +++ b/crates/milli/src/update/new/document_change.rs @@ -0,0 +1,203 @@ +use bumpalo::Bump; +use heed::RoTxn; + +use super::document::{DocumentFromDb, DocumentFromVersions, MergedDocument, Versions}; +use super::vector_document::{ + MergedVectorDocument, VectorDocumentFromDb, VectorDocumentFromVersions, +}; +use crate::documents::FieldIdMapper; +use crate::vector::EmbeddingConfigs; +use crate::{DocumentId, Index, Result}; + +pub enum DocumentChange<'doc> { + Deletion(Deletion<'doc>), + Update(Update<'doc>), + Insertion(Insertion<'doc>), +} + +pub struct Deletion<'doc> { + docid: DocumentId, + external_document_id: &'doc str, +} + +pub struct Update<'doc> { + docid: DocumentId, + external_document_id: &'doc str, + new: Versions<'doc>, + has_deletion: bool, +} + +pub struct Insertion<'doc> { + docid: DocumentId, + external_document_id: &'doc str, + new: Versions<'doc>, +} + +impl<'doc> DocumentChange<'doc> { + pub fn docid(&self) -> DocumentId { + match &self { + Self::Deletion(inner) => inner.docid(), + Self::Update(inner) => inner.docid(), + Self::Insertion(inner) => inner.docid(), + } + } + + pub fn external_docid(&self) -> &'doc str { + match self { + DocumentChange::Deletion(deletion) => deletion.external_document_id(), + DocumentChange::Update(update) => update.external_document_id(), + DocumentChange::Insertion(insertion) => insertion.external_document_id(), + } + } +} + +impl<'doc> Deletion<'doc> { + pub fn create(docid: DocumentId, external_document_id: &'doc str) -> Self { + Self { docid, external_document_id } + } + + pub fn docid(&self) -> DocumentId { + self.docid + } + + pub fn external_document_id(&self) -> &'doc str { + self.external_document_id + } + + pub fn current<'a, Mapper: FieldIdMapper>( + &self, + rtxn: &'a RoTxn, + index: &'a Index, + mapper: &'a Mapper, + ) -> Result> { + Ok(DocumentFromDb::new(self.docid, rtxn, index, mapper)?.ok_or( + crate::error::UserError::UnknownInternalDocumentId { document_id: self.docid }, + )?) + } +} + +impl<'doc> Insertion<'doc> { + pub fn create(docid: DocumentId, external_document_id: &'doc str, new: Versions<'doc>) -> Self { + Insertion { docid, external_document_id, new } + } + + pub fn docid(&self) -> DocumentId { + self.docid + } + + pub fn external_document_id(&self) -> &'doc str { + self.external_document_id + } + pub fn inserted(&self) -> DocumentFromVersions<'_, 'doc> { + DocumentFromVersions::new(&self.new) + } + + pub fn inserted_vectors( + &self, + doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result>> { + VectorDocumentFromVersions::new(self.external_document_id, &self.new, doc_alloc, embedders) + } +} + +impl<'doc> Update<'doc> { + pub fn create( + docid: DocumentId, + external_document_id: &'doc str, + new: Versions<'doc>, + has_deletion: bool, + ) -> Self { + Update { docid, new, external_document_id, has_deletion } + } + + pub fn docid(&self) -> DocumentId { + self.docid + } + + pub fn external_document_id(&self) -> &'doc str { + self.external_document_id + } + pub fn current<'a, Mapper: FieldIdMapper>( + &self, + rtxn: &'a RoTxn, + index: &'a Index, + mapper: &'a Mapper, + ) -> Result> { + Ok(DocumentFromDb::new(self.docid, rtxn, index, mapper)?.ok_or( + crate::error::UserError::UnknownInternalDocumentId { document_id: self.docid }, + )?) + } + + pub fn current_vectors<'a, Mapper: FieldIdMapper>( + &self, + rtxn: &'a RoTxn, + index: &'a Index, + mapper: &'a Mapper, + doc_alloc: &'a Bump, + ) -> Result> { + Ok(VectorDocumentFromDb::new(self.docid, index, rtxn, mapper, doc_alloc)?.ok_or( + crate::error::UserError::UnknownInternalDocumentId { document_id: self.docid }, + )?) + } + + pub fn updated(&self) -> DocumentFromVersions<'_, 'doc> { + DocumentFromVersions::new(&self.new) + } + + pub fn merged<'t, Mapper: FieldIdMapper>( + &self, + rtxn: &'t RoTxn, + index: &'t Index, + mapper: &'t Mapper, + ) -> Result> { + if self.has_deletion { + Ok(MergedDocument::without_db(DocumentFromVersions::new(&self.new))) + } else { + MergedDocument::with_db( + self.docid, + rtxn, + index, + mapper, + DocumentFromVersions::new(&self.new), + ) + } + } + + pub fn updated_vectors( + &self, + doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result>> { + VectorDocumentFromVersions::new(self.external_document_id, &self.new, doc_alloc, embedders) + } + + pub fn merged_vectors( + &self, + rtxn: &'doc RoTxn, + index: &'doc Index, + mapper: &'doc Mapper, + doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result>> { + if self.has_deletion { + MergedVectorDocument::without_db( + self.external_document_id, + &self.new, + doc_alloc, + embedders, + ) + } else { + MergedVectorDocument::with_db( + self.docid, + self.external_document_id, + index, + rtxn, + mapper, + &self.new, + doc_alloc, + embedders, + ) + } + } +} diff --git a/crates/milli/src/update/new/extract/cache.rs b/crates/milli/src/update/new/extract/cache.rs new file mode 100644 index 000000000..9c864372d --- /dev/null +++ b/crates/milli/src/update/new/extract/cache.rs @@ -0,0 +1,722 @@ +//! # How the Merge Algorithm works +//! +//! Each extractor create #Threads caches and balances the entries +//! based on the hash of the keys. To do that we can use the +//! hashbrown::hash_map::RawEntryBuilderMut::from_key_hashed_nocheck. +//! This way we can compute the hash on our own, decide on the cache to +//! target, and insert it into the right HashMap. +//! +//! #Thread -> caches +//! t1 -> [t1c1, t1c2, t1c3] +//! t2 -> [t2c1, t2c2, t2c3] +//! t3 -> [t3c1, t3c2, t3c3] +//! +//! When the extractors are done filling the caches, we want to merge +//! the content of all the caches. We do a transpose and each thread is +//! assigned the associated cache. By doing that we know that every key +//! is put in a known cache and will collide with keys in the other +//! caches of the other threads. +//! +//! #Thread -> caches +//! t1 -> [t1c1, t2c1, t3c1] +//! t2 -> [t1c2, t2c2, t3c2] +//! t3 -> [t1c3, t2c3, t3c3] +//! +//! When we encountered a miss in the other caches we must still try +//! to find it in the spilled entries. This is the reason why we use +//! a grenad sorter/reader so that we can seek "efficiently" for a key. +//! +//! ## More Detailled Algorithm +//! +//! Each sub-cache has an in-memory HashMap and some spilled +//! lexicographically ordered entries on disk (grenad). We first iterate +//! over the spilled entries of all the caches at once by using a merge +//! join algorithm. This algorithm will merge the entries by using its +//! merge function. +//! +//! Everytime a merged entry is emited by the merge join algorithm we also +//! fetch the value from the other in-memory caches (HashMaps) to finish +//! the merge. Everytime we retrieve an entry from the in-memory caches +//! we mark them with a tombstone for later. +//! +//! Once we are done with the spilled entries we iterate over the in-memory +//! HashMaps. We iterate over the first one, retrieve the content from the +//! other onces and mark them with a tombstone again. We also make sure +//! to ignore the dead (tombstoned) ones. +//! +//! ## Memory Control +//! +//! We can detect that there are no more memory available when the +//! bump allocator reaches a threshold. When this is the case we +//! freeze the cache. There is one bump allocator by thread and the +//! memory must be well balanced as we manage one type of extraction +//! at a time with well-balanced documents. +//! +//! It means that the unknown new keys added to the +//! cache are directly spilled to disk: basically a key followed by a +//! del/add bitmap. For the known keys we can keep modifying them in +//! the materialized version in the cache: update the del/add bitmaps. +//! +//! For now we can use a grenad sorter for spilling even thought I think +//! it's not the most efficient way (too many files open, sorting entries). + +use std::cmp::Ordering; +use std::collections::binary_heap::PeekMut; +use std::collections::BinaryHeap; +use std::fs::File; +use std::hash::BuildHasher; +use std::io::BufReader; +use std::{io, iter, mem}; + +use bumpalo::Bump; +use grenad::ReaderCursor; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use raw_collections::bbbul::{BitPacker, BitPacker4x}; +use raw_collections::map::FrozenMap; +use raw_collections::{Bbbul, FrozenBbbul}; +use roaring::RoaringBitmap; +use rustc_hash::FxBuildHasher; + +use crate::update::del_add::{DelAdd, KvWriterDelAdd}; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::KvReaderDelAdd; +use crate::update::MergeDeladdCboRoaringBitmaps; +use crate::{CboRoaringBitmapCodec, Result}; + +/// A cache that stores bytes keys associated to CboDelAddRoaringBitmaps. +/// +/// Internally balances the content over `N` buckets for future merging. +pub struct BalancedCaches<'extractor> { + hasher: FxBuildHasher, + alloc: &'extractor Bump, + max_memory: Option, + caches: InnerCaches<'extractor>, +} + +enum InnerCaches<'extractor> { + Normal(NormalCaches<'extractor>), + Spilling(SpillingCaches<'extractor>), +} + +impl<'extractor> BalancedCaches<'extractor> { + pub fn new_in(buckets: usize, max_memory: Option, alloc: &'extractor Bump) -> Self { + Self { + hasher: FxBuildHasher, + max_memory, + caches: InnerCaches::Normal(NormalCaches { + caches: iter::repeat_with(|| HashMap::with_hasher_in(FxBuildHasher, alloc)) + .take(buckets) + .collect(), + }), + alloc, + } + } + + fn buckets(&self) -> usize { + match &self.caches { + InnerCaches::Normal(caches) => caches.caches.len(), + InnerCaches::Spilling(caches) => caches.caches.len(), + } + } + + pub fn insert_del_u32(&mut self, key: &[u8], n: u32) -> Result<()> { + if self.max_memory.map_or(false, |mm| self.alloc.allocated_bytes() >= mm) { + self.start_spilling()?; + } + + let buckets = self.buckets(); + match &mut self.caches { + InnerCaches::Normal(normal) => { + normal.insert_del_u32(&self.hasher, self.alloc, buckets, key, n); + Ok(()) + } + InnerCaches::Spilling(spilling) => { + spilling.insert_del_u32(&self.hasher, self.alloc, buckets, key, n) + } + } + } + + pub fn insert_add_u32(&mut self, key: &[u8], n: u32) -> Result<()> { + if self.max_memory.map_or(false, |mm| self.alloc.allocated_bytes() >= mm) { + self.start_spilling()?; + } + + let buckets = self.buckets(); + match &mut self.caches { + InnerCaches::Normal(normal) => { + normal.insert_add_u32(&self.hasher, self.alloc, buckets, key, n); + Ok(()) + } + InnerCaches::Spilling(spilling) => { + spilling.insert_add_u32(&self.hasher, self.alloc, buckets, key, n) + } + } + } + + /// Make sure the cache is no longer allocating data + /// and writes every new and unknow entry to disk. + fn start_spilling(&mut self) -> Result<()> { + let BalancedCaches { hasher: _, alloc, max_memory: _, caches } = self; + + if let InnerCaches::Normal(normal_caches) = caches { + eprintln!( + "We are spilling after we allocated {} bytes on thread #{}", + alloc.allocated_bytes(), + rayon::current_thread_index().unwrap_or(0) + ); + + let allocated: usize = normal_caches.caches.iter().map(|m| m.allocation_size()).sum(); + eprintln!("The last allocated HashMap took {allocated} bytes"); + + let dummy = NormalCaches { caches: Vec::new() }; + let NormalCaches { caches: cache_maps } = mem::replace(normal_caches, dummy); + *caches = InnerCaches::Spilling(SpillingCaches::from_cache_maps(cache_maps)); + } + + Ok(()) + } + + pub fn freeze(&mut self) -> Result>> { + match &mut self.caches { + InnerCaches::Normal(NormalCaches { caches }) => caches + .iter_mut() + .enumerate() + .map(|(bucket, map)| { + // safety: we are transmuting the Bbbul into a FrozenBbbul + // that are the same size. + let map = unsafe { + std::mem::transmute::< + &mut HashMap< + &[u8], + DelAddBbbul, // from this + FxBuildHasher, + &Bump, + >, + &mut HashMap< + &[u8], + FrozenDelAddBbbul, // to that + FxBuildHasher, + &Bump, + >, + >(map) + }; + Ok(FrozenCache { bucket, cache: FrozenMap::new(map), spilled: Vec::new() }) + }) + .collect(), + InnerCaches::Spilling(SpillingCaches { caches, spilled_entries, .. }) => caches + .iter_mut() + .zip(mem::take(spilled_entries)) + .enumerate() + .map(|(bucket, (map, sorter))| { + let spilled = sorter + .into_reader_cursors()? + .into_iter() + .map(ReaderCursor::into_inner) + .map(BufReader::new) + .map(|bufreader| grenad::Reader::new(bufreader).map_err(Into::into)) + .collect::>()?; + // safety: we are transmuting the Bbbul into a FrozenBbbul + // that are the same size. + let map = unsafe { + std::mem::transmute::< + &mut HashMap< + &[u8], + DelAddBbbul, // from this + FxBuildHasher, + &Bump, + >, + &mut HashMap< + &[u8], + FrozenDelAddBbbul, // to that + FxBuildHasher, + &Bump, + >, + >(map) + }; + Ok(FrozenCache { bucket, cache: FrozenMap::new(map), spilled }) + }) + .collect(), + } + } +} + +unsafe impl MostlySend for BalancedCaches<'_> {} + +struct NormalCaches<'extractor> { + caches: Vec< + HashMap< + &'extractor [u8], + DelAddBbbul<'extractor, BitPacker4x>, + FxBuildHasher, + &'extractor Bump, + >, + >, +} + +impl<'extractor> NormalCaches<'extractor> { + pub fn insert_del_u32( + &mut self, + hasher: &FxBuildHasher, + alloc: &'extractor Bump, + buckets: usize, + key: &[u8], + n: u32, + ) { + let hash = hasher.hash_one(key); + let bucket = compute_bucket_from_hash(buckets, hash); + + match self.caches[bucket].raw_entry_mut().from_hash(hash, |&k| k == key) { + RawEntryMut::Occupied(mut entry) => { + entry.get_mut().del.get_or_insert_with(|| Bbbul::new_in(alloc)).insert(n); + } + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck( + hash, + alloc.alloc_slice_copy(key), + DelAddBbbul::new_del_u32_in(n, alloc), + ); + } + } + } + + pub fn insert_add_u32( + &mut self, + hasher: &FxBuildHasher, + alloc: &'extractor Bump, + buckets: usize, + key: &[u8], + n: u32, + ) { + let hash = hasher.hash_one(key); + let bucket = compute_bucket_from_hash(buckets, hash); + match self.caches[bucket].raw_entry_mut().from_hash(hash, |&k| k == key) { + RawEntryMut::Occupied(mut entry) => { + entry.get_mut().add.get_or_insert_with(|| Bbbul::new_in(alloc)).insert(n); + } + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck( + hash, + alloc.alloc_slice_copy(key), + DelAddBbbul::new_add_u32_in(n, alloc), + ); + } + } + } +} + +struct SpillingCaches<'extractor> { + caches: Vec< + HashMap< + &'extractor [u8], + DelAddBbbul<'extractor, BitPacker4x>, + FxBuildHasher, + &'extractor Bump, + >, + >, + spilled_entries: Vec>, + deladd_buffer: Vec, + cbo_buffer: Vec, +} + +impl<'extractor> SpillingCaches<'extractor> { + fn from_cache_maps( + caches: Vec< + HashMap< + &'extractor [u8], + DelAddBbbul<'extractor, BitPacker4x>, + FxBuildHasher, + &'extractor Bump, + >, + >, + ) -> SpillingCaches<'extractor> { + SpillingCaches { + spilled_entries: iter::repeat_with(|| { + let mut builder = grenad::SorterBuilder::new(MergeDeladdCboRoaringBitmaps); + builder.dump_threshold(0); + builder.allow_realloc(false); + builder.build() + }) + .take(caches.len()) + .collect(), + caches, + deladd_buffer: Vec::new(), + cbo_buffer: Vec::new(), + } + } + + pub fn insert_del_u32( + &mut self, + hasher: &FxBuildHasher, + alloc: &'extractor Bump, + buckets: usize, + key: &[u8], + n: u32, + ) -> Result<()> { + let hash = hasher.hash_one(key); + let bucket = compute_bucket_from_hash(buckets, hash); + match self.caches[bucket].raw_entry_mut().from_hash(hash, |&k| k == key) { + RawEntryMut::Occupied(mut entry) => { + entry.get_mut().del.get_or_insert_with(|| Bbbul::new_in(alloc)).insert(n); + Ok(()) + } + RawEntryMut::Vacant(_entry) => spill_entry_to_sorter( + &mut self.spilled_entries[bucket], + &mut self.deladd_buffer, + &mut self.cbo_buffer, + key, + DelAddRoaringBitmap::new_del_u32(n), + ), + } + } + + pub fn insert_add_u32( + &mut self, + hasher: &FxBuildHasher, + alloc: &'extractor Bump, + buckets: usize, + key: &[u8], + n: u32, + ) -> Result<()> { + let hash = hasher.hash_one(key); + let bucket = compute_bucket_from_hash(buckets, hash); + match self.caches[bucket].raw_entry_mut().from_hash(hash, |&k| k == key) { + RawEntryMut::Occupied(mut entry) => { + entry.get_mut().add.get_or_insert_with(|| Bbbul::new_in(alloc)).insert(n); + Ok(()) + } + RawEntryMut::Vacant(_entry) => spill_entry_to_sorter( + &mut self.spilled_entries[bucket], + &mut self.deladd_buffer, + &mut self.cbo_buffer, + key, + DelAddRoaringBitmap::new_add_u32(n), + ), + } + } +} + +#[inline] +fn compute_bucket_from_hash(buckets: usize, hash: u64) -> usize { + hash as usize % buckets +} + +fn spill_entry_to_sorter( + spilled_entries: &mut grenad::Sorter, + deladd_buffer: &mut Vec, + cbo_buffer: &mut Vec, + key: &[u8], + deladd: DelAddRoaringBitmap, +) -> Result<()> { + deladd_buffer.clear(); + let mut value_writer = KvWriterDelAdd::new(deladd_buffer); + + match deladd { + DelAddRoaringBitmap { del: Some(del), add: None } => { + cbo_buffer.clear(); + CboRoaringBitmapCodec::serialize_into(&del, cbo_buffer); + value_writer.insert(DelAdd::Deletion, &cbo_buffer)?; + } + DelAddRoaringBitmap { del: None, add: Some(add) } => { + cbo_buffer.clear(); + CboRoaringBitmapCodec::serialize_into(&add, cbo_buffer); + value_writer.insert(DelAdd::Addition, &cbo_buffer)?; + } + DelAddRoaringBitmap { del: Some(del), add: Some(add) } => { + cbo_buffer.clear(); + CboRoaringBitmapCodec::serialize_into(&del, cbo_buffer); + value_writer.insert(DelAdd::Deletion, &cbo_buffer)?; + + cbo_buffer.clear(); + CboRoaringBitmapCodec::serialize_into(&add, cbo_buffer); + value_writer.insert(DelAdd::Addition, &cbo_buffer)?; + } + DelAddRoaringBitmap { del: None, add: None } => return Ok(()), + } + + let bytes = value_writer.into_inner().unwrap(); + spilled_entries.insert(key, bytes).map_err(Into::into) +} + +pub struct FrozenCache<'a, 'extractor> { + bucket: usize, + cache: FrozenMap< + 'a, + 'extractor, + &'extractor [u8], + FrozenDelAddBbbul<'extractor, BitPacker4x>, + FxBuildHasher, + >, + spilled: Vec>>, +} + +pub fn transpose_and_freeze_caches<'a, 'extractor>( + caches: &'a mut [BalancedCaches<'extractor>], +) -> Result>>> { + let width = caches.first().map(BalancedCaches::buckets).unwrap_or(0); + let mut bucket_caches: Vec<_> = iter::repeat_with(Vec::new).take(width).collect(); + + for thread_cache in caches { + for frozen in thread_cache.freeze()? { + bucket_caches[frozen.bucket].push(frozen); + } + } + + Ok(bucket_caches) +} + +/// Merges the caches that must be all associated to the same bucket. +/// +/// # Panics +/// +/// - If the bucket IDs in these frozen caches are not exactly the same. +pub fn merge_caches(frozen: Vec, mut f: F) -> Result<()> +where + F: for<'a> FnMut(&'a [u8], DelAddRoaringBitmap) -> Result<()>, +{ + let mut maps = Vec::new(); + let mut readers = Vec::new(); + let mut current_bucket = None; + for FrozenCache { bucket, cache, ref mut spilled } in frozen { + assert_eq!(*current_bucket.get_or_insert(bucket), bucket); + maps.push(cache); + readers.append(spilled); + } + + // First manage the spilled entries by looking into the HashMaps, + // merge them and mark them as dummy. + let mut heap = BinaryHeap::new(); + for (source_index, source) in readers.into_iter().enumerate() { + let mut cursor = source.into_cursor()?; + if cursor.move_on_next()?.is_some() { + heap.push(Entry { cursor, source_index }); + } + } + + loop { + let mut first_entry = match heap.pop() { + Some(entry) => entry, + None => break, + }; + + let (first_key, first_value) = match first_entry.cursor.current() { + Some((key, value)) => (key, value), + None => break, + }; + + let mut output = DelAddRoaringBitmap::from_bytes(first_value)?; + while let Some(mut entry) = heap.peek_mut() { + if let Some((key, _value)) = entry.cursor.current() { + if first_key == key { + let new = DelAddRoaringBitmap::from_bytes(first_value)?; + output = output.merge(new); + // When we are done we the current value of this entry move make + // it move forward and let the heap reorganize itself (on drop) + if entry.cursor.move_on_next()?.is_none() { + PeekMut::pop(entry); + } + } else { + break; + } + } + } + + // Once we merged all of the spilled bitmaps we must also + // fetch the entries from the non-spilled entries (the HashMaps). + for (map_index, map) in maps.iter_mut().enumerate() { + if first_entry.source_index != map_index { + if let Some(new) = map.get_mut(first_key) { + output.union_and_clear_bbbul(new); + } + } + } + + // We send the merged entry outside. + (f)(first_key, output)?; + + // Don't forget to put the first entry back into the heap. + if first_entry.cursor.move_on_next()?.is_some() { + heap.push(first_entry) + } + } + + // Then manage the content on the HashMap entries that weren't taken (mem::take). + while let Some(mut map) = maps.pop() { + for (key, bbbul) in map.iter_mut() { + // Make sure we don't try to work with entries already managed by the spilled + if bbbul.is_empty() { + continue; + } + + let mut output = DelAddRoaringBitmap::empty(); + output.union_and_clear_bbbul(bbbul); + + for rhs in maps.iter_mut() { + if let Some(new) = rhs.get_mut(key) { + output.union_and_clear_bbbul(new); + } + } + + // We send the merged entry outside. + (f)(key, output)?; + } + } + + Ok(()) +} + +struct Entry { + cursor: ReaderCursor, + source_index: usize, +} + +impl Ord for Entry { + fn cmp(&self, other: &Entry) -> Ordering { + let skey = self.cursor.current().map(|(k, _)| k); + let okey = other.cursor.current().map(|(k, _)| k); + skey.cmp(&okey).then(self.source_index.cmp(&other.source_index)).reverse() + } +} + +impl Eq for Entry {} + +impl PartialEq for Entry { + fn eq(&self, other: &Entry) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for Entry { + fn partial_cmp(&self, other: &Entry) -> Option { + Some(self.cmp(other)) + } +} + +pub struct DelAddBbbul<'bump, B> { + pub del: Option>, + pub add: Option>, +} + +impl<'bump, B: BitPacker> DelAddBbbul<'bump, B> { + pub fn new_del_u32_in(n: u32, bump: &'bump Bump) -> Self { + let mut bbbul = Bbbul::new_in(bump); + bbbul.insert(n); + DelAddBbbul { del: Some(bbbul), add: None } + } + + pub fn new_add_u32_in(n: u32, bump: &'bump Bump) -> Self { + let mut bbbul = Bbbul::new_in(bump); + bbbul.insert(n); + DelAddBbbul { del: None, add: Some(bbbul) } + } +} + +pub struct FrozenDelAddBbbul<'bump, B> { + pub del: Option>, + pub add: Option>, +} + +impl<'bump, B> FrozenDelAddBbbul<'bump, B> { + fn is_empty(&self) -> bool { + self.del.is_none() && self.add.is_none() + } +} + +#[derive(Debug, Default, Clone)] +pub struct DelAddRoaringBitmap { + pub del: Option, + pub add: Option, +} + +impl DelAddRoaringBitmap { + fn from_bytes(bytes: &[u8]) -> io::Result { + let reader = KvReaderDelAdd::from_slice(bytes); + + let del = match reader.get(DelAdd::Deletion) { + Some(bytes) => CboRoaringBitmapCodec::deserialize_from(bytes).map(Some)?, + None => None, + }; + + let add = match reader.get(DelAdd::Addition) { + Some(bytes) => CboRoaringBitmapCodec::deserialize_from(bytes).map(Some)?, + None => None, + }; + + Ok(DelAddRoaringBitmap { del, add }) + } + + pub fn empty() -> DelAddRoaringBitmap { + DelAddRoaringBitmap { del: None, add: None } + } + + pub fn insert_del_u32(&mut self, n: u32) { + self.del.get_or_insert_with(RoaringBitmap::new).insert(n); + } + + pub fn insert_add_u32(&mut self, n: u32) { + self.add.get_or_insert_with(RoaringBitmap::new).insert(n); + } + + pub fn new_del_u32(n: u32) -> Self { + DelAddRoaringBitmap { del: Some(RoaringBitmap::from([n])), add: None } + } + + pub fn new_add_u32(n: u32) -> Self { + DelAddRoaringBitmap { del: None, add: Some(RoaringBitmap::from([n])) } + } + + pub fn union_and_clear_bbbul(&mut self, bbbul: &mut FrozenDelAddBbbul<'_, B>) { + let FrozenDelAddBbbul { del, add } = bbbul; + + if let Some(ref mut bbbul) = del.take() { + let del = self.del.get_or_insert_with(RoaringBitmap::new); + let mut iter = bbbul.iter_and_clear(); + while let Some(block) = iter.next_block() { + let iter = block.iter().copied(); + let block = RoaringBitmap::from_sorted_iter(iter).unwrap(); + *del |= block; + } + } + + if let Some(ref mut bbbul) = add.take() { + let add = self.add.get_or_insert_with(RoaringBitmap::new); + let mut iter = bbbul.iter_and_clear(); + while let Some(block) = iter.next_block() { + let iter = block.iter().copied(); + let block = RoaringBitmap::from_sorted_iter(iter).unwrap(); + *add |= block; + } + } + } + + pub fn merge(self, rhs: DelAddRoaringBitmap) -> DelAddRoaringBitmap { + let DelAddRoaringBitmap { del, add } = self; + let DelAddRoaringBitmap { del: ndel, add: nadd } = rhs; + + let del = match (del, ndel) { + (None, None) => None, + (None, Some(del)) | (Some(del), None) => Some(del), + (Some(del), Some(ndel)) => Some(del | ndel), + }; + + let add = match (add, nadd) { + (None, None) => None, + (None, Some(add)) | (Some(add), None) => Some(add), + (Some(add), Some(nadd)) => Some(add | nadd), + }; + + DelAddRoaringBitmap { del, add } + } + + pub fn apply_to(&self, documents_ids: &mut RoaringBitmap) { + let DelAddRoaringBitmap { del, add } = self; + + if let Some(del) = del { + *documents_ids -= del; + } + + if let Some(add) = add { + *documents_ids |= add; + } + } +} diff --git a/crates/milli/src/update/new/extract/documents.rs b/crates/milli/src/update/new/extract/documents.rs new file mode 100644 index 000000000..23d93a2c2 --- /dev/null +++ b/crates/milli/src/update/new/extract/documents.rs @@ -0,0 +1,150 @@ +use std::cell::RefCell; + +use bumpalo::Bump; +use hashbrown::HashMap; + +use super::DelAddRoaringBitmap; +use crate::update::new::channel::DocumentsSender; +use crate::update::new::document::{write_to_obkv, Document as _}; +use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::thread_local::FullySend; +use crate::update::new::DocumentChange; +use crate::vector::EmbeddingConfigs; +use crate::Result; +pub struct DocumentsExtractor<'a> { + document_sender: &'a DocumentsSender<'a>, + embedders: &'a EmbeddingConfigs, +} + +impl<'a> DocumentsExtractor<'a> { + pub fn new(document_sender: &'a DocumentsSender<'a>, embedders: &'a EmbeddingConfigs) -> Self { + Self { document_sender, embedders } + } +} + +#[derive(Default)] +pub struct DocumentExtractorData { + pub docids_delta: DelAddRoaringBitmap, + pub field_distribution_delta: HashMap, +} + +impl<'a, 'extractor> Extractor<'extractor> for DocumentsExtractor<'a> { + type Data = FullySend>; + + fn init_data(&self, _extractor_alloc: &'extractor Bump) -> Result { + Ok(FullySend(Default::default())) + } + + fn process<'doc>( + &self, + changes: impl Iterator>>, + context: &DocumentChangeContext, + ) -> Result<()> { + let mut document_buffer = bumpalo::collections::Vec::new_in(&context.doc_alloc); + let mut document_extractor_data = context.data.0.borrow_mut_or_yield(); + + for change in changes { + let change = change?; + // **WARNING**: the exclusive borrow on `new_fields_ids_map` needs to be taken **inside** of the `for change in changes` loop + // Otherwise, `BorrowMutError` will occur for document changes that also need the new_fields_ids_map (e.g.: UpdateByFunction) + let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); + + let external_docid = change.external_docid().to_owned(); + + // document but we need to create a function that collects and compresses documents. + match change { + DocumentChange::Deletion(deletion) => { + let docid = deletion.docid(); + let content = deletion.current( + &context.rtxn, + context.index, + &context.db_fields_ids_map, + )?; + let geo_iter = + content.geo_field().transpose().map(|res| res.map(|rv| ("_geo", rv))); + for res in content.iter_top_level_fields().chain(geo_iter) { + let (f, _) = res?; + let entry = document_extractor_data + .field_distribution_delta + .entry_ref(f) + .or_default(); + *entry -= 1; + } + document_extractor_data.docids_delta.insert_del_u32(docid); + self.document_sender.delete(docid, external_docid).unwrap(); + } + DocumentChange::Update(update) => { + let docid = update.docid(); + let content = + update.current(&context.rtxn, context.index, &context.db_fields_ids_map)?; + let geo_iter = + content.geo_field().transpose().map(|res| res.map(|rv| ("_geo", rv))); + for res in content.iter_top_level_fields().chain(geo_iter) { + let (f, _) = res?; + let entry = document_extractor_data + .field_distribution_delta + .entry_ref(f) + .or_default(); + *entry -= 1; + } + let content = update.updated(); + let geo_iter = + content.geo_field().transpose().map(|res| res.map(|rv| ("_geo", rv))); + for res in content.iter_top_level_fields().chain(geo_iter) { + let (f, _) = res?; + let entry = document_extractor_data + .field_distribution_delta + .entry_ref(f) + .or_default(); + *entry += 1; + } + + let content = + update.merged(&context.rtxn, context.index, &context.db_fields_ids_map)?; + let vector_content = update.merged_vectors( + &context.rtxn, + context.index, + &context.db_fields_ids_map, + &context.doc_alloc, + self.embedders, + )?; + let content = write_to_obkv( + &content, + vector_content.as_ref(), + &mut new_fields_ids_map, + &mut document_buffer, + )?; + self.document_sender.uncompressed(docid, external_docid, content).unwrap(); + } + DocumentChange::Insertion(insertion) => { + let docid = insertion.docid(); + let content = insertion.inserted(); + let geo_iter = + content.geo_field().transpose().map(|res| res.map(|rv| ("_geo", rv))); + for res in content.iter_top_level_fields().chain(geo_iter) { + let (f, _) = res?; + let entry = document_extractor_data + .field_distribution_delta + .entry_ref(f) + .or_default(); + *entry += 1; + } + let inserted_vectors = + insertion.inserted_vectors(&context.doc_alloc, self.embedders)?; + let content = write_to_obkv( + &content, + inserted_vectors.as_ref(), + &mut new_fields_ids_map, + &mut document_buffer, + )?; + document_extractor_data.docids_delta.insert_add_u32(docid); + self.document_sender.uncompressed(docid, external_docid, content).unwrap(); + // extracted_dictionary_sender.send(self, dictionary: &[u8]); + } + } + } + + Ok(()) + } +} diff --git a/crates/milli/src/update/new/extract/faceted/extract_facets.rs b/crates/milli/src/update/new/extract/faceted/extract_facets.rs new file mode 100644 index 000000000..5394a6e86 --- /dev/null +++ b/crates/milli/src/update/new/extract/faceted/extract_facets.rs @@ -0,0 +1,393 @@ +use std::cell::RefCell; +use std::collections::HashSet; +use std::ops::DerefMut as _; + +use bumpalo::collections::Vec as BVec; +use bumpalo::Bump; +use hashbrown::HashMap; +use heed::RoTxn; +use serde_json::Value; + +use super::super::cache::BalancedCaches; +use super::facet_document::extract_document_facets; +use super::FacetKind; +use crate::heed_codec::facet::OrderedF64Codec; +use crate::update::del_add::DelAdd; +use crate::update::new::channel::FieldIdDocidFacetSender; +use crate::update::new::indexer::document_changes::{ + extract, DocumentChangeContext, DocumentChanges, Extractor, IndexingContext, Progress, +}; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::steps::Step; +use crate::update::new::thread_local::{FullySend, ThreadLocal}; +use crate::update::new::DocumentChange; +use crate::update::GrenadParameters; +use crate::{DocumentId, FieldId, Index, Result, MAX_FACET_VALUE_LENGTH}; + +pub struct FacetedExtractorData<'a> { + attributes_to_extract: &'a [&'a str], + sender: &'a FieldIdDocidFacetSender<'a>, + grenad_parameters: GrenadParameters, + buckets: usize, +} + +impl<'a, 'extractor> Extractor<'extractor> for FacetedExtractorData<'a> { + type Data = RefCell>; + + fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result { + Ok(RefCell::new(BalancedCaches::new_in( + self.buckets, + self.grenad_parameters.max_memory_by_thread(), + extractor_alloc, + ))) + } + + fn process<'doc>( + &self, + changes: impl Iterator>>, + context: &DocumentChangeContext, + ) -> Result<()> { + for change in changes { + let change = change?; + FacetedDocidsExtractor::extract_document_change( + context, + self.attributes_to_extract, + change, + self.sender, + )? + } + Ok(()) + } +} + +pub struct FacetedDocidsExtractor; + +impl FacetedDocidsExtractor { + fn extract_document_change( + context: &DocumentChangeContext>, + attributes_to_extract: &[&str], + document_change: DocumentChange, + sender: &FieldIdDocidFacetSender, + ) -> Result<()> { + let index = &context.index; + let rtxn = &context.rtxn; + let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); + let mut cached_sorter = context.data.borrow_mut_or_yield(); + let mut del_add_facet_value = DelAddFacetValue::new(&context.doc_alloc); + let docid = document_change.docid(); + let res = match document_change { + DocumentChange::Deletion(inner) => extract_document_facets( + attributes_to_extract, + inner.current(rtxn, index, context.db_fields_ids_map)?, + inner.external_document_id(), + new_fields_ids_map.deref_mut(), + &mut |fid, value| { + Self::facet_fn_with_options( + &context.doc_alloc, + cached_sorter.deref_mut(), + BalancedCaches::insert_del_u32, + &mut del_add_facet_value, + DelAddFacetValue::insert_del, + docid, + fid, + value, + ) + }, + ), + DocumentChange::Update(inner) => { + extract_document_facets( + attributes_to_extract, + inner.current(rtxn, index, context.db_fields_ids_map)?, + inner.external_document_id(), + new_fields_ids_map.deref_mut(), + &mut |fid, value| { + Self::facet_fn_with_options( + &context.doc_alloc, + cached_sorter.deref_mut(), + BalancedCaches::insert_del_u32, + &mut del_add_facet_value, + DelAddFacetValue::insert_del, + docid, + fid, + value, + ) + }, + )?; + + extract_document_facets( + attributes_to_extract, + inner.merged(rtxn, index, context.db_fields_ids_map)?, + inner.external_document_id(), + new_fields_ids_map.deref_mut(), + &mut |fid, value| { + Self::facet_fn_with_options( + &context.doc_alloc, + cached_sorter.deref_mut(), + BalancedCaches::insert_add_u32, + &mut del_add_facet_value, + DelAddFacetValue::insert_add, + docid, + fid, + value, + ) + }, + ) + } + DocumentChange::Insertion(inner) => extract_document_facets( + attributes_to_extract, + inner.inserted(), + inner.external_document_id(), + new_fields_ids_map.deref_mut(), + &mut |fid, value| { + Self::facet_fn_with_options( + &context.doc_alloc, + cached_sorter.deref_mut(), + BalancedCaches::insert_add_u32, + &mut del_add_facet_value, + DelAddFacetValue::insert_add, + docid, + fid, + value, + ) + }, + ), + }; + + del_add_facet_value.send_data(docid, sender, &context.doc_alloc).unwrap(); + res + } + + #[allow(clippy::too_many_arguments)] + fn facet_fn_with_options<'extractor, 'doc>( + doc_alloc: &'doc Bump, + cached_sorter: &mut BalancedCaches<'extractor>, + cache_fn: impl Fn(&mut BalancedCaches<'extractor>, &[u8], u32) -> Result<()>, + del_add_facet_value: &mut DelAddFacetValue<'doc>, + facet_fn: impl Fn(&mut DelAddFacetValue<'doc>, FieldId, BVec<'doc, u8>, FacetKind), + docid: DocumentId, + fid: FieldId, + value: &Value, + ) -> Result<()> { + let mut buffer = BVec::new_in(doc_alloc); + // Exists + // key: fid + buffer.push(FacetKind::Exists as u8); + buffer.extend_from_slice(&fid.to_be_bytes()); + cache_fn(cached_sorter, &buffer, docid)?; + + match value { + // Number + // key: fid - level - orderedf64 - orignalf64 + Value::Number(number) => { + let mut ordered = [0u8; 16]; + if number + .as_f64() + .and_then(|n| OrderedF64Codec::serialize_into(n, &mut ordered).ok()) + .is_some() + { + let mut number = BVec::with_capacity_in(16, doc_alloc); + number.extend_from_slice(&ordered); + facet_fn(del_add_facet_value, fid, number, FacetKind::Number); + + buffer.clear(); + buffer.push(FacetKind::Number as u8); + buffer.extend_from_slice(&fid.to_be_bytes()); + buffer.push(0); // level 0 + buffer.extend_from_slice(&ordered); + cache_fn(cached_sorter, &buffer, docid) + } else { + Ok(()) + } + } + // String + // key: fid - level - truncated_string + Value::String(s) => { + let mut string = BVec::new_in(doc_alloc); + string.extend_from_slice(s.as_bytes()); + facet_fn(del_add_facet_value, fid, string, FacetKind::String); + + let normalized = crate::normalize_facet(s); + let truncated = truncate_str(&normalized); + buffer.clear(); + buffer.push(FacetKind::String as u8); + buffer.extend_from_slice(&fid.to_be_bytes()); + buffer.push(0); // level 0 + buffer.extend_from_slice(truncated.as_bytes()); + cache_fn(cached_sorter, &buffer, docid) + } + // Null + // key: fid + Value::Null => { + buffer.clear(); + buffer.push(FacetKind::Null as u8); + buffer.extend_from_slice(&fid.to_be_bytes()); + cache_fn(cached_sorter, &buffer, docid) + } + // Empty + // key: fid + Value::Array(a) if a.is_empty() => { + buffer.clear(); + buffer.push(FacetKind::Empty as u8); + buffer.extend_from_slice(&fid.to_be_bytes()); + cache_fn(cached_sorter, &buffer, docid) + } + Value::Object(o) if o.is_empty() => { + buffer.clear(); + buffer.push(FacetKind::Empty as u8); + buffer.extend_from_slice(&fid.to_be_bytes()); + cache_fn(cached_sorter, &buffer, docid) + } + // Otherwise, do nothing + /// TODO: What about Value::Bool? + _ => Ok(()), + } + } + + fn attributes_to_extract<'a>(rtxn: &'a RoTxn, index: &'a Index) -> Result> { + index.user_defined_faceted_fields(rtxn) + } +} + +struct DelAddFacetValue<'doc> { + strings: HashMap<(FieldId, BVec<'doc, u8>), DelAdd, hashbrown::DefaultHashBuilder, &'doc Bump>, + f64s: HashMap<(FieldId, BVec<'doc, u8>), DelAdd, hashbrown::DefaultHashBuilder, &'doc Bump>, +} + +impl<'doc> DelAddFacetValue<'doc> { + fn new(doc_alloc: &'doc Bump) -> Self { + Self { strings: HashMap::new_in(doc_alloc), f64s: HashMap::new_in(doc_alloc) } + } + + fn insert_add(&mut self, fid: FieldId, value: BVec<'doc, u8>, kind: FacetKind) { + let cache = match kind { + FacetKind::String => &mut self.strings, + FacetKind::Number => &mut self.f64s, + _ => return, + }; + + let key = (fid, value); + if let Some(DelAdd::Deletion) = cache.get(&key) { + cache.remove(&key); + } else { + cache.insert(key, DelAdd::Addition); + } + } + + fn insert_del(&mut self, fid: FieldId, value: BVec<'doc, u8>, kind: FacetKind) { + let cache = match kind { + FacetKind::String => &mut self.strings, + FacetKind::Number => &mut self.f64s, + _ => return, + }; + + let key = (fid, value); + if let Some(DelAdd::Addition) = cache.get(&key) { + cache.remove(&key); + } else { + cache.insert(key, DelAdd::Deletion); + } + } + + fn send_data( + self, + docid: DocumentId, + sender: &FieldIdDocidFacetSender, + doc_alloc: &Bump, + ) -> std::result::Result<(), crossbeam_channel::SendError<()>> { + let mut buffer = bumpalo::collections::Vec::new_in(doc_alloc); + for ((fid, value), deladd) in self.strings { + if let Ok(s) = std::str::from_utf8(&value) { + buffer.clear(); + buffer.extend_from_slice(&fid.to_be_bytes()); + buffer.extend_from_slice(&docid.to_be_bytes()); + let normalized = crate::normalize_facet(s); + let truncated = truncate_str(&normalized); + buffer.extend_from_slice(truncated.as_bytes()); + match deladd { + DelAdd::Deletion => sender.delete_facet_string(&buffer)?, + DelAdd::Addition => sender.write_facet_string(&buffer, &value)?, + } + } + } + + for ((fid, value), deladd) in self.f64s { + buffer.clear(); + buffer.extend_from_slice(&fid.to_be_bytes()); + buffer.extend_from_slice(&docid.to_be_bytes()); + buffer.extend_from_slice(&value); + match deladd { + DelAdd::Deletion => sender.delete_facet_f64(&buffer)?, + DelAdd::Addition => sender.write_facet_f64(&buffer)?, + } + } + + Ok(()) + } +} + +/// Truncates a string to the biggest valid LMDB key size. +fn truncate_str(s: &str) -> &str { + let index = s + .char_indices() + .map(|(idx, _)| idx) + .chain(std::iter::once(s.len())) + .take_while(|idx| idx <= &MAX_FACET_VALUE_LENGTH) + .last(); + + &s[..index.unwrap_or(0)] +} + +impl FacetedDocidsExtractor { + #[tracing::instrument(level = "trace", skip_all, target = "indexing::extract::faceted")] + pub fn run_extraction< + 'pl, + 'fid, + 'indexer, + 'index, + 'extractor, + DC: DocumentChanges<'pl>, + MSP, + SP, + >( + grenad_parameters: GrenadParameters, + document_changes: &DC, + indexing_context: IndexingContext<'fid, 'indexer, 'index, MSP, SP>, + extractor_allocs: &'extractor mut ThreadLocal>, + sender: &FieldIdDocidFacetSender, + step: Step, + ) -> Result>> + where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, + { + let index = indexing_context.index; + let rtxn = index.read_txn()?; + let attributes_to_extract = Self::attributes_to_extract(&rtxn, index)?; + let attributes_to_extract: Vec<_> = + attributes_to_extract.iter().map(|s| s.as_ref()).collect(); + let datastore = ThreadLocal::new(); + + { + let span = + tracing::trace_span!(target: "indexing::documents::extract", "docids_extraction"); + let _entered = span.enter(); + + let extractor = FacetedExtractorData { + attributes_to_extract: &attributes_to_extract, + grenad_parameters, + buckets: rayon::current_num_threads(), + sender, + }; + extract( + document_changes, + &extractor, + indexing_context, + extractor_allocs, + &datastore, + step, + )?; + } + + Ok(datastore.into_iter().map(RefCell::into_inner).collect()) + } +} diff --git a/crates/milli/src/update/new/extract/faceted/facet_document.rs b/crates/milli/src/update/new/extract/faceted/facet_document.rs new file mode 100644 index 000000000..141af7fbe --- /dev/null +++ b/crates/milli/src/update/new/extract/faceted/facet_document.rs @@ -0,0 +1,61 @@ +use serde_json::Value; + +use crate::update::new::document::Document; +use crate::update::new::extract::geo::extract_geo_coordinates; +use crate::update::new::extract::perm_json_p; +use crate::{FieldId, GlobalFieldsIdsMap, InternalError, Result, UserError}; + +pub fn extract_document_facets<'doc>( + attributes_to_extract: &[&str], + document: impl Document<'doc>, + external_document_id: &str, + field_id_map: &mut GlobalFieldsIdsMap, + facet_fn: &mut impl FnMut(FieldId, &Value) -> Result<()>, +) -> Result<()> { + for res in document.iter_top_level_fields() { + let (field_name, value) = res?; + + let mut tokenize_field = |name: &str, value: &Value| match field_id_map.id_or_insert(name) { + Some(field_id) => facet_fn(field_id, value), + None => Err(UserError::AttributeLimitReached.into()), + }; + + // if the current field is searchable or contains a searchable attribute + if perm_json_p::select_field(field_name, Some(attributes_to_extract), &[]) { + // parse json. + match serde_json::value::to_value(value).map_err(InternalError::SerdeJson)? { + Value::Object(object) => perm_json_p::seek_leaf_values_in_object( + &object, + Some(attributes_to_extract), + &[], // skip no attributes + field_name, + &mut tokenize_field, + )?, + Value::Array(array) => perm_json_p::seek_leaf_values_in_array( + &array, + Some(attributes_to_extract), + &[], // skip no attributes + field_name, + &mut tokenize_field, + )?, + value => tokenize_field(field_name, &value)?, + } + } + } + + if attributes_to_extract.contains(&"_geo") { + if let Some(geo_value) = document.geo_field()? { + if let Some([lat, lng]) = extract_geo_coordinates(external_document_id, geo_value)? { + let (lat_fid, lng_fid) = field_id_map + .id_or_insert("_geo.lat") + .zip(field_id_map.id_or_insert("_geo.lng")) + .ok_or(UserError::AttributeLimitReached)?; + + facet_fn(lat_fid, &lat.into())?; + facet_fn(lng_fid, &lng.into())?; + } + } + } + + Ok(()) +} diff --git a/crates/milli/src/update/new/extract/faceted/mod.rs b/crates/milli/src/update/new/extract/faceted/mod.rs new file mode 100644 index 000000000..0c012d739 --- /dev/null +++ b/crates/milli/src/update/new/extract/faceted/mod.rs @@ -0,0 +1,33 @@ +mod extract_facets; +mod facet_document; + +pub use extract_facets::FacetedDocidsExtractor; + +#[repr(u8)] +#[derive(Debug, Clone, Copy)] +pub enum FacetKind { + Number = 0, + String = 1, + Null = 2, + Empty = 3, + Exists, +} + +impl From for FacetKind { + fn from(value: u8) -> Self { + match value { + 0 => Self::Number, + 1 => Self::String, + 2 => Self::Null, + 3 => Self::Empty, + 4 => Self::Exists, + _ => unreachable!(), + } + } +} + +impl FacetKind { + pub fn extract_from_key(key: &[u8]) -> (FacetKind, &[u8]) { + (FacetKind::from(key[0]), &key[1..]) + } +} diff --git a/crates/milli/src/update/new/extract/geo/mod.rs b/crates/milli/src/update/new/extract/geo/mod.rs new file mode 100644 index 000000000..c3ea76c42 --- /dev/null +++ b/crates/milli/src/update/new/extract/geo/mod.rs @@ -0,0 +1,324 @@ +use std::cell::RefCell; +use std::fs::File; +use std::io::{self, BufReader, BufWriter, ErrorKind, Read, Write as _}; +use std::{iter, mem, result}; + +use bumpalo::Bump; +use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; +use heed::RoTxn; +use serde_json::value::RawValue; +use serde_json::Value; + +use crate::error::GeoError; +use crate::update::new::document::Document; +use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::DocumentChange; +use crate::update::GrenadParameters; +use crate::{lat_lng_to_xyz, DocumentId, GeoPoint, Index, InternalError, Result}; + +pub struct GeoExtractor { + grenad_parameters: GrenadParameters, +} + +impl GeoExtractor { + pub fn new( + rtxn: &RoTxn, + index: &Index, + grenad_parameters: GrenadParameters, + ) -> Result> { + let is_sortable = index.sortable_fields(rtxn)?.contains("_geo"); + let is_filterable = index.filterable_fields(rtxn)?.contains("_geo"); + if is_sortable || is_filterable { + Ok(Some(GeoExtractor { grenad_parameters })) + } else { + Ok(None) + } + } +} + +#[derive(Pod, Zeroable, Copy, Clone)] +#[repr(C, packed)] +pub struct ExtractedGeoPoint { + pub docid: DocumentId, + pub lat_lng: [f64; 2], +} + +impl From for GeoPoint { + /// Converts the latitude and longitude back to an xyz GeoPoint. + fn from(value: ExtractedGeoPoint) -> Self { + let [lat, lng] = value.lat_lng; + let point = [lat, lng]; + let xyz_point = lat_lng_to_xyz(&point); + GeoPoint::new(xyz_point, (value.docid, point)) + } +} + +pub struct GeoExtractorData<'extractor> { + /// The set of documents ids that were removed. If a document sees its geo + /// point being updated, we first put it in the deleted and then in the inserted. + removed: bumpalo::collections::Vec<'extractor, ExtractedGeoPoint>, + inserted: bumpalo::collections::Vec<'extractor, ExtractedGeoPoint>, + /// TODO Do the doc + spilled_removed: Option>, + /// TODO Do the doc + spilled_inserted: Option>, +} + +impl<'extractor> GeoExtractorData<'extractor> { + pub fn freeze(self) -> Result> { + let GeoExtractorData { removed, inserted, spilled_removed, spilled_inserted } = self; + + Ok(FrozenGeoExtractorData { + removed: removed.into_bump_slice(), + inserted: inserted.into_bump_slice(), + spilled_removed: spilled_removed + .map(|bw| bw.into_inner().map(BufReader::new).map_err(|iie| iie.into_error())) + .transpose()?, + spilled_inserted: spilled_inserted + .map(|bw| bw.into_inner().map(BufReader::new).map_err(|iie| iie.into_error())) + .transpose()?, + }) + } +} + +unsafe impl MostlySend for GeoExtractorData<'_> {} + +pub struct FrozenGeoExtractorData<'extractor> { + pub removed: &'extractor [ExtractedGeoPoint], + pub inserted: &'extractor [ExtractedGeoPoint], + pub spilled_removed: Option>, + pub spilled_inserted: Option>, +} + +impl<'extractor> FrozenGeoExtractorData<'extractor> { + pub fn iter_and_clear_removed( + &mut self, + ) -> impl IntoIterator> + '_ { + mem::take(&mut self.removed) + .iter() + .copied() + .map(Ok) + .chain(iterator_over_spilled_geopoints(&mut self.spilled_removed)) + } + + pub fn iter_and_clear_inserted( + &mut self, + ) -> impl IntoIterator> + '_ { + mem::take(&mut self.inserted) + .iter() + .copied() + .map(Ok) + .chain(iterator_over_spilled_geopoints(&mut self.spilled_inserted)) + } +} + +fn iterator_over_spilled_geopoints( + spilled: &mut Option>, +) -> impl IntoIterator> + '_ { + let mut spilled = spilled.take(); + iter::from_fn(move || match &mut spilled { + Some(file) => { + let geopoint_bytes = &mut [0u8; mem::size_of::()]; + match file.read_exact(geopoint_bytes) { + Ok(()) => Some(Ok(pod_read_unaligned(geopoint_bytes))), + Err(e) if e.kind() == ErrorKind::UnexpectedEof => None, + Err(e) => Some(Err(e)), + } + } + None => None, + }) +} + +impl<'extractor> Extractor<'extractor> for GeoExtractor { + type Data = RefCell>; + + fn init_data<'doc>(&'doc self, extractor_alloc: &'extractor Bump) -> Result { + Ok(RefCell::new(GeoExtractorData { + removed: bumpalo::collections::Vec::new_in(extractor_alloc), + // inserted: Uell::new_in(extractor_alloc), + inserted: bumpalo::collections::Vec::new_in(extractor_alloc), + spilled_inserted: None, + spilled_removed: None, + })) + } + + fn process<'doc>( + &'doc self, + changes: impl Iterator>>, + context: &'doc DocumentChangeContext, + ) -> Result<()> { + let rtxn = &context.rtxn; + let index = context.index; + let max_memory = self.grenad_parameters.max_memory_by_thread(); + let db_fields_ids_map = context.db_fields_ids_map; + let mut data_ref = context.data.borrow_mut_or_yield(); + + for change in changes { + if max_memory.map_or(false, |mm| context.extractor_alloc.allocated_bytes() >= mm) { + // We must spill as we allocated too much memory + data_ref.spilled_removed = tempfile::tempfile().map(BufWriter::new).map(Some)?; + data_ref.spilled_inserted = tempfile::tempfile().map(BufWriter::new).map(Some)?; + } + + match change? { + DocumentChange::Deletion(deletion) => { + let docid = deletion.docid(); + let external_id = deletion.external_document_id(); + let current = deletion.current(rtxn, index, db_fields_ids_map)?; + let current_geo = current + .geo_field()? + .map(|geo| extract_geo_coordinates(external_id, geo)) + .transpose()?; + + if let Some(lat_lng) = current_geo.flatten() { + let geopoint = ExtractedGeoPoint { docid, lat_lng }; + match &mut data_ref.spilled_removed { + Some(file) => file.write_all(bytes_of(&geopoint))?, + None => data_ref.removed.push(geopoint), + } + } + } + DocumentChange::Update(update) => { + let current = update.current(rtxn, index, db_fields_ids_map)?; + let external_id = update.external_document_id(); + let docid = update.docid(); + + let current_geo = current + .geo_field()? + .map(|geo| extract_geo_coordinates(external_id, geo)) + .transpose()?; + + let updated_geo = update + .updated() + .geo_field()? + .map(|geo| extract_geo_coordinates(external_id, geo)) + .transpose()?; + + if current_geo != updated_geo { + // If the current and new geo points are different it means that + // we need to replace the current by the new point and therefore + // delete the current point from the RTree. + if let Some(lat_lng) = current_geo.flatten() { + let geopoint = ExtractedGeoPoint { docid, lat_lng }; + match &mut data_ref.spilled_removed { + Some(file) => file.write_all(bytes_of(&geopoint))?, + None => data_ref.removed.push(geopoint), + } + } + + if let Some(lat_lng) = updated_geo.flatten() { + let geopoint = ExtractedGeoPoint { docid, lat_lng }; + match &mut data_ref.spilled_inserted { + Some(file) => file.write_all(bytes_of(&geopoint))?, + None => data_ref.inserted.push(geopoint), + } + } + } + } + DocumentChange::Insertion(insertion) => { + let external_id = insertion.external_document_id(); + let docid = insertion.docid(); + + let inserted_geo = insertion + .inserted() + .geo_field()? + .map(|geo| extract_geo_coordinates(external_id, geo)) + .transpose()?; + + if let Some(lat_lng) = inserted_geo.flatten() { + let geopoint = ExtractedGeoPoint { docid, lat_lng }; + match &mut data_ref.spilled_inserted { + Some(file) => file.write_all(bytes_of(&geopoint))?, + None => data_ref.inserted.push(geopoint), + } + } + } + } + } + + Ok(()) + } +} + +/// Extracts and validate the latitude and latitude from a document geo field. +/// +/// It can be of the form `{ "lat": 0.0, "lng": "1.0" }`. +pub fn extract_geo_coordinates( + external_id: &str, + raw_value: &RawValue, +) -> Result> { + let mut geo = match serde_json::from_str(raw_value.get()).map_err(InternalError::SerdeJson)? { + Value::Null => return Ok(None), + Value::Object(map) => map, + value => { + return Err( + GeoError::NotAnObject { document_id: Value::from(external_id), value }.into() + ) + } + }; + + let [lat, lng] = match (geo.remove("lat"), geo.remove("lng")) { + (Some(lat), Some(lng)) => { + if geo.is_empty() { + [lat, lng] + } else { + return Err(GeoError::UnexpectedExtraFields { + document_id: Value::from(external_id), + value: Value::from(geo), + } + .into()); + } + } + (Some(_), None) => { + return Err(GeoError::MissingLongitude { document_id: Value::from(external_id) }.into()) + } + (None, Some(_)) => { + return Err(GeoError::MissingLatitude { document_id: Value::from(external_id) }.into()) + } + (None, None) => { + return Err(GeoError::MissingLatitudeAndLongitude { + document_id: Value::from(external_id), + } + .into()) + } + }; + + match (extract_finite_float_from_value(lat), extract_finite_float_from_value(lng)) { + (Ok(lat), Ok(lng)) => Ok(Some([lat, lng])), + (Ok(_), Err(value)) => { + Err(GeoError::BadLongitude { document_id: Value::from(external_id), value }.into()) + } + (Err(value), Ok(_)) => { + Err(GeoError::BadLatitude { document_id: Value::from(external_id), value }.into()) + } + (Err(lat), Err(lng)) => Err(GeoError::BadLatitudeAndLongitude { + document_id: Value::from(external_id), + lat, + lng, + } + .into()), + } +} + +/// Extracts and validate that a serde JSON Value is actually a finite f64. +pub fn extract_finite_float_from_value(value: Value) -> result::Result { + let number = match value { + Value::Number(ref n) => match n.as_f64() { + Some(number) => number, + None => return Err(value), + }, + Value::String(ref s) => match s.parse::() { + Ok(number) => number, + Err(_) => return Err(value), + }, + value => return Err(value), + }; + + if number.is_finite() { + Ok(number) + } else { + Err(value) + } +} diff --git a/crates/milli/src/update/new/extract/mod.rs b/crates/milli/src/update/new/extract/mod.rs new file mode 100644 index 000000000..7364434ee --- /dev/null +++ b/crates/milli/src/update/new/extract/mod.rs @@ -0,0 +1,146 @@ +mod cache; +mod documents; +mod faceted; +mod geo; +mod searchable; +mod vectors; + +use bumpalo::Bump; +pub use cache::{merge_caches, transpose_and_freeze_caches, BalancedCaches, DelAddRoaringBitmap}; +pub use documents::*; +pub use faceted::*; +pub use geo::*; +pub use searchable::*; +pub use vectors::EmbeddingExtractor; + +use super::indexer::document_changes::{DocumentChanges, IndexingContext, Progress}; +use super::steps::Step; +use super::thread_local::{FullySend, ThreadLocal}; +use crate::update::GrenadParameters; +use crate::Result; + +pub trait DocidsExtractor { + fn run_extraction<'pl, 'fid, 'indexer, 'index, 'extractor, DC: DocumentChanges<'pl>, MSP, SP>( + grenad_parameters: GrenadParameters, + document_changes: &DC, + indexing_context: IndexingContext<'fid, 'indexer, 'index, MSP, SP>, + extractor_allocs: &'extractor mut ThreadLocal>, + step: Step, + ) -> Result>> + where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync; +} + +/// TODO move in permissive json pointer +pub mod perm_json_p { + use serde_json::{Map, Value}; + + use crate::Result; + const SPLIT_SYMBOL: char = '.'; + + /// Returns `true` if the `selector` match the `key`. + /// + /// ```text + /// Example: + /// `animaux` match `animaux` + /// `animaux.chien` match `animaux` + /// `animaux.chien` match `animaux` + /// `animaux.chien.nom` match `animaux` + /// `animaux.chien.nom` match `animaux.chien` + /// ----------------------------------------- + /// `animaux` doesn't match `animaux.chien` + /// `animaux.` doesn't match `animaux` + /// `animaux.ch` doesn't match `animaux.chien` + /// `animau` doesn't match `animaux` + /// ``` + pub fn contained_in(selector: &str, key: &str) -> bool { + selector.starts_with(key) + && selector[key.len()..].chars().next().map(|c| c == SPLIT_SYMBOL).unwrap_or(true) + } + + pub fn seek_leaf_values_in_object( + value: &Map, + selectors: Option<&[&str]>, + skip_selectors: &[&str], + base_key: &str, + seeker: &mut impl FnMut(&str, &Value) -> Result<()>, + ) -> Result<()> { + if value.is_empty() { + seeker(base_key, &Value::Object(Map::with_capacity(0)))?; + } + + for (key, value) in value.iter() { + let base_key = if base_key.is_empty() { + key.to_string() + } else { + format!("{}{}{}", base_key, SPLIT_SYMBOL, key) + }; + + // here if the user only specified `doggo` we need to iterate in all the fields of `doggo` + // so we check the contained_in on both side + let should_continue = select_field(&base_key, selectors, skip_selectors); + if should_continue { + match value { + Value::Object(object) => seek_leaf_values_in_object( + object, + selectors, + skip_selectors, + &base_key, + seeker, + ), + Value::Array(array) => seek_leaf_values_in_array( + array, + selectors, + skip_selectors, + &base_key, + seeker, + ), + value => seeker(&base_key, value), + }?; + } + } + + Ok(()) + } + + pub fn seek_leaf_values_in_array( + values: &[Value], + selectors: Option<&[&str]>, + skip_selectors: &[&str], + base_key: &str, + seeker: &mut impl FnMut(&str, &Value) -> Result<()>, + ) -> Result<()> { + if values.is_empty() { + seeker(base_key, &Value::Array(vec![]))?; + } + + for value in values { + match value { + Value::Object(object) => { + seek_leaf_values_in_object(object, selectors, skip_selectors, base_key, seeker) + } + Value::Array(array) => { + seek_leaf_values_in_array(array, selectors, skip_selectors, base_key, seeker) + } + value => seeker(base_key, value), + }?; + } + + Ok(()) + } + + pub fn select_field( + field_name: &str, + selectors: Option<&[&str]>, + skip_selectors: &[&str], + ) -> bool { + selectors.map_or(true, |selectors| { + selectors.iter().any(|selector| { + contained_in(selector, field_name) || contained_in(field_name, selector) + }) + }) && !skip_selectors.iter().any(|skip_selector| { + contained_in(skip_selector, field_name) || contained_in(field_name, skip_selector) + }) + } +} diff --git a/crates/milli/src/update/new/extract/searchable/extract_word_docids.rs b/crates/milli/src/update/new/extract/searchable/extract_word_docids.rs new file mode 100644 index 000000000..f3d4afcb8 --- /dev/null +++ b/crates/milli/src/update/new/extract/searchable/extract_word_docids.rs @@ -0,0 +1,421 @@ +use std::cell::RefCell; +use std::collections::HashMap; +use std::mem::size_of; +use std::ops::DerefMut as _; + +use bumpalo::collections::vec::Vec as BumpVec; +use bumpalo::Bump; +use heed::RoTxn; + +use super::tokenize_document::{tokenizer_builder, DocumentTokenizer}; +use crate::update::new::extract::cache::BalancedCaches; +use crate::update::new::extract::perm_json_p::contained_in; +use crate::update::new::indexer::document_changes::{ + extract, DocumentChangeContext, DocumentChanges, Extractor, IndexingContext, Progress, +}; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::steps::Step; +use crate::update::new::thread_local::{FullySend, MostlySend, ThreadLocal}; +use crate::update::new::DocumentChange; +use crate::update::GrenadParameters; +use crate::{bucketed_position, DocumentId, FieldId, Index, Result, MAX_POSITION_PER_ATTRIBUTE}; + +const MAX_COUNTED_WORDS: usize = 30; + +pub struct WordDocidsBalancedCaches<'extractor> { + word_fid_docids: BalancedCaches<'extractor>, + word_docids: BalancedCaches<'extractor>, + exact_word_docids: BalancedCaches<'extractor>, + word_position_docids: BalancedCaches<'extractor>, + fid_word_count_docids: BalancedCaches<'extractor>, + fid_word_count: HashMap, + current_docid: Option, +} + +unsafe impl<'extractor> MostlySend for WordDocidsBalancedCaches<'extractor> {} + +impl<'extractor> WordDocidsBalancedCaches<'extractor> { + /// TODO Make sure to give the same max_memory to all of them, without splitting it + pub fn new_in(buckets: usize, max_memory: Option, alloc: &'extractor Bump) -> Self { + Self { + word_fid_docids: BalancedCaches::new_in(buckets, max_memory, alloc), + word_docids: BalancedCaches::new_in(buckets, max_memory, alloc), + exact_word_docids: BalancedCaches::new_in(buckets, max_memory, alloc), + word_position_docids: BalancedCaches::new_in(buckets, max_memory, alloc), + fid_word_count_docids: BalancedCaches::new_in(buckets, max_memory, alloc), + fid_word_count: HashMap::new(), + current_docid: None, + } + } + + fn insert_add_u32( + &mut self, + field_id: FieldId, + position: u16, + word: &str, + exact: bool, + docid: u32, + bump: &Bump, + ) -> Result<()> { + let word_bytes = word.as_bytes(); + if exact { + self.exact_word_docids.insert_add_u32(word_bytes, docid)?; + } else { + self.word_docids.insert_add_u32(word_bytes, docid)?; + } + + let buffer_size = word_bytes.len() + 1 + size_of::(); + let mut buffer = BumpVec::with_capacity_in(buffer_size, bump); + + buffer.clear(); + buffer.extend_from_slice(word_bytes); + buffer.push(0); + buffer.extend_from_slice(&field_id.to_be_bytes()); + self.word_fid_docids.insert_add_u32(&buffer, docid)?; + + let position = bucketed_position(position); + buffer.clear(); + buffer.extend_from_slice(word_bytes); + buffer.push(0); + buffer.extend_from_slice(&position.to_be_bytes()); + self.word_position_docids.insert_add_u32(&buffer, docid)?; + + if self.current_docid.map_or(false, |id| docid != id) { + self.flush_fid_word_count(&mut buffer)?; + } + + self.fid_word_count + .entry(field_id) + .and_modify(|(_current_count, new_count)| *new_count += 1) + .or_insert((0, 1)); + self.current_docid = Some(docid); + + Ok(()) + } + + fn insert_del_u32( + &mut self, + field_id: FieldId, + position: u16, + word: &str, + exact: bool, + docid: u32, + bump: &Bump, + ) -> Result<()> { + let word_bytes = word.as_bytes(); + if exact { + self.exact_word_docids.insert_del_u32(word_bytes, docid)?; + } else { + self.word_docids.insert_del_u32(word_bytes, docid)?; + } + + let buffer_size = word_bytes.len() + 1 + size_of::(); + let mut buffer = BumpVec::with_capacity_in(buffer_size, bump); + + buffer.clear(); + buffer.extend_from_slice(word_bytes); + buffer.push(0); + buffer.extend_from_slice(&field_id.to_be_bytes()); + self.word_fid_docids.insert_del_u32(&buffer, docid)?; + + let position = bucketed_position(position); + buffer.clear(); + buffer.extend_from_slice(word_bytes); + buffer.push(0); + buffer.extend_from_slice(&position.to_be_bytes()); + self.word_position_docids.insert_del_u32(&buffer, docid)?; + + if self.current_docid.map_or(false, |id| docid != id) { + self.flush_fid_word_count(&mut buffer)?; + } + + self.fid_word_count + .entry(field_id) + .and_modify(|(current_count, _new_count)| *current_count += 1) + .or_insert((1, 0)); + + self.current_docid = Some(docid); + + Ok(()) + } + + fn flush_fid_word_count(&mut self, buffer: &mut BumpVec) -> Result<()> { + for (fid, (current_count, new_count)) in self.fid_word_count.drain() { + if current_count != new_count { + if current_count <= MAX_COUNTED_WORDS { + buffer.clear(); + buffer.extend_from_slice(&fid.to_be_bytes()); + buffer.push(current_count as u8); + self.fid_word_count_docids + .insert_del_u32(buffer, self.current_docid.unwrap())?; + } + if new_count <= MAX_COUNTED_WORDS { + buffer.clear(); + buffer.extend_from_slice(&fid.to_be_bytes()); + buffer.push(new_count as u8); + self.fid_word_count_docids + .insert_add_u32(buffer, self.current_docid.unwrap())?; + } + } + } + + Ok(()) + } +} + +pub struct WordDocidsCaches<'extractor> { + pub word_docids: Vec>, + pub word_fid_docids: Vec>, + pub exact_word_docids: Vec>, + pub word_position_docids: Vec>, + pub fid_word_count_docids: Vec>, +} + +impl<'extractor> WordDocidsCaches<'extractor> { + fn new() -> Self { + Self { + word_docids: Vec::new(), + word_fid_docids: Vec::new(), + exact_word_docids: Vec::new(), + word_position_docids: Vec::new(), + fid_word_count_docids: Vec::new(), + } + } + + fn push(&mut self, other: WordDocidsBalancedCaches<'extractor>) -> Result<()> { + let WordDocidsBalancedCaches { + word_docids, + word_fid_docids, + exact_word_docids, + word_position_docids, + fid_word_count_docids, + fid_word_count: _, + current_docid: _, + } = other; + + self.word_docids.push(word_docids); + self.word_fid_docids.push(word_fid_docids); + self.exact_word_docids.push(exact_word_docids); + self.word_position_docids.push(word_position_docids); + self.fid_word_count_docids.push(fid_word_count_docids); + + Ok(()) + } +} + +pub struct WordDocidsExtractorData<'a> { + tokenizer: &'a DocumentTokenizer<'a>, + grenad_parameters: GrenadParameters, + buckets: usize, +} + +impl<'a, 'extractor> Extractor<'extractor> for WordDocidsExtractorData<'a> { + type Data = RefCell>>; + + fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result { + Ok(RefCell::new(Some(WordDocidsBalancedCaches::new_in( + self.buckets, + self.grenad_parameters.max_memory_by_thread(), + extractor_alloc, + )))) + } + + fn process<'doc>( + &self, + changes: impl Iterator>>, + context: &DocumentChangeContext, + ) -> Result<()> { + for change in changes { + let change = change?; + WordDocidsExtractors::extract_document_change(context, self.tokenizer, change)?; + } + Ok(()) + } +} + +pub struct WordDocidsExtractors; + +impl WordDocidsExtractors { + pub fn run_extraction< + 'pl, + 'fid, + 'indexer, + 'index, + 'extractor, + DC: DocumentChanges<'pl>, + MSP, + SP, + >( + grenad_parameters: GrenadParameters, + document_changes: &DC, + indexing_context: IndexingContext<'fid, 'indexer, 'index, MSP, SP>, + extractor_allocs: &'extractor mut ThreadLocal>, + step: Step, + ) -> Result> + where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, + { + let index = indexing_context.index; + let rtxn = index.read_txn()?; + + let stop_words = index.stop_words(&rtxn)?; + let allowed_separators = index.allowed_separators(&rtxn)?; + let allowed_separators: Option> = + allowed_separators.as_ref().map(|s| s.iter().map(String::as_str).collect()); + let dictionary = index.dictionary(&rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect()); + let builder = tokenizer_builder( + stop_words.as_ref(), + allowed_separators.as_deref(), + dictionary.as_deref(), + ); + let tokenizer = builder.into_tokenizer(); + + let attributes_to_extract = Self::attributes_to_extract(&rtxn, index)?; + let attributes_to_skip = Self::attributes_to_skip(&rtxn, index)?; + let localized_attributes_rules = + index.localized_attributes_rules(&rtxn)?.unwrap_or_default(); + + let document_tokenizer = DocumentTokenizer { + tokenizer: &tokenizer, + attribute_to_extract: attributes_to_extract.as_deref(), + attribute_to_skip: attributes_to_skip.as_slice(), + localized_attributes_rules: &localized_attributes_rules, + max_positions_per_attributes: MAX_POSITION_PER_ATTRIBUTE, + }; + + let datastore = ThreadLocal::new(); + + { + let span = + tracing::trace_span!(target: "indexing::documents::extract", "docids_extraction"); + let _entered = span.enter(); + + let extractor = WordDocidsExtractorData { + tokenizer: &document_tokenizer, + grenad_parameters, + buckets: rayon::current_num_threads(), + }; + + extract( + document_changes, + &extractor, + indexing_context, + extractor_allocs, + &datastore, + step, + )?; + } + + let mut merger = WordDocidsCaches::new(); + for cache in datastore.into_iter().flat_map(RefCell::into_inner) { + merger.push(cache)?; + } + + Ok(merger) + } + + fn extract_document_change( + context: &DocumentChangeContext>>, + document_tokenizer: &DocumentTokenizer, + document_change: DocumentChange, + ) -> Result<()> { + let index = &context.index; + let rtxn = &context.rtxn; + let mut cached_sorter_ref = context.data.borrow_mut_or_yield(); + let cached_sorter = cached_sorter_ref.as_mut().unwrap(); + let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); + let new_fields_ids_map = new_fields_ids_map.deref_mut(); + let doc_alloc = &context.doc_alloc; + + let exact_attributes = index.exact_attributes(rtxn)?; + let is_exact_attribute = + |fname: &str| exact_attributes.iter().any(|attr| contained_in(fname, attr)); + match document_change { + DocumentChange::Deletion(inner) => { + let mut token_fn = |fname: &str, fid, pos, word: &str| { + cached_sorter.insert_del_u32( + fid, + pos, + word, + is_exact_attribute(fname), + inner.docid(), + doc_alloc, + ) + }; + document_tokenizer.tokenize_document( + inner.current(rtxn, index, context.db_fields_ids_map)?, + new_fields_ids_map, + &mut token_fn, + )?; + } + DocumentChange::Update(inner) => { + let mut token_fn = |fname: &str, fid, pos, word: &str| { + cached_sorter.insert_del_u32( + fid, + pos, + word, + is_exact_attribute(fname), + inner.docid(), + doc_alloc, + ) + }; + document_tokenizer.tokenize_document( + inner.current(rtxn, index, context.db_fields_ids_map)?, + new_fields_ids_map, + &mut token_fn, + )?; + + let mut token_fn = |fname: &str, fid, pos, word: &str| { + cached_sorter.insert_add_u32( + fid, + pos, + word, + is_exact_attribute(fname), + inner.docid(), + doc_alloc, + ) + }; + document_tokenizer.tokenize_document( + inner.merged(rtxn, index, context.db_fields_ids_map)?, + new_fields_ids_map, + &mut token_fn, + )?; + } + DocumentChange::Insertion(inner) => { + let mut token_fn = |fname: &str, fid, pos, word: &str| { + cached_sorter.insert_add_u32( + fid, + pos, + word, + is_exact_attribute(fname), + inner.docid(), + doc_alloc, + ) + }; + document_tokenizer.tokenize_document( + inner.inserted(), + new_fields_ids_map, + &mut token_fn, + )?; + } + } + + let buffer_size = size_of::(); + let mut buffer = BumpVec::with_capacity_in(buffer_size, &context.doc_alloc); + cached_sorter.flush_fid_word_count(&mut buffer) + } + + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(_rtxn: &'a RoTxn, _index: &'a Index) -> Result> { + Ok(vec!["_geo"]) + } +} diff --git a/crates/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs b/crates/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs new file mode 100644 index 000000000..945f0b8b3 --- /dev/null +++ b/crates/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs @@ -0,0 +1,190 @@ +use std::cell::RefCell; +use std::collections::VecDeque; +use std::rc::Rc; + +use heed::RoTxn; + +use super::tokenize_document::DocumentTokenizer; +use super::SearchableExtractor; +use crate::proximity::{index_proximity, MAX_DISTANCE}; +use crate::update::new::document::Document; +use crate::update::new::extract::cache::BalancedCaches; +use crate::update::new::indexer::document_changes::DocumentChangeContext; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::DocumentChange; +use crate::{FieldId, GlobalFieldsIdsMap, Index, Result}; + +pub struct WordPairProximityDocidsExtractor; + +impl SearchableExtractor for WordPairProximityDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(_rtxn: &'a RoTxn, _index: &'a Index) -> Result> { + Ok(vec!["_geo"]) + } + + // This method is reimplemented to count the number of words in the document in each field + // and to store the docids of the documents that have a number of words in a given field + // equal to or under than MAX_COUNTED_WORDS. + fn extract_document_change( + context: &DocumentChangeContext>, + document_tokenizer: &DocumentTokenizer, + document_change: DocumentChange, + ) -> Result<()> { + let doc_alloc = &context.doc_alloc; + + let index = context.index; + let rtxn = &context.rtxn; + + let mut key_buffer = bumpalo::collections::Vec::new_in(doc_alloc); + let mut del_word_pair_proximity = bumpalo::collections::Vec::new_in(doc_alloc); + let mut add_word_pair_proximity = bumpalo::collections::Vec::new_in(doc_alloc); + + let mut new_fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); + let new_fields_ids_map = &mut *new_fields_ids_map; + + let mut cached_sorter = context.data.borrow_mut_or_yield(); + let cached_sorter = &mut *cached_sorter; + + // is a vecdequeue, and will be smol, so can stay on the heap for now + let mut word_positions: VecDeque<(Rc, u16)> = + VecDeque::with_capacity(MAX_DISTANCE as usize); + + let docid = document_change.docid(); + match document_change { + DocumentChange::Deletion(inner) => { + let document = inner.current(rtxn, index, context.db_fields_ids_map)?; + process_document_tokens( + document, + document_tokenizer, + new_fields_ids_map, + &mut word_positions, + &mut |(w1, w2), prox| { + del_word_pair_proximity.push(((w1, w2), prox)); + }, + )?; + } + DocumentChange::Update(inner) => { + let document = inner.current(rtxn, index, context.db_fields_ids_map)?; + process_document_tokens( + document, + document_tokenizer, + new_fields_ids_map, + &mut word_positions, + &mut |(w1, w2), prox| { + del_word_pair_proximity.push(((w1, w2), prox)); + }, + )?; + let document = inner.merged(rtxn, index, context.db_fields_ids_map)?; + process_document_tokens( + document, + document_tokenizer, + new_fields_ids_map, + &mut word_positions, + &mut |(w1, w2), prox| { + add_word_pair_proximity.push(((w1, w2), prox)); + }, + )?; + } + DocumentChange::Insertion(inner) => { + let document = inner.inserted(); + process_document_tokens( + document, + document_tokenizer, + new_fields_ids_map, + &mut word_positions, + &mut |(w1, w2), prox| { + add_word_pair_proximity.push(((w1, w2), prox)); + }, + )?; + } + } + + del_word_pair_proximity.sort_unstable(); + del_word_pair_proximity.dedup_by(|(k1, _), (k2, _)| k1 == k2); + for ((w1, w2), prox) in del_word_pair_proximity.iter() { + let key = build_key(*prox, w1, w2, &mut key_buffer); + cached_sorter.insert_del_u32(key, docid)?; + } + + add_word_pair_proximity.sort_unstable(); + add_word_pair_proximity.dedup_by(|(k1, _), (k2, _)| k1 == k2); + for ((w1, w2), prox) in add_word_pair_proximity.iter() { + let key = build_key(*prox, w1, w2, &mut key_buffer); + cached_sorter.insert_add_u32(key, docid)?; + } + Ok(()) + } +} + +fn build_key<'a>( + prox: u8, + w1: &str, + w2: &str, + key_buffer: &'a mut bumpalo::collections::Vec, +) -> &'a [u8] { + key_buffer.clear(); + key_buffer.push(prox); + key_buffer.extend_from_slice(w1.as_bytes()); + key_buffer.push(0); + key_buffer.extend_from_slice(w2.as_bytes()); + key_buffer.as_slice() +} + +fn word_positions_into_word_pair_proximity( + word_positions: &mut VecDeque<(Rc, u16)>, + word_pair_proximity: &mut impl FnMut((Rc, Rc), u8), +) { + let (head_word, head_position) = word_positions.pop_front().unwrap(); + for (word, position) in word_positions.iter() { + let prox = index_proximity(head_position as u32, *position as u32) as u8; + if prox > 0 && prox < MAX_DISTANCE as u8 { + word_pair_proximity((head_word.clone(), word.clone()), prox); + } + } +} + +fn drain_word_positions( + word_positions: &mut VecDeque<(Rc, u16)>, + word_pair_proximity: &mut impl FnMut((Rc, Rc), u8), +) { + while !word_positions.is_empty() { + word_positions_into_word_pair_proximity(word_positions, word_pair_proximity); + } +} + +fn process_document_tokens<'doc>( + document: impl Document<'doc>, + document_tokenizer: &DocumentTokenizer, + fields_ids_map: &mut GlobalFieldsIdsMap, + word_positions: &mut VecDeque<(Rc, u16)>, + word_pair_proximity: &mut impl FnMut((Rc, Rc), u8), +) -> Result<()> { + let mut field_id = None; + let mut token_fn = |_fname: &str, fid: FieldId, pos: u16, word: &str| { + if field_id != Some(fid) { + field_id = Some(fid); + drain_word_positions(word_positions, word_pair_proximity); + } + // drain the proximity window until the head word is considered close to the word we are inserting. + while word_positions + .front() + .map_or(false, |(_w, p)| index_proximity(*p as u32, pos as u32) >= MAX_DISTANCE) + { + word_positions_into_word_pair_proximity(word_positions, word_pair_proximity); + } + + // insert the new word. + word_positions.push_back((Rc::from(word), pos)); + Ok(()) + }; + document_tokenizer.tokenize_document(document, fields_ids_map, &mut token_fn)?; + + drain_word_positions(word_positions, word_pair_proximity); + Ok(()) +} diff --git a/crates/milli/src/update/new/extract/searchable/mod.rs b/crates/milli/src/update/new/extract/searchable/mod.rs new file mode 100644 index 000000000..b61dfcf92 --- /dev/null +++ b/crates/milli/src/update/new/extract/searchable/mod.rs @@ -0,0 +1,156 @@ +mod extract_word_docids; +mod extract_word_pair_proximity_docids; +mod tokenize_document; + +use std::cell::RefCell; +use std::marker::PhantomData; + +use bumpalo::Bump; +pub use extract_word_docids::{WordDocidsCaches, WordDocidsExtractors}; +pub use extract_word_pair_proximity_docids::WordPairProximityDocidsExtractor; +use heed::RoTxn; +use tokenize_document::{tokenizer_builder, DocumentTokenizer}; + +use super::cache::BalancedCaches; +use super::DocidsExtractor; +use crate::update::new::indexer::document_changes::{ + extract, DocumentChangeContext, DocumentChanges, Extractor, IndexingContext, Progress, +}; +use crate::update::new::steps::Step; +use crate::update::new::thread_local::{FullySend, ThreadLocal}; +use crate::update::new::DocumentChange; +use crate::update::GrenadParameters; +use crate::{Index, Result, MAX_POSITION_PER_ATTRIBUTE}; + +pub struct SearchableExtractorData<'a, EX: SearchableExtractor> { + tokenizer: &'a DocumentTokenizer<'a>, + grenad_parameters: GrenadParameters, + buckets: usize, + _ex: PhantomData, +} + +impl<'a, 'extractor, EX: SearchableExtractor + Sync> Extractor<'extractor> + for SearchableExtractorData<'a, EX> +{ + type Data = RefCell>; + + fn init_data(&self, extractor_alloc: &'extractor Bump) -> Result { + Ok(RefCell::new(BalancedCaches::new_in( + self.buckets, + self.grenad_parameters.max_memory_by_thread(), + extractor_alloc, + ))) + } + + fn process<'doc>( + &self, + changes: impl Iterator>>, + context: &DocumentChangeContext, + ) -> Result<()> { + for change in changes { + let change = change?; + EX::extract_document_change(context, self.tokenizer, change)?; + } + Ok(()) + } +} + +pub trait SearchableExtractor: Sized + Sync { + fn run_extraction<'pl, 'fid, 'indexer, 'index, 'extractor, DC: DocumentChanges<'pl>, MSP, SP>( + grenad_parameters: GrenadParameters, + document_changes: &DC, + indexing_context: IndexingContext<'fid, 'indexer, 'index, MSP, SP>, + extractor_allocs: &'extractor mut ThreadLocal>, + step: Step, + ) -> Result>> + where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, + { + let rtxn = indexing_context.index.read_txn()?; + let stop_words = indexing_context.index.stop_words(&rtxn)?; + let allowed_separators = indexing_context.index.allowed_separators(&rtxn)?; + let allowed_separators: Option> = + allowed_separators.as_ref().map(|s| s.iter().map(String::as_str).collect()); + let dictionary = indexing_context.index.dictionary(&rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect()); + let builder = tokenizer_builder( + stop_words.as_ref(), + allowed_separators.as_deref(), + dictionary.as_deref(), + ); + let tokenizer = builder.into_tokenizer(); + + let attributes_to_extract = Self::attributes_to_extract(&rtxn, indexing_context.index)?; + let attributes_to_skip = Self::attributes_to_skip(&rtxn, indexing_context.index)?; + let localized_attributes_rules = + indexing_context.index.localized_attributes_rules(&rtxn)?.unwrap_or_default(); + + let document_tokenizer = DocumentTokenizer { + tokenizer: &tokenizer, + attribute_to_extract: attributes_to_extract.as_deref(), + attribute_to_skip: attributes_to_skip.as_slice(), + localized_attributes_rules: &localized_attributes_rules, + max_positions_per_attributes: MAX_POSITION_PER_ATTRIBUTE, + }; + + let extractor_data: SearchableExtractorData = SearchableExtractorData { + tokenizer: &document_tokenizer, + grenad_parameters, + buckets: rayon::current_num_threads(), + _ex: PhantomData, + }; + + let datastore = ThreadLocal::new(); + + { + let span = + tracing::trace_span!(target: "indexing::documents::extract", "docids_extraction"); + let _entered = span.enter(); + extract( + document_changes, + &extractor_data, + indexing_context, + extractor_allocs, + &datastore, + step, + )?; + } + + Ok(datastore.into_iter().map(RefCell::into_inner).collect()) + } + + fn extract_document_change( + context: &DocumentChangeContext>, + document_tokenizer: &DocumentTokenizer, + document_change: DocumentChange, + ) -> Result<()>; + + fn attributes_to_extract<'a>(rtxn: &'a RoTxn, index: &'a Index) + -> Result>>; + + fn attributes_to_skip<'a>(rtxn: &'a RoTxn, index: &'a Index) -> Result>; +} + +impl DocidsExtractor for T { + fn run_extraction<'pl, 'fid, 'indexer, 'index, 'extractor, DC: DocumentChanges<'pl>, MSP, SP>( + grenad_parameters: GrenadParameters, + document_changes: &DC, + indexing_context: IndexingContext<'fid, 'indexer, 'index, MSP, SP>, + extractor_allocs: &'extractor mut ThreadLocal>, + step: Step, + ) -> Result>> + where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, + { + Self::run_extraction( + grenad_parameters, + document_changes, + indexing_context, + extractor_allocs, + step, + ) + } +} diff --git a/crates/milli/src/update/new/extract/searchable/tokenize_document.rs b/crates/milli/src/update/new/extract/searchable/tokenize_document.rs new file mode 100644 index 000000000..bc7a2acd3 --- /dev/null +++ b/crates/milli/src/update/new/extract/searchable/tokenize_document.rs @@ -0,0 +1,277 @@ +use std::collections::HashMap; + +use charabia::{SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder}; +use serde_json::Value; + +use crate::update::new::document::Document; +use crate::update::new::extract::perm_json_p::{ + seek_leaf_values_in_array, seek_leaf_values_in_object, select_field, +}; +use crate::{ + FieldId, GlobalFieldsIdsMap, InternalError, LocalizedAttributesRule, Result, UserError, + MAX_WORD_LENGTH, +}; + +// todo: should be crate::proximity::MAX_DISTANCE but it has been forgotten +const MAX_DISTANCE: u32 = 8; + +pub struct DocumentTokenizer<'a> { + pub tokenizer: &'a Tokenizer<'a>, + pub attribute_to_extract: Option<&'a [&'a str]>, + pub attribute_to_skip: &'a [&'a str], + pub localized_attributes_rules: &'a [LocalizedAttributesRule], + pub max_positions_per_attributes: u32, +} + +impl<'a> DocumentTokenizer<'a> { + pub fn tokenize_document<'doc>( + &self, + document: impl Document<'doc>, + field_id_map: &mut GlobalFieldsIdsMap, + token_fn: &mut impl FnMut(&str, FieldId, u16, &str) -> Result<()>, + ) -> Result<()> { + let mut field_position = HashMap::new(); + + for entry in document.iter_top_level_fields() { + let (field_name, value) = entry?; + + let mut tokenize_field = |field_name: &str, value: &Value| { + let Some(field_id) = field_id_map.id_or_insert(field_name) else { + return Err(UserError::AttributeLimitReached.into()); + }; + + let position = field_position + .entry(field_id) + .and_modify(|counter| *counter += MAX_DISTANCE) + .or_insert(0); + if *position >= self.max_positions_per_attributes { + return Ok(()); + } + + match value { + Value::Number(n) => { + let token = n.to_string(); + if let Ok(position) = (*position).try_into() { + token_fn(field_name, field_id, position, token.as_str())?; + } + + Ok(()) + } + Value::String(text) => { + // create an iterator of token with their positions. + let locales = self + .localized_attributes_rules + .iter() + .find(|rule| rule.match_str(field_name)) + .map(|rule| rule.locales()); + let tokens = process_tokens( + *position, + self.tokenizer.tokenize_with_allow_list(text.as_str(), locales), + ) + .take_while(|(p, _)| *p < self.max_positions_per_attributes); + + for (index, token) in tokens { + // keep a word only if it is not empty and fit in a LMDB key. + let token = token.lemma().trim(); + if !token.is_empty() && token.len() <= MAX_WORD_LENGTH { + *position = index; + if let Ok(position) = (*position).try_into() { + token_fn(field_name, field_id, position, token)?; + } + } + } + + Ok(()) + } + _ => Ok(()), + } + }; + + // if the current field is searchable or contains a searchable attribute + if select_field(field_name, self.attribute_to_extract, self.attribute_to_skip) { + // parse json. + match serde_json::to_value(value).map_err(InternalError::SerdeJson)? { + Value::Object(object) => seek_leaf_values_in_object( + &object, + self.attribute_to_extract, + self.attribute_to_skip, + field_name, + &mut tokenize_field, + )?, + Value::Array(array) => seek_leaf_values_in_array( + &array, + self.attribute_to_extract, + self.attribute_to_skip, + field_name, + &mut tokenize_field, + )?, + value => tokenize_field(field_name, &value)?, + } + } + } + + Ok(()) + } +} + +/// take an iterator on tokens and compute their relative position depending on separator kinds +/// if it's an `Hard` separator we add an additional relative proximity of MAX_DISTANCE between words, +/// else we keep the standard proximity of 1 between words. +fn process_tokens<'a>( + start_offset: u32, + tokens: impl Iterator>, +) -> impl Iterator)> { + tokens + .skip_while(|token| token.is_separator()) + .scan((start_offset, None), |(offset, prev_kind), mut token| { + match token.kind { + TokenKind::Word | TokenKind::StopWord if !token.lemma().is_empty() => { + *offset += match *prev_kind { + Some(TokenKind::Separator(SeparatorKind::Hard)) => MAX_DISTANCE, + Some(_) => 1, + None => 0, + }; + *prev_kind = Some(token.kind) + } + TokenKind::Separator(SeparatorKind::Hard) => { + *prev_kind = Some(token.kind); + } + TokenKind::Separator(SeparatorKind::Soft) + if *prev_kind != Some(TokenKind::Separator(SeparatorKind::Hard)) => + { + *prev_kind = Some(token.kind); + } + _ => token.kind = TokenKind::Unknown, + } + Some((*offset, token)) + }) + .filter(|(_, t)| t.is_word()) +} + +/// Factorize tokenizer building. +pub fn tokenizer_builder<'a>( + stop_words: Option<&'a fst::Set<&'a [u8]>>, + allowed_separators: Option<&'a [&str]>, + dictionary: Option<&'a [&str]>, +) -> TokenizerBuilder<'a, &'a [u8]> { + let mut tokenizer_builder = TokenizerBuilder::new(); + if let Some(stop_words) = stop_words { + tokenizer_builder.stop_words(stop_words); + } + if let Some(dictionary) = dictionary { + tokenizer_builder.words_dict(dictionary); + } + if let Some(separators) = allowed_separators { + tokenizer_builder.separators(separators); + } + + tokenizer_builder +} + +#[cfg(test)] +mod test { + use bumpalo::Bump; + use charabia::TokenizerBuilder; + use meili_snap::snapshot; + use raw_collections::RawMap; + use serde_json::json; + use serde_json::value::RawValue; + + use super::*; + use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; + use crate::update::new::document::{DocumentFromVersions, Versions}; + use crate::FieldsIdsMap; + + #[test] + fn test_tokenize_document() { + let mut fields_ids_map = FieldsIdsMap::new(); + + let document = json!({ + "doggo": { "name": "doggo", + "age": 10,}, + "catto": { + "catto": { + "name": "pesti", + "age": 23, + } + }, + "doggo.name": ["doggo", "catto"], + "not-me": "UNSEARCHABLE", + "me-nether": {"nope": "unsearchable"} + }); + + let _field_1_id = fields_ids_map.insert("doggo").unwrap(); + let _field_2_id = fields_ids_map.insert("catto").unwrap(); + let _field_3_id = fields_ids_map.insert("doggo.name").unwrap(); + let _field_4_id = fields_ids_map.insert("not-me").unwrap(); + let _field_5_id = fields_ids_map.insert("me-nether").unwrap(); + + let mut tb = TokenizerBuilder::default(); + let document_tokenizer = DocumentTokenizer { + tokenizer: &tb.build(), + attribute_to_extract: None, + attribute_to_skip: &["not-me", "me-nether.nope"], + localized_attributes_rules: &[], + max_positions_per_attributes: 1000, + }; + + let fields_ids_map = FieldIdMapWithMetadata::new( + fields_ids_map, + MetadataBuilder::new(Default::default(), Default::default(), Default::default(), None), + ); + + let fields_ids_map_lock = std::sync::RwLock::new(fields_ids_map); + let mut global_fields_ids_map = GlobalFieldsIdsMap::new(&fields_ids_map_lock); + + let mut words = std::collections::BTreeMap::new(); + + let document = document.to_string(); + + let bump = Bump::new(); + let document: &RawValue = serde_json::from_str(&document).unwrap(); + let document = RawMap::from_raw_value(document, &bump).unwrap(); + + let document = Versions::single(document); + let document = DocumentFromVersions::new(&document); + + document_tokenizer + .tokenize_document( + document, + &mut global_fields_ids_map, + &mut |_fname, fid, pos, word| { + words.insert([fid, pos], word.to_string()); + Ok(()) + }, + ) + .unwrap(); + + snapshot!(format!("{:#?}", words), @r###" + { + [ + 2, + 0, + ]: "doggo", + [ + 2, + 8, + ]: "doggo", + [ + 2, + 16, + ]: "catto", + [ + 5, + 0, + ]: "10", + [ + 6, + 0, + ]: "pesti", + [ + 7, + 0, + ]: "23", + } + "###); + } +} diff --git a/crates/milli/src/update/new/extract/vectors/mod.rs b/crates/milli/src/update/new/extract/vectors/mod.rs new file mode 100644 index 000000000..8ac73a8d7 --- /dev/null +++ b/crates/milli/src/update/new/extract/vectors/mod.rs @@ -0,0 +1,489 @@ +use std::cell::RefCell; + +use bumpalo::collections::Vec as BVec; +use bumpalo::Bump; +use hashbrown::{DefaultHashBuilder, HashMap}; + +use super::cache::DelAddRoaringBitmap; +use crate::error::FaultSource; +use crate::prompt::Prompt; +use crate::update::new::channel::EmbeddingSender; +use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::vector_document::VectorDocument; +use crate::update::new::DocumentChange; +use crate::vector::error::{ + EmbedErrorKind, PossibleEmbeddingMistakes, UnusedVectorsDistributionBump, +}; +use crate::vector::{Embedder, Embedding, EmbeddingConfigs}; +use crate::{DocumentId, FieldDistribution, InternalError, Result, ThreadPoolNoAbort, UserError}; + +pub struct EmbeddingExtractor<'a> { + embedders: &'a EmbeddingConfigs, + sender: &'a EmbeddingSender<'a>, + possible_embedding_mistakes: PossibleEmbeddingMistakes, + threads: &'a ThreadPoolNoAbort, +} + +impl<'a> EmbeddingExtractor<'a> { + pub fn new( + embedders: &'a EmbeddingConfigs, + sender: &'a EmbeddingSender<'a>, + field_distribution: &'a FieldDistribution, + threads: &'a ThreadPoolNoAbort, + ) -> Self { + let possible_embedding_mistakes = PossibleEmbeddingMistakes::new(field_distribution); + Self { embedders, sender, threads, possible_embedding_mistakes } + } +} + +pub struct EmbeddingExtractorData<'extractor>( + pub HashMap, +); + +unsafe impl MostlySend for EmbeddingExtractorData<'_> {} + +impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { + type Data = RefCell>; + + fn init_data<'doc>(&'doc self, extractor_alloc: &'extractor Bump) -> crate::Result { + Ok(RefCell::new(EmbeddingExtractorData(HashMap::new_in(extractor_alloc)))) + } + + fn process<'doc>( + &'doc self, + changes: impl Iterator>>, + context: &'doc DocumentChangeContext, + ) -> crate::Result<()> { + let embedders = self.embedders.inner_as_ref(); + let mut unused_vectors_distribution = + UnusedVectorsDistributionBump::new_in(&context.doc_alloc); + + let mut all_chunks = BVec::with_capacity_in(embedders.len(), &context.doc_alloc); + for (embedder_name, (embedder, prompt, _is_quantized)) in embedders { + let embedder_id = + context.index.embedder_category_id.get(&context.rtxn, embedder_name)?.ok_or_else( + || InternalError::DatabaseMissingEntry { + db_name: "embedder_category_id", + key: None, + }, + )?; + all_chunks.push(Chunks::new( + embedder, + embedder_id, + embedder_name, + prompt, + context.data, + &self.possible_embedding_mistakes, + self.threads, + self.sender, + &context.doc_alloc, + )) + } + + for change in changes { + let change = change?; + match change { + DocumentChange::Deletion(deletion) => { + // vector deletion is handled by document sender, + // we still need to accomodate deletion from user_provided + for chunks in &mut all_chunks { + // regenerate: true means we delete from user_provided + chunks.set_regenerate(deletion.docid(), true); + } + } + DocumentChange::Update(update) => { + let old_vectors = update.current_vectors( + &context.rtxn, + context.index, + context.db_fields_ids_map, + &context.doc_alloc, + )?; + let new_vectors = update.updated_vectors(&context.doc_alloc, self.embedders)?; + + if let Some(new_vectors) = &new_vectors { + unused_vectors_distribution.append(new_vectors)?; + } + + for chunks in &mut all_chunks { + let embedder_name = chunks.embedder_name(); + let prompt = chunks.prompt(); + + let old_vectors = old_vectors.vectors_for_key(embedder_name)?.unwrap(); + if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { + new_vectors.vectors_for_key(embedder_name).transpose() + }) { + let new_vectors = new_vectors?; + if old_vectors.regenerate != new_vectors.regenerate { + chunks.set_regenerate(update.docid(), new_vectors.regenerate); + } + // do we have set embeddings? + if let Some(embeddings) = new_vectors.embeddings { + chunks.set_vectors( + update.docid(), + embeddings + .into_vec(&context.doc_alloc, embedder_name) + .map_err(|error| UserError::InvalidVectorsEmbedderConf { + document_id: update.external_document_id().to_string(), + error: error.to_string(), + })?, + ); + } else if new_vectors.regenerate { + let new_rendered = prompt.render_document( + update.current( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + let old_rendered = prompt.render_document( + update.merged( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + if new_rendered != old_rendered { + chunks.set_autogenerated( + update.docid(), + update.external_document_id(), + new_rendered, + &unused_vectors_distribution, + )?; + } + } + } else if old_vectors.regenerate { + let old_rendered = prompt.render_document( + update.current( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + let new_rendered = prompt.render_document( + update.merged( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + if new_rendered != old_rendered { + chunks.set_autogenerated( + update.docid(), + update.external_document_id(), + new_rendered, + &unused_vectors_distribution, + )?; + } + } + } + } + DocumentChange::Insertion(insertion) => { + let new_vectors = + insertion.inserted_vectors(&context.doc_alloc, self.embedders)?; + if let Some(new_vectors) = &new_vectors { + unused_vectors_distribution.append(new_vectors)?; + } + + for chunks in &mut all_chunks { + let embedder_name = chunks.embedder_name(); + let prompt = chunks.prompt(); + // if no inserted vectors, then regenerate: true + no embeddings => autogenerate + if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| { + new_vectors.vectors_for_key(embedder_name).transpose() + }) { + let new_vectors = new_vectors?; + chunks.set_regenerate(insertion.docid(), new_vectors.regenerate); + if let Some(embeddings) = new_vectors.embeddings { + chunks.set_vectors( + insertion.docid(), + embeddings + .into_vec(&context.doc_alloc, embedder_name) + .map_err(|error| UserError::InvalidVectorsEmbedderConf { + document_id: insertion + .external_document_id() + .to_string(), + error: error.to_string(), + })?, + ); + } else if new_vectors.regenerate { + let rendered = prompt.render_document( + insertion.inserted(), + context.new_fields_ids_map, + &context.doc_alloc, + )?; + chunks.set_autogenerated( + insertion.docid(), + insertion.external_document_id(), + rendered, + &unused_vectors_distribution, + )?; + } + } else { + let rendered = prompt.render_document( + insertion.inserted(), + context.new_fields_ids_map, + &context.doc_alloc, + )?; + chunks.set_autogenerated( + insertion.docid(), + insertion.external_document_id(), + rendered, + &unused_vectors_distribution, + )?; + } + } + } + } + } + + for chunk in all_chunks { + chunk.drain(&unused_vectors_distribution)?; + } + Ok(()) + } +} + +// **Warning**: the destructor of this struct is not normally run, make sure that all its fields: +// 1. don't have side effects tied to they destructors +// 2. if allocated, are allocated inside of the bumpalo +// +// Currently this is the case as: +// 1. BVec are inside of the bumaplo +// 2. All other fields are either trivial (u8) or references. +struct Chunks<'a, 'extractor> { + texts: BVec<'a, &'a str>, + ids: BVec<'a, DocumentId>, + + embedder: &'a Embedder, + embedder_id: u8, + embedder_name: &'a str, + prompt: &'a Prompt, + possible_embedding_mistakes: &'a PossibleEmbeddingMistakes, + user_provided: &'a RefCell>, + threads: &'a ThreadPoolNoAbort, + sender: &'a EmbeddingSender<'a>, + has_manual_generation: Option<&'a str>, +} + +impl<'a, 'extractor> Chunks<'a, 'extractor> { + #[allow(clippy::too_many_arguments)] + pub fn new( + embedder: &'a Embedder, + embedder_id: u8, + embedder_name: &'a str, + prompt: &'a Prompt, + user_provided: &'a RefCell>, + possible_embedding_mistakes: &'a PossibleEmbeddingMistakes, + threads: &'a ThreadPoolNoAbort, + sender: &'a EmbeddingSender<'a>, + doc_alloc: &'a Bump, + ) -> Self { + let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint(); + let texts = BVec::with_capacity_in(capacity, doc_alloc); + let ids = BVec::with_capacity_in(capacity, doc_alloc); + Self { + texts, + ids, + embedder, + prompt, + possible_embedding_mistakes, + threads, + sender, + embedder_id, + embedder_name, + user_provided, + has_manual_generation: None, + } + } + + pub fn set_autogenerated( + &mut self, + docid: DocumentId, + external_docid: &'a str, + rendered: &'a str, + unused_vectors_distribution: &UnusedVectorsDistributionBump, + ) -> Result<()> { + let is_manual = matches!(&self.embedder, &Embedder::UserProvided(_)); + if is_manual { + self.has_manual_generation.get_or_insert(external_docid); + } + + if self.texts.len() < self.texts.capacity() { + self.texts.push(rendered); + self.ids.push(docid); + return Ok(()); + } + + Self::embed_chunks( + &mut self.texts, + &mut self.ids, + self.embedder, + self.embedder_id, + self.embedder_name, + self.possible_embedding_mistakes, + unused_vectors_distribution, + self.threads, + self.sender, + self.has_manual_generation.take(), + ) + } + + pub fn drain( + mut self, + unused_vectors_distribution: &UnusedVectorsDistributionBump, + ) -> Result<()> { + let res = Self::embed_chunks( + &mut self.texts, + &mut self.ids, + self.embedder, + self.embedder_id, + self.embedder_name, + self.possible_embedding_mistakes, + unused_vectors_distribution, + self.threads, + self.sender, + self.has_manual_generation, + ); + // optimization: don't run bvec dtors as they only contain bumpalo allocated stuff + std::mem::forget(self); + res + } + + #[allow(clippy::too_many_arguments)] + pub fn embed_chunks( + texts: &mut BVec<'a, &'a str>, + ids: &mut BVec<'a, DocumentId>, + embedder: &Embedder, + embedder_id: u8, + embedder_name: &str, + possible_embedding_mistakes: &PossibleEmbeddingMistakes, + unused_vectors_distribution: &UnusedVectorsDistributionBump, + threads: &ThreadPoolNoAbort, + sender: &EmbeddingSender<'a>, + has_manual_generation: Option<&'a str>, + ) -> Result<()> { + if let Some(external_docid) = has_manual_generation { + let mut msg = format!( + r"While embedding documents for embedder `{embedder_name}`: no vectors provided for document `{}`{}", + external_docid, + if ids.len() > 1 { + format!(" and at least {} other document(s)", ids.len() - 1) + } else { + "".to_string() + } + ); + + msg += &format!("\n- Note: `{embedder_name}` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.{embedder_name}`."); + + let mut hint_count = 0; + + for (vector_misspelling, count) in possible_embedding_mistakes.vector_mistakes().take(2) + { + msg += &format!("\n- Hint: try replacing `{vector_misspelling}` by `_vectors` in {count} document(s)."); + hint_count += 1; + } + + for (embedder_misspelling, count) in possible_embedding_mistakes + .embedder_mistakes_bump(embedder_name, unused_vectors_distribution) + .take(2) + { + msg += &format!("\n- Hint: try replacing `_vectors.{embedder_misspelling}` by `_vectors.{embedder_name}` in {count} document(s)."); + hint_count += 1; + } + + if hint_count == 0 { + msg += &format!( + "\n- Hint: opt-out for a document with `_vectors.{embedder_name}: null`" + ); + } + + return Err(crate::Error::UserError(crate::UserError::DocumentEmbeddingError(msg))); + } + + let res = match embedder.embed_chunks_ref(texts.as_slice(), threads) { + Ok(embeddings) => { + for (docid, embedding) in ids.into_iter().zip(embeddings) { + sender.set_vector(*docid, embedder_id, embedding).unwrap(); + } + Ok(()) + } + Err(error) => { + if let FaultSource::Bug = error.fault { + Err(crate::Error::InternalError(crate::InternalError::VectorEmbeddingError( + error.into(), + ))) + } else { + let mut msg = format!( + r"While embedding documents for embedder `{embedder_name}`: {error}" + ); + + if let EmbedErrorKind::ManualEmbed(_) = &error.kind { + msg += &format!("\n- Note: `{embedder_name}` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.{embedder_name}`."); + } + + let mut hint_count = 0; + + for (vector_misspelling, count) in + possible_embedding_mistakes.vector_mistakes().take(2) + { + msg += &format!("\n- Hint: try replacing `{vector_misspelling}` by `_vectors` in {count} document(s)."); + hint_count += 1; + } + + for (embedder_misspelling, count) in possible_embedding_mistakes + .embedder_mistakes_bump(embedder_name, unused_vectors_distribution) + .take(2) + { + msg += &format!("\n- Hint: try replacing `_vectors.{embedder_misspelling}` by `_vectors.{embedder_name}` in {count} document(s)."); + hint_count += 1; + } + + if hint_count == 0 { + if let EmbedErrorKind::ManualEmbed(_) = &error.kind { + msg += &format!( + "\n- Hint: opt-out for a document with `_vectors.{embedder_name}: null`" + ); + } + } + + Err(crate::Error::UserError(crate::UserError::DocumentEmbeddingError(msg))) + } + } + }; + texts.clear(); + ids.clear(); + res + } + + pub fn prompt(&self) -> &'a Prompt { + self.prompt + } + + pub fn embedder_name(&self) -> &'a str { + self.embedder_name + } + + fn set_regenerate(&self, docid: DocumentId, regenerate: bool) { + let mut user_provided = self.user_provided.borrow_mut(); + let user_provided = user_provided.0.entry_ref(self.embedder_name).or_default(); + if regenerate { + // regenerate == !user_provided + user_provided.insert_del_u32(docid); + } else { + user_provided.insert_add_u32(docid); + } + } + + fn set_vectors(&self, docid: DocumentId, embeddings: Vec) { + self.sender.set_vectors(docid, self.embedder_id, embeddings).unwrap(); + } +} diff --git a/crates/milli/src/update/new/facet_search_builder.rs b/crates/milli/src/update/new/facet_search_builder.rs new file mode 100644 index 000000000..39e04a589 --- /dev/null +++ b/crates/milli/src/update/new/facet_search_builder.rs @@ -0,0 +1,252 @@ +use std::collections::hash_map::Entry; +use std::collections::{BTreeSet, HashMap}; + +use charabia::normalizer::NormalizerOption; +use charabia::{Language, Normalize, StrDetection, Token}; +use grenad::Sorter; +use heed::types::{Bytes, SerdeJson}; +use heed::{BytesDecode, BytesEncode, RoTxn, RwTxn}; + +use super::fst_merger_builder::FstMergerBuilder; +use super::KvReaderDelAdd; +use crate::heed_codec::facet::FacetGroupKey; +use crate::update::del_add::{DelAdd, KvWriterDelAdd}; +use crate::update::{create_sorter, MergeDeladdBtreesetString}; +use crate::{ + BEU16StrCodec, FieldId, GlobalFieldsIdsMap, Index, LocalizedAttributesRule, Result, + MAX_FACET_VALUE_LENGTH, +}; + +pub struct FacetSearchBuilder<'indexer> { + registered_facets: HashMap, + normalized_facet_string_docids_sorter: Sorter, + global_fields_ids_map: GlobalFieldsIdsMap<'indexer>, + localized_attributes_rules: Vec, + // Buffered data below + buffer: Vec, + localized_field_ids: HashMap>>, +} + +impl<'indexer> FacetSearchBuilder<'indexer> { + pub fn new( + global_fields_ids_map: GlobalFieldsIdsMap<'indexer>, + localized_attributes_rules: Vec, + ) -> Self { + let registered_facets = HashMap::new(); + let normalized_facet_string_docids_sorter = create_sorter( + grenad::SortAlgorithm::Stable, + MergeDeladdBtreesetString, + grenad::CompressionType::None, + None, + None, + Some(0), + true, + ); + + Self { + registered_facets, + normalized_facet_string_docids_sorter, + buffer: Vec::new(), + global_fields_ids_map, + localized_attributes_rules, + localized_field_ids: HashMap::new(), + } + } + + pub fn register_from_key( + &mut self, + deladd: DelAdd, + facet_key: FacetGroupKey<&str>, + ) -> Result<()> { + let FacetGroupKey { field_id, level: _level, left_bound } = facet_key; + + if deladd == DelAdd::Addition { + self.registered_facets.entry(field_id).and_modify(|count| *count += 1).or_insert(1); + } + + let locales = self.locales(field_id); + let hyper_normalized_value = normalize_facet_string(left_bound, locales); + + let set = BTreeSet::from_iter(std::iter::once(left_bound)); + + // as the facet string is the same, we can put the deletion and addition in the same obkv. + self.buffer.clear(); + let mut obkv = KvWriterDelAdd::new(&mut self.buffer); + let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?; + obkv.insert(deladd, val)?; + obkv.finish()?; + + let key: (u16, &str) = (field_id, hyper_normalized_value.as_ref()); + let key_bytes = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; + self.normalized_facet_string_docids_sorter.insert(key_bytes, &self.buffer)?; + + Ok(()) + } + + fn locales(&mut self, field_id: FieldId) -> Option<&[Language]> { + if let Entry::Vacant(e) = self.localized_field_ids.entry(field_id) { + let Some(field_name) = self.global_fields_ids_map.name(field_id) else { + unreachable!("Field id {field_id} not found in the global fields ids map"); + }; + + let locales = self + .localized_attributes_rules + .iter() + .find(|rule| rule.match_str(field_name)) + .map(|rule| rule.locales.clone()); + + e.insert(locales); + } + + self.localized_field_ids.get(&field_id).unwrap().as_deref() + } + + #[tracing::instrument(level = "trace", skip_all, target = "indexing::facet_fst")] + pub fn merge_and_write(self, index: &Index, wtxn: &mut RwTxn, rtxn: &RoTxn) -> Result<()> { + let reader = self.normalized_facet_string_docids_sorter.into_reader_cursors()?; + let mut builder = grenad::MergerBuilder::new(MergeDeladdBtreesetString); + builder.extend(reader); + + let database = index.facet_id_normalized_string_strings.remap_types::(); + + let mut merger_iter = builder.build().into_stream_merger_iter()?; + let mut current_field_id = None; + let mut fst; + let mut fst_merger_builder: Option = None; + while let Some((key, deladd)) = merger_iter.next()? { + let (field_id, normalized_facet_string) = + BEU16StrCodec::bytes_decode(key).map_err(heed::Error::Encoding)?; + + if current_field_id != Some(field_id) { + if let Some(fst_merger_builder) = fst_merger_builder { + let mmap = fst_merger_builder.build(&mut callback)?; + index + .facet_id_string_fst + .remap_data_type::() + .put(wtxn, &field_id, &mmap)?; + } + + fst = index.facet_id_string_fst.get(rtxn, &field_id)?; + fst_merger_builder = Some(FstMergerBuilder::new(fst.as_ref())?); + current_field_id = Some(field_id); + } + + let previous = database.get(rtxn, key)?; + let deladd: &KvReaderDelAdd = deladd.into(); + let del = deladd.get(DelAdd::Deletion); + let add = deladd.get(DelAdd::Addition); + + match merge_btreesets(previous, del, add)? { + Operation::Write(value) => { + match fst_merger_builder.as_mut() { + Some(fst_merger_builder) => { + fst_merger_builder.register( + DelAdd::Addition, + normalized_facet_string.as_bytes(), + &mut callback, + )?; + } + None => unreachable!(), + } + let key = (field_id, normalized_facet_string); + let key_bytes = + BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; + database.put(wtxn, &key_bytes, &value)?; + } + Operation::Delete => { + match fst_merger_builder.as_mut() { + Some(fst_merger_builder) => { + fst_merger_builder.register( + DelAdd::Deletion, + normalized_facet_string.as_bytes(), + &mut callback, + )?; + } + None => unreachable!(), + } + let key = (field_id, normalized_facet_string); + let key_bytes = + BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; + database.delete(wtxn, &key_bytes)?; + } + Operation::Ignore => (), + } + } + + if let (Some(field_id), Some(fst_merger_builder)) = (current_field_id, fst_merger_builder) { + let mmap = fst_merger_builder.build(&mut callback)?; + index.facet_id_string_fst.remap_data_type::().put(wtxn, &field_id, &mmap)?; + } + + Ok(()) + } +} + +fn callback(_bytes: &[u8], _deladd: DelAdd, _is_modified: bool) -> Result<()> { + Ok(()) +} + +fn merge_btreesets( + current: Option<&[u8]>, + del: Option<&[u8]>, + add: Option<&[u8]>, +) -> Result { + let mut result: BTreeSet = match current { + Some(current) => SerdeJson::bytes_decode(current).map_err(heed::Error::Encoding)?, + None => BTreeSet::new(), + }; + if let Some(del) = del { + let del: BTreeSet = SerdeJson::bytes_decode(del).map_err(heed::Error::Encoding)?; + result = result.difference(&del).cloned().collect(); + } + if let Some(add) = add { + let add: BTreeSet = SerdeJson::bytes_decode(add).map_err(heed::Error::Encoding)?; + result.extend(add); + } + + /// TODO remove allocation + let result = SerdeJson::bytes_encode(&result).map_err(heed::Error::Encoding)?.into_owned(); + if Some(result.as_ref()) == current { + Ok(Operation::Ignore) + } else if result.is_empty() { + Ok(Operation::Delete) + } else { + Ok(Operation::Write(result)) + } +} + +/// Normalizes the facet string and truncates it to the max length. +fn normalize_facet_string(facet_string: &str, locales: Option<&[Language]>) -> String { + let options: NormalizerOption = NormalizerOption { lossy: true, ..Default::default() }; + let mut detection = StrDetection::new(facet_string, locales); + + let script = detection.script(); + // Detect the language of the facet string only if several locales are explicitly provided. + let language = match locales { + Some(&[language]) => Some(language), + Some(multiple_locales) if multiple_locales.len() > 1 => detection.language(), + _ => None, + }; + + let token = Token { + lemma: std::borrow::Cow::Borrowed(facet_string), + script, + language, + ..Default::default() + }; + + // truncate the facet string to the max length + token + .normalize(&options) + .lemma + .char_indices() + .take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH) + .map(|(_, c)| c) + .collect() +} + +enum Operation { + Write(Vec), + Delete, + Ignore, +} diff --git a/crates/milli/src/update/new/fst_merger_builder.rs b/crates/milli/src/update/new/fst_merger_builder.rs new file mode 100644 index 000000000..1c584ef53 --- /dev/null +++ b/crates/milli/src/update/new/fst_merger_builder.rs @@ -0,0 +1,157 @@ +use std::fs::File; +use std::io::BufWriter; + +use fst::{Set, SetBuilder, Streamer}; +use memmap2::Mmap; +use tempfile::tempfile; + +use crate::update::del_add::DelAdd; +use crate::{InternalError, Result}; + +pub struct FstMergerBuilder<'a> { + stream: Option>, + fst_builder: SetBuilder>, + last: Option>, + inserted_words: usize, +} + +impl<'a> FstMergerBuilder<'a> { + pub fn new>(fst: Option<&'a Set>) -> Result { + Ok(Self { + stream: fst.map(|fst| fst.stream()), + fst_builder: SetBuilder::new(BufWriter::new(tempfile()?))?, + last: None, + inserted_words: 0, + }) + } + + pub fn register( + &mut self, + deladd: DelAdd, + right: &[u8], + insertion_callback: &mut impl FnMut(&[u8], DelAdd, bool) -> Result<()>, + ) -> Result<()> { + if let Some(left) = self.last.take() { + let (left_inserted, right_inserted) = + self.compare_and_insert(deladd, left.as_slice(), right, insertion_callback)?; + + // left was not inserted, so we keep it for the next iteration + if !left_inserted { + self.last = Some(left); + } + + // right was inserted, so we can stop + if right_inserted { + return Ok(()); + } + } + + if let Some(mut stream) = self.stream.take() { + while let Some(left) = stream.next() { + let (left_inserted, right_inserted) = + self.compare_and_insert(deladd, left, right, insertion_callback)?; + + // left was not inserted, so we keep it for the next iteration + if !left_inserted { + self.last = Some(left.to_vec()); + } + + // right was inserted, so we can stop + if right_inserted { + self.stream = Some(stream); + return Ok(()); + } + } + } + + // If we reach this point, it means that the stream is empty + // and we need to insert the incoming word + self.insert(right, deladd, true, insertion_callback)?; + + Ok(()) + } + + fn compare_and_insert( + &mut self, + deladd: DelAdd, + left: &[u8], + right: &[u8], + insertion_callback: &mut impl FnMut(&[u8], DelAdd, bool) -> Result<()>, + ) -> Result<(bool, bool)> { + let mut left_inserted = false; + let mut right_inserted = false; + match left.cmp(right) { + std::cmp::Ordering::Less => { + // We need to insert the last word from the current fst + self.insert(left, DelAdd::Addition, false, insertion_callback)?; + + left_inserted = true; + } + std::cmp::Ordering::Equal => { + self.insert(right, deladd, true, insertion_callback)?; + + left_inserted = true; + right_inserted = true; + } + std::cmp::Ordering::Greater => { + self.insert(right, deladd, true, insertion_callback)?; + + right_inserted = true; + } + } + + Ok((left_inserted, right_inserted)) + } + + fn insert( + &mut self, + bytes: &[u8], + deladd: DelAdd, + is_modified: bool, + insertion_callback: &mut impl FnMut(&[u8], DelAdd, bool) -> Result<()>, + ) -> Result<()> { + // Addition: We insert the word + // Deletion: We delete the word by not inserting it + if deladd == DelAdd::Addition { + self.inserted_words += 1; + self.fst_builder.insert(bytes)?; + } + + insertion_callback(bytes, deladd, is_modified)?; + + Ok(()) + } + + fn drain_stream( + &mut self, + insertion_callback: &mut impl FnMut(&[u8], DelAdd, bool) -> Result<()>, + ) -> Result<()> { + if let Some(last) = self.last.take() { + self.insert(last.as_slice(), DelAdd::Addition, false, insertion_callback)?; + } + + if let Some(mut stream) = self.stream.take() { + while let Some(current) = stream.next() { + self.insert(current, DelAdd::Addition, false, insertion_callback)?; + } + } + + Ok(()) + } + + pub fn build( + mut self, + insertion_callback: &mut impl FnMut(&[u8], DelAdd, bool) -> Result<()>, + ) -> Result { + self.drain_stream(insertion_callback)?; + + let fst_file = self + .fst_builder + .into_inner()? + .into_inner() + .map_err(|_| InternalError::IndexingMergingKeys { process: "building-fst" })?; + let fst_mmap = unsafe { Mmap::map(&fst_file)? }; + + Ok(fst_mmap) + } +} diff --git a/crates/milli/src/update/new/indexer/de.rs b/crates/milli/src/update/new/indexer/de.rs new file mode 100644 index 000000000..c9808360e --- /dev/null +++ b/crates/milli/src/update/new/indexer/de.rs @@ -0,0 +1,640 @@ +use std::ops::ControlFlow; + +use bumpalo::Bump; +use serde::de::{DeserializeSeed, Deserializer as _, Visitor}; +use serde_json::value::RawValue; + +use crate::documents::{ + validate_document_id_str, DocumentIdExtractionError, FieldIdMapper, PrimaryKey, +}; +use crate::fields_ids_map::MutFieldIdMapper; +use crate::{FieldId, UserError}; + +// visits a document to fill the top level fields of the field id map and retrieve the external document id. +pub struct FieldAndDocidExtractor<'p, 'indexer, Mapper: MutFieldIdMapper> { + fields_ids_map: &'p mut Mapper, + primary_key: &'p PrimaryKey<'p>, + indexer: &'indexer Bump, +} + +impl<'p, 'indexer, Mapper: MutFieldIdMapper> FieldAndDocidExtractor<'p, 'indexer, Mapper> { + pub fn new( + fields_ids_map: &'p mut Mapper, + primary_key: &'p PrimaryKey<'p>, + indexer: &'indexer Bump, + ) -> Self { + Self { fields_ids_map, primary_key, indexer } + } +} + +impl<'de, 'p, 'indexer: 'de, Mapper: MutFieldIdMapper> Visitor<'de> + for FieldAndDocidExtractor<'p, 'indexer, Mapper> +{ + type Value = + Result, DocumentIdExtractionError>, crate::UserError>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a map") + } + + fn visit_map(mut self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + // We need to remember if we encountered a semantic error, because raw values don't like to be parsed partially + // (trying to do so results in parsing errors). + // So we'll exhaust all keys and values even if we encounter an error, and we'll then return any error we detected. + let mut attribute_limit_reached = false; + let mut document_id_extraction_error = None; + let mut docid = None; + + while let Some(((level_name, right), (fid, fields_ids_map))) = + map.next_key_seed(ComponentsSeed { + name: self.primary_key.name(), + visitor: MutFieldIdMapVisitor(self.fields_ids_map), + })? + { + self.fields_ids_map = fields_ids_map; + + let value: &'de RawValue = map.next_value()?; + if attribute_limit_reached || document_id_extraction_error.is_some() { + continue; + } + + let Some(_fid) = fid else { + attribute_limit_reached = true; + continue; + }; + + match match_component(level_name, right, value, self.indexer, &mut docid) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(Err(err)) => return Err(serde::de::Error::custom(err)), + ControlFlow::Break(Ok(err)) => { + document_id_extraction_error = Some(err); + continue; + } + } + } + + // return previously detected errors + if attribute_limit_reached { + return Ok(Err(UserError::AttributeLimitReached)); + } + if let Some(document_id_extraction_error) = document_id_extraction_error { + return Ok(Ok(Err(document_id_extraction_error))); + } + + Ok(Ok(match docid { + Some(docid) => Ok(docid), + None => Err(DocumentIdExtractionError::MissingDocumentId), + })) + } +} + +struct NestedPrimaryKeyVisitor<'a, 'bump> { + components: &'a str, + bump: &'bump Bump, +} + +impl<'de, 'a, 'bump: 'de> Visitor<'de> for NestedPrimaryKeyVisitor<'a, 'bump> { + type Value = std::result::Result>, DocumentIdExtractionError>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a map") + } + + fn visit_map(self, mut map: A) -> std::result::Result + where + A: serde::de::MapAccess<'de>, + { + let mut docid = None; + while let Some(((matched_component, right), _)) = map.next_key_seed(ComponentsSeed { + name: self.components, + visitor: serde::de::IgnoredAny, + })? { + let value: &'de RawValue = map.next_value()?; + + match match_component(matched_component, right, value, self.bump, &mut docid) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(Err(err)) => return Err(serde::de::Error::custom(err)), + ControlFlow::Break(Ok(err)) => return Ok(Err(err)), + } + } + Ok(Ok(docid)) + } +} + +/// Either a `&'de str` or a `&'bump str`. +pub enum DeOrBumpStr<'de, 'bump: 'de> { + /// Lifetime of the deserializer + De(&'de str), + /// Lifetime of the allocator + Bump(&'bump str), +} + +impl<'de, 'bump: 'de> DeOrBumpStr<'de, 'bump> { + /// Returns a `&'bump str`, possibly allocating to extend its lifetime. + pub fn to_bump(&self, bump: &'bump Bump) -> &'bump str { + match self { + DeOrBumpStr::De(de) => bump.alloc_str(de), + DeOrBumpStr::Bump(bump) => bump, + } + } + + /// Returns a `&'de str`. + /// + /// This function never allocates because `'bump: 'de`. + pub fn to_de(&self) -> &'de str { + match self { + DeOrBumpStr::De(de) => de, + DeOrBumpStr::Bump(bump) => bump, + } + } +} + +struct ComponentsSeed<'a, V> { + name: &'a str, + visitor: V, +} + +impl<'de, 'a, V: Visitor<'de>> DeserializeSeed<'de> for ComponentsSeed<'a, V> { + type Value = ((&'a str, &'a str), V::Value); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct ComponentsSeedVisitor<'a, V> { + name: &'a str, + visitor: V, + } + + impl<'a, V> ComponentsSeedVisitor<'a, V> { + fn match_str(&self, v: &str) -> (&'a str, &'a str) { + let p = PrimaryKey::Nested { name: self.name }; + for (name, right) in p.possible_level_names() { + if name == v { + return (name, right); + } + } + ("", self.name) + } + } + + impl<'de, 'a, V: Visitor<'de>> Visitor<'de> for ComponentsSeedVisitor<'a, V> { + type Value = ((&'a str, &'a str), V::Value); + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting a string") + } + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + let matched = self.match_str(v); + let inner = self.visitor.visit_borrowed_str(v)?; + Ok((matched, inner)) + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + let matched = self.match_str(v); + let inner = self.visitor.visit_str(v)?; + + Ok((matched, inner)) + } + } + deserializer + .deserialize_str(ComponentsSeedVisitor { name: self.name, visitor: self.visitor }) + } +} + +struct MutFieldIdMapVisitor<'a, Mapper: MutFieldIdMapper>(&'a mut Mapper); + +impl<'de, 'a, Mapper: MutFieldIdMapper> Visitor<'de> for MutFieldIdMapVisitor<'a, Mapper> { + type Value = (Option, &'a mut Mapper); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting a string") + } + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + Ok((self.0.insert(v), self.0)) + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + Ok((self.0.insert(v), self.0)) + } +} + +pub struct FieldIdMapVisitor<'a, Mapper: FieldIdMapper>(pub &'a Mapper); + +impl<'de, 'a, Mapper: FieldIdMapper> Visitor<'de> for FieldIdMapVisitor<'a, Mapper> { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting a string") + } + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + Ok(self.0.id(v)) + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + Ok(self.0.id(v)) + } +} +pub struct DocumentIdVisitor<'indexer>(pub &'indexer Bump); + +impl<'de, 'indexer: 'de> Visitor<'de> for DocumentIdVisitor<'indexer> { + type Value = std::result::Result, DocumentIdExtractionError>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an integer or a string") + } + + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + Ok(validate_document_id_str(v) + .ok_or_else(|| { + DocumentIdExtractionError::InvalidDocumentId(UserError::InvalidDocumentId { + document_id: serde_json::Value::String(v.to_owned()), + }) + }) + .map(DeOrBumpStr::De)) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + let v = self.0.alloc_str(v); + Ok(match self.visit_borrowed_str(v)? { + Ok(_) => Ok(DeOrBumpStr::Bump(v)), + Err(err) => Err(err), + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + use std::fmt::Write as _; + + let mut out = bumpalo::collections::String::new_in(self.0); + write!(&mut out, "{v}").unwrap(); + Ok(Ok(DeOrBumpStr::Bump(out.into_bump_str()))) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + use std::fmt::Write as _; + + let mut out = bumpalo::collections::String::new_in(self.0); + write!(&mut out, "{v}").unwrap(); + Ok(Ok(DeOrBumpStr::Bump(out.into_bump_str()))) + } +} + +pub fn match_component<'de, 'indexer: 'de>( + first_level_name: &str, + right: &str, + value: &'de RawValue, + bump: &'indexer Bump, + docid: &mut Option>, +) -> ControlFlow, ()> { + if first_level_name.is_empty() { + return ControlFlow::Continue(()); + } + + let value = if right.is_empty() { + match value.deserialize_any(DocumentIdVisitor(bump)).map_err(|_err| { + DocumentIdExtractionError::InvalidDocumentId(UserError::InvalidDocumentId { + document_id: serde_json::to_value(value).unwrap(), + }) + }) { + Ok(Ok(value)) => value, + Ok(Err(err)) | Err(err) => return ControlFlow::Break(Ok(err)), + } + } else { + // if right is not empty, recursively extract right components from value + let res = value.deserialize_map(NestedPrimaryKeyVisitor { components: right, bump }); + match res { + Ok(Ok(Some(value))) => value, + Ok(Ok(None)) => return ControlFlow::Continue(()), + Ok(Err(err)) => return ControlFlow::Break(Ok(err)), + Err(err) if err.is_data() => return ControlFlow::Continue(()), // we expected the field to be a map, but it was not and that's OK. + Err(err) => return ControlFlow::Break(Err(err)), + } + }; + if let Some(_previous_value) = docid.replace(value) { + return ControlFlow::Break(Ok(DocumentIdExtractionError::TooManyDocumentIds(2))); + } + ControlFlow::Continue(()) +} + +pub struct DeserrRawValue<'a> { + value: &'a RawValue, + alloc: &'a Bump, +} + +impl<'a> DeserrRawValue<'a> { + pub fn new_in(value: &'a RawValue, alloc: &'a Bump) -> Self { + Self { value, alloc } + } +} + +pub struct DeserrRawVec<'a> { + vec: raw_collections::RawVec<'a>, + alloc: &'a Bump, +} + +impl<'a> deserr::Sequence for DeserrRawVec<'a> { + type Value = DeserrRawValue<'a>; + + type Iter = DeserrRawVecIter<'a>; + + fn len(&self) -> usize { + self.vec.len() + } + + fn into_iter(self) -> Self::Iter { + DeserrRawVecIter { it: self.vec.into_iter(), alloc: self.alloc } + } +} + +pub struct DeserrRawVecIter<'a> { + it: raw_collections::vec::iter::IntoIter<'a>, + alloc: &'a Bump, +} + +impl<'a> Iterator for DeserrRawVecIter<'a> { + type Item = DeserrRawValue<'a>; + + fn next(&mut self) -> Option { + let next = self.it.next()?; + Some(DeserrRawValue { value: next, alloc: self.alloc }) + } +} + +pub struct DeserrRawMap<'a> { + map: raw_collections::RawMap<'a>, + alloc: &'a Bump, +} + +impl<'a> deserr::Map for DeserrRawMap<'a> { + type Value = DeserrRawValue<'a>; + + type Iter = DeserrRawMapIter<'a>; + + fn len(&self) -> usize { + self.map.len() + } + + fn remove(&mut self, _key: &str) -> Option { + unimplemented!() + } + + fn into_iter(self) -> Self::Iter { + DeserrRawMapIter { it: self.map.into_iter(), alloc: self.alloc } + } +} + +pub struct DeserrRawMapIter<'a> { + it: raw_collections::map::iter::IntoIter<'a>, + alloc: &'a Bump, +} + +impl<'a> Iterator for DeserrRawMapIter<'a> { + type Item = (String, DeserrRawValue<'a>); + + fn next(&mut self) -> Option { + let (name, value) = self.it.next()?; + Some((name.to_string(), DeserrRawValue { value, alloc: self.alloc })) + } +} + +impl<'a> deserr::IntoValue for DeserrRawValue<'a> { + type Sequence = DeserrRawVec<'a>; + + type Map = DeserrRawMap<'a>; + + fn kind(&self) -> deserr::ValueKind { + self.value.deserialize_any(DeserrKindVisitor).unwrap() + } + + fn into_value(self) -> deserr::Value { + self.value.deserialize_any(DeserrRawValueVisitor { alloc: self.alloc }).unwrap() + } +} + +pub struct DeserrKindVisitor; + +impl<'de> Visitor<'de> for DeserrKindVisitor { + type Value = deserr::ValueKind; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "any value") + } + + fn visit_bool(self, _v: bool) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Boolean) + } + + fn visit_i64(self, _v: i64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::NegativeInteger) + } + + fn visit_u64(self, _v: u64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Integer) + } + + fn visit_f64(self, _v: f64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Float) + } + + fn visit_str(self, _v: &str) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::String) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Null) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::ValueKind::Null) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_seq(self, _seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + Ok(deserr::ValueKind::Sequence) + } + + fn visit_map(self, _map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + Ok(deserr::ValueKind::Map) + } +} + +pub struct DeserrRawValueVisitor<'a> { + alloc: &'a Bump, +} + +impl<'de> Visitor<'de> for DeserrRawValueVisitor<'de> { + type Value = deserr::Value>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "any value") + } + + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Boolean(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::NegativeInteger(v)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Integer(v)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Float(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::String(v.to_string())) + } + + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::String(v)) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Null) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(deserr::Value::Null) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut raw_vec = raw_collections::RawVec::new_in(self.alloc); + while let Some(next) = seq.next_element()? { + raw_vec.push(next); + } + Ok(deserr::Value::Sequence(DeserrRawVec { vec: raw_vec, alloc: self.alloc })) + } + + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let _ = map; + Err(serde::de::Error::invalid_type(serde::de::Unexpected::Map, &self)) + } + + fn visit_enum(self, data: A) -> Result + where + A: serde::de::EnumAccess<'de>, + { + let _ = data; + Err(serde::de::Error::invalid_type(serde::de::Unexpected::Enum, &self)) + } +} diff --git a/crates/milli/src/update/new/indexer/document_changes.rs b/crates/milli/src/update/new/indexer/document_changes.rs new file mode 100644 index 000000000..4efebc586 --- /dev/null +++ b/crates/milli/src/update/new/indexer/document_changes.rs @@ -0,0 +1,306 @@ +use std::cell::{Cell, RefCell}; +use std::sync::{Arc, RwLock}; + +use bumpalo::Bump; +use heed::RoTxn; +use rayon::iter::IndexedParallelIterator; + +use super::super::document_change::DocumentChange; +use crate::fields_ids_map::metadata::FieldIdMapWithMetadata; +use crate::update::new::parallel_iterator_ext::ParallelIteratorExt as _; +use crate::update::new::steps::Step; +use crate::update::new::thread_local::{FullySend, MostlySend, ThreadLocal}; +use crate::{FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result}; + +pub struct DocumentChangeContext< + 'doc, // covariant lifetime of a single `process` call + 'extractor: 'doc, // invariant lifetime of the extractor_allocs + 'fid: 'doc, // invariant lifetime of the new_fields_ids_map + 'indexer: 'doc, // covariant lifetime of objects that outlive a single `process` call + T: MostlySend, +> { + /// The index we're indexing in + pub index: &'indexer Index, + /// The fields ids map as it was at the start of this indexing process. Contains at least all top-level fields from documents + /// inside of the DB. + pub db_fields_ids_map: &'indexer FieldsIdsMap, + /// A transaction providing data from the DB before all indexing operations + pub rtxn: RoTxn<'indexer>, + + /// Global field id map that is up to date with the current state of the indexing process. + /// + /// - Inserting a field will take a lock + /// - Retrieving a field may take a lock as well + pub new_fields_ids_map: &'doc std::cell::RefCell>, + + /// Data allocated in this allocator is cleared between each call to `process`. + pub doc_alloc: Bump, + + /// Data allocated in this allocator is not cleared between each call to `process`, unless the data spills. + pub extractor_alloc: &'extractor Bump, + + /// Pool of doc allocators, used to retrieve the doc allocator we provided for the documents + doc_allocs: &'doc ThreadLocal>>, + + /// Extractor-specific data + pub data: &'doc T, +} + +impl< + 'doc, // covariant lifetime of a single `process` call + 'data: 'doc, // invariant on T lifetime of the datastore + 'extractor: 'doc, // invariant lifetime of extractor_allocs + 'fid: 'doc, // invariant lifetime of fields ids map + 'indexer: 'doc, // covariant lifetime of objects that survive a `process` call + T: MostlySend, + > DocumentChangeContext<'doc, 'extractor, 'fid, 'indexer, T> +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + index: &'indexer Index, + db_fields_ids_map: &'indexer FieldsIdsMap, + new_fields_ids_map: &'fid RwLock, + extractor_allocs: &'extractor ThreadLocal>, + doc_allocs: &'doc ThreadLocal>>, + datastore: &'data ThreadLocal, + fields_ids_map_store: &'doc ThreadLocal>>>, + init_data: F, + ) -> Result + where + F: FnOnce(&'extractor Bump) -> Result, + { + let doc_alloc = + doc_allocs.get_or(|| FullySend(Cell::new(Bump::with_capacity(1024 * 1024 * 1024)))); + let doc_alloc = doc_alloc.0.take(); + let fields_ids_map = fields_ids_map_store + .get_or(|| RefCell::new(GlobalFieldsIdsMap::new(new_fields_ids_map)).into()); + + let fields_ids_map = &fields_ids_map.0; + let extractor_alloc = extractor_allocs.get_or_default(); + + let data = datastore.get_or_try(move || init_data(&extractor_alloc.0))?; + + let txn = index.read_txn()?; + Ok(DocumentChangeContext { + index, + rtxn: txn, + db_fields_ids_map, + new_fields_ids_map: fields_ids_map, + doc_alloc, + extractor_alloc: &extractor_alloc.0, + data, + doc_allocs, + }) + } +} + +/// An internal iterator (i.e. using `foreach`) of `DocumentChange`s +pub trait Extractor<'extractor>: Sync { + type Data: MostlySend; + + fn init_data<'doc>(&'doc self, extractor_alloc: &'extractor Bump) -> Result; + + fn process<'doc>( + &'doc self, + changes: impl Iterator>>, + context: &'doc DocumentChangeContext, + ) -> Result<()>; +} + +pub trait DocumentChanges<'pl // lifetime of the underlying payload +>: Sync { + type Item: Send; + + fn iter(&self, chunk_size: usize) -> impl IndexedParallelIterator>; + + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn item_to_document_change<'doc, // lifetime of a single `process` call + T: MostlySend>( + &'doc self, + context: &'doc DocumentChangeContext, + item: &'doc Self::Item, + ) -> Result>> where 'pl: 'doc // the payload must survive the process calls + ; +} + +pub struct IndexingContext< + 'fid, // invariant lifetime of fields ids map + 'indexer, // covariant lifetime of objects that are borrowed during the entire indexing operation + 'index, // covariant lifetime of the index + MSP, + SP, +> where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, +{ + pub index: &'index Index, + pub db_fields_ids_map: &'indexer FieldsIdsMap, + pub new_fields_ids_map: &'fid RwLock, + pub doc_allocs: &'indexer ThreadLocal>>, + pub fields_ids_map_store: &'indexer ThreadLocal>>>, + pub must_stop_processing: &'indexer MSP, + pub send_progress: &'indexer SP, +} + +impl< + 'fid, // invariant lifetime of fields ids map + 'indexer, // covariant lifetime of objects that are borrowed during the entire indexing operation + 'index, // covariant lifetime of the index + MSP, + SP, + > Copy + for IndexingContext< + 'fid, // invariant lifetime of fields ids map + 'indexer, // covariant lifetime of objects that are borrowed during the entire indexing operation + 'index, // covariant lifetime of the index + MSP, + SP, + > +where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, +{ +} + +impl< + 'fid, // invariant lifetime of fields ids map + 'indexer, // covariant lifetime of objects that are borrowed during the entire indexing operation + 'index, // covariant lifetime of the index + MSP, + SP, + > Clone + for IndexingContext< + 'fid, // invariant lifetime of fields ids map + 'indexer, // covariant lifetime of objects that are borrowed during the entire indexing operation + 'index, // covariant lifetime of the index + MSP, + SP, + > +where + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, +{ + fn clone(&self) -> Self { + *self + } +} + +const CHUNK_SIZE: usize = 100; + +pub fn extract< + 'pl, // covariant lifetime of the underlying payload + 'extractor, // invariant lifetime of extractor_alloc + 'fid, // invariant lifetime of fields ids map + 'indexer, // covariant lifetime of objects that are borrowed during the entire indexing + 'data, // invariant on EX::Data lifetime of datastore + 'index, // covariant lifetime of the index + EX, + DC: DocumentChanges<'pl>, + MSP, + SP, +>( + document_changes: &DC, + extractor: &EX, + IndexingContext { + index, + db_fields_ids_map, + new_fields_ids_map, + doc_allocs, + fields_ids_map_store, + must_stop_processing, + send_progress, + }: IndexingContext<'fid, 'indexer, 'index, MSP, SP>, + extractor_allocs: &'extractor mut ThreadLocal>, + datastore: &'data ThreadLocal, + step: Step, +) -> Result<()> +where + EX: Extractor<'extractor>, + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, +{ + eprintln!("We are resetting the extractor allocators"); + // Clean up and reuse the extractor allocs + for extractor_alloc in extractor_allocs.iter_mut() { + eprintln!("\tWith {} bytes resetted", extractor_alloc.0.allocated_bytes()); + extractor_alloc.0.reset(); + } + + let total_documents = document_changes.len() as u32; + + let pi = document_changes.iter(CHUNK_SIZE); + pi.enumerate().try_arc_for_each_try_init( + || { + DocumentChangeContext::new( + index, + db_fields_ids_map, + new_fields_ids_map, + extractor_allocs, + doc_allocs, + datastore, + fields_ids_map_store, + move |index_alloc| extractor.init_data(index_alloc), + ) + }, + |context, (finished_documents, items)| { + if (must_stop_processing)() { + return Err(Arc::new(InternalError::AbortedIndexation.into())); + } + let finished_documents = (finished_documents * CHUNK_SIZE) as u32; + + (send_progress)(Progress::from_step_documents( + step, + finished_documents, + total_documents, + )); + + // Clean up and reuse the document-specific allocator + context.doc_alloc.reset(); + + let items = items.as_ref(); + let changes = items.iter().filter_map(|item| { + document_changes.item_to_document_change(context, item).transpose() + }); + + let res = extractor.process(changes, context).map_err(Arc::new); + + // send back the doc_alloc in the pool + context.doc_allocs.get_or_default().0.set(std::mem::take(&mut context.doc_alloc)); + + res + }, + )?; + + (send_progress)(Progress::from_step_documents(step, total_documents, total_documents)); + + Ok(()) +} + +pub struct Progress { + pub finished_steps: u16, + pub total_steps: u16, + pub step_name: &'static str, + pub finished_total_documents: Option<(u32, u32)>, +} + +impl Progress { + pub fn from_step(step: Step) -> Self { + Self { + finished_steps: step.finished_steps(), + total_steps: Step::total_steps(), + step_name: step.name(), + finished_total_documents: None, + } + } + pub fn from_step_documents(step: Step, finished_documents: u32, total_documents: u32) -> Self { + Self { + finished_total_documents: Some((finished_documents, total_documents)), + ..Progress::from_step(step) + } + } +} diff --git a/crates/milli/src/update/new/indexer/document_deletion.rs b/crates/milli/src/update/new/indexer/document_deletion.rs new file mode 100644 index 000000000..fe3f08583 --- /dev/null +++ b/crates/milli/src/update/new/indexer/document_deletion.rs @@ -0,0 +1,191 @@ +use bumpalo::collections::CollectIn; +use bumpalo::Bump; +use rayon::iter::IndexedParallelIterator; +use rayon::slice::ParallelSlice as _; +use roaring::RoaringBitmap; + +use super::document_changes::{DocumentChangeContext, DocumentChanges}; +use crate::documents::PrimaryKey; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::{Deletion, DocumentChange}; +use crate::{DocumentId, Result}; + +#[derive(Default)] +pub struct DocumentDeletion { + pub to_delete: RoaringBitmap, +} + +impl DocumentDeletion { + pub fn new() -> Self { + Self { to_delete: Default::default() } + } + + pub fn delete_documents_by_docids(&mut self, docids: RoaringBitmap) { + self.to_delete |= docids; + } + + pub fn into_changes<'indexer>( + self, + indexer: &'indexer Bump, + primary_key: PrimaryKey<'indexer>, + ) -> DocumentDeletionChanges<'indexer> { + let to_delete: bumpalo::collections::Vec<_> = + self.to_delete.into_iter().collect_in(indexer); + + let to_delete = to_delete.into_bump_slice(); + + DocumentDeletionChanges { to_delete, primary_key } + } +} + +pub struct DocumentDeletionChanges<'indexer> { + to_delete: &'indexer [DocumentId], + primary_key: PrimaryKey<'indexer>, +} + +impl<'pl> DocumentChanges<'pl> for DocumentDeletionChanges<'pl> { + type Item = DocumentId; + + fn iter( + &self, + chunk_size: usize, + ) -> impl IndexedParallelIterator> { + self.to_delete.par_chunks(chunk_size) + } + + fn item_to_document_change< + 'doc, // lifetime of a single `process` call + T: MostlySend, + >( + &'doc self, + context: &'doc DocumentChangeContext, + docid: &'doc Self::Item, + ) -> Result>> + where + 'pl: 'doc, // the payload must survive the process calls + { + let current = context.index.document(&context.rtxn, *docid)?; + + let external_document_id = self.primary_key.extract_docid_from_db( + current, + &context.db_fields_ids_map, + &context.doc_alloc, + )?; + + let external_document_id = external_document_id.to_bump(&context.doc_alloc); + + Ok(Some(DocumentChange::Deletion(Deletion::create(*docid, external_document_id)))) + } + + fn len(&self) -> usize { + self.to_delete.len() + } +} + +#[cfg(test)] +mod test { + use std::cell::RefCell; + use std::marker::PhantomData; + use std::sync::RwLock; + + use bumpalo::Bump; + + use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; + use crate::index::tests::TempIndex; + use crate::update::new::indexer::document_changes::{ + extract, DocumentChangeContext, Extractor, IndexingContext, + }; + use crate::update::new::indexer::DocumentDeletion; + use crate::update::new::steps::Step; + use crate::update::new::thread_local::{MostlySend, ThreadLocal}; + use crate::update::new::DocumentChange; + use crate::DocumentId; + + #[test] + fn test_deletions() { + struct DeletionWithData<'extractor> { + deleted: RefCell< + hashbrown::HashSet, + >, + } + + unsafe impl<'extractor> MostlySend for DeletionWithData<'extractor> {} + + struct TrackDeletion<'extractor>(PhantomData<&'extractor ()>); + + impl<'extractor> Extractor<'extractor> for TrackDeletion<'extractor> { + type Data = DeletionWithData<'extractor>; + + fn init_data(&self, extractor_alloc: &'extractor Bump) -> crate::Result { + let deleted = RefCell::new(hashbrown::HashSet::new_in(extractor_alloc)); + Ok(DeletionWithData { deleted }) + } + + fn process<'doc>( + &self, + changes: impl Iterator>>, + context: &DocumentChangeContext, + ) -> crate::Result<()> { + for change in changes { + let change = change?; + context.data.deleted.borrow_mut().insert(change.docid()); + } + Ok(()) + } + } + + let mut deletions = DocumentDeletion::new(); + deletions.delete_documents_by_docids(Vec::::new().into_iter().collect()); + let indexer = Bump::new(); + + let index = TempIndex::new(); + + let rtxn = index.read_txn().unwrap(); + + let db_fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); + let metadata_builder = MetadataBuilder::from_index(&index, &rtxn).unwrap(); + let fields_ids_map = + RwLock::new(FieldIdMapWithMetadata::new(db_fields_ids_map.clone(), metadata_builder)); + + let fields_ids_map_store = ThreadLocal::new(); + + let mut extractor_allocs = ThreadLocal::new(); + let doc_allocs = ThreadLocal::new(); + + let deletion_tracker = TrackDeletion(PhantomData); + + let changes = deletions + .into_changes(&indexer, crate::documents::PrimaryKey::Flat { name: "id", field_id: 0 }); + + let context = IndexingContext { + index: &index, + db_fields_ids_map: &db_fields_ids_map, + new_fields_ids_map: &fields_ids_map, + doc_allocs: &doc_allocs, + fields_ids_map_store: &fields_ids_map_store, + must_stop_processing: &(|| false), + send_progress: &(|_progress| {}), + }; + + for _ in 0..3 { + let datastore = ThreadLocal::new(); + + extract( + &changes, + &deletion_tracker, + context, + &mut extractor_allocs, + &datastore, + Step::ExtractingDocuments, + ) + .unwrap(); + + for (index, data) in datastore.into_iter().enumerate() { + println!("deleted by {index}: {:?}", data.deleted.borrow()); + } + for alloc in extractor_allocs.iter_mut() { + alloc.0.reset(); + } + } + } +} diff --git a/crates/milli/src/update/new/indexer/document_operation.rs b/crates/milli/src/update/new/indexer/document_operation.rs new file mode 100644 index 000000000..71d410ea6 --- /dev/null +++ b/crates/milli/src/update/new/indexer/document_operation.rs @@ -0,0 +1,674 @@ +use bumpalo::collections::CollectIn; +use bumpalo::Bump; +use hashbrown::hash_map::Entry; +use heed::RoTxn; +use memmap2::Mmap; +use raw_collections::RawMap; +use rayon::slice::ParallelSlice; +use serde_json::value::RawValue; +use serde_json::Deserializer; + +use super::super::document_change::DocumentChange; +use super::document_changes::{DocumentChangeContext, DocumentChanges}; +use super::retrieve_or_guess_primary_key; +use crate::documents::PrimaryKey; +use crate::update::new::document::Versions; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::{Deletion, Insertion, Update}; +use crate::update::{AvailableIds, IndexDocumentsMethod}; +use crate::{DocumentId, Error, FieldsIdsMap, Index, InternalError, Result, UserError}; + +pub struct DocumentOperation<'pl> { + operations: Vec>, + method: MergeMethod, +} + +impl<'pl> DocumentOperation<'pl> { + pub fn new(method: IndexDocumentsMethod) -> Self { + Self { operations: Default::default(), method: MergeMethod::from(method) } + } + + /// TODO please give me a type + /// The payload is expected to be in the grenad format + pub fn add_documents(&mut self, payload: &'pl Mmap) -> Result<()> { + payload.advise(memmap2::Advice::Sequential)?; + self.operations.push(Payload::Addition(&payload[..])); + Ok(()) + } + + pub fn delete_documents(&mut self, to_delete: &'pl [&'pl str]) { + self.operations.push(Payload::Deletion(to_delete)) + } + + pub fn into_changes( + self, + indexer: &'pl Bump, + index: &Index, + rtxn: &'pl RoTxn<'pl>, + primary_key_from_op: Option<&'pl str>, + new_fields_ids_map: &mut FieldsIdsMap, + ) -> Result<(DocumentOperationChanges<'pl>, Vec, Option>)> { + let Self { operations, method } = self; + + let documents_ids = index.documents_ids(rtxn)?; + let mut operations_stats = Vec::new(); + let mut available_docids = AvailableIds::new(&documents_ids); + let mut docids_version_offsets = hashbrown::HashMap::new(); + let mut primary_key = None; + + for operation in operations { + let (bytes, document_count, result) = match operation { + Payload::Addition(payload) => extract_addition_payload_changes( + indexer, + index, + rtxn, + primary_key_from_op, + &mut primary_key, + new_fields_ids_map, + &mut available_docids, + &docids_version_offsets, + method, + payload, + ), + Payload::Deletion(to_delete) => extract_deletion_payload_changes( + index, + rtxn, + &mut available_docids, + &docids_version_offsets, + method, + to_delete, + ), + }; + + let error = match result { + Ok(new_docids_version_offsets) => { + // If we don't have any error then we can merge the content of this payload + // into to main payload. Else we just drop this payload extraction. + merge_version_offsets(&mut docids_version_offsets, new_docids_version_offsets); + None + } + Err(Error::UserError(user_error)) => Some(user_error), + Err(e) => return Err(e), + }; + + operations_stats.push(PayloadStats { document_count, bytes, error }); + } + + // TODO We must drain the HashMap into a Vec because rayon::hash_map::IntoIter: !Clone + let mut docids_version_offsets: bumpalo::collections::vec::Vec<_> = + docids_version_offsets.drain().collect_in(indexer); + + // Reorder the offsets to make sure we iterate on the file sequentially + // And finally sort them + docids_version_offsets.sort_unstable_by_key(|(_, po)| method.sort_key(&po.operations)); + + let docids_version_offsets = docids_version_offsets.into_bump_slice(); + Ok((DocumentOperationChanges { docids_version_offsets }, operations_stats, primary_key)) + } +} + +#[allow(clippy::too_many_arguments)] +fn extract_addition_payload_changes<'r, 'pl: 'r>( + indexer: &'pl Bump, + index: &Index, + rtxn: &'r RoTxn<'r>, + primary_key_from_op: Option<&'r str>, + primary_key: &mut Option>, + new_fields_ids_map: &mut FieldsIdsMap, + available_docids: &mut AvailableIds, + main_docids_version_offsets: &hashbrown::HashMap<&'pl str, PayloadOperations<'pl>>, + method: MergeMethod, + payload: &'pl [u8], +) -> (u64, u64, Result>>) { + let mut new_docids_version_offsets = hashbrown::HashMap::<&str, PayloadOperations<'pl>>::new(); + + /// TODO manage the error + let mut previous_offset = 0; + let mut iter = Deserializer::from_slice(payload).into_iter::<&RawValue>(); + loop { + let optdoc = match iter.next().transpose() { + Ok(optdoc) => optdoc, + Err(e) => { + return ( + payload.len() as u64, + new_docids_version_offsets.len() as u64, + Err(InternalError::SerdeJson(e).into()), + ) + } + }; + + // Only guess the primary key if it is the first document + let retrieved_primary_key = if previous_offset == 0 { + let optdoc = match optdoc { + Some(doc) => match RawMap::from_raw_value(doc, indexer) { + Ok(docmap) => Some(docmap), + Err(error) => { + return ( + payload.len() as u64, + new_docids_version_offsets.len() as u64, + Err(Error::UserError(UserError::SerdeJson(error))), + ) + } + }, + None => None, + }; + + let result = retrieve_or_guess_primary_key( + rtxn, + index, + new_fields_ids_map, + primary_key_from_op, + optdoc, + ); + + let (pk, _has_been_changed) = match result { + Ok(Ok(pk)) => pk, + Ok(Err(user_error)) => { + return ( + payload.len() as u64, + new_docids_version_offsets.len() as u64, + Err(Error::UserError(user_error)), + ) + } + Err(error) => { + return ( + payload.len() as u64, + new_docids_version_offsets.len() as u64, + Err(error), + ) + } + }; + + primary_key.get_or_insert(pk) + } else { + primary_key.as_ref().unwrap() + }; + + let doc = match optdoc { + Some(doc) => doc, + None => break, + }; + + let external_id = match retrieved_primary_key.extract_fields_and_docid( + doc, + new_fields_ids_map, + indexer, + ) { + Ok(edi) => edi, + Err(e) => { + return (payload.len() as u64, new_docids_version_offsets.len() as u64, Err(e)) + } + }; + + let external_id = external_id.to_de(); + let current_offset = iter.byte_offset(); + let document_offset = DocumentOffset { content: &payload[previous_offset..current_offset] }; + + match main_docids_version_offsets.get(external_id) { + None => { + let (docid, is_new) = match index.external_documents_ids().get(rtxn, external_id) { + Ok(Some(docid)) => (docid, false), + Ok(None) => ( + match available_docids.next() { + Some(docid) => docid, + None => { + return ( + payload.len() as u64, + new_docids_version_offsets.len() as u64, + Err(UserError::DocumentLimitReached.into()), + ) + } + }, + true, + ), + Err(e) => { + return ( + payload.len() as u64, + new_docids_version_offsets.len() as u64, + Err(e.into()), + ) + } + }; + + match new_docids_version_offsets.entry(external_id) { + Entry::Occupied(mut entry) => entry.get_mut().push_addition(document_offset), + Entry::Vacant(entry) => { + entry.insert(PayloadOperations::new_addition( + method, + docid, + is_new, + document_offset, + )); + } + } + } + Some(payload_operations) => match new_docids_version_offsets.entry(external_id) { + Entry::Occupied(mut entry) => entry.get_mut().push_addition(document_offset), + Entry::Vacant(entry) => { + entry.insert(PayloadOperations::new_addition( + method, + payload_operations.docid, + payload_operations.is_new, + document_offset, + )); + } + }, + } + + previous_offset = iter.byte_offset(); + } + + (payload.len() as u64, new_docids_version_offsets.len() as u64, Ok(new_docids_version_offsets)) +} + +fn extract_deletion_payload_changes<'s, 'pl: 's>( + index: &Index, + rtxn: &RoTxn, + available_docids: &mut AvailableIds, + main_docids_version_offsets: &hashbrown::HashMap<&'s str, PayloadOperations<'pl>>, + method: MergeMethod, + to_delete: &'pl [&'pl str], +) -> (u64, u64, Result>>) { + let mut new_docids_version_offsets = hashbrown::HashMap::<&str, PayloadOperations<'pl>>::new(); + let mut document_count = 0; + + for external_id in to_delete { + match main_docids_version_offsets.get(external_id) { + None => { + let (docid, is_new) = match index.external_documents_ids().get(rtxn, external_id) { + Ok(Some(docid)) => (docid, false), + Ok(None) => ( + match available_docids.next() { + Some(docid) => docid, + None => { + return ( + 0, + new_docids_version_offsets.len() as u64, + Err(UserError::DocumentLimitReached.into()), + ) + } + }, + true, + ), + Err(e) => return (0, new_docids_version_offsets.len() as u64, Err(e.into())), + }; + + match new_docids_version_offsets.entry(external_id) { + Entry::Occupied(mut entry) => entry.get_mut().push_deletion(), + Entry::Vacant(entry) => { + entry.insert(PayloadOperations::new_deletion(method, docid, is_new)); + } + } + } + Some(payload_operations) => match new_docids_version_offsets.entry(external_id) { + Entry::Occupied(mut entry) => entry.get_mut().push_deletion(), + Entry::Vacant(entry) => { + entry.insert(PayloadOperations::new_deletion( + method, + payload_operations.docid, + payload_operations.is_new, + )); + } + }, + } + document_count += 1; + } + + (0, document_count, Ok(new_docids_version_offsets)) +} + +fn merge_version_offsets<'s, 'pl>( + main: &mut hashbrown::HashMap<&'s str, PayloadOperations<'pl>>, + new: hashbrown::HashMap<&'s str, PayloadOperations<'pl>>, +) { + // We cannot swap like nothing because documents + // operations must be in the right order. + if main.is_empty() { + return *main = new; + } + + for (key, new_payload) in new { + match main.entry(key) { + Entry::Occupied(mut entry) => entry.get_mut().append_operations(new_payload.operations), + Entry::Vacant(entry) => { + entry.insert(new_payload); + } + } + } +} + +impl<'pl> DocumentChanges<'pl> for DocumentOperationChanges<'pl> { + type Item = (&'pl str, PayloadOperations<'pl>); + + fn iter( + &self, + chunk_size: usize, + ) -> impl rayon::prelude::IndexedParallelIterator> { + self.docids_version_offsets.par_chunks(chunk_size) + } + + fn item_to_document_change<'doc, T: MostlySend + 'doc>( + &'doc self, + context: &'doc DocumentChangeContext, + item: &'doc Self::Item, + ) -> Result>> + where + 'pl: 'doc, + { + let (external_doc, payload_operations) = item; + payload_operations.merge_method.merge( + payload_operations.docid, + external_doc, + payload_operations.is_new, + &context.doc_alloc, + &payload_operations.operations[..], + ) + } + + fn len(&self) -> usize { + self.docids_version_offsets.len() + } +} + +pub struct DocumentOperationChanges<'pl> { + docids_version_offsets: &'pl [(&'pl str, PayloadOperations<'pl>)], +} + +pub enum Payload<'pl> { + Addition(&'pl [u8]), + Deletion(&'pl [&'pl str]), +} + +pub struct PayloadStats { + pub bytes: u64, + pub document_count: u64, + pub error: Option, +} + +pub struct PayloadOperations<'pl> { + /// The internal document id of the document. + pub docid: DocumentId, + /// Wether this document is not in the current database (visible by the rtxn). + pub is_new: bool, + /// The operations to perform, in order, on this document. + pub operations: Vec>, + /// The merge method we are using to merge payloads and documents. + merge_method: MergeMethod, +} + +impl<'pl> PayloadOperations<'pl> { + fn new_deletion(merge_method: MergeMethod, docid: DocumentId, is_new: bool) -> Self { + Self { docid, is_new, operations: vec![InnerDocOp::Deletion], merge_method } + } + + fn new_addition( + merge_method: MergeMethod, + docid: DocumentId, + is_new: bool, + offset: DocumentOffset<'pl>, + ) -> Self { + Self { docid, is_new, operations: vec![InnerDocOp::Addition(offset)], merge_method } + } +} + +impl<'pl> PayloadOperations<'pl> { + fn push_addition(&mut self, offset: DocumentOffset<'pl>) { + if self.merge_method.useless_previous_changes() { + self.operations.clear(); + } + self.operations.push(InnerDocOp::Addition(offset)) + } + + fn push_deletion(&mut self) { + self.operations.clear(); + self.operations.push(InnerDocOp::Deletion); + } + + fn append_operations(&mut self, mut operations: Vec>) { + debug_assert!(!operations.is_empty()); + if self.merge_method.useless_previous_changes() { + self.operations.clear(); + } + self.operations.append(&mut operations); + } +} + +#[derive(Clone)] +pub enum InnerDocOp<'pl> { + Addition(DocumentOffset<'pl>), + Deletion, +} + +/// Represents an offset where a document lives +/// in an mmapped grenad reader file. +#[derive(Clone)] +pub struct DocumentOffset<'pl> { + /// The mmapped payload files. + pub content: &'pl [u8], +} + +trait MergeChanges { + /// Whether the payloads in the list of operations are useless or not. + fn useless_previous_changes(&self) -> bool; + + /// Returns a key that is used to order the payloads the right way. + fn sort_key(&self, docops: &[InnerDocOp]) -> usize; + + fn merge<'doc>( + &self, + docid: DocumentId, + external_docid: &'doc str, + is_new: bool, + doc_alloc: &'doc Bump, + operations: &'doc [InnerDocOp], + ) -> Result>>; +} + +#[derive(Debug, Clone, Copy)] +enum MergeMethod { + ForReplacement(MergeDocumentForReplacement), + ForUpdates(MergeDocumentForUpdates), +} + +impl MergeChanges for MergeMethod { + fn useless_previous_changes(&self) -> bool { + match self { + MergeMethod::ForReplacement(merge) => merge.useless_previous_changes(), + MergeMethod::ForUpdates(merge) => merge.useless_previous_changes(), + } + } + + fn sort_key(&self, docops: &[InnerDocOp]) -> usize { + match self { + MergeMethod::ForReplacement(merge) => merge.sort_key(docops), + MergeMethod::ForUpdates(merge) => merge.sort_key(docops), + } + } + + fn merge<'doc>( + &self, + docid: DocumentId, + external_docid: &'doc str, + is_new: bool, + doc_alloc: &'doc Bump, + operations: &'doc [InnerDocOp], + ) -> Result>> { + match self { + MergeMethod::ForReplacement(merge) => { + merge.merge(docid, external_docid, is_new, doc_alloc, operations) + } + MergeMethod::ForUpdates(merge) => { + merge.merge(docid, external_docid, is_new, doc_alloc, operations) + } + } + } +} + +impl From for MergeMethod { + fn from(method: IndexDocumentsMethod) -> Self { + match method { + IndexDocumentsMethod::ReplaceDocuments => { + MergeMethod::ForReplacement(MergeDocumentForReplacement) + } + IndexDocumentsMethod::UpdateDocuments => { + MergeMethod::ForUpdates(MergeDocumentForUpdates) + } + } + } +} + +#[derive(Debug, Clone, Copy)] +struct MergeDocumentForReplacement; + +impl MergeChanges for MergeDocumentForReplacement { + fn useless_previous_changes(&self) -> bool { + true + } + + /// Reorders to read only the last change. + fn sort_key(&self, docops: &[InnerDocOp]) -> usize { + let f = |ido: &_| match ido { + InnerDocOp::Addition(add) => Some(add.content.as_ptr() as usize), + InnerDocOp::Deletion => None, + }; + docops.iter().rev().find_map(f).unwrap_or(0) + } + + /// Returns only the most recent version of a document based on the updates from the payloads. + /// + /// This function is only meant to be used when doing a replacement and not an update. + fn merge<'doc>( + &self, + docid: DocumentId, + external_doc: &'doc str, + is_new: bool, + doc_alloc: &'doc Bump, + operations: &'doc [InnerDocOp], + ) -> Result>> { + match operations.last() { + Some(InnerDocOp::Addition(DocumentOffset { content })) => { + let document = serde_json::from_slice(content).unwrap(); + let document = raw_collections::RawMap::from_raw_value(document, doc_alloc) + .map_err(UserError::SerdeJson)?; + + if is_new { + Ok(Some(DocumentChange::Insertion(Insertion::create( + docid, + external_doc, + Versions::single(document), + )))) + } else { + Ok(Some(DocumentChange::Update(Update::create( + docid, + external_doc, + Versions::single(document), + true, + )))) + } + } + Some(InnerDocOp::Deletion) => { + return if is_new { + Ok(None) + } else { + let deletion = Deletion::create(docid, external_doc); + Ok(Some(DocumentChange::Deletion(deletion))) + }; + } + None => unreachable!("We must not have empty set of operations on a document"), + } + } +} + +#[derive(Debug, Clone, Copy)] +struct MergeDocumentForUpdates; + +impl MergeChanges for MergeDocumentForUpdates { + fn useless_previous_changes(&self) -> bool { + false + } + + /// Reorders to read the first changes first so that it's faster to read the first one and then the rest. + fn sort_key(&self, docops: &[InnerDocOp]) -> usize { + let f = |ido: &_| match ido { + InnerDocOp::Addition(add) => Some(add.content.as_ptr() as usize), + InnerDocOp::Deletion => None, + }; + docops.iter().find_map(f).unwrap_or(0) + } + + /// Reads the previous version of a document from the database, the new versions + /// in the grenad update files and merges them to generate a new boxed obkv. + /// + /// This function is only meant to be used when doing an update and not a replacement. + fn merge<'doc>( + &self, + docid: DocumentId, + external_docid: &'doc str, + is_new: bool, + doc_alloc: &'doc Bump, + operations: &'doc [InnerDocOp], + ) -> Result>> { + if operations.is_empty() { + unreachable!("We must not have empty set of operations on a document"); + } + + let last_deletion = operations.iter().rposition(|op| matches!(op, InnerDocOp::Deletion)); + let operations = &operations[last_deletion.map_or(0, |i| i + 1)..]; + + let has_deletion = last_deletion.is_some(); + + if operations.is_empty() { + return if is_new { + Ok(None) + } else { + let deletion = Deletion::create(docid, external_docid); + Ok(Some(DocumentChange::Deletion(deletion))) + }; + } + + let versions = match operations { + [single] => { + let DocumentOffset { content } = match single { + InnerDocOp::Addition(offset) => offset, + InnerDocOp::Deletion => { + unreachable!("Deletion in document operations") + } + }; + let document = serde_json::from_slice(content).unwrap(); + let document = raw_collections::RawMap::from_raw_value(document, doc_alloc) + .map_err(UserError::SerdeJson)?; + + Some(Versions::single(document)) + } + operations => { + let versions = operations.iter().map(|operation| { + let DocumentOffset { content } = match operation { + InnerDocOp::Addition(offset) => offset, + InnerDocOp::Deletion => { + unreachable!("Deletion in document operations") + } + }; + + let document = serde_json::from_slice(content).unwrap(); + let document = raw_collections::RawMap::from_raw_value(document, doc_alloc) + .map_err(UserError::SerdeJson)?; + Ok(document) + }); + Versions::multiple(versions)? + } + }; + + let Some(versions) = versions else { return Ok(None) }; + + if is_new { + Ok(Some(DocumentChange::Insertion(Insertion::create(docid, external_docid, versions)))) + } else { + Ok(Some(DocumentChange::Update(Update::create( + docid, + external_docid, + versions, + has_deletion, + )))) + } + } +} diff --git a/crates/milli/src/update/new/indexer/mod.rs b/crates/milli/src/update/new/indexer/mod.rs new file mode 100644 index 000000000..dfc3d9b02 --- /dev/null +++ b/crates/milli/src/update/new/indexer/mod.rs @@ -0,0 +1,726 @@ +use std::cmp::Ordering; +use std::sync::{OnceLock, RwLock}; +use std::thread::{self, Builder}; + +use big_s::S; +use document_changes::{extract, DocumentChanges, IndexingContext, Progress}; +pub use document_deletion::DocumentDeletion; +pub use document_operation::{DocumentOperation, PayloadStats}; +use hashbrown::HashMap; +use heed::types::{Bytes, DecodeIgnore, Str}; +use heed::{RoTxn, RwTxn}; +use itertools::{merge_join_by, EitherOrBoth}; +pub use partial_dump::PartialDump; +use rand::SeedableRng as _; +use raw_collections::RawMap; +use time::OffsetDateTime; +pub use update_by_function::UpdateByFunction; + +use super::channel::*; +use super::extract::*; +use super::facet_search_builder::FacetSearchBuilder; +use super::merger::FacetFieldIdsDelta; +use super::steps::Step; +use super::thread_local::ThreadLocal; +use super::word_fst_builder::{PrefixData, PrefixDelta, WordFstBuilder}; +use super::words_prefix_docids::{ + compute_word_prefix_docids, compute_word_prefix_fid_docids, compute_word_prefix_position_docids, +}; +use super::StdResult; +use crate::documents::{PrimaryKey, DEFAULT_PRIMARY_KEY}; +use crate::facet::FacetType; +use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; +use crate::index::main_key::{WORDS_FST_KEY, WORDS_PREFIXES_FST_KEY}; +use crate::proximity::ProximityPrecision; +use crate::update::del_add::DelAdd; +use crate::update::new::extract::EmbeddingExtractor; +use crate::update::new::merger::merge_and_send_rtree; +use crate::update::new::words_prefix_docids::compute_exact_word_prefix_docids; +use crate::update::new::{merge_and_send_docids, merge_and_send_facet_docids, FacetDatabases}; +use crate::update::settings::InnerIndexSettings; +use crate::update::{FacetsUpdateBulk, GrenadParameters}; +use crate::vector::{ArroyWrapper, EmbeddingConfigs, Embeddings}; +use crate::{ + FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, Result, ThreadPoolNoAbort, + ThreadPoolNoAbortBuilder, UserError, +}; + +pub(crate) mod de; +pub mod document_changes; +mod document_deletion; +mod document_operation; +mod partial_dump; +mod update_by_function; + +/// This is the main function of this crate. +/// +/// Give it the output of the [`Indexer::document_changes`] method and it will execute it in the [`rayon::ThreadPool`]. +/// +/// TODO return stats +#[allow(clippy::too_many_arguments)] // clippy: 😝 +pub fn index<'pl, 'indexer, 'index, DC, MSP, SP>( + wtxn: &mut RwTxn, + index: &'index Index, + grenad_parameters: GrenadParameters, + db_fields_ids_map: &'indexer FieldsIdsMap, + new_fields_ids_map: FieldsIdsMap, + new_primary_key: Option>, + document_changes: &DC, + embedders: EmbeddingConfigs, + must_stop_processing: &'indexer MSP, + send_progress: &'indexer SP, +) -> Result<()> +where + DC: DocumentChanges<'pl>, + MSP: Fn() -> bool + Sync, + SP: Fn(Progress) + Sync, +{ + let (extractor_sender, writer_receiver) = extractor_writer_channel(10_000); + + let metadata_builder = MetadataBuilder::from_index(index, wtxn)?; + let new_fields_ids_map = FieldIdMapWithMetadata::new(new_fields_ids_map, metadata_builder); + let new_fields_ids_map = RwLock::new(new_fields_ids_map); + let fields_ids_map_store = ThreadLocal::with_capacity(rayon::current_num_threads()); + let mut extractor_allocs = ThreadLocal::with_capacity(rayon::current_num_threads()); + let doc_allocs = ThreadLocal::with_capacity(rayon::current_num_threads()); + + let indexing_context = IndexingContext { + index, + db_fields_ids_map, + new_fields_ids_map: &new_fields_ids_map, + doc_allocs: &doc_allocs, + fields_ids_map_store: &fields_ids_map_store, + must_stop_processing, + send_progress, + }; + + let mut field_distribution = index.field_distribution(wtxn)?; + let mut document_ids = index.documents_ids(wtxn)?; + + thread::scope(|s| -> Result<()> { + let indexer_span = tracing::Span::current(); + let embedders = &embedders; + // prevent moving the field_distribution and document_ids in the inner closure... + let field_distribution = &mut field_distribution; + let document_ids = &mut document_ids; + // TODO manage the errors correctly + let extractor_handle = Builder::new().name(S("indexer-extractors")).spawn_scoped(s, move || { + let span = tracing::trace_span!(target: "indexing::documents", parent: &indexer_span, "extract"); + let _entered = span.enter(); + + let rtxn = index.read_txn()?; + + // document but we need to create a function that collects and compresses documents. + let document_sender = extractor_sender.documents(); + let document_extractor = DocumentsExtractor::new(&document_sender, embedders); + let datastore = ThreadLocal::with_capacity(rayon::current_num_threads()); + + extract(document_changes, + &document_extractor, + indexing_context, + &mut extractor_allocs, + &datastore, + Step::ExtractingDocuments, + )?; + + for document_extractor_data in datastore { + let document_extractor_data = document_extractor_data.0.into_inner(); + for (field, delta) in document_extractor_data.field_distribution_delta { + let current = field_distribution.entry(field).or_default(); + // adding the delta should never cause a negative result, as we are removing fields that previously existed. + *current = current.saturating_add_signed(delta); + } + document_extractor_data.docids_delta.apply_to(document_ids); + } + + field_distribution.retain(|_, v| *v != 0); + + let facet_field_ids_delta; + + { + let span = tracing::trace_span!(target: "indexing::documents::extract", "faceted"); + let _entered = span.enter(); + + facet_field_ids_delta = merge_and_send_facet_docids( + FacetedDocidsExtractor::run_extraction( + grenad_parameters, + document_changes, + indexing_context, + &mut extractor_allocs, + &extractor_sender.field_id_docid_facet_sender(), + Step::ExtractingFacets + )?, + FacetDatabases::new(index), + index, + extractor_sender.facet_docids(), + )?; + } + + { + let span = tracing::trace_span!(target: "indexing::documents::extract", "word_docids"); + let _entered = span.enter(); + + + let WordDocidsCaches { + word_docids, + word_fid_docids, + exact_word_docids, + word_position_docids, + fid_word_count_docids, + } = WordDocidsExtractors::run_extraction( + grenad_parameters, + document_changes, + indexing_context, + &mut extractor_allocs, + Step::ExtractingWords + )?; + + // TODO Word Docids Merger + // extractor_sender.send_searchable::(word_docids).unwrap(); + { + let span = tracing::trace_span!(target: "indexing::documents::merge", "word_docids"); + let _entered = span.enter(); + merge_and_send_docids( + word_docids, + index.word_docids.remap_types(), + index, + extractor_sender.docids::(), + &indexing_context.must_stop_processing, + )?; + } + + // Word Fid Docids Merging + // extractor_sender.send_searchable::(word_fid_docids).unwrap(); + { + let span = tracing::trace_span!(target: "indexing::documents::merge", "word_fid_docids"); + let _entered = span.enter(); + merge_and_send_docids( + word_fid_docids, + index.word_fid_docids.remap_types(), + index, + extractor_sender.docids::(), + &indexing_context.must_stop_processing, + )?; + } + + // Exact Word Docids Merging + // extractor_sender.send_searchable::(exact_word_docids).unwrap(); + { + let span = tracing::trace_span!(target: "indexing::documents::merge", "exact_word_docids"); + let _entered = span.enter(); + merge_and_send_docids( + exact_word_docids, + index.exact_word_docids.remap_types(), + index, + extractor_sender.docids::(), + &indexing_context.must_stop_processing, + )?; + } + + // Word Position Docids Merging + // extractor_sender.send_searchable::(word_position_docids).unwrap(); + { + let span = tracing::trace_span!(target: "indexing::documents::merge", "word_position_docids"); + let _entered = span.enter(); + merge_and_send_docids( + word_position_docids, + index.word_position_docids.remap_types(), + index, + extractor_sender.docids::(), + &indexing_context.must_stop_processing, + )?; + } + + // Fid Word Count Docids Merging + // extractor_sender.send_searchable::(fid_word_count_docids).unwrap(); + { + let span = tracing::trace_span!(target: "indexing::documents::merge", "fid_word_count_docids"); + let _entered = span.enter(); + merge_and_send_docids( + fid_word_count_docids, + index.field_id_word_count_docids.remap_types(), + index, + extractor_sender.docids::(), + &indexing_context.must_stop_processing, + )?; + } + } + + // run the proximity extraction only if the precision is by word + // this works only if the settings didn't change during this transaction. + let proximity_precision = index.proximity_precision(&rtxn)?.unwrap_or_default(); + if proximity_precision == ProximityPrecision::ByWord { + let span = tracing::trace_span!(target: "indexing::documents::extract", "word_pair_proximity_docids"); + let _entered = span.enter(); + + + let caches = ::run_extraction( + grenad_parameters, + document_changes, + indexing_context, + &mut extractor_allocs, + Step::ExtractingWordProximity, + )?; + + merge_and_send_docids( + caches, + index.word_pair_proximity_docids.remap_types(), + index, + extractor_sender.docids::(), + &indexing_context.must_stop_processing, + )?; + } + + 'vectors: { + let span = tracing::trace_span!(target: "indexing::documents::extract", "vectors"); + let _entered = span.enter(); + + let mut index_embeddings = index.embedding_configs(&rtxn)?; + if index_embeddings.is_empty() { + break 'vectors; + } + + let embedding_sender = extractor_sender.embeddings(); + let extractor = EmbeddingExtractor::new(embedders, &embedding_sender, field_distribution, request_threads()); + let mut datastore = ThreadLocal::with_capacity(rayon::current_num_threads()); + extract(document_changes, &extractor, indexing_context, &mut extractor_allocs, &datastore, Step::ExtractingEmbeddings)?; + + for config in &mut index_embeddings { + 'data: for data in datastore.iter_mut() { + let data = &mut data.get_mut().0; + let Some(deladd) = data.remove(&config.name) else { continue 'data; }; + deladd.apply_to(&mut config.user_provided); + } + } + + embedding_sender.finish(index_embeddings).unwrap(); + } + + 'geo: { + let span = tracing::trace_span!(target: "indexing::documents::extract", "geo"); + let _entered = span.enter(); + + // let geo_sender = extractor_sender.geo_points(); + let Some(extractor) = GeoExtractor::new(&rtxn, index, grenad_parameters)? else { + break 'geo; + }; + let datastore = ThreadLocal::with_capacity(rayon::current_num_threads()); + extract( + document_changes, + &extractor, + indexing_context, + &mut extractor_allocs, + &datastore, + Step::WritingGeoPoints + )?; + + merge_and_send_rtree( + datastore, + &rtxn, + index, + extractor_sender.geo(), + &indexing_context.must_stop_processing, + )?; + } + + // TODO THIS IS TOO MUCH + // - [ ] Extract fieldid docid facet number + // - [ ] Extract fieldid docid facet string + // - [ ] Extract facetid string fst + // - [ ] Extract facetid normalized string strings + + // TODO Inverted Indexes again + // - [x] Extract fieldid facet isempty docids + // - [x] Extract fieldid facet isnull docids + // - [x] Extract fieldid facet exists docids + + // TODO This is the normal system + // - [x] Extract fieldid facet number docids + // - [x] Extract fieldid facet string docids + + { + let span = tracing::trace_span!(target: "indexing::documents::extract", "FINISH"); + let _entered = span.enter(); + (indexing_context.send_progress)(Progress::from_step(Step::WritingToDatabase)); + } + + Result::Ok(facet_field_ids_delta) + })?; + + let global_fields_ids_map = GlobalFieldsIdsMap::new(&new_fields_ids_map); + + let vector_arroy = index.vector_arroy; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let indexer_span = tracing::Span::current(); + let arroy_writers: Result> = embedders + .inner_as_ref() + .iter() + .map(|(embedder_name, (embedder, _, was_quantized))| { + let embedder_index = index.embedder_category_id.get(wtxn, embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { + db_name: "embedder_category_id", + key: None, + }, + )?; + + let dimensions = embedder.dimensions(); + let writer = ArroyWrapper::new(vector_arroy, embedder_index, *was_quantized); + + Ok(( + embedder_index, + (embedder_name.as_str(), embedder.as_ref(), writer, dimensions), + )) + }) + .collect(); + + let mut arroy_writers = arroy_writers?; + for operation in writer_receiver { + match operation { + WriterOperation::DbOperation(db_operation) => { + let database = db_operation.database(index); + match db_operation.entry() { + EntryOperation::Delete(e) => { + if !database.delete(wtxn, e.entry())? { + unreachable!("We tried to delete an unknown key") + } + } + EntryOperation::Write(e) => database.put(wtxn, e.key(), e.value())?, + } + } + WriterOperation::ArroyOperation(arroy_operation) => match arroy_operation { + ArroyOperation::DeleteVectors { docid } => { + for (_embedder_index, (_embedder_name, _embedder, writer, dimensions)) in + &mut arroy_writers + { + let dimensions = *dimensions; + writer.del_items(wtxn, dimensions, docid)?; + } + } + ArroyOperation::SetVectors { + docid, + embedder_id, + embeddings: raw_embeddings, + } => { + let (_, _, writer, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + // TODO: switch to Embeddings + let mut embeddings = Embeddings::new(*dimensions); + for embedding in raw_embeddings { + embeddings.append(embedding).unwrap(); + } + + writer.del_items(wtxn, *dimensions, docid)?; + writer.add_items(wtxn, docid, &embeddings)?; + } + ArroyOperation::SetVector { docid, embedder_id, embedding } => { + let (_, _, writer, dimensions) = + arroy_writers.get(&embedder_id).expect("requested a missing embedder"); + writer.del_items(wtxn, *dimensions, docid)?; + writer.add_item(wtxn, docid, &embedding)?; + } + ArroyOperation::Finish { configs } => { + let span = tracing::trace_span!(target: "indexing::vectors", parent: &indexer_span, "build"); + let _entered = span.enter(); + + (indexing_context.send_progress)(Progress::from_step( + Step::WritingEmbeddingsToDatabase, + )); + + for (_embedder_index, (_embedder_name, _embedder, writer, dimensions)) in + &mut arroy_writers + { + let dimensions = *dimensions; + writer.build_and_quantize( + wtxn, + &mut rng, + dimensions, + false, + &indexing_context.must_stop_processing, + )?; + } + + index.put_embedding_configs(wtxn, configs)?; + } + }, + } + } + + (indexing_context.send_progress)(Progress::from_step(Step::WaitingForExtractors)); + + let facet_field_ids_delta = extractor_handle.join().unwrap()?; + + (indexing_context.send_progress)(Progress::from_step(Step::PostProcessingFacets)); + + compute_facet_search_database(index, wtxn, global_fields_ids_map)?; + compute_facet_level_database(index, wtxn, facet_field_ids_delta)?; + + (indexing_context.send_progress)(Progress::from_step(Step::PostProcessingWords)); + + if let Some(prefix_delta) = compute_word_fst(index, wtxn)? { + compute_prefix_database(index, wtxn, prefix_delta)?; + } + (indexing_context.send_progress)(Progress::from_step(Step::Finalizing)); + + Ok(()) as Result<_> + })?; + + // required to into_inner the new_fields_ids_map + drop(fields_ids_map_store); + + let new_fields_ids_map = new_fields_ids_map.into_inner().unwrap(); + index.put_fields_ids_map(wtxn, new_fields_ids_map.as_fields_ids_map())?; + + if let Some(new_primary_key) = new_primary_key { + index.put_primary_key(wtxn, new_primary_key.name())?; + } + + // used to update the localized and weighted maps while sharing the update code with the settings pipeline. + let mut inner_index_settings = InnerIndexSettings::from_index(index, wtxn)?; + inner_index_settings.recompute_facets(wtxn, index)?; + inner_index_settings.recompute_searchables(wtxn, index)?; + index.put_field_distribution(wtxn, &field_distribution)?; + index.put_documents_ids(wtxn, &document_ids)?; + index.set_updated_at(wtxn, &OffsetDateTime::now_utc())?; + + Ok(()) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] +fn compute_prefix_database( + index: &Index, + wtxn: &mut RwTxn, + prefix_delta: PrefixDelta, +) -> Result<()> { + eprintln!("prefix_delta: {:?}", &prefix_delta); + let PrefixDelta { modified, deleted } = prefix_delta; + // Compute word prefix docids + compute_word_prefix_docids(wtxn, index, &modified, &deleted)?; + // Compute exact word prefix docids + compute_exact_word_prefix_docids(wtxn, index, &modified, &deleted)?; + // Compute word prefix fid docids + compute_word_prefix_fid_docids(wtxn, index, &modified, &deleted)?; + // Compute word prefix position docids + compute_word_prefix_position_docids(wtxn, index, &modified, &deleted) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing")] +fn compute_word_fst(index: &Index, wtxn: &mut RwTxn) -> Result> { + let rtxn = index.read_txn()?; + let words_fst = index.words_fst(&rtxn)?; + let mut word_fst_builder = WordFstBuilder::new(&words_fst)?; + let prefix_settings = index.prefix_settings(&rtxn)?; + word_fst_builder.with_prefix_settings(prefix_settings); + + let previous_words = index.word_docids.iter(&rtxn)?.remap_data_type::(); + let current_words = index.word_docids.iter(wtxn)?.remap_data_type::(); + for eob in merge_join_by(previous_words, current_words, |lhs, rhs| match (lhs, rhs) { + (Ok((l, _)), Ok((r, _))) => l.cmp(r), + (Err(_), _) | (_, Err(_)) => Ordering::Equal, + }) { + match eob { + EitherOrBoth::Both(lhs, rhs) => { + let (word, lhs_bytes) = lhs?; + let (_, rhs_bytes) = rhs?; + if lhs_bytes != rhs_bytes { + word_fst_builder.register_word(DelAdd::Addition, word.as_ref())?; + } + } + EitherOrBoth::Left(result) => { + let (word, _) = result?; + word_fst_builder.register_word(DelAdd::Deletion, word.as_ref())?; + } + EitherOrBoth::Right(result) => { + let (word, _) = result?; + word_fst_builder.register_word(DelAdd::Addition, word.as_ref())?; + } + } + } + + let span = tracing::trace_span!(target: "indexing::documents::merge", "words_fst"); + let _entered = span.enter(); + + let (word_fst_mmap, prefix_data) = word_fst_builder.build(index, &rtxn)?; + index.main.remap_types::().put(wtxn, WORDS_FST_KEY, &word_fst_mmap)?; + if let Some(PrefixData { prefixes_fst_mmap, prefix_delta }) = prefix_data { + index.main.remap_types::().put( + wtxn, + WORDS_PREFIXES_FST_KEY, + &prefixes_fst_mmap, + )?; + Ok(Some(prefix_delta)) + } else { + Ok(None) + } +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::facet_search")] +fn compute_facet_search_database( + index: &Index, + wtxn: &mut RwTxn, + global_fields_ids_map: GlobalFieldsIdsMap, +) -> Result<()> { + let rtxn = index.read_txn()?; + let localized_attributes_rules = index.localized_attributes_rules(&rtxn)?; + let mut facet_search_builder = FacetSearchBuilder::new( + global_fields_ids_map, + localized_attributes_rules.unwrap_or_default(), + ); + + let previous_facet_id_string_docids = index + .facet_id_string_docids + .iter(&rtxn)? + .remap_data_type::() + .filter(|r| r.as_ref().map_or(true, |(k, _)| k.level == 0)); + let current_facet_id_string_docids = index + .facet_id_string_docids + .iter(wtxn)? + .remap_data_type::() + .filter(|r| r.as_ref().map_or(true, |(k, _)| k.level == 0)); + for eob in merge_join_by( + previous_facet_id_string_docids, + current_facet_id_string_docids, + |lhs, rhs| match (lhs, rhs) { + (Ok((l, _)), Ok((r, _))) => l.cmp(r), + (Err(_), _) | (_, Err(_)) => Ordering::Equal, + }, + ) { + match eob { + EitherOrBoth::Both(lhs, rhs) => { + let (_, _) = lhs?; + let (_, _) = rhs?; + } + EitherOrBoth::Left(result) => { + let (key, _) = result?; + facet_search_builder.register_from_key(DelAdd::Deletion, key)?; + } + EitherOrBoth::Right(result) => { + let (key, _) = result?; + facet_search_builder.register_from_key(DelAdd::Addition, key)?; + } + } + } + + facet_search_builder.merge_and_write(index, wtxn, &rtxn) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::facet_field_ids")] +fn compute_facet_level_database( + index: &Index, + wtxn: &mut RwTxn, + facet_field_ids_delta: FacetFieldIdsDelta, +) -> Result<()> { + eprintln!("facet_field_ids_delta: {:?}", &facet_field_ids_delta); + if let Some(modified_facet_string_ids) = facet_field_ids_delta.modified_facet_string_ids() { + let span = tracing::trace_span!(target: "indexing::facet_field_ids", "string"); + let _entered = span.enter(); + FacetsUpdateBulk::new_not_updating_level_0( + index, + modified_facet_string_ids, + FacetType::String, + ) + .execute(wtxn)?; + } + if let Some(modified_facet_number_ids) = facet_field_ids_delta.modified_facet_number_ids() { + let span = tracing::trace_span!(target: "indexing::facet_field_ids", "number"); + let _entered = span.enter(); + FacetsUpdateBulk::new_not_updating_level_0( + index, + modified_facet_number_ids, + FacetType::Number, + ) + .execute(wtxn)?; + } + + Ok(()) +} + +/// Returns the primary key that has already been set for this index or the +/// one we will guess by searching for the first key that contains "id" as a substring, +/// and whether the primary key changed +/// TODO move this elsewhere +pub fn retrieve_or_guess_primary_key<'a>( + rtxn: &'a RoTxn<'a>, + index: &Index, + new_fields_ids_map: &mut FieldsIdsMap, + primary_key_from_op: Option<&'a str>, + first_document: Option>, +) -> Result, bool), UserError>> { + // make sure that we have a declared primary key, either fetching it from the index or attempting to guess it. + + // do we have an existing declared primary key? + let (primary_key, has_changed) = if let Some(primary_key_from_db) = index.primary_key(rtxn)? { + // did we request a primary key in the operation? + match primary_key_from_op { + // we did, and it is different from the DB one + Some(primary_key_from_op) if primary_key_from_op != primary_key_from_db => { + return Ok(Err(UserError::PrimaryKeyCannotBeChanged( + primary_key_from_db.to_string(), + ))); + } + _ => (primary_key_from_db, false), + } + } else { + // no primary key in the DB => let's set one + // did we request a primary key in the operation? + let primary_key = if let Some(primary_key_from_op) = primary_key_from_op { + // set primary key from operation + primary_key_from_op + } else { + // guess primary key + let first_document = match first_document { + Some(document) => document, + // previous indexer when no pk is set + we send an empty payload => index_primary_key_no_candidate_found + None => return Ok(Err(UserError::NoPrimaryKeyCandidateFound)), + }; + + let guesses: Result> = first_document + .keys() + .filter_map(|name| { + let Some(_) = new_fields_ids_map.insert(name) else { + return Some(Err(UserError::AttributeLimitReached.into())); + }; + name.to_lowercase().ends_with(DEFAULT_PRIMARY_KEY).then_some(Ok(name)) + }) + .collect(); + + let mut guesses = guesses?; + + // sort the keys in lexicographical order, so that fields are always in the same order. + guesses.sort_unstable(); + + match guesses.as_slice() { + [] => return Ok(Err(UserError::NoPrimaryKeyCandidateFound)), + [name] => { + tracing::info!("Primary key was not specified in index. Inferred to '{name}'"); + *name + } + multiple => { + return Ok(Err(UserError::MultiplePrimaryKeyCandidatesFound { + candidates: multiple + .iter() + .map(|candidate| candidate.to_string()) + .collect(), + })) + } + } + }; + (primary_key, true) + }; + + match PrimaryKey::new_or_insert(primary_key, new_fields_ids_map) { + Ok(primary_key) => Ok(Ok((primary_key, has_changed))), + Err(err) => Ok(Err(err)), + } +} + +fn request_threads() -> &'static ThreadPoolNoAbort { + static REQUEST_THREADS: OnceLock = OnceLock::new(); + + REQUEST_THREADS.get_or_init(|| { + ThreadPoolNoAbortBuilder::new() + .num_threads(crate::vector::REQUEST_PARALLELISM) + .thread_name(|index| format!("embedding-request-{index}")) + .build() + .unwrap() + }) +} diff --git a/crates/milli/src/update/new/indexer/partial_dump.rs b/crates/milli/src/update/new/indexer/partial_dump.rs new file mode 100644 index 000000000..8b5a8b650 --- /dev/null +++ b/crates/milli/src/update/new/indexer/partial_dump.rs @@ -0,0 +1,88 @@ +use std::ops::DerefMut; + +use rayon::iter::IndexedParallelIterator; +use serde_json::value::RawValue; + +use super::document_changes::{DocumentChangeContext, DocumentChanges}; +use crate::documents::PrimaryKey; +use crate::update::concurrent_available_ids::ConcurrentAvailableIds; +use crate::update::new::document::Versions; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::{DocumentChange, Insertion}; +use crate::{Error, InternalError, Result, UserError}; + +pub struct PartialDump { + iter: I, +} + +impl PartialDump { + pub fn new_from_jsonlines(iter: I) -> Self { + PartialDump { iter } + } + + pub fn into_changes<'index>( + self, + concurrent_available_ids: &'index ConcurrentAvailableIds, + primary_key: &'index PrimaryKey, + ) -> PartialDumpChanges<'index, I> { + /// Note for future self: + /// - We recommend sending chunks of documents in this `PartialDumpIndexer` we therefore need to create a custom take_while_size method (that doesn't drop items). + PartialDumpChanges { iter: self.iter, concurrent_available_ids, primary_key } + } +} + +pub struct PartialDumpChanges<'doc, I> { + iter: I, + concurrent_available_ids: &'doc ConcurrentAvailableIds, + primary_key: &'doc PrimaryKey<'doc>, +} + +impl<'index, Iter> DocumentChanges<'index> for PartialDumpChanges<'index, Iter> +where + Iter: IndexedParallelIterator> + Clone + Sync + 'index, +{ + type Item = Box; + + fn iter( + &self, + chunk_size: usize, + ) -> impl IndexedParallelIterator> { + self.iter.clone().chunks(chunk_size) + } + + fn item_to_document_change<'doc, T: MostlySend + 'doc>( + &'doc self, + context: &'doc DocumentChangeContext, + document: &'doc Self::Item, + ) -> Result>> + where + 'index: 'doc, + { + let doc_alloc = &context.doc_alloc; + let docid = match self.concurrent_available_ids.next() { + Some(id) => id, + None => return Err(Error::UserError(UserError::DocumentLimitReached)), + }; + + let mut fields_ids_map = context.new_fields_ids_map.borrow_mut_or_yield(); + let fields_ids_map = fields_ids_map.deref_mut(); + + let document = doc_alloc.alloc_str(document.get()); + let document: &RawValue = unsafe { std::mem::transmute(document) }; + + let external_document_id = + self.primary_key.extract_fields_and_docid(document, fields_ids_map, doc_alloc)?; + let external_document_id = external_document_id.to_de(); + + let document = raw_collections::RawMap::from_raw_value(document, doc_alloc) + .map_err(InternalError::SerdeJson)?; + + let insertion = Insertion::create(docid, external_document_id, Versions::single(document)); + Ok(Some(DocumentChange::Insertion(insertion))) + } + + fn len(&self) -> usize { + self.iter.len() + } +} diff --git a/crates/milli/src/update/new/indexer/update_by_function.rs b/crates/milli/src/update/new/indexer/update_by_function.rs new file mode 100644 index 000000000..a8e3e38a8 --- /dev/null +++ b/crates/milli/src/update/new/indexer/update_by_function.rs @@ -0,0 +1,213 @@ +use raw_collections::RawMap; +use rayon::iter::IndexedParallelIterator; +use rayon::slice::ParallelSlice as _; +use rhai::{Dynamic, Engine, OptimizationLevel, Scope, AST}; +use roaring::RoaringBitmap; + +use super::document_changes::DocumentChangeContext; +use super::DocumentChanges; +use crate::documents::Error::InvalidDocumentFormat; +use crate::documents::PrimaryKey; +use crate::error::{FieldIdMapMissingEntry, InternalError}; +use crate::update::new::document::Versions; +use crate::update::new::ref_cell_ext::RefCellExt as _; +use crate::update::new::thread_local::MostlySend; +use crate::update::new::{Deletion, DocumentChange, KvReaderFieldId, Update}; +use crate::{all_obkv_to_json, Error, FieldsIdsMap, Object, Result, UserError}; + +pub struct UpdateByFunction { + documents: RoaringBitmap, + context: Option, + code: String, +} + +pub struct UpdateByFunctionChanges<'doc> { + primary_key: &'doc PrimaryKey<'doc>, + engine: Engine, + ast: AST, + context: Option, + // It is sad that the RoaringBitmap doesn't + // implement IndexedParallelIterator + documents: Vec, +} + +impl UpdateByFunction { + pub fn new(documents: RoaringBitmap, context: Option, code: String) -> Self { + UpdateByFunction { documents, context, code } + } + + pub fn into_changes<'index>( + self, + primary_key: &'index PrimaryKey, + ) -> Result> { + let Self { documents, context, code } = self; + + // Setup the security and limits of the Engine + let mut engine = Engine::new(); + engine.set_optimization_level(OptimizationLevel::Full); + engine.set_max_call_levels(1000); + // It is an arbitrary value. We need to let users define this in the settings. + engine.set_max_operations(1_000_000); + engine.set_max_variables(1000); + engine.set_max_functions(30); + engine.set_max_expr_depths(100, 1000); + engine.set_max_string_size(1024 * 1024 * 1024); // 1 GiB + engine.set_max_array_size(10_000); + engine.set_max_map_size(10_000); + + let ast = engine.compile(code).map_err(UserError::DocumentEditionCompilationError)?; + let context = match context { + Some(context) => { + Some(serde_json::from_value(context.into()).map_err(InternalError::SerdeJson)?) + } + None => None, + }; + + Ok(UpdateByFunctionChanges { + primary_key, + engine, + ast, + context, + documents: documents.into_iter().collect(), + }) + } +} + +impl<'index> DocumentChanges<'index> for UpdateByFunctionChanges<'index> { + type Item = u32; + + fn iter( + &self, + chunk_size: usize, + ) -> impl IndexedParallelIterator> { + self.documents.as_slice().par_chunks(chunk_size) + } + + fn item_to_document_change<'doc, T: MostlySend + 'doc>( + &self, + context: &'doc DocumentChangeContext, + docid: &'doc Self::Item, + ) -> Result>> + where + 'index: 'doc, + { + let DocumentChangeContext { + index, + db_fields_ids_map, + rtxn: txn, + new_fields_ids_map, + doc_alloc, + .. + } = context; + + let docid = *docid; + + // safety: Both documents *must* exists in the database as + // their IDs comes from the list of documents ids. + let document = index.document(txn, docid)?; + let rhai_document = obkv_to_rhaimap(document, db_fields_ids_map)?; + let json_document = all_obkv_to_json(document, db_fields_ids_map)?; + + let document_id = self + .primary_key + .document_id(document, db_fields_ids_map)? + .map_err(|_| InvalidDocumentFormat)?; + + let mut scope = Scope::new(); + if let Some(context) = self.context.as_ref().cloned() { + scope.push_constant_dynamic("context", context.clone()); + } + scope.push("doc", rhai_document); + // We run the user script which edits "doc" scope variable reprensenting + // the document and ignore the output and even the type of it, i.e., Dynamic. + let _ = self + .engine + .eval_ast_with_scope::(&mut scope, &self.ast) + .map_err(UserError::DocumentEditionRuntimeError)?; + + match scope.remove::("doc") { + // If the "doc" variable has been set to (), we effectively delete the document. + Some(doc) if doc.is_unit() => Ok(Some(DocumentChange::Deletion(Deletion::create( + docid, + doc_alloc.alloc_str(&document_id), + )))), + None => unreachable!("missing doc variable from the Rhai scope"), + Some(new_document) => match new_document.try_cast() { + Some(new_rhai_document) => { + let mut buffer = bumpalo::collections::Vec::new_in(doc_alloc); + serde_json::to_writer(&mut buffer, &new_rhai_document) + .map_err(InternalError::SerdeJson)?; + let raw_new_doc = serde_json::from_slice(buffer.into_bump_slice()) + .map_err(InternalError::SerdeJson)?; + + // Note: This condition is not perfect. Sometimes it detect changes + // like with floating points numbers and consider updating + // the document even if nothing actually changed. + // + // Future: Use a custom function rhai function to track changes. + // + if json_document != rhaimap_to_object(new_rhai_document) { + let mut global_fields_ids_map = new_fields_ids_map.borrow_mut_or_yield(); + let new_document_id = self + .primary_key + .extract_fields_and_docid( + raw_new_doc, + &mut *global_fields_ids_map, + doc_alloc, + )? + .to_de(); + + if document_id != new_document_id { + Err(Error::UserError(UserError::DocumentEditionCannotModifyPrimaryKey)) + } else { + let raw_new_doc = RawMap::from_raw_value(raw_new_doc, doc_alloc) + .map_err(InternalError::SerdeJson)?; + + Ok(Some(DocumentChange::Update(Update::create( + docid, + new_document_id, + Versions::single(raw_new_doc), + true, // It is like document replacement + )))) + } + } else { + Ok(None) + } + } + None => Err(Error::UserError(UserError::DocumentEditionDocumentMustBeObject)), + }, + } + } + + fn len(&self) -> usize { + self.documents.len() + } +} + +fn obkv_to_rhaimap(obkv: &KvReaderFieldId, fields_ids_map: &FieldsIdsMap) -> Result { + let all_keys = obkv.iter().map(|(k, _v)| k).collect::>(); + let map: Result = all_keys + .iter() + .copied() + .flat_map(|id| obkv.get(id).map(|value| (id, value))) + .map(|(id, value)| { + let name = fields_ids_map.name(id).ok_or(FieldIdMapMissingEntry::FieldId { + field_id: id, + process: "all_obkv_to_rhaimap", + })?; + let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + Ok((name.into(), value)) + }) + .collect(); + + map +} + +fn rhaimap_to_object(map: rhai::Map) -> Object { + let mut output = Object::new(); + for (key, value) in map { + let value = serde_json::to_value(&value).unwrap(); + output.insert(key.into(), value); + } + output +} diff --git a/crates/milli/src/update/new/merger.rs b/crates/milli/src/update/new/merger.rs new file mode 100644 index 000000000..9d0d8e176 --- /dev/null +++ b/crates/milli/src/update/new/merger.rs @@ -0,0 +1,259 @@ +use std::cell::RefCell; + +use hashbrown::HashSet; +use heed::types::Bytes; +use heed::{Database, RoTxn}; +use memmap2::Mmap; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use roaring::RoaringBitmap; + +use super::channel::*; +use super::extract::{ + merge_caches, transpose_and_freeze_caches, BalancedCaches, DelAddRoaringBitmap, FacetKind, + GeoExtractorData, +}; +use crate::{CboRoaringBitmapCodec, FieldId, GeoPoint, Index, InternalError, Result}; + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::merge")] +pub fn merge_and_send_rtree<'extractor, MSP>( + datastore: impl IntoIterator>>, + rtxn: &RoTxn, + index: &Index, + geo_sender: GeoSender<'_>, + must_stop_processing: &MSP, +) -> Result<()> +where + MSP: Fn() -> bool + Sync, +{ + let mut rtree = index.geo_rtree(rtxn)?.unwrap_or_default(); + let mut faceted = index.geo_faceted_documents_ids(rtxn)?; + + for data in datastore { + if must_stop_processing() { + return Err(InternalError::AbortedIndexation.into()); + } + + let mut frozen = data.into_inner().freeze()?; + for result in frozen.iter_and_clear_removed() { + let extracted_geo_point = result?; + debug_assert!(rtree.remove(&GeoPoint::from(extracted_geo_point)).is_some()); + debug_assert!(faceted.remove(extracted_geo_point.docid)); + } + + for result in frozen.iter_and_clear_inserted() { + let extracted_geo_point = result?; + rtree.insert(GeoPoint::from(extracted_geo_point)); + debug_assert!(faceted.insert(extracted_geo_point.docid)); + } + } + + let mut file = tempfile::tempfile()?; + /// manage error + bincode::serialize_into(&mut file, &rtree).unwrap(); + file.sync_all()?; + + let rtree_mmap = unsafe { Mmap::map(&file)? }; + geo_sender.set_rtree(rtree_mmap).unwrap(); + geo_sender.set_geo_faceted(&faceted).unwrap(); + + Ok(()) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::merge")] +pub fn merge_and_send_docids<'extractor, MSP>( + mut caches: Vec>, + database: Database, + index: &Index, + docids_sender: impl DocidsSender + Sync, + must_stop_processing: &MSP, +) -> Result<()> +where + MSP: Fn() -> bool + Sync, +{ + transpose_and_freeze_caches(&mut caches)?.into_par_iter().try_for_each(|frozen| { + let rtxn = index.read_txn()?; + let mut buffer = Vec::new(); + if must_stop_processing() { + return Err(InternalError::AbortedIndexation.into()); + } + merge_caches(frozen, |key, DelAddRoaringBitmap { del, add }| { + let current = database.get(&rtxn, key)?; + match merge_cbo_bitmaps(current, del, add)? { + Operation::Write(bitmap) => { + let value = cbo_bitmap_serialize_into_vec(&bitmap, &mut buffer); + docids_sender.write(key, value).unwrap(); + Ok(()) + } + Operation::Delete => { + docids_sender.delete(key).unwrap(); + Ok(()) + } + Operation::Ignore => Ok(()), + } + }) + }) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::merge")] +pub fn merge_and_send_facet_docids<'extractor>( + mut caches: Vec>, + database: FacetDatabases, + index: &Index, + docids_sender: impl DocidsSender + Sync, +) -> Result { + transpose_and_freeze_caches(&mut caches)? + .into_par_iter() + .map(|frozen| { + let mut facet_field_ids_delta = FacetFieldIdsDelta::default(); + let rtxn = index.read_txn()?; + let mut buffer = Vec::new(); + merge_caches(frozen, |key, DelAddRoaringBitmap { del, add }| { + let current = database.get_cbo_roaring_bytes_value(&rtxn, key)?; + match merge_cbo_bitmaps(current, del, add)? { + Operation::Write(bitmap) => { + facet_field_ids_delta.register_from_key(key); + let value = cbo_bitmap_serialize_into_vec(&bitmap, &mut buffer); + docids_sender.write(key, value).unwrap(); + Ok(()) + } + Operation::Delete => { + facet_field_ids_delta.register_from_key(key); + docids_sender.delete(key).unwrap(); + Ok(()) + } + Operation::Ignore => Ok(()), + } + })?; + + Ok(facet_field_ids_delta) + }) + .reduce(|| Ok(FacetFieldIdsDelta::default()), |lhs, rhs| Ok(lhs?.merge(rhs?))) +} + +pub struct FacetDatabases<'a> { + index: &'a Index, +} + +impl<'a> FacetDatabases<'a> { + pub fn new(index: &'a Index) -> Self { + Self { index } + } + + fn get_cbo_roaring_bytes_value<'t>( + &self, + rtxn: &'t RoTxn<'_>, + key: &[u8], + ) -> heed::Result> { + let (facet_kind, key) = FacetKind::extract_from_key(key); + + let value = + super::channel::Database::from(facet_kind).database(self.index).get(rtxn, key)?; + match facet_kind { + // skip level group size + FacetKind::String | FacetKind::Number => Ok(value.map(|v| &v[1..])), + _ => Ok(value), + } + } +} + +#[derive(Debug, Default)] +pub struct FacetFieldIdsDelta { + /// The field ids that have been modified + modified_facet_string_ids: HashSet, + modified_facet_number_ids: HashSet, +} + +impl FacetFieldIdsDelta { + fn register_facet_string_id(&mut self, field_id: FieldId) { + self.modified_facet_string_ids.insert(field_id); + } + + fn register_facet_number_id(&mut self, field_id: FieldId) { + self.modified_facet_number_ids.insert(field_id); + } + + fn register_from_key(&mut self, key: &[u8]) { + let (facet_kind, field_id) = self.extract_key_data(key); + match facet_kind { + FacetKind::Number => self.register_facet_number_id(field_id), + FacetKind::String => self.register_facet_string_id(field_id), + _ => (), + } + } + + fn extract_key_data(&self, key: &[u8]) -> (FacetKind, FieldId) { + let facet_kind = FacetKind::from(key[0]); + let field_id = FieldId::from_be_bytes([key[1], key[2]]); + (facet_kind, field_id) + } + + pub fn modified_facet_string_ids(&self) -> Option> { + if self.modified_facet_string_ids.is_empty() { + None + } else { + Some(self.modified_facet_string_ids.iter().copied().collect()) + } + } + + pub fn modified_facet_number_ids(&self) -> Option> { + if self.modified_facet_number_ids.is_empty() { + None + } else { + Some(self.modified_facet_number_ids.iter().copied().collect()) + } + } + + pub fn merge(mut self, rhs: Self) -> Self { + let Self { modified_facet_number_ids, modified_facet_string_ids } = rhs; + modified_facet_number_ids.into_iter().for_each(|fid| { + self.modified_facet_number_ids.insert(fid); + }); + modified_facet_string_ids.into_iter().for_each(|fid| { + self.modified_facet_string_ids.insert(fid); + }); + self + } +} + +enum Operation { + Write(RoaringBitmap), + Delete, + Ignore, +} + +/// A function that merges the DelAdd CboRoaringBitmaps with the current bitmap. +fn merge_cbo_bitmaps( + current: Option<&[u8]>, + del: Option, + add: Option, +) -> Result { + let current = current.map(CboRoaringBitmapCodec::deserialize_from).transpose()?; + match (current, del, add) { + (None, None, None) => Ok(Operation::Ignore), // but it's strange + (None, None, Some(add)) => Ok(Operation::Write(add)), + (None, Some(_del), None) => Ok(Operation::Ignore), // but it's strange + (None, Some(_del), Some(add)) => Ok(Operation::Write(add)), + (Some(_current), None, None) => Ok(Operation::Ignore), // but it's strange + (Some(current), None, Some(add)) => Ok(Operation::Write(current | add)), + (Some(current), Some(del), add) => { + let output = match add { + Some(add) => (¤t - del) | add, + None => ¤t - del, + }; + if output.is_empty() { + Ok(Operation::Delete) + } else if current == output { + Ok(Operation::Ignore) + } else { + Ok(Operation::Write(output)) + } + } + } +} + +/// TODO Return the slice directly from the serialize_into method +fn cbo_bitmap_serialize_into_vec<'b>(bitmap: &RoaringBitmap, buffer: &'b mut Vec) -> &'b [u8] { + buffer.clear(); + CboRoaringBitmapCodec::serialize_into(bitmap, buffer); + buffer.as_slice() +} diff --git a/crates/milli/src/update/new/mod.rs b/crates/milli/src/update/new/mod.rs new file mode 100644 index 000000000..140f4ccf0 --- /dev/null +++ b/crates/milli/src/update/new/mod.rs @@ -0,0 +1,32 @@ +pub use document_change::{Deletion, DocumentChange, Insertion, Update}; +pub use merger::{ + merge_and_send_docids, merge_and_send_facet_docids, FacetDatabases, FacetFieldIdsDelta, +}; +pub use top_level_map::{CowStr, TopLevelMap}; + +use super::del_add::DelAdd; +use crate::FieldId; + +mod channel; +pub mod document; +mod document_change; +mod extract; +mod facet_search_builder; +mod fst_merger_builder; +pub mod indexer; +mod merger; +mod parallel_iterator_ext; +mod ref_cell_ext; +pub(crate) mod steps; +pub(crate) mod thread_local; +mod top_level_map; +pub mod vector_document; +mod word_fst_builder; +mod words_prefix_docids; + +/// TODO move them elsewhere +pub type StdResult = std::result::Result; +pub type KvReaderDelAdd = obkv::KvReader; +pub type KvReaderFieldId = obkv::KvReader; +pub type KvWriterDelAdd = obkv::KvWriter; +pub type KvWriterFieldId = obkv::KvWriter; diff --git a/crates/milli/src/update/new/parallel_iterator_ext.rs b/crates/milli/src/update/new/parallel_iterator_ext.rs new file mode 100644 index 000000000..ff69d7acf --- /dev/null +++ b/crates/milli/src/update/new/parallel_iterator_ext.rs @@ -0,0 +1,33 @@ +use std::sync::Arc; + +use rayon::iter::ParallelIterator; + +pub trait ParallelIteratorExt: ParallelIterator { + /// A method to run a closure of all the items and return an owned error. + /// + /// The init function is ran only as necessary which is basically once by thread. + fn try_arc_for_each_try_init(self, init: INIT, op: F) -> Result<(), E> + where + E: Send + Sync, + F: Fn(&mut T, Self::Item) -> Result<(), Arc> + Sync + Send + Clone, + INIT: Fn() -> Result + Sync + Send + Clone, + { + let result = self.try_for_each_init( + move || match init() { + Ok(t) => Ok(t), + Err(err) => Err(Arc::new(err)), + }, + move |result, item| match result { + Ok(t) => op(t, item), + Err(err) => Err(err.clone()), + }, + ); + + match result { + Ok(()) => Ok(()), + Err(err) => Err(Arc::into_inner(err).expect("the error must be only owned by us")), + } + } +} + +impl ParallelIteratorExt for T {} diff --git a/crates/milli/src/update/new/ref_cell_ext.rs b/crates/milli/src/update/new/ref_cell_ext.rs new file mode 100644 index 000000000..c66f4af0a --- /dev/null +++ b/crates/milli/src/update/new/ref_cell_ext.rs @@ -0,0 +1,31 @@ +use std::cell::{RefCell, RefMut}; + +pub trait RefCellExt { + fn try_borrow_mut_or_yield( + &self, + ) -> std::result::Result, std::cell::BorrowMutError>; + + fn borrow_mut_or_yield(&self) -> RefMut<'_, T> { + self.try_borrow_mut_or_yield().unwrap() + } +} + +impl RefCellExt for RefCell { + fn try_borrow_mut_or_yield( + &self, + ) -> std::result::Result, std::cell::BorrowMutError> { + loop { + match self.try_borrow_mut() { + Ok(borrow) => break Ok(borrow), + Err(error) => { + tracing::warn!("dynamic borrow failed, yielding to local tasks"); + + match rayon::yield_local() { + Some(rayon::Yield::Executed) => continue, + _ => return Err(error), + } + } + } + } + } +} diff --git a/crates/milli/src/update/new/steps.rs b/crates/milli/src/update/new/steps.rs new file mode 100644 index 000000000..60a0c872b --- /dev/null +++ b/crates/milli/src/update/new/steps.rs @@ -0,0 +1,45 @@ +use enum_iterator::Sequence; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Sequence)] +#[repr(u16)] +pub enum Step { + ExtractingDocuments, + ExtractingFacets, + ExtractingWords, + ExtractingWordProximity, + ExtractingEmbeddings, + WritingGeoPoints, + WritingToDatabase, + WritingEmbeddingsToDatabase, + WaitingForExtractors, + PostProcessingFacets, + PostProcessingWords, + Finalizing, +} + +impl Step { + pub fn name(&self) -> &'static str { + match self { + Step::ExtractingDocuments => "extracting documents", + Step::ExtractingFacets => "extracting facets", + Step::ExtractingWords => "extracting words", + Step::ExtractingWordProximity => "extracting word proximity", + Step::ExtractingEmbeddings => "extracting embeddings", + Step::WritingGeoPoints => "writing geo points", + Step::WritingToDatabase => "writing to database", + Step::WritingEmbeddingsToDatabase => "writing embeddings to database", + Step::WaitingForExtractors => "waiting for extractors", + Step::PostProcessingFacets => "post-processing facets", + Step::PostProcessingWords => "post-processing words", + Step::Finalizing => "finalizing", + } + } + + pub fn finished_steps(self) -> u16 { + self as u16 + } + + pub const fn total_steps() -> u16 { + Self::CARDINALITY as u16 + } +} diff --git a/crates/milli/src/update/new/thread_local.rs b/crates/milli/src/update/new/thread_local.rs new file mode 100644 index 000000000..acdc78c7b --- /dev/null +++ b/crates/milli/src/update/new/thread_local.rs @@ -0,0 +1,174 @@ +use std::cell::RefCell; + +/// A trait for types that are **not** [`Send`] only because they would then allow concurrent access to a type that is not [`Sync`]. +/// +/// The primary example of such a type is `&T`, with `T: !Sync`. +/// +/// In the authors' understanding, a type can be `!Send` for two distinct reasons: +/// +/// 1. Because it contains data that *genuinely* cannot be moved between threads, such as thread-local data. +/// 2. Because sending the type would allow concurrent access to a `!Sync` type, which is undefined behavior. +/// +/// `MostlySend` exists to be used in bounds where you need a type whose data is **not** *attached* to a thread +/// because you might access it from a different thread, but where you will never access the type **concurrently** from +/// multiple threads. +/// +/// Like [`Send`], `MostlySend` assumes properties on types that cannot be verified by the compiler, which is why implementing +/// this trait is unsafe. +/// +/// # Safety +/// +/// Implementers of this trait promises that the following properties hold on the implementing type: +/// +/// 1. Its data can be accessed from any thread and will be the same regardless of the thread accessing it. +/// 2. Any operation that can be performed on the type does not depend on the thread that executes it. +/// +/// As these properties are subtle and are not generally tracked by the Rust type system, great care should be taken before +/// implementing `MostlySend` on a type, especially a foreign type. +/// +/// - An example of a type that verifies (1) and (2) is [`std::rc::Rc`] (when `T` is `Send` and `Sync`). +/// - An example of a type that doesn't verify (1) is thread-local data. +/// - An example of a type that doesn't verify (2) is [`std::sync::MutexGuard`]: a lot of mutex implementations require that +/// a lock is returned to the operating system on the same thread that initially locked the mutex, failing to uphold this +/// invariant will cause Undefined Behavior +/// (see last § in [the nomicon](https://doc.rust-lang.org/nomicon/send-and-sync.html)). +/// +/// It is **always safe** to implement this trait on a type that is `Send`, but no placeholder impl is provided due to limitations in +/// coherency. Use the [`FullySend`] wrapper in this situation. +pub unsafe trait MostlySend {} + +#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct FullySend(pub T); + +// SAFETY: a type **fully** send is always mostly send as well. +unsafe impl MostlySend for FullySend where T: Send {} + +unsafe impl MostlySend for RefCell where T: MostlySend {} + +unsafe impl MostlySend for Option where T: MostlySend {} + +impl FullySend { + pub fn into(self) -> T { + self.0 + } +} + +impl From for FullySend { + fn from(value: T) -> Self { + Self(value) + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +struct MostlySendWrapper(T); + +impl MostlySendWrapper { + /// # Safety + /// + /// - (P1) Users of this type will never access the type concurrently from multiple threads without synchronization + unsafe fn new(t: T) -> Self { + Self(t) + } + + fn as_ref(&self) -> &T { + &self.0 + } + + fn as_mut(&mut self) -> &mut T { + &mut self.0 + } + + fn into_inner(self) -> T { + self.0 + } +} + +/// # Safety +/// +/// 1. `T` is [`MostlySend`], so by its safety contract it can be accessed by any thread and all of its operations are available +/// from any thread. +/// 2. (P1) of `MostlySendWrapper::new` forces the user to never access the value from multiple threads concurrently. +unsafe impl Send for MostlySendWrapper {} + +/// A wrapper around [`thread_local::ThreadLocal`] that accepts [`MostlySend`] `T`s. +#[derive(Default)] +pub struct ThreadLocal { + inner: thread_local::ThreadLocal>, + // FIXME: this should be necessary + //_no_send: PhantomData<*mut ()>, +} + +impl ThreadLocal { + pub fn new() -> Self { + Self { inner: thread_local::ThreadLocal::new() } + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { inner: thread_local::ThreadLocal::with_capacity(capacity) } + } + + pub fn clear(&mut self) { + self.inner.clear() + } + + pub fn get(&self) -> Option<&T> { + self.inner.get().map(|t| t.as_ref()) + } + + pub fn get_or(&self, create: F) -> &T + where + F: FnOnce() -> T, + { + self.inner.get_or(|| unsafe { MostlySendWrapper::new(create()) }).as_ref() + } + + pub fn get_or_try(&self, create: F) -> std::result::Result<&T, E> + where + F: FnOnce() -> std::result::Result, + { + self.inner + .get_or_try(|| unsafe { Ok(MostlySendWrapper::new(create()?)) }) + .map(MostlySendWrapper::as_ref) + } + + pub fn get_or_default(&self) -> &T + where + T: Default, + { + self.inner.get_or_default().as_ref() + } + + pub fn iter_mut(&mut self) -> IterMut { + IterMut(self.inner.iter_mut()) + } +} + +impl IntoIterator for ThreadLocal { + type Item = T; + + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.inner.into_iter()) + } +} + +pub struct IterMut<'a, T: MostlySend>(thread_local::IterMut<'a, MostlySendWrapper>); + +impl<'a, T: MostlySend> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + self.0.next().map(|t| t.as_mut()) + } +} + +pub struct IntoIter(thread_local::IntoIter>); + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next().map(|t| t.into_inner()) + } +} diff --git a/crates/milli/src/update/new/top_level_map.rs b/crates/milli/src/update/new/top_level_map.rs new file mode 100644 index 000000000..aebb64bc9 --- /dev/null +++ b/crates/milli/src/update/new/top_level_map.rs @@ -0,0 +1,66 @@ +use std::borrow::{Borrow, Cow}; +use std::collections::BTreeMap; +use std::{fmt, ops}; + +use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue; +use serde_json::{Map, Value}; + +#[derive(Deserialize, Serialize)] +pub struct TopLevelMap<'p>(#[serde(borrow)] pub BTreeMap, &'p RawValue>); + +impl TryFrom<&'_ TopLevelMap<'_>> for Map { + type Error = serde_json::Error; + + fn try_from(tlmap: &TopLevelMap<'_>) -> Result { + let mut object = Map::new(); + for (k, v) in &tlmap.0 { + let value = serde_json::from_str(v.get())?; + object.insert(k.to_string(), value); + } + Ok(object) + } +} + +impl TryFrom> for Map { + type Error = serde_json::Error; + + fn try_from(tlmap: TopLevelMap<'_>) -> Result { + TryFrom::try_from(&tlmap) + } +} + +impl<'p> ops::Deref for TopLevelMap<'p> { + type Target = BTreeMap, &'p RawValue>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl ops::DerefMut for TopLevelMap<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct CowStr<'p>(#[serde(borrow)] pub Cow<'p, str>); + +impl fmt::Display for CowStr<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl AsRef for CowStr<'_> { + fn as_ref(&self) -> &str { + self.0.as_ref() + } +} + +impl<'doc> Borrow for CowStr<'doc> { + fn borrow(&self) -> &str { + self.0.borrow() + } +} diff --git a/crates/milli/src/update/new/vector_document.rs b/crates/milli/src/update/new/vector_document.rs new file mode 100644 index 000000000..319730db0 --- /dev/null +++ b/crates/milli/src/update/new/vector_document.rs @@ -0,0 +1,345 @@ +use std::collections::BTreeSet; + +use bumpalo::Bump; +use deserr::{Deserr, IntoValue}; +use heed::RoTxn; +use raw_collections::RawMap; +use serde::Serialize; +use serde_json::value::RawValue; + +use super::document::{Document, DocumentFromDb, DocumentFromVersions, Versions}; +use super::indexer::de::DeserrRawValue; +use crate::documents::FieldIdMapper; +use crate::index::IndexEmbeddingConfig; +use crate::vector::parsed_vectors::{ + RawVectors, RawVectorsError, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME, +}; +use crate::vector::{ArroyWrapper, Embedding, EmbeddingConfigs}; +use crate::{DocumentId, Index, InternalError, Result, UserError}; + +#[derive(Serialize)] +#[serde(untagged)] +pub enum Embeddings<'doc> { + FromJsonExplicit(&'doc RawValue), + FromJsonImplicityUserProvided(&'doc RawValue), + FromDb(Vec), +} +impl<'doc> Embeddings<'doc> { + pub fn into_vec( + self, + doc_alloc: &'doc Bump, + embedder_name: &str, + ) -> std::result::Result, deserr::errors::JsonError> { + match self { + Embeddings::FromJsonExplicit(value) => { + let vectors_ref = deserr::ValuePointerRef::Key { + key: RESERVED_VECTORS_FIELD_NAME, + prev: &deserr::ValuePointerRef::Origin, + }; + let embedders_ref = + deserr::ValuePointerRef::Key { key: embedder_name, prev: &vectors_ref }; + + let embeddings_ref = + deserr::ValuePointerRef::Key { key: "embeddings", prev: &embedders_ref }; + + let v: VectorOrArrayOfVectors = VectorOrArrayOfVectors::deserialize_from_value( + DeserrRawValue::new_in(value, doc_alloc).into_value(), + embeddings_ref, + )?; + Ok(v.into_array_of_vectors().unwrap_or_default()) + } + Embeddings::FromJsonImplicityUserProvided(value) => { + let vectors_ref = deserr::ValuePointerRef::Key { + key: RESERVED_VECTORS_FIELD_NAME, + prev: &deserr::ValuePointerRef::Origin, + }; + let embedders_ref = + deserr::ValuePointerRef::Key { key: embedder_name, prev: &vectors_ref }; + + let v: VectorOrArrayOfVectors = VectorOrArrayOfVectors::deserialize_from_value( + DeserrRawValue::new_in(value, doc_alloc).into_value(), + embedders_ref, + )?; + Ok(v.into_array_of_vectors().unwrap_or_default()) + } + Embeddings::FromDb(vec) => Ok(vec), + } + } +} + +pub struct VectorEntry<'doc> { + pub has_configured_embedder: bool, + pub embeddings: Option>, + pub regenerate: bool, + pub implicit: bool, +} + +pub trait VectorDocument<'doc> { + fn iter_vectors(&self) -> impl Iterator)>>; + + fn vectors_for_key(&self, key: &str) -> Result>>; +} + +pub struct VectorDocumentFromDb<'t> { + docid: DocumentId, + embedding_config: Vec, + index: &'t Index, + vectors_field: Option>, + rtxn: &'t RoTxn<'t>, + doc_alloc: &'t Bump, +} + +impl<'t> VectorDocumentFromDb<'t> { + pub fn new( + docid: DocumentId, + index: &'t Index, + rtxn: &'t RoTxn, + db_fields_ids_map: &'t Mapper, + doc_alloc: &'t Bump, + ) -> Result> { + let Some(document) = DocumentFromDb::new(docid, rtxn, index, db_fields_ids_map)? else { + return Ok(None); + }; + let vectors = document.vectors_field()?; + let vectors_field = match vectors { + Some(vectors) => { + Some(RawMap::from_raw_value(vectors, doc_alloc).map_err(InternalError::SerdeJson)?) + } + None => None, + }; + + let embedding_config = index.embedding_configs(rtxn)?; + + Ok(Some(Self { docid, embedding_config, index, vectors_field, rtxn, doc_alloc })) + } + + fn entry_from_db( + &self, + embedder_id: u8, + config: &IndexEmbeddingConfig, + ) -> Result> { + let reader = + ArroyWrapper::new(self.index.vector_arroy, embedder_id, config.config.quantized()); + let vectors = reader.item_vectors(self.rtxn, self.docid)?; + + Ok(VectorEntry { + has_configured_embedder: true, + embeddings: Some(Embeddings::FromDb(vectors)), + regenerate: !config.user_provided.contains(self.docid), + implicit: false, + }) + } +} + +impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { + fn iter_vectors(&self) -> impl Iterator)>> { + self.embedding_config + .iter() + .map(|config| { + let embedder_id = + self.index.embedder_category_id.get(self.rtxn, &config.name)?.unwrap(); + let entry = self.entry_from_db(embedder_id, config)?; + let config_name = self.doc_alloc.alloc_str(config.name.as_str()); + Ok((&*config_name, entry)) + }) + .chain(self.vectors_field.iter().flat_map(|map| map.iter()).map(|(name, value)| { + Ok(( + name, + entry_from_raw_value(value, false).map_err(|_| { + InternalError::Serialization(crate::SerializationError::Decoding { + db_name: Some(crate::index::db_name::VECTOR_ARROY), + }) + })?, + )) + })) + } + + fn vectors_for_key(&self, key: &str) -> Result>> { + Ok(match self.index.embedder_category_id.get(self.rtxn, key)? { + Some(embedder_id) => { + let config = + self.embedding_config.iter().find(|config| config.name == key).unwrap(); + Some(self.entry_from_db(embedder_id, config)?) + } + None => match self.vectors_field.as_ref().and_then(|obkv| obkv.get(key)) { + Some(embedding_from_doc) => { + Some(entry_from_raw_value(embedding_from_doc, false).map_err(|_| { + InternalError::Serialization(crate::SerializationError::Decoding { + db_name: Some(crate::index::db_name::VECTOR_ARROY), + }) + })?) + } + None => None, + }, + }) + } +} + +fn entry_from_raw_value_user<'doc>( + external_docid: &str, + embedder_name: &str, + value: &'doc RawValue, + has_configured_embedder: bool, +) -> Result> { + entry_from_raw_value(value, has_configured_embedder).map_err(|error| { + UserError::InvalidVectorsEmbedderConf { + document_id: external_docid.to_string(), + error: error.msg(embedder_name), + } + .into() + }) +} + +fn entry_from_raw_value( + value: &RawValue, + has_configured_embedder: bool, +) -> std::result::Result, RawVectorsError> { + let value: RawVectors = RawVectors::from_raw_value(value)?; + + Ok(match value { + RawVectors::Explicit(raw_explicit_vectors) => VectorEntry { + has_configured_embedder, + embeddings: raw_explicit_vectors.embeddings.map(Embeddings::FromJsonExplicit), + regenerate: raw_explicit_vectors.regenerate, + implicit: false, + }, + RawVectors::ImplicitlyUserProvided(value) => VectorEntry { + has_configured_embedder, + // implicitly user provided always provide embeddings + // `None` here means that there are no embeddings + embeddings: Some( + value + .map(Embeddings::FromJsonImplicityUserProvided) + .unwrap_or(Embeddings::FromDb(Default::default())), + ), + regenerate: false, + implicit: true, + }, + }) +} + +pub struct VectorDocumentFromVersions<'doc> { + external_document_id: &'doc str, + vectors: RawMap<'doc>, + embedders: &'doc EmbeddingConfigs, +} + +impl<'doc> VectorDocumentFromVersions<'doc> { + pub fn new( + external_document_id: &'doc str, + versions: &Versions<'doc>, + bump: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result> { + let document = DocumentFromVersions::new(versions); + if let Some(vectors_field) = document.vectors_field()? { + let vectors = + RawMap::from_raw_value(vectors_field, bump).map_err(UserError::SerdeJson)?; + Ok(Some(Self { external_document_id, vectors, embedders })) + } else { + Ok(None) + } + } +} + +impl<'doc> VectorDocument<'doc> for VectorDocumentFromVersions<'doc> { + fn iter_vectors(&self) -> impl Iterator)>> { + self.vectors.iter().map(|(embedder, vectors)| { + let vectors = entry_from_raw_value_user( + self.external_document_id, + embedder, + vectors, + self.embedders.contains(embedder), + )?; + Ok((embedder, vectors)) + }) + } + + fn vectors_for_key(&self, key: &str) -> Result>> { + let Some(vectors) = self.vectors.get(key) else { return Ok(None) }; + let vectors = entry_from_raw_value_user( + self.external_document_id, + key, + vectors, + self.embedders.contains(key), + )?; + Ok(Some(vectors)) + } +} + +pub struct MergedVectorDocument<'doc> { + new_doc: Option>, + db: Option>, +} + +impl<'doc> MergedVectorDocument<'doc> { + #[allow(clippy::too_many_arguments)] + pub fn with_db( + docid: DocumentId, + external_document_id: &'doc str, + index: &'doc Index, + rtxn: &'doc RoTxn, + db_fields_ids_map: &'doc Mapper, + versions: &Versions<'doc>, + doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result> { + let db = VectorDocumentFromDb::new(docid, index, rtxn, db_fields_ids_map, doc_alloc)?; + let new_doc = + VectorDocumentFromVersions::new(external_document_id, versions, doc_alloc, embedders)?; + Ok(if db.is_none() && new_doc.is_none() { None } else { Some(Self { new_doc, db }) }) + } + + pub fn without_db( + external_document_id: &'doc str, + versions: &Versions<'doc>, + doc_alloc: &'doc Bump, + embedders: &'doc EmbeddingConfigs, + ) -> Result> { + let Some(new_doc) = + VectorDocumentFromVersions::new(external_document_id, versions, doc_alloc, embedders)? + else { + return Ok(None); + }; + Ok(Some(Self { new_doc: Some(new_doc), db: None })) + } +} + +impl<'doc> VectorDocument<'doc> for MergedVectorDocument<'doc> { + fn iter_vectors(&self) -> impl Iterator)>> { + let mut new_doc_it = self.new_doc.iter().flat_map(|new_doc| new_doc.iter_vectors()); + let mut db_it = self.db.iter().flat_map(|db| db.iter_vectors()); + let mut seen_fields = BTreeSet::new(); + + std::iter::from_fn(move || { + if let Some(next) = new_doc_it.next() { + if let Ok((name, _)) = next { + seen_fields.insert(name); + } + return Some(next); + } + loop { + match db_it.next()? { + Ok((name, value)) => { + if seen_fields.contains(name) { + continue; + } + return Some(Ok((name, value))); + } + Err(err) => return Some(Err(err)), + } + } + }) + } + + fn vectors_for_key(&self, key: &str) -> Result>> { + if let Some(new_doc) = &self.new_doc { + if let Some(entry) = new_doc.vectors_for_key(key)? { + return Ok(Some(entry)); + } + } + + let Some(db) = self.db.as_ref() else { return Ok(None) }; + db.vectors_for_key(key) + } +} diff --git a/crates/milli/src/update/new/word_fst_builder.rs b/crates/milli/src/update/new/word_fst_builder.rs new file mode 100644 index 000000000..2b1c4604b --- /dev/null +++ b/crates/milli/src/update/new/word_fst_builder.rs @@ -0,0 +1,199 @@ +use std::collections::HashSet; +use std::io::BufWriter; + +use fst::{Set, SetBuilder, Streamer}; +use memmap2::Mmap; +use tempfile::tempfile; + +use super::fst_merger_builder::FstMergerBuilder; +use crate::index::PrefixSettings; +use crate::update::del_add::DelAdd; +use crate::{InternalError, Prefix, Result}; + +pub struct WordFstBuilder<'a> { + word_fst_builder: FstMergerBuilder<'a>, + prefix_fst_builder: Option, + registered_words: usize, +} + +impl<'a> WordFstBuilder<'a> { + pub fn new(words_fst: &'a Set>) -> Result { + Ok(Self { + word_fst_builder: FstMergerBuilder::new(Some(words_fst))?, + prefix_fst_builder: None, + registered_words: 0, + }) + } + + pub fn with_prefix_settings(&mut self, prefix_settings: PrefixSettings) -> &Self { + self.prefix_fst_builder = PrefixFstBuilder::new(prefix_settings); + self + } + + pub fn register_word(&mut self, deladd: DelAdd, right: &[u8]) -> Result<()> { + if deladd == DelAdd::Addition { + self.registered_words += 1; + } + + self.word_fst_builder.register(deladd, right, &mut |bytes, deladd, is_modified| { + if let Some(prefix_fst_builder) = &mut self.prefix_fst_builder { + prefix_fst_builder.insert_word(bytes, deladd, is_modified) + } else { + Ok(()) + } + })?; + + Ok(()) + } + + pub fn build( + mut self, + index: &crate::Index, + rtxn: &heed::RoTxn, + ) -> Result<(Mmap, Option)> { + let words_fst_mmap = self.word_fst_builder.build(&mut |bytes, deladd, is_modified| { + if let Some(prefix_fst_builder) = &mut self.prefix_fst_builder { + prefix_fst_builder.insert_word(bytes, deladd, is_modified) + } else { + Ok(()) + } + })?; + + let prefix_data = self + .prefix_fst_builder + .map(|prefix_fst_builder| prefix_fst_builder.build(index, rtxn)) + .transpose()?; + + Ok((words_fst_mmap, prefix_data)) + } +} + +pub struct PrefixData { + pub prefixes_fst_mmap: Mmap, + pub prefix_delta: PrefixDelta, +} + +#[derive(Debug)] +pub struct PrefixDelta { + pub modified: HashSet, + pub deleted: HashSet, +} + +struct PrefixFstBuilder { + prefix_count_threshold: u64, + max_prefix_length: usize, + /// TODO: Replace the full memory allocation + prefix_fst_builders: Vec>>, + current_prefix: Vec, + current_prefix_count: Vec, + modified_prefixes: HashSet, + current_prefix_is_modified: Vec, +} + +impl PrefixFstBuilder { + pub fn new(prefix_settings: PrefixSettings) -> Option { + let PrefixSettings { prefix_count_threshold, max_prefix_length, compute_prefixes } = + prefix_settings; + + if !compute_prefixes { + return None; + } + + let mut prefix_fst_builders = Vec::new(); + for _ in 0..max_prefix_length { + prefix_fst_builders.push(SetBuilder::memory()); + } + + Some(Self { + prefix_count_threshold, + max_prefix_length, + prefix_fst_builders, + current_prefix: vec![Prefix::new(); max_prefix_length], + current_prefix_count: vec![0; max_prefix_length], + modified_prefixes: HashSet::new(), + current_prefix_is_modified: vec![false; max_prefix_length], + }) + } + + fn insert_word(&mut self, bytes: &[u8], deladd: DelAdd, is_modified: bool) -> Result<()> { + for n in 0..self.max_prefix_length { + let current_prefix = &mut self.current_prefix[n]; + let current_prefix_count = &mut self.current_prefix_count[n]; + let builder = &mut self.prefix_fst_builders[n]; + let current_prefix_is_modified = &mut self.current_prefix_is_modified[n]; + + // We try to get the first n bytes out of this string but we only want + // to split at valid characters bounds. If we try to split in the middle of + // a character we ignore this word and go to the next one. + let word = std::str::from_utf8(bytes)?; + let prefix = match word.get(..=n) { + Some(prefix) => prefix, + None => continue, + }; + + // This is the first iteration of the loop, + // or the current word doesn't starts with the current prefix. + if *current_prefix_count == 0 || prefix != current_prefix.as_str() { + *current_prefix = Prefix::from(prefix); + *current_prefix_count = 0; + *current_prefix_is_modified = false; + } + + if deladd == DelAdd::Addition { + *current_prefix_count += 1; + } + + if is_modified && !*current_prefix_is_modified { + if *current_prefix_count > self.prefix_count_threshold { + self.modified_prefixes.insert(current_prefix.clone()); + } + + *current_prefix_is_modified = true; + } + + // There is enough words corresponding to this prefix to add it to the cache. + if *current_prefix_count == self.prefix_count_threshold { + builder.insert(prefix)?; + + if *current_prefix_is_modified { + self.modified_prefixes.insert(current_prefix.clone()); + } + } + } + + Ok(()) + } + + fn build(self, index: &crate::Index, rtxn: &heed::RoTxn) -> Result { + // We merge all of the previously computed prefixes into on final set. + let mut prefix_fsts = Vec::new(); + for builder in self.prefix_fst_builders.into_iter() { + let prefix_fst = builder.into_set(); + prefix_fsts.push(prefix_fst); + } + let op = fst::set::OpBuilder::from_iter(prefix_fsts.iter()); + let mut builder = SetBuilder::new(BufWriter::new(tempfile()?))?; + builder.extend_stream(op.r#union())?; + let prefix_fst_file = builder.into_inner()?.into_inner().map_err(|_| { + InternalError::IndexingMergingKeys { process: "building-words-prefixes-fst" } + })?; + let prefix_fst_mmap = unsafe { Mmap::map(&prefix_fst_file)? }; + let new_prefix_fst = Set::new(&prefix_fst_mmap)?; + let old_prefix_fst = index.words_prefixes_fst(rtxn)?; + let mut deleted_prefixes = HashSet::new(); + { + let mut deleted_prefixes_stream = old_prefix_fst.op().add(&new_prefix_fst).difference(); + while let Some(prefix) = deleted_prefixes_stream.next() { + deleted_prefixes.insert(Prefix::from(std::str::from_utf8(prefix)?)); + } + } + + Ok(PrefixData { + prefixes_fst_mmap: prefix_fst_mmap, + prefix_delta: PrefixDelta { + modified: self.modified_prefixes, + deleted: deleted_prefixes, + }, + }) + } +} diff --git a/crates/milli/src/update/new/words_prefix_docids.rs b/crates/milli/src/update/new/words_prefix_docids.rs new file mode 100644 index 000000000..5454d815e --- /dev/null +++ b/crates/milli/src/update/new/words_prefix_docids.rs @@ -0,0 +1,343 @@ +use std::cell::RefCell; +use std::collections::HashSet; +use std::io::{BufReader, BufWriter, Read, Seek, Write}; + +use hashbrown::HashMap; +use heed::types::Bytes; +use heed::{BytesDecode, Database, Error, RoTxn, RwTxn}; +use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; +use roaring::MultiOps; +use tempfile::tempfile; +use thread_local::ThreadLocal; + +use super::ref_cell_ext::RefCellExt as _; +use crate::heed_codec::StrBEU16Codec; +use crate::{CboRoaringBitmapCodec, Index, Prefix, Result}; + +struct WordPrefixDocids { + database: Database, + prefix_database: Database, +} + +impl WordPrefixDocids { + fn new( + database: Database, + prefix_database: Database, + ) -> WordPrefixDocids { + WordPrefixDocids { database, prefix_database } + } + + fn execute( + self, + wtxn: &mut heed::RwTxn, + prefix_to_compute: &HashSet, + prefix_to_delete: &HashSet, + ) -> Result<()> { + delete_prefixes(wtxn, &self.prefix_database, prefix_to_delete)?; + self.recompute_modified_prefixes(wtxn, prefix_to_compute) + } + + #[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] + fn recompute_modified_prefixes( + &self, + wtxn: &mut RwTxn, + prefixes: &HashSet, + ) -> Result<()> { + // We fetch the docids associated to the newly added word prefix fst only. + // And collect the CboRoaringBitmaps pointers in an HashMap. + let frozen = FrozenPrefixBitmaps::from_prefixes(self.database, wtxn, prefixes)?; + + // We access this HashMap in parallel to compute the *union* of all + // of them and *serialize* them into files. There is one file by CPU. + let local_entries = ThreadLocal::with_capacity(rayon::current_num_threads()); + prefixes.into_par_iter().map(AsRef::as_ref).try_for_each(|prefix| { + let refcell = local_entries.get_or_try(|| { + tempfile().map(BufWriter::new).map(|f| RefCell::new((Vec::new(), f, Vec::new()))) + })?; + + let mut refmut = refcell.borrow_mut_or_yield(); + let (ref mut index, ref mut file, ref mut buffer) = *refmut; + + let output = frozen + .bitmaps(prefix) + .unwrap() + .iter() + .map(|bytes| CboRoaringBitmapCodec::deserialize_from(bytes)) + .union()?; + + buffer.clear(); + CboRoaringBitmapCodec::serialize_into(&output, buffer); + index.push(PrefixEntry { prefix, serialized_length: buffer.len() }); + file.write_all(buffer) + })?; + + drop(frozen); + + // We iterate over all the collected and serialized bitmaps through + // the files and entries to eventually put them in the final database. + for refcell in local_entries { + let (index, file, mut buffer) = refcell.into_inner(); + let mut file = file.into_inner().map_err(|e| e.into_error())?; + file.rewind()?; + let mut file = BufReader::new(file); + for PrefixEntry { prefix, serialized_length } in index { + buffer.resize(serialized_length, 0); + file.read_exact(&mut buffer)?; + self.prefix_database.remap_data_type::().put( + wtxn, + prefix.as_bytes(), + &buffer, + )?; + } + } + + Ok(()) + } +} + +/// Represents a prefix and the lenght the bitmap takes on disk. +struct PrefixEntry<'a> { + prefix: &'a str, + serialized_length: usize, +} + +/// Stores prefixes along with all the pointers to the associated +/// CBoRoaringBitmaps. +/// +/// They are collected synchronously and stored into an HashMap. The +/// Synchronous process is doing a small amount of work by just storing +/// pointers. It can then be accessed in parallel to get the associated +/// bitmaps pointers. +struct FrozenPrefixBitmaps<'a, 'rtxn> { + prefixes_bitmaps: HashMap<&'a str, Vec<&'rtxn [u8]>>, +} + +impl<'a, 'rtxn> FrozenPrefixBitmaps<'a, 'rtxn> { + #[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] + pub fn from_prefixes( + database: Database, + rtxn: &'rtxn RoTxn, + prefixes: &'a HashSet, + ) -> heed::Result { + let database = database.remap_data_type::(); + + let mut prefixes_bitmaps = HashMap::new(); + for prefix in prefixes { + let mut bitmap_bytes = Vec::new(); + for result in database.prefix_iter(rtxn, prefix.as_bytes())? { + let (_word, bytes) = result?; + bitmap_bytes.push(bytes); + } + assert!(prefixes_bitmaps.insert(prefix.as_str(), bitmap_bytes).is_none()); + } + + Ok(Self { prefixes_bitmaps }) + } + + pub fn bitmaps(&self, key: &str) -> Option<&[&'rtxn [u8]]> { + self.prefixes_bitmaps.get(key).map(AsRef::as_ref) + } +} + +unsafe impl<'a, 'rtxn> Sync for FrozenPrefixBitmaps<'a, 'rtxn> {} + +struct WordPrefixIntegerDocids { + database: Database, + prefix_database: Database, +} + +impl WordPrefixIntegerDocids { + fn new( + database: Database, + prefix_database: Database, + ) -> WordPrefixIntegerDocids { + WordPrefixIntegerDocids { database, prefix_database } + } + + fn execute( + self, + wtxn: &mut heed::RwTxn, + prefix_to_compute: &HashSet, + prefix_to_delete: &HashSet, + ) -> Result<()> { + delete_prefixes(wtxn, &self.prefix_database, prefix_to_delete)?; + self.recompute_modified_prefixes(wtxn, prefix_to_compute) + } + + #[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] + fn recompute_modified_prefixes( + &self, + wtxn: &mut RwTxn, + prefixes: &HashSet, + ) -> Result<()> { + // We fetch the docids associated to the newly added word prefix fst only. + // And collect the CboRoaringBitmaps pointers in an HashMap. + let frozen = FrozenPrefixIntegerBitmaps::from_prefixes(self.database, wtxn, prefixes)?; + + // We access this HashMap in parallel to compute the *union* of all + // of them and *serialize* them into files. There is one file by CPU. + let local_entries = ThreadLocal::with_capacity(rayon::current_num_threads()); + prefixes.into_par_iter().map(AsRef::as_ref).try_for_each(|prefix| { + let refcell = local_entries.get_or_try(|| { + tempfile().map(BufWriter::new).map(|f| RefCell::new((Vec::new(), f, Vec::new()))) + })?; + + let mut refmut = refcell.borrow_mut_or_yield(); + let (ref mut index, ref mut file, ref mut buffer) = *refmut; + + for (&pos, bitmaps_bytes) in frozen.bitmaps(prefix).unwrap() { + let output = bitmaps_bytes + .iter() + .map(|bytes| CboRoaringBitmapCodec::deserialize_from(bytes)) + .union()?; + + buffer.clear(); + CboRoaringBitmapCodec::serialize_into(&output, buffer); + index.push(PrefixIntegerEntry { prefix, pos, serialized_length: buffer.len() }); + file.write_all(buffer)?; + } + + Result::Ok(()) + })?; + + drop(frozen); + + // We iterate over all the collected and serialized bitmaps through + // the files and entries to eventually put them in the final database. + let mut key_buffer = Vec::new(); + for refcell in local_entries { + let (index, file, mut buffer) = refcell.into_inner(); + let mut file = file.into_inner().map_err(|e| e.into_error())?; + file.rewind()?; + let mut file = BufReader::new(file); + for PrefixIntegerEntry { prefix, pos, serialized_length } in index { + buffer.resize(serialized_length, 0); + file.read_exact(&mut buffer)?; + + key_buffer.clear(); + key_buffer.extend_from_slice(prefix.as_bytes()); + key_buffer.push(0); + key_buffer.extend_from_slice(&pos.to_be_bytes()); + self.prefix_database.remap_data_type::().put(wtxn, &key_buffer, &buffer)?; + } + } + + Ok(()) + } +} + +/// Represents a prefix and the lenght the bitmap takes on disk. +struct PrefixIntegerEntry<'a> { + prefix: &'a str, + pos: u16, + serialized_length: usize, +} + +/// TODO doc +struct FrozenPrefixIntegerBitmaps<'a, 'rtxn> { + prefixes_bitmaps: HashMap<&'a str, HashMap>>, +} + +impl<'a, 'rtxn> FrozenPrefixIntegerBitmaps<'a, 'rtxn> { + #[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] + pub fn from_prefixes( + database: Database, + rtxn: &'rtxn RoTxn, + prefixes: &'a HashSet, + ) -> heed::Result { + let database = database.remap_data_type::(); + + let mut prefixes_bitmaps = HashMap::new(); + for prefix in prefixes { + let mut positions = HashMap::new(); + for result in database.prefix_iter(rtxn, prefix.as_bytes())? { + let (key, bytes) = result?; + let (_word, pos) = StrBEU16Codec::bytes_decode(key).map_err(Error::Decoding)?; + positions.entry(pos).or_insert_with(Vec::new).push(bytes); + } + assert!(prefixes_bitmaps.insert(prefix.as_str(), positions).is_none()); + } + + Ok(Self { prefixes_bitmaps }) + } + + pub fn bitmaps(&self, key: &'a str) -> Option<&HashMap>> { + self.prefixes_bitmaps.get(&key) + } +} + +unsafe impl<'a, 'rtxn> Sync for FrozenPrefixIntegerBitmaps<'a, 'rtxn> {} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] +fn delete_prefixes( + wtxn: &mut RwTxn, + prefix_database: &Database, + prefixes: &HashSet, +) -> Result<()> { + // We remove all the entries that are no more required in this word prefix docids database. + for prefix in prefixes { + let mut iter = prefix_database.prefix_iter_mut(wtxn, prefix.as_bytes())?; + while iter.next().transpose()?.is_some() { + // safety: we do not keep a reference on database entries. + unsafe { iter.del_current()? }; + } + } + + Ok(()) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] +pub fn compute_word_prefix_docids( + wtxn: &mut RwTxn, + index: &Index, + prefix_to_compute: &HashSet, + prefix_to_delete: &HashSet, +) -> Result<()> { + WordPrefixDocids::new( + index.word_docids.remap_key_type(), + index.word_prefix_docids.remap_key_type(), + ) + .execute(wtxn, prefix_to_compute, prefix_to_delete) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] +pub fn compute_exact_word_prefix_docids( + wtxn: &mut RwTxn, + index: &Index, + prefix_to_compute: &HashSet, + prefix_to_delete: &HashSet, +) -> Result<()> { + WordPrefixDocids::new( + index.exact_word_docids.remap_key_type(), + index.exact_word_prefix_docids.remap_key_type(), + ) + .execute(wtxn, prefix_to_compute, prefix_to_delete) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] +pub fn compute_word_prefix_fid_docids( + wtxn: &mut RwTxn, + index: &Index, + prefix_to_compute: &HashSet, + prefix_to_delete: &HashSet, +) -> Result<()> { + WordPrefixIntegerDocids::new( + index.word_fid_docids.remap_key_type(), + index.word_prefix_fid_docids.remap_key_type(), + ) + .execute(wtxn, prefix_to_compute, prefix_to_delete) +} + +#[tracing::instrument(level = "trace", skip_all, target = "indexing::prefix")] +pub fn compute_word_prefix_position_docids( + wtxn: &mut RwTxn, + index: &Index, + prefix_to_compute: &HashSet, + prefix_to_delete: &HashSet, +) -> Result<()> { + WordPrefixIntegerDocids::new( + index.word_position_docids.remap_key_type(), + index.word_prefix_position_docids.remap_key_type(), + ) + .execute(wtxn, prefix_to_compute, prefix_to_delete) +} diff --git a/crates/milli/src/update/word_prefix_docids.rs b/crates/milli/src/update/word_prefix_docids.rs index 925635f80..d129d485e 100644 --- a/crates/milli/src/update/word_prefix_docids.rs +++ b/crates/milli/src/update/word_prefix_docids.rs @@ -6,9 +6,8 @@ use heed::Database; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvWriterDelAdd}; use crate::update::index_documents::{ - create_sorter, merge_deladd_cbo_roaring_bitmaps, - merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, valid_lmdb_key, - write_sorter_into_database, CursorClonableMmap, MergeFn, + create_sorter, merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, valid_lmdb_key, + write_sorter_into_database, CursorClonableMmap, MergeDeladdCboRoaringBitmaps, }; use crate::{CboRoaringBitmapCodec, Result}; @@ -47,7 +46,7 @@ impl<'t, 'i> WordPrefixDocids<'t, 'i> { )] pub fn execute( self, - new_word_docids: grenad::Merger, + new_word_docids: grenad::Merger, new_prefix_fst_words: &[String], common_prefix_fst_words: &[&[String]], del_prefix_fst_words: &HashSet>, @@ -56,11 +55,12 @@ impl<'t, 'i> WordPrefixDocids<'t, 'i> { // and write into it at the same time, therefore we write into another file. let mut prefix_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, self.chunk_compression_type, self.chunk_compression_level, self.max_nb_chunks, self.max_memory, + true, ); if !common_prefix_fst_words.is_empty() { @@ -139,7 +139,7 @@ impl<'t, 'i> WordPrefixDocids<'t, 'i> { fn write_prefixes_in_sorter( prefixes: &mut HashMap, Vec>>, - sorter: &mut grenad::Sorter, + sorter: &mut grenad::Sorter, ) -> Result<()> { for (key, data_slices) in prefixes.drain() { for data in data_slices { diff --git a/crates/milli/src/update/words_prefix_integer_docids.rs b/crates/milli/src/update/words_prefix_integer_docids.rs index 9b6aa21ae..ff974b797 100644 --- a/crates/milli/src/update/words_prefix_integer_docids.rs +++ b/crates/milli/src/update/words_prefix_integer_docids.rs @@ -11,9 +11,8 @@ use crate::heed_codec::StrBEU16Codec; use crate::index::main_key::WORDS_PREFIXES_FST_KEY; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvWriterDelAdd}; use crate::update::index_documents::{ - create_sorter, merge_deladd_cbo_roaring_bitmaps, - merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, valid_lmdb_key, - write_sorter_into_database, CursorClonableMmap, MergeFn, + create_sorter, merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap, valid_lmdb_key, + write_sorter_into_database, CursorClonableMmap, MergeDeladdCboRoaringBitmaps, }; use crate::{CboRoaringBitmapCodec, Result}; @@ -52,7 +51,7 @@ impl<'t, 'i> WordPrefixIntegerDocids<'t, 'i> { )] pub fn execute( self, - new_word_integer_docids: grenad::Merger, + new_word_integer_docids: grenad::Merger, new_prefix_fst_words: &[String], common_prefix_fst_words: &[&[String]], del_prefix_fst_words: &HashSet>, @@ -61,11 +60,12 @@ impl<'t, 'i> WordPrefixIntegerDocids<'t, 'i> { let mut prefix_integer_docids_sorter = create_sorter( grenad::SortAlgorithm::Unstable, - merge_deladd_cbo_roaring_bitmaps, + MergeDeladdCboRoaringBitmaps, self.chunk_compression_type, self.chunk_compression_level, self.max_nb_chunks, self.max_memory, + true, ); if !common_prefix_fst_words.is_empty() { @@ -173,7 +173,7 @@ impl<'t, 'i> WordPrefixIntegerDocids<'t, 'i> { fn write_prefixes_in_sorter( prefixes: &mut HashMap, Vec>>, - sorter: &mut grenad::Sorter, + sorter: &mut grenad::Sorter, ) -> Result<()> { // TODO: Merge before insertion. for (key, data_slices) in prefixes.drain() { diff --git a/crates/milli/src/vector/error.rs b/crates/milli/src/vector/error.rs index 3c8cb4b06..97bbe5d68 100644 --- a/crates/milli/src/vector/error.rs +++ b/crates/milli/src/vector/error.rs @@ -1,11 +1,13 @@ use std::collections::BTreeMap; use std::path::PathBuf; +use bumpalo::Bump; use hf_hub::api::sync::ApiError; use super::parsed_vectors::ParsedVectorsDiff; use super::rest::ConfigurationSource; use crate::error::FaultSource; +use crate::update::new::vector_document::VectorDocument; use crate::{FieldDistribution, PanicCatched}; #[derive(Debug, thiserror::Error)] @@ -58,7 +60,7 @@ pub enum EmbedErrorKind { ManualEmbed(String), #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))] OllamaModelNotFoundError(Option), - #[error("error deserialization the response body as JSON:\n - {0}")] + #[error("error deserializing the response body as JSON:\n - {0}")] RestResponseDeserialization(std::io::Error), #[error("expected a response containing {0} embeddings, got only {1}")] RestResponseEmbeddingCount(usize, usize), @@ -417,6 +419,23 @@ impl PossibleEmbeddingMistakes { } }) } + + pub fn embedder_mistakes_bump<'a, 'doc: 'a>( + &'a self, + embedder_name: &'a str, + unused_vectors_distribution: &'a UnusedVectorsDistributionBump<'doc>, + ) -> impl Iterator + 'a { + let builder = levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true); + let automata = builder.build_dfa(embedder_name); + + unused_vectors_distribution.0.iter().filter_map(move |(field, count)| { + match automata.eval(field) { + levenshtein_automata::Distance::Exact(0) => None, + levenshtein_automata::Distance::Exact(_) => Some((*field, *count)), + levenshtein_automata::Distance::AtLeast(_) => None, + } + }) + } } #[derive(Default)] @@ -433,3 +452,23 @@ impl UnusedVectorsDistribution { } } } + +pub struct UnusedVectorsDistributionBump<'doc>( + hashbrown::HashMap<&'doc str, u64, hashbrown::DefaultHashBuilder, &'doc Bump>, +); + +impl<'doc> UnusedVectorsDistributionBump<'doc> { + pub fn new_in(doc_alloc: &'doc Bump) -> Self { + Self(hashbrown::HashMap::new_in(doc_alloc)) + } + + pub fn append(&mut self, vectors: &impl VectorDocument<'doc>) -> Result<(), crate::Error> { + for res in vectors.iter_vectors() { + let (embedder_name, entry) = res?; + if !entry.has_configured_embedder { + *self.0.entry(embedder_name).or_default() += 1; + } + } + Ok(()) + } +} diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index dc1e7d324..3fe28e53a 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; pub use super::error::{EmbedError, Error, NewEmbedderError}; -use super::{DistributionShift, Embedding, Embeddings}; +use super::{DistributionShift, Embedding}; #[derive( Debug, @@ -139,15 +139,12 @@ impl Embedder { let embeddings = this .embed(vec!["test".into()]) .map_err(NewEmbedderError::could_not_determine_dimension)?; - this.dimensions = embeddings.first().unwrap().dimension(); + this.dimensions = embeddings.first().unwrap().len(); Ok(this) } - pub fn embed( - &self, - mut texts: Vec, - ) -> std::result::Result>, EmbedError> { + pub fn embed(&self, mut texts: Vec) -> std::result::Result, EmbedError> { let tokens = match texts.len() { 1 => vec![self .tokenizer @@ -177,13 +174,34 @@ impl Embedder { .map_err(EmbedError::tensor_shape)?; let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; - Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) + Ok(embeddings) + } + + pub fn embed_one(&self, text: &str) -> std::result::Result { + let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; + let token_ids = tokens.get_ids(); + let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids }; + let token_ids = + Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; + let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + let embeddings = + self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape)?; + let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?; + let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; + Ok(embedding) } pub fn embed_chunks( &self, text_chunks: Vec>, - ) -> std::result::Result>>, EmbedError> { + ) -> std::result::Result>, EmbedError> { text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() } @@ -211,4 +229,8 @@ impl Embedder { } }) } + + pub(crate) fn embed_chunks_ref(&self, texts: &[&str]) -> Result, EmbedError> { + texts.iter().map(|text| self.embed_one(text)).collect() + } } diff --git a/crates/milli/src/vector/manual.rs b/crates/milli/src/vector/manual.rs index 4cfbb0d3c..8c2ef97b2 100644 --- a/crates/milli/src/vector/manual.rs +++ b/crates/milli/src/vector/manual.rs @@ -1,5 +1,6 @@ use super::error::EmbedError; -use super::{DistributionShift, Embeddings}; +use super::DistributionShift; +use crate::vector::Embedding; #[derive(Debug, Clone, Copy)] pub struct Embedder { @@ -18,11 +19,13 @@ impl Embedder { Self { dimensions: options.dimensions, distribution: options.distribution } } - pub fn embed(&self, mut texts: Vec) -> Result>, EmbedError> { - let Some(text) = texts.pop() else { return Ok(Default::default()) }; - Err(EmbedError::embed_on_manual_embedder(text.chars().take(250).collect())) + pub fn embed>(&self, texts: &[S]) -> Result, EmbedError> { + texts.as_ref().iter().map(|text| self.embed_one(text)).collect() } + pub fn embed_one>(&self, text: S) -> Result { + Err(EmbedError::embed_on_manual_embedder(text.as_ref().chars().take(250).collect())) + } pub fn dimensions(&self) -> usize { self.dimensions } @@ -30,11 +33,15 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - ) -> Result>>, EmbedError> { - text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() + ) -> Result>, EmbedError> { + text_chunks.into_iter().map(|prompts| self.embed(&prompts)).collect() } pub fn distribution(&self) -> Option { self.distribution } + + pub(crate) fn embed_chunks_ref(&self, texts: &[&str]) -> Result, EmbedError> { + texts.iter().map(|text| self.embed_one(text)).collect() + } } diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index 571c02c8c..3047e6dfc 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::Instant; use arroy::distances::{BinaryQuantizedCosine, Cosine}; use arroy::ItemId; @@ -531,6 +532,10 @@ impl EmbeddingConfigs { Self(data) } + pub fn contains(&self, name: &str) -> bool { + self.0.contains_key(name) + } + /// Get an embedder configuration and template from its name. pub fn get(&self, name: &str) -> Option<(Arc, Arc, bool)> { self.0.get(name).cloned() @@ -594,25 +599,25 @@ impl Embedder { pub fn embed( &self, texts: Vec, - ) -> std::result::Result>, EmbedError> { + deadline: Option, + ) -> std::result::Result, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => embedder.embed(texts), - Embedder::Ollama(embedder) => embedder.embed(texts), - Embedder::UserProvided(embedder) => embedder.embed(texts), - Embedder::Rest(embedder) => embedder.embed(texts), + Embedder::OpenAi(embedder) => embedder.embed(&texts, deadline), + Embedder::Ollama(embedder) => embedder.embed(&texts, deadline), + Embedder::UserProvided(embedder) => embedder.embed(&texts), + Embedder::Rest(embedder) => embedder.embed(texts, deadline), } } - pub fn embed_one(&self, text: String) -> std::result::Result { - let mut embeddings = self.embed(vec![text])?; - let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?; - Ok(if embeddings.iter().nth(1).is_some() { - tracing::warn!("Ignoring embeddings past the first one in long search query"); - embeddings.iter().next().unwrap().to_vec() - } else { - embeddings.into_inner() - }) + pub fn embed_one( + &self, + text: String, + deadline: Option, + ) -> std::result::Result { + let mut embedding = self.embed(vec![text], deadline)?; + let embedding = embedding.pop().ok_or_else(EmbedError::missing_embedding)?; + Ok(embedding) } /// Embed multiple chunks of texts. @@ -622,7 +627,7 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> std::result::Result>>, EmbedError> { + ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), @@ -632,13 +637,27 @@ impl Embedder { } } + pub fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> std::result::Result, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed_chunks_ref(texts), + Embedder::OpenAi(embedder) => embedder.embed_chunks_ref(texts, threads), + Embedder::Ollama(embedder) => embedder.embed_chunks_ref(texts, threads), + Embedder::UserProvided(embedder) => embedder.embed_chunks_ref(texts), + Embedder::Rest(embedder) => embedder.embed_chunks_ref(texts, threads), + } + } + /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`] pub fn chunk_count_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), Embedder::Ollama(embedder) => embedder.chunk_count_hint(), - Embedder::UserProvided(_) => 1, + Embedder::UserProvided(_) => 100, Embedder::Rest(embedder) => embedder.chunk_count_hint(), } } diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index 7d41ab4e9..7ee775cbf 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -1,9 +1,13 @@ +use std::time::Instant; + use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::{DistributionShift, Embeddings}; +use super::DistributionShift; use crate::error::FaultSource; +use crate::vector::Embedding; use crate::ThreadPoolNoAbort; #[derive(Debug)] @@ -75,8 +79,12 @@ impl Embedder { Ok(Self { rest_embedder }) } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - match self.rest_embedder.embed(texts) { + pub fn embed + serde::Serialize>( + &self, + texts: &[S], + deadline: Option, + ) -> Result, EmbedError> { + match self.rest_embedder.embed_ref(texts, deadline) { Ok(embeddings) => Ok(embeddings), Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError::ollama_model_not_found(error)) @@ -89,10 +97,31 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> Result>>, EmbedError> { + ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } + + pub(crate) fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> Result>, EmbedError> { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index 152d1fb7a..7262bfef8 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -1,11 +1,15 @@ +use std::time::Instant; + use ordered_float::OrderedFloat; use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; +use rayon::slice::ParallelSlice as _; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; -use super::{DistributionShift, Embeddings}; +use super::DistributionShift; use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; +use crate::vector::Embedding; use crate::ThreadPoolNoAbort; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] @@ -206,37 +210,42 @@ impl Embedder { Ok(Self { options, rest_embedder, tokenizer }) } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - match self.rest_embedder.embed_ref(&texts) { + pub fn embed + serde::Serialize>( + &self, + texts: &[S], + deadline: Option, + ) -> Result, EmbedError> { + match self.rest_embedder.embed_ref(texts, deadline) { Ok(embeddings) => Ok(embeddings), Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); - self.try_embed_tokenized(&texts) + self.try_embed_tokenized(texts, deadline) } Err(error) => Err(error), } } - fn try_embed_tokenized(&self, text: &[String]) -> Result>, EmbedError> { + fn try_embed_tokenized>( + &self, + text: &[S], + deadline: Option, + ) -> Result, EmbedError> { let mut all_embeddings = Vec::with_capacity(text.len()); for text in text { + let text = text.as_ref(); let max_token_count = self.options.embedding_model.max_token(); - let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let encoded = self.tokenizer.encode_ordinary(text); let len = encoded.len(); if len < max_token_count { - all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); + all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?); continue; } let tokens = &encoded.as_slice()[0..max_token_count]; - let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); - let embedding = self.rest_embedder.embed_tokens(tokens)?; - embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { - EmbedError::rest_unexpected_dimension(self.dimensions(), got.len()) - })?; + let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?; - all_embeddings.push(embeddings_for_prompt); + all_embeddings.push(embedding); } Ok(all_embeddings) } @@ -245,10 +254,31 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> Result>>, EmbedError> { + ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(&chunk, None)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } + + pub(crate) fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> Result>, EmbedError> { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), diff --git a/crates/milli/src/vector/parsed_vectors.rs b/crates/milli/src/vector/parsed_vectors.rs index 9dbf025e6..da41d1771 100644 --- a/crates/milli/src/vector/parsed_vectors.rs +++ b/crates/milli/src/vector/parsed_vectors.rs @@ -2,6 +2,7 @@ use std::collections::{BTreeMap, BTreeSet}; use deserr::{take_cf_content, DeserializeError, Deserr, Sequence}; use obkv::KvReader; +use serde_json::value::RawValue; use serde_json::{from_slice, Value}; use super::Embedding; @@ -11,6 +12,250 @@ use crate::{DocumentId, FieldId, InternalError, UserError}; pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors"; +#[derive(serde::Serialize, Debug)] +#[serde(untagged)] +pub enum RawVectors<'doc> { + Explicit(#[serde(borrow)] RawExplicitVectors<'doc>), + ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>), +} + +pub enum RawVectorsError { + DeserializeSeq { index: usize, error: String }, + DeserializeKey { error: String }, + DeserializeRegenerate { error: String }, + DeserializeEmbeddings { error: String }, + UnknownField { field: String }, + MissingRegenerate, + WrongKind { kind: &'static str, value: String }, + Parsing(serde_json::Error), +} + +impl RawVectorsError { + pub fn msg(self, embedder_name: &str) -> String { + match self { + RawVectorsError::DeserializeSeq { index, error } => format!( + "Could not parse `._vectors.{embedder_name}[{index}]`: {error}" + ), + RawVectorsError::DeserializeKey { error } => format!( + "Could not parse a field at `._vectors.{embedder_name}`: {error}" + ), + RawVectorsError::DeserializeRegenerate { error } => format!( + "Could not parse `._vectors.{embedder_name}.regenerate`: {error}" + ), + RawVectorsError::DeserializeEmbeddings { error } => format!( + "Could not parse `._vectors.{embedder_name}.embeddings`: {error}" + ), + RawVectorsError::UnknownField { field } => format!( + "Unexpected field `._vectors.{embedder_name}.{field}`\n \ + - note: the allowed fields are `regenerate` and `embeddings`" + ), + RawVectorsError::MissingRegenerate => format!( + "Missing field `._vectors.{embedder_name}.regenerate`\n \ + - note: `._vectors.{embedder_name}` must be an array of floats, an array of arrays of floats, or an object with field `regenerate`" + ), + RawVectorsError::WrongKind { kind, value } => format!( + "Expected `._vectors.{embedder_name}` to be an array of floats, an array of arrays of floats, or an object with at least the field `regenerate`, but got the {kind} `{value}`" + ), + RawVectorsError::Parsing(error) => format!( + "Could not parse `._vectors.{embedder_name}`: {error}" + ), + } + } +} + +impl<'doc> RawVectors<'doc> { + pub fn from_raw_value(raw: &'doc RawValue) -> Result { + use serde::de::Deserializer as _; + Ok(match raw.deserialize_any(RawVectorsVisitor).map_err(RawVectorsError::Parsing)?? { + RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None), + RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)), + RawVectorsVisitorValue::Explicit { regenerate, embeddings } => { + RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate }) + } + }) + } +} + +struct RawVectorsVisitor; + +enum RawVectorsVisitorValue<'doc> { + ImplicitNone, + Implicit, + Explicit { regenerate: bool, embeddings: Option<&'doc RawValue> }, +} + +impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor { + type Value = std::result::Result, RawVectorsError>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a map containing at least `regenerate`, or an array of floats`") + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(Ok(RawVectorsVisitorValue::ImplicitNone)) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: serde::Deserializer<'doc>, + { + deserializer.deserialize_any(self) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Ok(RawVectorsVisitorValue::ImplicitNone)) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'doc>, + { + let mut index = 0; + // must consume all elements or parsing fails + loop { + match seq.next_element::<&RawValue>() { + Ok(Some(_)) => index += 1, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeSeq { + index, + error: error.to_string(), + })) + } + Ok(None) => break, + }; + } + Ok(Ok(RawVectorsVisitorValue::Implicit)) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'doc>, + { + let mut regenerate = None; + let mut embeddings = None; + loop { + match map.next_key::<&str>() { + Ok(Some("regenerate")) => { + let value: bool = match map.next_value() { + Ok(value) => value, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeRegenerate { + error: error.to_string(), + })) + } + }; + regenerate = Some(value); + } + Ok(Some("embeddings")) => { + let value: &RawValue = match map.next_value() { + Ok(value) => value, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeEmbeddings { + error: error.to_string(), + })) + } + }; + embeddings = Some(value); + } + Ok(Some(other)) => { + return Ok(Err(RawVectorsError::UnknownField { field: other.to_string() })) + } + Ok(None) => break, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeKey { error: error.to_string() })) + } + } + } + let Some(regenerate) = regenerate else { + return Ok(Err(RawVectorsError::MissingRegenerate)); + }; + Ok(Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings })) + } + + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "boolean", value: v.to_string() })) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_i128(self, v: i128) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_u128(self, v: u128) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "number", value: v.to_string() })) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "string", value: v.to_string() })) + } + + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "string", value: v })) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "bytes", value: format!("{v:?}") })) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'doc>, + { + deserializer.deserialize_any(self) + } + + fn visit_enum(self, _data: A) -> Result + where + A: serde::de::EnumAccess<'doc>, + { + Ok(Err(RawVectorsError::WrongKind { kind: "enum", value: "a variant".to_string() })) + } +} + #[derive(serde::Serialize, Debug)] #[serde(untagged)] pub enum Vectors { @@ -69,6 +314,21 @@ impl Vectors { } } +impl<'doc> RawVectors<'doc> { + pub fn must_regenerate(&self) -> bool { + match self { + RawVectors::ImplicitlyUserProvided(_) => false, + RawVectors::Explicit(RawExplicitVectors { regenerate, .. }) => *regenerate, + } + } + pub fn embeddings(&self) -> Option<&'doc RawValue> { + match self { + RawVectors::ImplicitlyUserProvided(embeddings) => *embeddings, + RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings, + } + } +} + #[derive(serde::Serialize, Deserr, Debug)] #[serde(rename_all = "camelCase")] pub struct ExplicitVectors { @@ -78,6 +338,15 @@ pub struct ExplicitVectors { pub regenerate: bool, } +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct RawExplicitVectors<'doc> { + #[serde(borrow)] + #[serde(default)] + pub embeddings: Option<&'doc RawValue>, + pub regenerate: bool, +} + pub enum VectorState { Inline(Vectors), Manual, @@ -109,14 +378,13 @@ impl ParsedVectorsDiff { pub fn new( docid: DocumentId, embedders_configs: &[IndexEmbeddingConfig], - documents_diff: KvReader<'_, FieldId>, + documents_diff: &KvReader, old_vectors_fid: Option, new_vectors_fid: Option, ) -> Result { let mut old = match old_vectors_fid .and_then(|vectors_fid| documents_diff.get(vectors_fid)) - .map(KvReaderDelAdd::new) - .map(|obkv| to_vector_map(obkv, DelAdd::Deletion)) + .map(|bytes| to_vector_map(bytes.into(), DelAdd::Deletion)) .transpose() { Ok(del) => del, @@ -143,8 +411,7 @@ impl ParsedVectorsDiff { let Some(bytes) = documents_diff.get(new_vectors_fid) else { break 'new VectorsState::NoVectorsFieldInDocument; }; - let obkv = KvReaderDelAdd::new(bytes); - match to_vector_map(obkv, DelAdd::Addition)? { + match to_vector_map(bytes.into(), DelAdd::Addition)? { Some(new) => VectorsState::Vectors(new), None => VectorsState::NoVectorsFieldInDocument, } @@ -228,7 +495,7 @@ impl Error { Error::InvalidEmbedderConf { error } => { crate::Error::UserError(UserError::InvalidVectorsEmbedderConf { document_id, - error, + error: error.to_string(), }) } Error::InternalSerdeJson(error) => { @@ -239,7 +506,7 @@ impl Error { } fn to_vector_map( - obkv: KvReaderDelAdd<'_>, + obkv: &KvReaderDelAdd, side: DelAdd, ) -> Result>, Error> { Ok(if let Some(value) = obkv.get(side) { diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index 2538f2fff..98be311d4 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -1,15 +1,15 @@ use std::collections::BTreeMap; +use std::time::Instant; use deserr::Deserr; use rand::Rng; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use rayon::slice::ParallelSlice as _; use serde::{Deserialize, Serialize}; use super::error::EmbedErrorKind; use super::json_template::ValueTemplate; -use super::{ - DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, -}; +use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM}; use crate::error::FaultSource; use crate::ThreadPoolNoAbort; @@ -154,19 +154,31 @@ impl Embedder { Ok(Self { data, dimensions, distribution: options.distribution }) } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) + pub fn embed( + &self, + texts: Vec, + deadline: Option, + ) -> Result, EmbedError> { + embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline) } - pub fn embed_ref(&self, texts: &[S]) -> Result>, EmbedError> + pub fn embed_ref( + &self, + texts: &[S], + deadline: Option, + ) -> Result, EmbedError> where S: AsRef + Serialize, { - embed(&self.data, texts, texts.len(), Some(self.dimensions)) + embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline) } - pub fn embed_tokens(&self, tokens: &[usize]) -> Result, EmbedError> { - let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; + pub fn embed_tokens( + &self, + tokens: &[usize], + deadline: Option, + ) -> Result { + let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?; // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error Ok(embeddings.pop().unwrap()) } @@ -175,10 +187,31 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - ) -> Result>>, EmbedError> { + ) -> Result>, EmbedError> { threads .install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? + } + + pub(crate) fn embed_chunks_ref( + &self, + texts: &[&str], + threads: &ThreadPoolNoAbort, + ) -> Result, EmbedError> { + threads + .install(move || { + let embeddings: Result>, _> = texts + .par_chunks(self.prompt_count_in_chunk_hint()) + .map(move |chunk| self.embed_ref(chunk, None)) + .collect(); + + let embeddings = embeddings?; + Ok(embeddings.into_iter().flatten().collect()) }) .map_err(|error| EmbedError { kind: EmbedErrorKind::PanicInThreadPool(error), @@ -207,10 +240,10 @@ impl Embedder { } fn infer_dimensions(data: &EmbedderData) -> Result { - let v = embed(data, ["test"].as_slice(), 1, None) + let v = embed(data, ["test"].as_slice(), 1, None, None) .map_err(NewEmbedderError::could_not_determine_dimension)?; // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error - Ok(v.first().unwrap().dimension()) + Ok(v.first().unwrap().len()) } fn embed( @@ -218,7 +251,8 @@ fn embed( inputs: &[S], expected_count: usize, expected_dimension: Option, -) -> Result>, EmbedError> + deadline: Option, +) -> Result, EmbedError> where S: Serialize, { @@ -237,15 +271,26 @@ where for attempt in 0..10 { let response = request.clone().send_json(&body); - let result = check_response(response, data.configuration_source); + let result = check_response(response, data.configuration_source).and_then(|response| { + response_to_embedding(response, data, expected_count, expected_dimension) + }); let retry_duration = match result { - Ok(response) => { - return response_to_embedding(response, data, expected_count, expected_dimension) - } + Ok(response) => return Ok(response), Err(retry) => { tracing::warn!("Failed: {}", retry.error); - retry.into_duration(attempt) + if let Some(deadline) = deadline { + let now = std::time::Instant::now(); + if now > deadline { + tracing::warn!("Could not embed due to deadline"); + return Err(retry.into_error()); + } + + let duration_to_deadline = deadline - now; + retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline)) + } else { + retry.into_duration(attempt) + } } }?; @@ -263,6 +308,7 @@ where let result = check_response(response, data.configuration_source); result.map_err(Retry::into_error).and_then(|response| { response_to_embedding(response, data, expected_count, expected_dimension) + .map_err(Retry::into_error) }) } @@ -304,23 +350,28 @@ fn response_to_embedding( data: &EmbedderData, expected_count: usize, expected_dimensions: Option, -) -> Result>, EmbedError> { - let response: serde_json::Value = - response.into_json().map_err(EmbedError::rest_response_deserialization)?; +) -> Result, Retry> { + let response: serde_json::Value = response + .into_json() + .map_err(EmbedError::rest_response_deserialization) + .map_err(Retry::retry_later)?; - let embeddings = data.response.extract_embeddings(response)?; + let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?; if embeddings.len() != expected_count { - return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); + return Err(Retry::give_up(EmbedError::rest_response_embedding_count( + expected_count, + embeddings.len(), + ))); } if let Some(dimensions) = expected_dimensions { for embedding in &embeddings { - if embedding.dimension() != dimensions { - return Err(EmbedError::rest_unexpected_dimension( + if embedding.len() != dimensions { + return Err(Retry::give_up(EmbedError::rest_unexpected_dimension( dimensions, - embedding.dimension(), - )); + embedding.len(), + ))); } } } @@ -394,7 +445,7 @@ impl Response { pub fn extract_embeddings( &self, response: serde_json::Value, - ) -> Result>, EmbedError> { + ) -> Result, EmbedError> { let extracted_values: Vec = match self.template.extract(response) { Ok(extracted_values) => extracted_values, Err(error) => { @@ -403,8 +454,7 @@ impl Response { return Err(EmbedError::rest_extraction_error(error_message)); } }; - let embeddings: Vec> = - extracted_values.into_iter().map(Embeddings::from_single_embedding).collect(); + let embeddings: Vec = extracted_values.into_iter().collect(); Ok(embeddings) } diff --git a/workloads/movies.json b/workloads/movies.json index 445ff3aca..9ad3fb7eb 100644 --- a/workloads/movies.json +++ b/workloads/movies.json @@ -1,6 +1,6 @@ { "name": "movies.json", - "run_count": 10, + "run_count": 1, "extra_cli_args": [], "assets": { "movies.json": {