From 13c2c6c16beda22942029326348db0e9929df421 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 15 Nov 2023 15:46:37 +0100 Subject: [PATCH] Small commit to add hybrid search and autoembedding --- Cargo.lock | 281 +++++++++--- dump/src/lib.rs | 1 + dump/src/reader/compat/v5_to_v6.rs | 1 + index-scheduler/src/batch.rs | 9 + index-scheduler/src/features.rs | 4 +- index-scheduler/src/insta_snapshot.rs | 1 + index-scheduler/src/lib.rs | 41 ++ meilisearch-types/src/error.rs | 8 +- meilisearch-types/src/settings.rs | 23 + .../src/analytics/segment_analytics.rs | 2 +- meilisearch/src/main.rs | 6 +- .../src/routes/indexes/facet_search.rs | 3 +- meilisearch/src/routes/indexes/search.rs | 40 +- meilisearch/src/routes/multi_search.rs | 3 + meilisearch/src/search.rs | 31 +- milli/Cargo.toml | 26 +- milli/examples/search.rs | 11 +- milli/src/error.rs | 28 ++ milli/src/index.rs | 29 ++ milli/src/lib.rs | 8 +- milli/src/prompt/context.rs | 97 ++++ milli/src/prompt/document.rs | 131 ++++++ milli/src/prompt/error.rs | 56 +++ milli/src/prompt/fields.rs | 172 ++++++++ milli/src/prompt/mod.rs | 144 ++++++ milli/src/prompt/template_checker.rs | 282 ++++++++++++ milli/src/score_details.rs | 164 ++++++- milli/src/search/hybrid.rs | 336 ++++++++++++++ milli/src/search/mod.rs | 102 ++++- milli/src/search/new/matches/mod.rs | 8 +- milli/src/search/new/mod.rs | 175 +++++--- milli/src/search/new/vector_sort.rs | 150 +++++++ .../extract/extract_vector_points.rs | 330 ++++++++++++-- .../src/update/index_documents/extract/mod.rs | 63 ++- milli/src/update/index_documents/mod.rs | 35 +- .../src/update/index_documents/typed_chunk.rs | 90 +++- milli/src/update/settings.rs | 113 ++++- milli/src/vector/error.rs | 229 ++++++++++ milli/src/vector/hf.rs | 192 ++++++++ milli/src/vector/mod.rs | 142 ++++++ milli/src/vector/openai.rs | 416 ++++++++++++++++++ milli/src/vector/settings.rs | 308 +++++++++++++ 42 files changed, 4045 insertions(+), 246 deletions(-) create mode 100644 milli/src/prompt/context.rs create mode 100644 milli/src/prompt/document.rs create mode 100644 milli/src/prompt/error.rs create mode 100644 milli/src/prompt/fields.rs create mode 100644 milli/src/prompt/mod.rs create mode 100644 milli/src/prompt/template_checker.rs create mode 100644 milli/src/search/hybrid.rs create mode 100644 milli/src/search/new/vector_sort.rs create mode 100644 milli/src/vector/error.rs create mode 100644 milli/src/vector/hf.rs create mode 100644 milli/src/vector/mod.rs create mode 100644 milli/src/vector/openai.rs create mode 100644 milli/src/vector/settings.rs diff --git a/Cargo.lock b/Cargo.lock index f6ce4b26b..a407244b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,7 +46,7 @@ dependencies = [ "actix-tls", "actix-utils", "ahash 0.8.3", - "base64 0.21.2", + "base64 0.21.5", "bitflags 1.3.2", "brotli", "bytes", @@ -120,7 +120,7 @@ dependencies = [ "futures-util", "mio", "num_cpus", - "socket2", + "socket2 0.4.9", "tokio", "tracing", ] @@ -201,7 +201,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "smallvec", - "socket2", + "socket2 0.4.9", "time", "url", ] @@ -365,6 +365,12 @@ dependencies = [ "backtrace", ] +[[package]] +name = "anymap2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" + [[package]] name = "arbitrary" version = "1.3.0" @@ -455,9 +461,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64ct" @@ -508,6 +514,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -555,12 +576,12 @@ dependencies = [ [[package]] name = "bstr" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6798148dccfbff0fae41c7574d2fa8f1ef3492fba0face179de5d8d447d67b05" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" dependencies = [ "memchr", - "regex-automata 0.3.6", + "regex-automata 0.4.3", "serde", ] @@ -1346,6 +1367,12 @@ dependencies = [ "syn 2.0.28", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "doxygen-rs" version = "0.2.2" @@ -1562,6 +1589,16 @@ dependencies = [ "cc", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "2.0.0" @@ -1690,9 +1727,9 @@ checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1705,9 +1742,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1715,15 +1752,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1732,15 +1769,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -1749,21 +1786,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -2207,7 +2244,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -2949,7 +2986,7 @@ version = "8.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "pem", "ring", "serde", @@ -2957,6 +2994,16 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "kstring" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3066350882a1cd6d950d055997f379ac37fd39f81cd4d8ed186032eb3c5747" +dependencies = [ + "serde", + "static_assertions", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -2980,9 +3027,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libgit2-sys" @@ -3251,6 +3298,63 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +[[package]] +name = "liquid" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f68ae1011499ae2ef879f631891f21c78e309755f4a5e483c4a8f12e10b609" +dependencies = [ + "doc-comment", + "liquid-core", + "liquid-derive", + "liquid-lib", + "serde", +] + +[[package]] +name = "liquid-core" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79e0724dfcaad5cfb7965ea0f178ca0870b8d7315178f4a7179f5696f7f04d5f" +dependencies = [ + "anymap2", + "itertools 0.10.5", + "kstring", + "liquid-derive", + "num-traits", + "pest", + "pest_derive", + "regex", + "serde", + "time", +] + +[[package]] +name = "liquid-derive" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + +[[package]] +name = "liquid-lib" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2a17e273a6fb1fb6268f7a5867ddfd0bd4683c7e19b51084f3d567fad4348c0" +dependencies = [ + "itertools 0.10.5", + "liquid-core", + "once_cell", + "percent-encoding", + "regex", + "time", + "unicode-segmentation", +] + [[package]] name = "litemap" version = "0.6.1" @@ -3483,7 +3587,7 @@ dependencies = [ name = "meilisearch-auth" version = "1.5.1" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "enum-iterator", "hmac", "maplit", @@ -3544,9 +3648,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "memmap2" @@ -3589,6 +3693,7 @@ dependencies = [ "filter-parser", "flatten-serde-json", "fst", + "futures", "fxhash", "geoutils", "grenad", @@ -3600,6 +3705,7 @@ dependencies = [ "itertools 0.11.0", "json-depth-checker", "levenshtein_automata", + "liquid", "log", "logging_timer", "maplit", @@ -3607,6 +3713,7 @@ dependencies = [ "meili-snap", "memmap2", "mimalloc", + "nolife", "obkv", "once_cell", "ordered-float", @@ -3614,6 +3721,7 @@ dependencies = [ "rand", "rand_pcg", "rayon", + "reqwest", "roaring", "rstar", "serde", @@ -3624,8 +3732,10 @@ dependencies = [ "smartstring", "tempfile", "thiserror", + "tiktoken-rs", "time", "tokenizers", + "tokio", "uuid 1.5.0", ] @@ -3671,9 +3781,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "log", @@ -3725,6 +3835,12 @@ name = "nelson" version = "0.1.0" source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a" +[[package]] +name = "nolife" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc" + [[package]] name = "nom" version = "7.1.3" @@ -4480,6 +4596,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" + [[package]] name = "regex-syntax" version = "0.7.4" @@ -4488,11 +4610,11 @@ checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "bytes", "encoding_rs", "futures-core", @@ -4514,6 +4636,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -4521,7 +4644,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots 0.22.6", + "webpki-roots 0.25.3", "winreg", ] @@ -4582,6 +4705,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -4648,7 +4777,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", ] [[package]] @@ -4977,6 +5106,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" @@ -5097,6 +5236,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tar" version = "0.4.40" @@ -5159,6 +5319,21 @@ dependencies = [ "syn 2.0.28", ] +[[package]] +name = "tiktoken-rs" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4427b6b1c6b38215b92dd47a83a0ecc6735573d0a5a4c14acc0ac5b33b28adb" +dependencies = [ + "anyhow", + "base64 0.21.5", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot", + "rustc-hash", +] + [[package]] name = "time" version = "0.3.30" @@ -5258,11 +5433,10 @@ dependencies = [ [[package]] name = "tokio" -version = "1.29.1" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ - "autocfg", "backtrace", "bytes", "libc", @@ -5271,16 +5445,16 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.5", "tokio-macros", "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -5508,7 +5682,7 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "flate2", "log", "native-tls", @@ -5758,6 +5932,12 @@ dependencies = [ "rustls-webpki 0.100.2", ] +[[package]] +name = "webpki-roots" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" + [[package]] name = "whatlang" version = "0.16.2" @@ -5942,11 +6122,12 @@ dependencies = [ [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] diff --git a/dump/src/lib.rs b/dump/src/lib.rs index 15b281c41..be0053a7c 100644 --- a/dump/src/lib.rs +++ b/dump/src/lib.rs @@ -276,6 +276,7 @@ pub(crate) mod test { ), }), pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: std::marker::PhantomData, }; settings.check() diff --git a/dump/src/reader/compat/v5_to_v6.rs b/dump/src/reader/compat/v5_to_v6.rs index 8a0d6e5e1..9351ae70d 100644 --- a/dump/src/reader/compat/v5_to_v6.rs +++ b/dump/src/reader/compat/v5_to_v6.rs @@ -378,6 +378,7 @@ impl From> for v6::Settings { v5::Setting::Reset => v6::Setting::Reset, v5::Setting::NotSet => v6::Setting::NotSet, }, + embedders: v6::Setting::NotSet, _kind: std::marker::PhantomData, } } diff --git a/index-scheduler/src/batch.rs b/index-scheduler/src/batch.rs index 94a8b3f07..cf8544ae7 100644 --- a/index-scheduler/src/batch.rs +++ b/index-scheduler/src/batch.rs @@ -1202,6 +1202,10 @@ impl IndexScheduler { let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; + let embedder_configs = index.embedding_configs(index_wtxn)?; + // TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense) + let embedders = self.embedders(embedder_configs)?; + let mut builder = milli::update::IndexDocuments::new( index_wtxn, index, @@ -1220,6 +1224,8 @@ impl IndexScheduler { let (new_builder, user_result) = builder.add_documents(reader)?; builder = new_builder; + builder = builder.with_embedders(embedders.clone()); + let received_documents = if let Some(Details::DocumentAdditionOrUpdate { received_documents, @@ -1345,6 +1351,9 @@ impl IndexScheduler { for (task, (_, settings)) in tasks.iter_mut().zip(settings) { let checked_settings = settings.clone().check(); + if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) { + self.features().check_vector("Passing `embedders` in settings")? + } if checked_settings.proximity_precision.set().is_some() { self.features.features().check_proximity_precision()?; } diff --git a/index-scheduler/src/features.rs b/index-scheduler/src/features.rs index ae2823c30..d6ce3cae4 100644 --- a/index-scheduler/src/features.rs +++ b/index-scheduler/src/features.rs @@ -56,12 +56,12 @@ impl RoFeatures { } } - pub fn check_vector(&self) -> Result<()> { + pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> { if self.runtime.vector_store { Ok(()) } else { Err(FeatureNotEnabledError { - disabled_action: "Passing `vector` as a query parameter", + disabled_action, feature: "vector store", issue_link: "https://github.com/meilisearch/product/discussions/677", } diff --git a/index-scheduler/src/insta_snapshot.rs b/index-scheduler/src/insta_snapshot.rs index bd8fa5148..ddb9e934a 100644 --- a/index-scheduler/src/insta_snapshot.rs +++ b/index-scheduler/src/insta_snapshot.rs @@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String { planned_failures: _, run_loop_iteration: _, currently_updating_index: _, + embedders: _, } = scheduler; let rtxn = env.read_txn().unwrap(); diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index a1b6497d9..fbe38a7fb 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::update::IndexerConfig; +use meilisearch_types::milli::vector::{Embedder, EmbedderOptions}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; use puffin::FrameView; @@ -341,6 +342,8 @@ pub struct IndexScheduler { /// so that a handle to the index is available from other threads (search) in an optimized manner. currently_updating_index: Arc>>, + embedders: Arc>>>, + // ================= test // The next entry is dedicated to the tests. /// Provide a way to set a breakpoint in multiple part of the scheduler. @@ -386,6 +389,7 @@ impl IndexScheduler { auth_path: self.auth_path.clone(), version_file_path: self.version_file_path.clone(), currently_updating_index: self.currently_updating_index.clone(), + embedders: self.embedders.clone(), #[cfg(test)] test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), #[cfg(test)] @@ -484,6 +488,7 @@ impl IndexScheduler { auth_path: options.auth_path, version_file_path: options.version_file_path, currently_updating_index: Arc::new(RwLock::new(None)), + embedders: Default::default(), #[cfg(test)] test_breakpoint_sdr, @@ -1333,6 +1338,42 @@ impl IndexScheduler { } } + // TODO: consider using a type alias or a struct embedder/template + #[allow(clippy::type_complexity)] + pub fn embedders( + &self, + embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, + ) -> Result, Arc)>> { + let res: Result<_> = embedding_configs + .into_iter() + .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { + let prompt = + Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?); + // optimistically return existing embedder + { + let embedders = self.embedders.read().unwrap(); + if let Some(embedder) = embedders.get(&embedder_options) { + return Ok((name, (embedder.clone(), prompt))); + } + } + + // add missing embedder + let embedder = Arc::new( + Embedder::new(embedder_options.clone()) + .map_err(meilisearch_types::milli::vector::Error::from) + .map_err(meilisearch_types::milli::UserError::from) + .map_err(meilisearch_types::milli::Error::from)?, + ); + { + let mut embedders = self.embedders.write().unwrap(); + embedders.insert(embedder_options, embedder.clone()); + } + Ok((name, (embedder, prompt))) + }) + .collect(); + res + } + /// Blocks the thread until the test handle asks to progress to/through this breakpoint. /// /// Two messages are sent through the channel for each breakpoint. diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index b1dc6b777..b1cc7cf82 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -256,6 +256,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ; InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ; InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ; +InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ; InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ; InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ; @@ -303,7 +304,8 @@ TaskNotFound , InvalidRequest , NOT_FOUND ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; -UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE +UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; +VectorEmbeddingError , InvalidRequest , BAD_REQUEST } impl ErrorCode for JoinError { @@ -336,6 +338,9 @@ impl ErrorCode for milli::Error { UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { Code::InvalidDocumentId } + UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, + UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::MultiplePrimaryKeyCandidatesFound { .. } => { Code::IndexPrimaryKeyMultipleCandidatesFound @@ -358,6 +363,7 @@ impl ErrorCode for milli::Error { UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance } + UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, } } } diff --git a/meilisearch-types/src/settings.rs b/meilisearch-types/src/settings.rs index 487354b8e..da06d5264 100644 --- a/meilisearch-types/src/settings.rs +++ b/meilisearch-types/src/settings.rs @@ -199,6 +199,10 @@ pub struct Settings { #[deserr(default, error = DeserrJsonError)] pub pagination: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default, error = DeserrJsonError)] + pub embedders: Setting>>, + #[serde(skip)] #[deserr(skip)] pub _kind: PhantomData, @@ -222,6 +226,7 @@ impl Settings { typo_tolerance: Setting::Reset, faceting: Setting::Reset, pagination: Setting::Reset, + embedders: Setting::Reset, _kind: PhantomData, } } @@ -243,6 +248,7 @@ impl Settings { typo_tolerance, faceting, pagination, + embedders, .. } = self; @@ -262,6 +268,7 @@ impl Settings { typo_tolerance, faceting, pagination, + embedders, _kind: PhantomData, } } @@ -307,6 +314,7 @@ impl Settings { typo_tolerance: self.typo_tolerance, faceting: self.faceting, pagination: self.pagination, + embedders: self.embedders, _kind: PhantomData, } } @@ -490,6 +498,12 @@ pub fn apply_settings_to_builder( Setting::Reset => builder.reset_pagination_max_total_hits(), Setting::NotSet => (), } + + match settings.embedders.clone() { + Setting::Set(value) => builder.set_embedder_settings(value), + Setting::Reset => builder.reset_embedder_settings(), + Setting::NotSet => (), + } } pub fn settings( @@ -571,6 +585,12 @@ pub fn settings( ), }; + let embedders = index + .embedding_configs(rtxn)? + .into_iter() + .map(|(name, config)| (name, Setting::Set(config.into()))) + .collect(); + Ok(Settings { displayed_attributes: match displayed_attributes { Some(attrs) => Setting::Set(attrs), @@ -599,6 +619,7 @@ pub fn settings( typo_tolerance: Setting::Set(typo_tolerance), faceting: Setting::Set(faceting), pagination: Setting::Set(pagination), + embedders: Setting::Set(embedders), _kind: PhantomData, }) } @@ -747,6 +768,7 @@ pub(crate) mod test { typo_tolerance: Setting::NotSet, faceting: Setting::NotSet, pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: PhantomData::, }; @@ -772,6 +794,7 @@ pub(crate) mod test { typo_tolerance: Setting::NotSet, faceting: Setting::NotSet, pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: PhantomData::, }; diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index f75516731..d5f08936d 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -686,7 +686,7 @@ impl SearchAggregator { ret.max_terms_number = q.split_whitespace().count(); } - if let Some(ref vector) = vector { + if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector { ret.max_vector_size = vector.len(); } diff --git a/meilisearch/src/main.rs b/meilisearch/src/main.rs index 246d62c3b..ddd37bbb6 100644 --- a/meilisearch/src/main.rs +++ b/meilisearch/src/main.rs @@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; /// does all the setup before meilisearch is launched fn setup(opt: &Opt) -> anyhow::Result<()> { let mut log_builder = env_logger::Builder::new(); - log_builder.parse_filters(&opt.log_level.to_string()); + let log_filters = format!( + "{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn", + opt.log_level + ); + log_builder.parse_filters(&log_filters); log_builder.init(); diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 142a424c0..72440711c 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::VectorQuery; use serde_json::Value; use crate::analytics::{Analytics, FacetSearchAggregator}; @@ -117,7 +118,7 @@ impl From for SearchQuery { highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), crop_marker: DEFAULT_CROP_MARKER(), matching_strategy, - vector, + vector: vector.map(VectorQuery::Vector), attributes_to_search_on, } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 5a0a9e92b..e63a95e60 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -2,12 +2,13 @@ use actix_web::web::Data; use actix_web::{web, HttpRequest, HttpResponse}; use deserr::actix_web::{AwebJson, AwebQueryParameter}; use index_scheduler::IndexScheduler; -use log::debug; +use log::{debug, warn}; use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::VectorQuery; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -88,7 +89,7 @@ impl From for SearchQuery { Self { q: other.q, - vector: other.vector.map(CS::into_inner), + vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), offset: other.offset.0, limit: other.limit.0, page: other.page.as_deref().copied(), @@ -193,6 +194,9 @@ pub async fn search_with_post( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); + + embed(&mut query, index_scheduler.get_ref(), &index).await?; + let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; if let Ok(ref search_result) = search_result { @@ -206,6 +210,38 @@ pub async fn search_with_post( Ok(HttpResponse::Ok().json(search_result)) } +pub async fn embed( + query: &mut SearchQuery, + index_scheduler: &IndexScheduler, + index: &meilisearch_types::milli::Index, +) -> Result<(), ResponseError> { + if let Some(VectorQuery::String(prompt)) = query.vector.take() { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedder = index_scheduler.embedders(embedder_configs)?; + + /// FIXME: add error if no embedder, remove unwrap, support multiple embedders + let embeddings = embedder + .get("default") + .unwrap() + .0 + .embed(vec![prompt]) + .await + .map_err(meilisearch_types::milli::vector::Error::from) + .map_err(meilisearch_types::milli::UserError::from) + .map_err(meilisearch_types::milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); + + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); + } else { + query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); + } + }; + Ok(()) +} + #[cfg(test)] mod test { use super::*; diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index bcb8bb2a1..4e578572d 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator}; use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::sequential_extractor::SeqHandler; +use crate::routes::indexes::search::embed; use crate::search::{ add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, }; @@ -74,6 +75,8 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; + embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?; + let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features)) .await diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 41f073b48..235b745a9 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -16,6 +16,7 @@ use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::{ dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, + VectorQuery, }; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -46,7 +47,7 @@ pub struct SearchQuery { #[deserr(default, error = DeserrJsonError)] pub q: Option, #[deserr(default, error = DeserrJsonError)] - pub vector: Option>, + pub vector: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -105,7 +106,7 @@ pub struct SearchQueryWithIndex { #[deserr(default, error = DeserrJsonError)] pub q: Option, #[deserr(default, error = DeserrJsonError)] - pub vector: Option>, + pub vector: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -339,11 +340,18 @@ fn prepare_search<'t>( let mut search = index.search(rtxn); if query.vector.is_some() && query.q.is_some() { - warn!("Ignoring the query string `q` when used with the `vector` parameter."); + warn!("Attempting hybrid search"); } if let Some(ref vector) = query.vector { - search.vector(vector.clone()); + match vector { + VectorQuery::Vector(vector) => { + search.vector(vector.clone()); + } + VectorQuery::String(_) => { + panic!("Failed while preparing search; caller did not generate embedding for query") + } + } } if let Some(ref query) = query.q { @@ -375,7 +383,7 @@ fn prepare_search<'t>( } if query.vector.is_some() { - features.check_vector()?; + features.check_vector("Passing `vector` as a query parameter")?; } // compute the offset on the limit depending on the pagination mode. @@ -429,7 +437,11 @@ pub fn perform_search( prepare_search(index, &rtxn, &query, features)?; let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = - search.execute()?; + if query.q.is_some() && query.vector.is_some() { + search.execute_hybrid()? + } else { + search.execute()? + }; let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); @@ -538,13 +550,13 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - let semantic_score = match query.vector.as_ref() { + let semantic_score = /*match query.vector.as_ref() { Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { Some(vectors) => compute_semantic_score(vector, vectors)?, None => None, }, None => None, - }; + };*/ None; let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -629,7 +641,8 @@ pub fn perform_search( hits: documents, hits_info, query: query.q.unwrap_or_default(), - vector: query.vector, + // FIXME: display input vector + vector: None, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 0c1c5ab97..38931ca0f 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -27,10 +27,13 @@ fst = "0.4.7" fxhash = "0.2.1" geoutils = "0.5.1" grenad = { version = "0.4.5", default-features = false, features = [ - "rayon", "tempfile" + "rayon", + "tempfile", ] } heed = { version = "0.20.0-alpha.9", default-features = false, features = [ - "serde-json", "serde-bincode", "read-txn-no-tls" + "serde-json", + "serde-bincode", + "read-txn-no-tls", ] } indexmap = { version = "2.0.0", features = ["serde"] } instant-distance = { version = "0.6.1", features = ["with-serde"] } @@ -77,6 +80,15 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } hf-hub = "0.3.2" +tokio = { version = "1.34.0", features = ["rt"] } +futures = "0.3.29" +nolife = { version = "0.3.1" } +reqwest = { version = "0.11.16", features = [ + "rustls-tls", + "json", +], default-features = false } +tiktoken-rs = "0.5.7" +liquid = "0.26.4" [dev-dependencies] mimalloc = { version = "0.1.37", default-features = false } @@ -88,7 +100,15 @@ meili-snap = { path = "../meili-snap" } rand = { version = "0.8.5", features = ["small_rng"] } [features] -all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] +all-tokenizations = [ + "charabia/chinese", + "charabia/hebrew", + "charabia/japanese", + "charabia/thai", + "charabia/korean", + "charabia/greek", + "charabia/khmer", +] # Use POSIX semaphores instead of SysV semaphores in LMDB # For more information on this feature, see heed's Cargo.toml diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 82de56434..a94677771 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -5,8 +5,8 @@ use std::time::Instant; use heed::EnvOpenOptions; use milli::{ - execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger, - TermsMatchingStrategy, + execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, + SearchLogger, TermsMatchingStrategy, }; #[global_allocator] @@ -49,14 +49,15 @@ fn main() -> Result<(), Box> { let start = Instant::now(); let mut ctx = SearchContext::new(&index, &txn); + let universe = filtered_universe(&ctx, &None)?; + let docs = execute_search( &mut ctx, - &(!query.trim().is_empty()).then(|| query.trim().to_owned()), - &None, + (!query.trim().is_empty()).then(|| query.trim()), TermsMatchingStrategy::Last, milli::score_details::ScoringStrategy::Skip, false, - &None, + universe, &None, GeoSortStrategy::default(), 0, diff --git a/milli/src/error.rs b/milli/src/error.rs index cbbd8a3e5..032fd63a7 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -180,6 +180,14 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco UnknownInternalDocumentId { document_id: DocumentId }, #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] InvalidMinTypoWordLenSetting(u8, u8), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), + #[error(transparent)] + MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), + #[error(transparent)] + InvalidPrompt(#[from] crate::prompt::error::NewPromptError), + #[error("Invalid prompt in for embeddings with name '{0}': {1}")] + InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), } #[derive(Error, Debug)] @@ -336,6 +344,26 @@ impl From for Error { } } +#[derive(Debug, Clone, Copy)] +pub enum FaultSource { + User, + Runtime, + Bug, + Undecided, +} + +impl std::fmt::Display for FaultSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + FaultSource::User => "user error", + FaultSource::Runtime => "runtime error", + FaultSource::Bug => "coding error", + FaultSource::Undecided => "error", + }; + f.write_str(s) + } +} + #[test] fn conditionally_lookup_for_error_message() { let prefix = "Attribute `name` is not sortable."; diff --git a/milli/src/index.rs b/milli/src/index.rs index 01a01ac37..307d87906 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -23,6 +23,7 @@ use crate::heed_codec::{ }; use crate::proximity::ProximityPrecision; use crate::readable_slices::ReadableSlices; +use crate::vector::EmbeddingConfig; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, @@ -74,6 +75,7 @@ pub mod main_key { pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; pub const PROXIMITY_PRECISION: &str = "proximity-precision"; + pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; } pub mod db_name { @@ -1528,6 +1530,33 @@ impl Index { Ok(script_language) } + + pub(crate) fn put_embedding_configs( + &self, + wtxn: &mut RwTxn<'_>, + configs: Vec<(String, EmbeddingConfig)>, + ) -> heed::Result<()> { + self.main.remap_types::>>().put( + wtxn, + main_key::EMBEDDING_CONFIGS, + &configs, + ) + } + + pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result { + self.main.remap_key_type::().delete(wtxn, main_key::EMBEDDING_CONFIGS) + } + + pub fn embedding_configs( + &self, + rtxn: &RoTxn<'_>, + ) -> Result> { + Ok(self + .main + .remap_types::>>() + .get(rtxn, main_key::EMBEDDING_CONFIGS)? + .unwrap_or_default()) + } } #[cfg(test)] diff --git a/milli/src/lib.rs b/milli/src/lib.rs index acea72c41..b3c15e205 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -17,11 +17,13 @@ pub mod facet; mod fields_ids_map; pub mod heed_codec; pub mod index; +pub mod prompt; pub mod proximity; mod readable_slices; pub mod score_details; mod search; pub mod update; +pub mod vector; #[cfg(test)] #[macro_use] @@ -37,8 +39,8 @@ pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; pub use search::new::{ - execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger, - VisualSearchLogger, + execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, + SearchLogger, VisualSearchLogger, }; use serde_json::Value; pub use {charabia as tokenizer, heed}; @@ -60,7 +62,7 @@ pub use self::index::Index; pub use self::search::{ FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, - DEFAULT_VALUES_PER_FACET, + VectorQuery, DEFAULT_VALUES_PER_FACET, }; pub type Result = std::result::Result; diff --git a/milli/src/prompt/context.rs b/milli/src/prompt/context.rs new file mode 100644 index 000000000..a28a87caa --- /dev/null +++ b/milli/src/prompt/context.rs @@ -0,0 +1,97 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; +use super::fields::Fields; +use crate::FieldsIdsMap; + +#[derive(Debug, Clone)] +pub struct Context<'a> { + document: &'a Document<'a>, + fields: Fields<'a>, +} + +impl<'a> Context<'a> { + pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { + Self { document, fields: Fields::new(document, field_id_map) } + } +} + +impl<'a> ObjectView for Context<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(self.document.as_value()) + .chain(std::iter::once(self.fields.as_value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "doc" || index == "fields" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "doc" => Some(self.document.as_value()), + "fields" => Some(self.fields.as_value()), + _ => None, + } + } +} + +impl<'a> ValueView for Context<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/prompt/document.rs b/milli/src/prompt/document.rs new file mode 100644 index 000000000..b5d43b5be --- /dev/null +++ b/milli/src/prompt/document.rs @@ -0,0 +1,131 @@ +use std::cell::OnceCell; +use std::collections::BTreeMap; + +use liquid::model::{ + DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use crate::update::del_add::{DelAdd, KvReaderDelAdd}; +use crate::FieldsIdsMap; + +#[derive(Debug, Clone)] +pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>); + +#[derive(Debug, Clone)] +struct ParsedValue(std::cell::OnceCell); + +impl ParsedValue { + fn empty() -> ParsedValue { + ParsedValue(OnceCell::new()) + } + + fn get(&self, raw: &[u8]) -> &LiquidValue { + self.0.get_or_init(|| { + let value: serde_json::Value = serde_json::from_slice(raw).unwrap(); + liquid::model::to_value(&value).unwrap() + }) + } +} + +impl<'a> Document<'a> { + pub fn new( + data: obkv::KvReaderU16<'a>, + side: DelAdd, + inverted_field_map: &'a FieldsIdsMap, + ) -> Self { + let mut out_data = BTreeMap::new(); + for (fid, raw) in data { + let obkv = KvReaderDelAdd::new(raw); + let Some(raw) = obkv.get(side) else { + continue; + }; + let Some(name) = inverted_field_map.name(fid) else { + continue; + }; + out_data.insert(name, (raw, ParsedValue::empty())); + } + Self(out_data) + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn iter(&self) -> impl Iterator + '_ { + self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned())) + } +} + +impl<'a> ObjectView for Document<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.len() as i64 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + let keys = BTreeMap::keys(&self.0).map(|&s| s.into()); + Box::new(keys) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView)) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView))) + } + + fn contains_key(&self, index: &str) -> bool { + self.0.contains_key(index) + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView) + } +} + +impl<'a> ValueView for Document<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object(self.iter().collect()) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/prompt/error.rs b/milli/src/prompt/error.rs new file mode 100644 index 000000000..8a762b60a --- /dev/null +++ b/milli/src/prompt/error.rs @@ -0,0 +1,56 @@ +use crate::error::FaultSource; + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewPromptError { + pub kind: NewPromptErrorKind, + pub fault: FaultSource, +} + +impl From for crate::Error { + fn from(value: NewPromptError) -> Self { + crate::Error::UserError(crate::UserError::InvalidPrompt(value)) + } +} + +impl NewPromptError { + pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError { + Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User } + } + + pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError { + Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum NewPromptErrorKind { + #[error("cannot parse template: {0}")] + CannotParseTemplate(liquid::Error), + #[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")] + InvalidFieldsInTemplate(liquid::Error), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct RenderPromptError { + pub kind: RenderPromptErrorKind, + pub fault: FaultSource, +} +impl RenderPromptError { + pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError { + Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RenderPromptErrorKind { + #[error("missing field in document: {0}")] + MissingContext(liquid::Error), +} + +impl From for crate::Error { + fn from(value: RenderPromptError) -> Self { + crate::Error::UserError(crate::UserError::MissingDocumentField(value)) + } +} diff --git a/milli/src/prompt/fields.rs b/milli/src/prompt/fields.rs new file mode 100644 index 000000000..3187485f1 --- /dev/null +++ b/milli/src/prompt/fields.rs @@ -0,0 +1,172 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; +use crate::FieldsIdsMap; +#[derive(Debug, Clone)] +pub struct Fields<'a>(Vec>); + +impl<'a> Fields<'a> { + pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { + Self( + std::iter::repeat(document) + .zip(field_id_map.iter()) + .map(|(document, (_fid, name))| FieldValue { document, name }) + .collect(), + ) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct FieldValue<'a> { + name: &'a str, + document: &'a Document<'a>, +} + +impl<'a> ValueView for FieldValue<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl<'a> FieldValue<'a> { + pub fn name(&self) -> &&'a str { + &self.name + } + + pub fn value(&self) -> &dyn ValueView { + self.document.get(self.name).unwrap_or(&LiquidValue::Nil) + } + + pub fn is_empty(&self) -> bool { + self.size() == 0 + } +} + +impl<'a> ObjectView for FieldValue<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "name" || index == "value" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "name" => Some(self.name()), + "value" => Some(self.value()), + _ => None, + } + } +} + +impl<'a> ArrayView for Fields<'a> { + fn as_value(&self) -> &dyn ValueView { + self.0.as_value() + } + + fn size(&self) -> i64 { + self.0.len() as i64 + } + + fn values<'k>(&'k self) -> Box + 'k> { + self.0.values() + } + + fn contains_key(&self, index: i64) -> bool { + self.0.contains_key(index) + } + + fn get(&self, index: i64) -> Option<&dyn ValueView> { + ArrayView::get(&self.0, index) + } +} + +impl<'a> ValueView for Fields<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + self.0.render() + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + self.0.source() + } + + fn type_name(&self) -> &'static str { + self.0.type_name() + } + + fn query_state(&self, state: liquid::model::State) -> bool { + self.0.query_state(state) + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + self.0.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + self.0.to_value() + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } +} diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs new file mode 100644 index 000000000..351a51bb1 --- /dev/null +++ b/milli/src/prompt/mod.rs @@ -0,0 +1,144 @@ +mod context; +mod document; +pub(crate) mod error; +mod fields; +mod template_checker; + +use std::convert::TryFrom; + +use error::{NewPromptError, RenderPromptError}; + +use self::context::Context; +use self::document::Document; +use crate::update::del_add::DelAdd; +use crate::FieldsIdsMap; + +pub struct Prompt { + template: liquid::Template, + template_text: String, + strategy: PromptFallbackStrategy, + fallback: String, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PromptData { + pub template: String, + pub strategy: PromptFallbackStrategy, + pub fallback: String, +} + +impl From for PromptData { + fn from(value: Prompt) -> Self { + Self { template: value.template_text, strategy: value.strategy, fallback: value.fallback } + } +} + +impl TryFrom for Prompt { + type Error = NewPromptError; + + fn try_from(value: PromptData) -> Result { + Prompt::new(value.template, Some(value.strategy), Some(value.fallback)) + } +} + +impl Clone for Prompt { + fn clone(&self) -> Self { + let template_text = self.template_text.clone(); + Self { + template: new_template(&template_text).unwrap(), + template_text, + strategy: self.strategy, + fallback: self.fallback.clone(), + } + } +} + +fn new_template(text: &str) -> Result { + liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text) +} + +fn default_template() -> liquid::Template { + new_template(default_template_text()).unwrap() +} + +fn default_template_text() -> &'static str { + "{% for field in fields %} \ + {{ field.name }}: {{ field.value }}\n\ + {% endfor %}" +} + +fn default_fallback() -> &'static str { + "" +} + +impl Default for Prompt { + fn default() -> Self { + Self { + template: default_template(), + template_text: default_template_text().into(), + strategy: Default::default(), + fallback: default_fallback().into(), + } + } +} + +impl Default for PromptData { + fn default() -> Self { + Self { + template: default_template_text().into(), + strategy: Default::default(), + fallback: default_fallback().into(), + } + } +} + +impl Prompt { + pub fn new( + template: String, + strategy: Option, + fallback: Option, + ) -> Result { + let this = Self { + template: liquid::ParserBuilder::with_stdlib() + .build() + .unwrap() + .parse(&template) + .map_err(NewPromptError::cannot_parse_template)?, + template_text: template, + strategy: strategy.unwrap_or_default(), + fallback: fallback.unwrap_or_default(), + }; + + // render template with special object that's OK with `doc.*` and `fields.*` + /// FIXME: doesn't work for nested objects e.g. `doc.a.b` + this.template + .render(&template_checker::TemplateChecker) + .map_err(NewPromptError::invalid_fields_in_template)?; + + Ok(this) + } + + pub fn render( + &self, + document: obkv::KvReaderU16<'_>, + side: DelAdd, + field_id_map: &FieldsIdsMap, + ) -> Result { + let document = Document::new(document, side, field_id_map); + let context = Context::new(&document, field_id_map); + + self.template.render(&context).map_err(RenderPromptError::missing_context) + } +} + +#[derive( + Debug, Default, Clone, PartialEq, Eq, Copy, serde::Serialize, serde::Deserialize, deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum PromptFallbackStrategy { + Fallback, + Skip, + #[default] + Error, +} diff --git a/milli/src/prompt/template_checker.rs b/milli/src/prompt/template_checker.rs new file mode 100644 index 000000000..641a9ed64 --- /dev/null +++ b/milli/src/prompt/template_checker.rs @@ -0,0 +1,282 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +#[derive(Debug)] +pub struct TemplateChecker; + +#[derive(Debug)] +pub struct DummyDoc; + +#[derive(Debug)] +pub struct DummyFields; + +#[derive(Debug)] +pub struct DummyField; + +const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil; + +impl ObjectView for DummyField { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, index: &str) -> bool { + index == "name" || index == "value" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + if self.contains_key(index) { + Some(DUMMY_VALUE.as_view()) + } else { + None + } + } +} + +impl ValueView for DummyField { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: State) -> bool { + DUMMY_VALUE.query_state(state) + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl ValueView for DummyFields { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "array" + } + + fn query_state(&self, state: State) -> bool { + DUMMY_VALUE.query_state(state) + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } +} + +impl ArrayView for DummyFields { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + i64::MAX + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, _index: i64) -> bool { + true + } + + fn get(&self, _index: i64) -> Option<&dyn ValueView> { + Some(DummyField.as_value()) + } +} + +impl ObjectView for DummyDoc { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 1000 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(std::iter::empty()) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, _index: &str) -> bool { + true + } + + fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { + Some(DUMMY_VALUE.as_view()) + } +} + +impl ValueView for DummyDoc { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: State) -> bool { + DUMMY_VALUE.query_state(state) + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl ObjectView for TemplateChecker { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "doc" || index == "fields" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "doc" => Some(DummyDoc.as_value()), + "fields" => Some(DummyFields.as_value()), + _ => None, + } + } +} + +impl ValueView for TemplateChecker { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/score_details.rs b/milli/src/score_details.rs index 8fc998ae4..f6b9db58c 100644 --- a/milli/src/score_details.rs +++ b/milli/src/score_details.rs @@ -1,3 +1,6 @@ +use std::cmp::Ordering; + +use itertools::Itertools; use serde::Serialize; use crate::distance_between_two_points; @@ -12,9 +15,24 @@ pub enum ScoreDetails { ExactAttribute(ExactAttribute), ExactWords(ExactWords), Sort(Sort), + Vector(Vector), GeoSort(GeoSort), } +#[derive(Clone, Copy)] +pub enum ScoreValue<'a> { + Score(f64), + Sort(&'a Sort), + GeoSort(&'a GeoSort), +} + +enum RankOrValue<'a> { + Rank(Rank), + Sort(&'a Sort), + GeoSort(&'a GeoSort), + Score(f64), +} + impl ScoreDetails { pub fn local_score(&self) -> Option { self.rank().map(Rank::local_score) @@ -31,11 +49,55 @@ impl ScoreDetails { ScoreDetails::ExactWords(details) => Some(details.rank()), ScoreDetails::Sort(_) => None, ScoreDetails::GeoSort(_) => None, + ScoreDetails::Vector(_) => None, } } - pub fn global_score<'a>(details: impl Iterator) -> f64 { - Rank::global_score(details.filter_map(Self::rank)) + pub fn global_score<'a>(details: impl Iterator + 'a) -> f64 { + Self::score_values(details) + .find_map(|x| { + let ScoreValue::Score(score) = x else { + return None; + }; + Some(score) + }) + .unwrap_or(1.0f64) + } + + pub fn score_values<'a>( + details: impl Iterator + 'a, + ) -> impl Iterator> + 'a { + details + .map(ScoreDetails::rank_or_value) + .coalesce(|left, right| match (left, right) { + (RankOrValue::Rank(left), RankOrValue::Rank(right)) => { + Ok(RankOrValue::Rank(Rank::merge(left, right))) + } + (left, right) => Err((left, right)), + }) + .map(|rank_or_value| match rank_or_value { + RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()), + RankOrValue::Sort(s) => ScoreValue::Sort(s), + RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g), + RankOrValue::Score(s) => ScoreValue::Score(s), + }) + } + + fn rank_or_value(&self) -> RankOrValue<'_> { + match self { + ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()), + ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()), + ScoreDetails::Proximity(p) => RankOrValue::Rank(*p), + ScoreDetails::Fid(f) => RankOrValue::Rank(*f), + ScoreDetails::Position(p) => RankOrValue::Rank(*p), + ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()), + ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), + ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), + ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), + ScoreDetails::Vector(vector) => RankOrValue::Score( + vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64), + ), + } } /// Panics @@ -181,6 +243,19 @@ impl ScoreDetails { details_map.insert(sort, sort_details); order += 1; } + ScoreDetails::Vector(s) => { + let vector = format!("vectorSort({:?})", s.target_vector); + let value = s.value_similarity.as_ref().map(|(v, _)| v); + let similarity = s.value_similarity.as_ref().map(|(_, s)| s); + + let details = serde_json::json!({ + "order": order, + "value": value, + "similarity": similarity, + }); + details_map.insert(vector, details); + order += 1; + } } } details_map @@ -297,15 +372,21 @@ impl Rank { pub fn global_score(details: impl Iterator) -> f64 { let mut rank = Rank { rank: 1, max_rank: 1 }; for inner_rank in details { - rank.rank -= 1; - - rank.rank *= inner_rank.max_rank; - rank.max_rank *= inner_rank.max_rank; - - rank.rank += inner_rank.rank; + rank = Rank::merge(rank, inner_rank); } rank.local_score() } + + pub fn merge(mut outer: Rank, inner: Rank) -> Rank { + outer.rank = outer.rank.saturating_sub(1); + + outer.rank *= inner.max_rank; + outer.max_rank *= inner.max_rank; + + outer.rank += inner.rank; + + outer + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] @@ -335,13 +416,78 @@ pub struct Sort { pub value: serde_json::Value, } -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +impl PartialOrd for Sort { + fn partial_cmp(&self, other: &Self) -> Option { + if self.field_name != other.field_name { + return None; + } + if self.ascending != other.ascending { + return None; + } + match (&self.value, &other.value) { + (serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal), + (serde_json::Value::Null, _) => Some(Ordering::Less), + (_, serde_json::Value::Null) => Some(Ordering::Greater), + // numbers are always before strings + (serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater), + (serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less), + (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { + // FIXME: unwrap permitted here? + let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?; + // 12 < 42, and when ascending, we want to see 12 first, so the smallest. + // Hence, when ascending, smaller is better + Some(if self.ascending { order.reverse() } else { order }) + } + (serde_json::Value::String(left), serde_json::Value::String(right)) => { + let order = left.cmp(right); + // Taking e.g. "a" and "z" + // "a" < "z", and when ascending, we want to see "a" first, so the smallest. + // Hence, when ascending, smaller is better + Some(if self.ascending { order.reverse() } else { order }) + } + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] pub struct GeoSort { pub target_point: [f64; 2], pub ascending: bool, pub value: Option<[f64; 2]>, } +impl PartialOrd for GeoSort { + fn partial_cmp(&self, other: &Self) -> Option { + if self.target_point != other.target_point { + return None; + } + if self.ascending != other.ascending { + return None; + } + Some(match (self.distance(), other.distance()) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(left), Some(right)) => { + let order = left.partial_cmp(&right)?; + if self.ascending { + // when ascending, the one with the smallest distance has the best score + order.reverse() + } else { + order + } + } + }) + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub struct Vector { + pub target_vector: Vec, + pub value_similarity: Option<(Vec, f32)>, +} + impl GeoSort { pub fn distance(&self) -> Option { self.value.map(|value| distance_between_two_points(&self.target_point, &value)) diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs new file mode 100644 index 000000000..02c518126 --- /dev/null +++ b/milli/src/search/hybrid.rs @@ -0,0 +1,336 @@ +use std::cmp::Ordering; +use std::collections::HashMap; + +use itertools::Itertools; +use roaring::RoaringBitmap; + +use super::new::{execute_vector_search, PartialSearchResult}; +use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; +use crate::{ + execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult, +}; + +struct CombinedSearchResult { + matching_words: MatchingWords, + candidates: RoaringBitmap, + document_scores: Vec<(u32, CombinedScore)>, +} + +type CombinedScore = (Vec, Option>); + +fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { + let mut left_main_it = ScoreDetails::score_values(left.0.iter()); + let mut left_sub_it = + ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten()); + + let mut right_main_it = ScoreDetails::score_values(right.0.iter()); + let mut right_sub_it = + ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten()); + + let mut left_main = left_main_it.next(); + let mut left_sub = left_sub_it.next(); + let mut right_main = right_main_it.next(); + let mut right_sub = right_sub_it.next(); + + loop { + let left = + take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it); + + let right = + take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it); + + match (left, right) { + (None, None) => return Ordering::Equal, + (None, Some(_)) => return Ordering::Less, + (Some(_), None) => return Ordering::Greater, + (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { + if (left - right).abs() <= f64::EPSILON { + continue; + } + return left.partial_cmp(&right).unwrap(); + } + (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { + match left.partial_cmp(right).unwrap() { + Ordering::Equal => continue, + order => return order, + } + } + (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { + match left.partial_cmp(right).unwrap() { + Ordering::Equal => continue, + order => return order, + } + } + (Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater, + (Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less, + // if we have this, we're bad + (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) + | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { + unreachable!("Unexpected geo and sort comparison") + } + } + } +} + +fn take_best_score<'a>( + main_score: &mut Option>, + sub_score: &mut Option>, + main_it: &mut impl Iterator>, + sub_it: &mut impl Iterator>, +) -> Option> { + match (*main_score, *sub_score) { + (Some(main), None) => { + *main_score = main_it.next(); + Some(main) + } + (None, Some(sub)) => { + *sub_score = sub_it.next(); + Some(sub) + } + (main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => { + // take max, both advance + *main_score = main_it.next(); + *sub_score = sub_it.next(); + if main_f >= sub_v { + main + } else { + sub + } + } + (main @ Some(ScoreValue::Score(_)), _) => { + *main_score = main_it.next(); + main + } + (_, sub @ Some(ScoreValue::Score(_))) => { + *sub_score = sub_it.next(); + sub + } + (main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => { + // take best advance both + *main_score = main_it.next(); + *sub_score = sub_it.next(); + if main_geo >= sub_geo { + main + } else { + sub + } + } + (main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => { + // take best advance both + *main_score = main_it.next(); + *sub_score = sub_it.next(); + if main_sort >= sub_sort { + main + } else { + sub + } + } + ( + Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), + Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), + ) => None, + + (None, None) => None, + } +} + +impl CombinedSearchResult { + fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self { + let mut docid_scores = HashMap::new(); + for (docid, score) in + main_results.documents_ids.iter().zip(main_results.document_scores.into_iter()) + { + docid_scores.insert(*docid, (score, None)); + } + + for (docid, score) in ancillary_results + .documents_ids + .iter() + .zip(ancillary_results.document_scores.into_iter()) + { + docid_scores + .entry(*docid) + .and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score)); + } + + let mut document_scores: Vec<_> = docid_scores.into_iter().collect(); + + document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse()); + + Self { + matching_words: main_results.matching_words, + candidates: main_results.candidates, + document_scores, + } + } + + fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult { + let mut documents_ids = + Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); + let mut document_scores = + Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); + + let mut documents_seen = RoaringBitmap::new(); + for (docid, (main_score, _sub_score)) in left + .document_scores + .into_iter() + .merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| { + // the first value is the one with the greatest score + compare_scores(left, right).is_ge() + }) + // remove documents we already saw + .filter(|(docid, _)| documents_seen.insert(*docid)) + // start skipping **after** the filter + .skip(from) + // take **after** skipping + .take(length) + { + documents_ids.push(docid); + // TODO: pass both scores to documents_score in some way? + document_scores.push(main_score); + } + + SearchResult { + matching_words: left.matching_words, + candidates: left.candidates | right.candidates, + documents_ids, + document_scores, + } + } +} + +impl<'a> Search<'a> { + pub fn execute_hybrid(&self) -> Result { + // TODO: find classier way to achieve that than to reset vector and query params + // create separate keyword and semantic searches + let mut search = Search { + query: self.query.clone(), + vector: self.vector.clone(), + filter: self.filter.clone(), + offset: 0, + limit: self.limit + self.offset, + sort_criteria: self.sort_criteria.clone(), + searchable_attributes: self.searchable_attributes, + geo_strategy: self.geo_strategy, + terms_matching_strategy: self.terms_matching_strategy, + scoring_strategy: ScoringStrategy::Detailed, + words_limit: self.words_limit, + exhaustive_number_hits: self.exhaustive_number_hits, + rtxn: self.rtxn, + index: self.index, + }; + + let vector_query = search.vector.take(); + let keyword_query = self.query.as_deref(); + + let keyword_results = search.execute()?; + + // skip semantic search if we don't have a vector query (placeholder search) + let Some(vector_query) = vector_query else { + return Ok(keyword_results); + }; + + // completely skip semantic search if the results of the keyword search are good enough + if self.results_good_enough(&keyword_results) { + return Ok(keyword_results); + } + + search.vector = Some(vector_query); + search.query = None; + + // TODO: would be better to have two distinct functions at this point + let vector_results = search.execute()?; + + // Compute keyword scores for vector_results + let keyword_results_for_vector = + self.keyword_results_for_vector(keyword_query, &vector_results)?; + + // compute vector scores for keyword_results + let vector_results_for_keyword = + // can unwrap because we returned already if there was no vector query + self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?; + + let keyword_results = + CombinedSearchResult::new(keyword_results, vector_results_for_keyword); + let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector); + + let merge_results = + CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit); + assert!(merge_results.documents_ids.len() <= self.limit); + Ok(merge_results) + } + + fn vector_results_for_keyword( + &self, + vector: &[f32], + keyword_results: &SearchResult, + ) -> Result { + let mut ctx = SearchContext::new(self.index, self.rtxn); + + if let Some(searchable_attributes) = self.searchable_attributes { + ctx.searchable_attributes(searchable_attributes)?; + } + + let universe = keyword_results.documents_ids.iter().collect(); + + execute_vector_search( + &mut ctx, + vector, + ScoringStrategy::Detailed, + universe, + &self.sort_criteria, + self.geo_strategy, + 0, + self.limit + self.offset, + ) + } + + fn keyword_results_for_vector( + &self, + query: Option<&str>, + vector_results: &SearchResult, + ) -> Result { + let mut ctx = SearchContext::new(self.index, self.rtxn); + + if let Some(searchable_attributes) = self.searchable_attributes { + ctx.searchable_attributes(searchable_attributes)?; + } + + let universe = vector_results.documents_ids.iter().collect(); + + execute_search( + &mut ctx, + query, + self.terms_matching_strategy, + ScoringStrategy::Detailed, + self.exhaustive_number_hits, + universe, + &self.sort_criteria, + self.geo_strategy, + 0, + self.limit + self.offset, + Some(self.words_limit), + &mut DefaultSearchLogger, + &mut DefaultSearchLogger, + ) + } + + fn results_good_enough(&self, keyword_results: &SearchResult) -> bool { + const GOOD_ENOUGH_SCORE: f64 = 0.9; + + // 1. we check that we got a sufficient number of results + if keyword_results.document_scores.len() < self.limit + self.offset { + return false; + } + + // 2. and that all results have a good enough score. + // we need to check all results because due to sort like rules, they're not necessarily in relevancy order + for score in &keyword_results.document_scores { + let score = ScoreDetails::global_score(score.iter()); + if score < GOOD_ENOUGH_SCORE { + return false; + } + } + true + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index ee8cd1faf..8b541ffcd 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -3,6 +3,7 @@ use std::ops::ControlFlow; use charabia::normalizer::NormalizerOption; use charabia::Normalize; +use deserr::{DeserializeError, Deserr, Sequence}; use fst::automaton::{Automaton, Str}; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; @@ -12,12 +13,13 @@ use roaring::bitmap::RoaringBitmap; pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; -use self::new::PartialSearchResult; +use self::new::{execute_vector_search, PartialSearchResult}; use crate::error::UserError; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::{ - execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, + execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, + Result, SearchContext, }; // Building these factories is not free. @@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100; pub mod facet; mod fst_utils; +pub mod hybrid; pub mod new; pub struct Search<'a> { @@ -50,6 +53,53 @@ pub struct Search<'a> { index: &'a Index, } +#[derive(Debug, Clone, PartialEq)] +pub enum VectorQuery { + Vector(Vec), + String(String), +} + +impl Deserr for VectorQuery +where + E: DeserializeError, +{ + fn deserialize_from_value( + value: deserr::Value, + location: deserr::ValuePointerRef, + ) -> std::result::Result { + match value { + deserr::Value::String(s) => Ok(VectorQuery::String(s)), + deserr::Value::Sequence(seq) => { + let v: std::result::Result, _> = seq + .into_iter() + .enumerate() + .map(|(index, v)| match v.into_value() { + deserr::Value::Float(f) => Ok(f as f32), + deserr::Value::Integer(i) => Ok(i as f32), + v => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: v, + accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer], + }, + location.push_index(index), + ))), + }) + .collect(); + Ok(VectorQuery::Vector(v?)) + } + _ => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: value, + accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence], + }, + location, + ))), + } + } +} + impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { Search { @@ -75,8 +125,8 @@ impl<'a> Search<'a> { self } - pub fn vector(&mut self, vector: impl Into>) -> &mut Search<'a> { - self.vector = Some(vector.into()); + pub fn vector(&mut self, vector: Vec) -> &mut Search<'a> { + self.vector = Some(vector); self } @@ -140,23 +190,35 @@ impl<'a> Search<'a> { ctx.searchable_attributes(searchable_attributes)?; } + let universe = filtered_universe(&ctx, &self.filter)?; let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = - execute_search( - &mut ctx, - &self.query, - &self.vector, - self.terms_matching_strategy, - self.scoring_strategy, - self.exhaustive_number_hits, - &self.filter, - &self.sort_criteria, - self.geo_strategy, - self.offset, - self.limit, - Some(self.words_limit), - &mut DefaultSearchLogger, - &mut DefaultSearchLogger, - )?; + match self.vector.as_ref() { + Some(vector) => execute_vector_search( + &mut ctx, + vector, + self.scoring_strategy, + universe, + &self.sort_criteria, + self.geo_strategy, + self.offset, + self.limit, + )?, + None => execute_search( + &mut ctx, + self.query.as_deref(), + self.terms_matching_strategy, + self.scoring_strategy, + self.exhaustive_number_hits, + universe, + &self.sort_criteria, + self.geo_strategy, + self.offset, + self.limit, + Some(self.words_limit), + &mut DefaultSearchLogger, + &mut DefaultSearchLogger, + )?, + }; // consume context and located_query_terms to build MatchingWords. let matching_words = match located_query_terms { diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 5d61de0f4..067fa1efd 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -498,19 +498,19 @@ mod tests { use super::*; use crate::index::tests::TempIndex; - use crate::{execute_search, SearchContext}; + use crate::{execute_search, filtered_universe, SearchContext}; impl<'a> MatcherBuilder<'a> { fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { let mut ctx = SearchContext::new(index, rtxn); + let universe = filtered_universe(&ctx, &None).unwrap(); let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, - &Some(query.to_string()), - &None, + Some(query), crate::TermsMatchingStrategy::default(), crate::score_details::ScoringStrategy::Skip, false, - &None, + universe, &None, crate::search::new::GeoSortStrategy::default(), 0, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index a1b5da4e8..372c89601 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -16,6 +16,7 @@ mod small_bitmap; mod exact_attribute; mod sort; +mod vector_sort; #[cfg(test)] mod tests; @@ -28,7 +29,6 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; -use instant_distance::Search; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -46,7 +46,7 @@ use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; use self::graph_based_ranking_rule::Words; use self::interner::Interned; -use crate::distance::NDotProductPoint; +use self::vector_sort::VectorSort; use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; @@ -258,6 +258,70 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( Ok(ranking_rules) } +fn get_ranking_rules_for_vector<'ctx>( + ctx: &SearchContext<'ctx>, + sort_criteria: &Option>, + geo_strategy: geo_sort::Strategy, + target: &[f32], +) -> Result>> { + // query graph search + + let mut sort = false; + let mut sorted_fields = HashSet::new(); + let mut geo_sorted = false; + + let mut vector = false; + let mut ranking_rules: Vec> = vec![]; + + let settings_ranking_rules = ctx.index.criteria(ctx.txn)?; + for rr in settings_ranking_rules { + match rr { + crate::Criterion::Words + | crate::Criterion::Typo + | crate::Criterion::Proximity + | crate::Criterion::Attribute + | crate::Criterion::Exactness => { + if !vector { + let vector_candidates = ctx.index.documents_ids(ctx.txn)?; + let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?; + ranking_rules.push(Box::new(vector_sort)); + vector = true; + } + } + crate::Criterion::Sort => { + if sort { + continue; + } + resolve_sort_criteria( + sort_criteria, + ctx, + &mut ranking_rules, + &mut sorted_fields, + &mut geo_sorted, + geo_strategy, + )?; + sort = true; + } + crate::Criterion::Asc(field_name) => { + if sorted_fields.contains(&field_name) { + continue; + } + sorted_fields.insert(field_name.clone()); + ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?)); + } + crate::Criterion::Desc(field_name) => { + if sorted_fields.contains(&field_name) { + continue; + } + sorted_fields.insert(field_name.clone()); + ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?)); + } + } + } + + Ok(ranking_rules) +} + /// Return the list of initialised ranking rules to be used for a query graph search. fn get_ranking_rules_for_query_graph_search<'ctx>( ctx: &SearchContext<'ctx>, @@ -422,15 +486,62 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( Ok(()) } +pub fn filtered_universe(ctx: &SearchContext, filters: &Option) -> Result { + Ok(if let Some(filters) = filters { + filters.evaluate(ctx.txn, ctx.index)? + } else { + ctx.index.documents_ids(ctx.txn)? + }) +} + +#[allow(clippy::too_many_arguments)] +pub fn execute_vector_search( + ctx: &mut SearchContext, + vector: &[f32], + scoring_strategy: ScoringStrategy, + universe: RoaringBitmap, + sort_criteria: &Option>, + geo_strategy: geo_sort::Strategy, + from: usize, + length: usize, +) -> Result { + check_sort_criteria(ctx, sort_criteria.as_ref())?; + + /// FIXME: input universe = universe & documents_with_vectors + // for now if we're computing embeddings for ALL documents, we can assume that this is just universe + let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?; + + let mut placeholder_search_logger = logger::DefaultSearchLogger; + let placeholder_search_logger: &mut dyn SearchLogger = + &mut placeholder_search_logger; + + let BucketSortOutput { docids, scores, all_candidates } = bucket_sort( + ctx, + ranking_rules, + &PlaceholderQuery, + &universe, + from, + length, + scoring_strategy, + placeholder_search_logger, + )?; + + Ok(PartialSearchResult { + candidates: all_candidates, + document_scores: scores, + documents_ids: docids, + located_query_terms: None, + }) +} + #[allow(clippy::too_many_arguments)] pub fn execute_search( ctx: &mut SearchContext, - query: &Option, - vector: &Option>, + query: Option<&str>, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, exhaustive_number_hits: bool, - filters: &Option, + mut universe: RoaringBitmap, sort_criteria: &Option>, geo_strategy: geo_sort::Strategy, from: usize, @@ -439,60 +550,8 @@ pub fn execute_search( placeholder_search_logger: &mut dyn SearchLogger, query_graph_logger: &mut dyn SearchLogger, ) -> Result { - let mut universe = if let Some(filters) = filters { - filters.evaluate(ctx.txn, ctx.index)? - } else { - ctx.index.documents_ids(ctx.txn)? - }; - check_sort_criteria(ctx, sort_criteria.as_ref())?; - if let Some(vector) = vector { - let mut search = Search::default(); - let docids = match ctx.index.vector_hnsw(ctx.txn)? { - Some(hnsw) => { - if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { - if vector.len() != expected_size { - return Err(UserError::InvalidVectorDimensions { - expected: expected_size, - found: vector.len(), - } - .into()); - } - } - - let vector = NDotProductPoint::new(vector.clone()); - - let neighbors = hnsw.search(&vector, &mut search); - - let mut docids = Vec::new(); - let mut uniq_docids = RoaringBitmap::new(); - for instant_distance::Item { distance: _, pid, point: _ } in neighbors { - let index = pid.into_inner(); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); - if universe.contains(docid) && uniq_docids.insert(docid) { - docids.push(docid); - if docids.len() == (from + length) { - break; - } - } - } - - // return the nearest documents that are also part of the candidates - // along with a dummy list of scores that are useless in this context. - docids.into_iter().skip(from).take(length).collect() - } - None => Vec::new(), - }; - - return Ok(PartialSearchResult { - candidates: universe, - document_scores: vec![Vec::new(); docids.len()], - documents_ids: docids, - located_query_terms: None, - }); - } - let mut located_query_terms = None; let query_terms = if let Some(query) = query { // We make sure that the analyzer is aware of the stop words @@ -546,7 +605,7 @@ pub fn execute_search( terms_matching_strategy, )?; - universe = + universe &= resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; bucket_sort( diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs new file mode 100644 index 000000000..831ed45cd --- /dev/null +++ b/milli/src/search/new/vector_sort.rs @@ -0,0 +1,150 @@ +use std::future::Future; +use std::iter::FromIterator; +use std::pin::Pin; + +use nolife::DynBoxScope; +use roaring::RoaringBitmap; + +use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; +use crate::distance::NDotProductPoint; +use crate::index::Hnsw; +use crate::score_details::{self, ScoreDetails}; +use crate::{Result, SearchContext, SearchLogger, UserError}; + +pub struct VectorSort { + query: Option, + target: Vec, + vector_candidates: RoaringBitmap, + scope: nolife::DynBoxScope, +} + +type Item<'a> = instant_distance::Item<'a, NDotProductPoint>; +type SearchFut = Pin>>; + +struct SearchFamily; +impl<'a> nolife::Family<'a> for SearchFamily { + type Family = Box> + 'a>; +} + +async fn search_scope( + mut time_capsule: nolife::TimeCapsule, + hnsw: Hnsw, + target: Vec, +) -> nolife::Never { + let mut search = instant_distance::Search::default(); + let it = Box::new(hnsw.search(&NDotProductPoint::new(target), &mut search)); + let mut it: Box> = it; + loop { + time_capsule.freeze(&mut it).await; + } +} + +impl VectorSort { + pub fn new( + ctx: &SearchContext, + target: Vec, + vector_candidates: RoaringBitmap, + ) -> Result { + let hnsw = + ctx.index.vector_hnsw(ctx.txn)?.unwrap_or(Hnsw::builder().build_hnsw(Vec::default()).0); + + if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { + if target.len() != expected_size { + return Err(UserError::InvalidVectorDimensions { + expected: expected_size, + found: target.len(), + } + .into()); + } + } + + let target_clone = target.clone(); + let producer = move |time_capsule| -> SearchFut { + Box::pin(search_scope(time_capsule, hnsw, target_clone)) + }; + let scope = DynBoxScope::new(producer); + + Ok(Self { query: None, target, vector_candidates, scope }) + } +} + +impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { + fn id(&self) -> String { + "vector_sort".to_owned() + } + + fn start_iteration( + &mut self, + _ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + universe: &RoaringBitmap, + query: &Q, + ) -> Result<()> { + assert!(self.query.is_none()); + + self.query = Some(query.clone()); + self.vector_candidates &= universe; + + Ok(()) + } + + #[allow(clippy::only_used_in_recursion)] + fn next_bucket( + &mut self, + ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + universe: &RoaringBitmap, + ) -> Result>> { + let query = self.query.as_ref().unwrap().clone(); + self.vector_candidates &= universe; + + if self.vector_candidates.is_empty() { + return Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: None, + }), + })); + } + + let scope = &mut self.scope; + let target = &self.target; + let vector_candidates = &self.vector_candidates; + + scope.enter(|it| { + for item in it.by_ref() { + let item: Item = item; + let index = item.pid.into_inner(); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); + + if vector_candidates.contains(docid) { + return Ok(Some(RankingRuleOutput { + query, + candidates: RoaringBitmap::from_iter([docid]), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: target.clone(), + value_similarity: Some(( + item.point.clone().into_inner(), + 1.0 - item.distance, + )), + }), + })); + } + } + Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: target.clone(), + value_similarity: None, + }), + })) + }) + } + + fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger) { + self.query = None; + } +} diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 317a9aec3..8399c220b 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,9 +1,10 @@ use std::cmp::Ordering; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io::{self, BufReader, BufWriter}; use std::mem::size_of; use std::str::from_utf8; +use std::sync::Arc; use bytemuck::cast_slice; use grenad::Writer; @@ -13,13 +14,56 @@ use serde_json::{from_slice, Value}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::error::UserError; +use crate::prompt::Prompt; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; -use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; +use crate::vector::Embedder; +use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::(); +pub struct ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + pub manual_vectors: grenad::Reader>, + // docid -> () + pub remove_vectors: grenad::Reader>, + // docid -> prompt + pub prompts: grenad::Reader>, +} + +enum VectorStateDelta { + NoChange, + // Remove all vectors, generated or manual, from this document + NowRemoved, + + // Add the manually specified vectors, passed in the other grenad + // Remove any previously generated vectors + // Note: changing the value of the manually specified vector **should not record** this delta + WasGeneratedNowManual(Vec>), + + ManualDelta(Vec>, Vec>), + + // Add the vector computed from the specified prompt + // Remove any previous vector + // Note: changing the value of the prompt **does require** recording this delta + NowGenerated(String), +} + +impl VectorStateDelta { + fn into_values(self) -> (bool, String, (Vec>, Vec>)) { + match self { + VectorStateDelta::NoChange => Default::default(), + VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()), + VectorStateDelta::WasGeneratedNowManual(add) => { + (true, Default::default(), (Default::default(), add)) + } + VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)), + VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()), + } + } +} + /// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec @@ -27,16 +71,34 @@ const TRUNCATE_SIZE: usize = size_of::(); pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - vectors_fid: FieldId, -) -> Result>> { + field_id_map: FieldsIdsMap, + prompt: Option<&Prompt>, +) -> Result { puffin::profile_function!(); - let mut writer = create_writer( + // (docid, _index) -> KvWriterDelAdd -> Vector + let mut manual_vectors_writer = create_writer( indexer.chunk_compression_type, indexer.chunk_compression_level, tempfile::tempfile()?, ); + // (docid) -> (prompt) + let mut prompts_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + // (docid) -> () + let mut remove_vectors_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let vectors_fid = field_id_map.id("_vectors"); + let mut key_buffer = Vec::new(); let mut cursor = obkv_documents.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { @@ -53,43 +115,148 @@ pub fn extract_vector_points( // lazily get it when needed let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; - // first we retrieve the _vectors field - if let Some(value) = obkv.get(vectors_fid) { + let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) { let vectors_obkv = KvReaderDelAdd::new(value); + match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { + (Some(old), Some(new)) => { + // no autogeneration + let del_vectors = extract_vectors(old, document_id)?; + let add_vectors = extract_vectors(new, document_id)?; - // then we extract the values - let del_vectors = vectors_obkv - .get(DelAdd::Deletion) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); - let add_vectors = vectors_obkv - .get(DelAdd::Addition) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); + VectorStateDelta::ManualDelta( + del_vectors.unwrap_or_default(), + add_vectors.unwrap_or_default(), + ) + } + (None, Some(new)) => { + // was possibly autogenerated, remove all vectors for that document + let add_vectors = extract_vectors(new, document_id)?; - // and we finally push the unique vectors into the writer - push_vectors_diff( - &mut writer, - &mut key_buffer, - del_vectors.unwrap_or_default(), - add_vectors.unwrap_or_default(), - )?; - } + VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) + } + (Some(_old), None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + if document_is_kept { + // becomes autogenerated + match prompt { + Some(prompt) => VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + &field_id_map, + )?), + None => VectorStateDelta::NowRemoved, + } + } else { + VectorStateDelta::NowRemoved + } + } + (None, None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + match prompt { + Some(prompt) => { + // Don't give up if the old prompt was failing + let old_prompt = prompt + .render(obkv, DelAdd::Deletion, &field_id_map) + .unwrap_or_default(); + let new_prompt = + prompt.render(obkv, DelAdd::Addition, &field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + VectorStateDelta::NoChange + } + } + // We no longer have a prompt, so we need to remove any existing vector + None => VectorStateDelta::NowRemoved, + } + } else { + VectorStateDelta::NowRemoved + } + } + } + } else { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + match prompt { + Some(prompt) => { + // Don't give up if the old prompt was failing + let old_prompt = prompt + .render(obkv, DelAdd::Deletion, &field_id_map) + .unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + VectorStateDelta::NoChange + } + } + None => VectorStateDelta::NowRemoved, + } + } else { + VectorStateDelta::NowRemoved + } + }; + + // and we finally push the unique vectors into the writer + push_vectors_diff( + &mut remove_vectors_writer, + &mut prompts_writer, + &mut manual_vectors_writer, + &mut key_buffer, + delta, + )?; } - writer_into_reader(writer) + Ok(ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + manual_vectors: writer_into_reader(manual_vectors_writer)?, + // docid -> () + remove_vectors: writer_into_reader(remove_vectors_writer)?, + // docid -> prompt + prompts: writer_into_reader(prompts_writer)?, + }) } /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. fn push_vectors_diff( - writer: &mut Writer>, + remove_vectors_writer: &mut Writer>, + prompts_writer: &mut Writer>, + manual_vectors_writer: &mut Writer>, key_buffer: &mut Vec, - mut del_vectors: Vec>, - mut add_vectors: Vec>, + delta: VectorStateDelta, ) -> Result<()> { + let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); + if must_remove { + key_buffer.truncate(TRUNCATE_SIZE); + remove_vectors_writer.insert(&key_buffer, [])?; + } + if !prompt.is_empty() { + key_buffer.truncate(TRUNCATE_SIZE); + prompts_writer.insert(&key_buffer, prompt.as_bytes())?; + } + // We sort and dedup the vectors del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); @@ -114,7 +281,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } EitherOrBoth::Right(vector) => { // We insert only the Add part of the Obkv to inform @@ -122,7 +289,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Addition, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } } } @@ -146,3 +313,102 @@ fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result( + // docid, prompt + prompt_reader: grenad::Reader, + indexer: GrenadParameters, + embedder: Arc, +) -> Result<(grenad::Reader>, Option)> { + let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; + + let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism + let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk + + // docid, state with embedding + let mut state_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut chunks = Vec::with_capacity(n_chunks); + let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk); + let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk); + let mut chunks_ids = Vec::with_capacity(n_chunks); + let mut cursor = prompt_reader.into_cursor()?; + + let mut expected_dimension = None; + + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + // SAFETY: precondition, the grenad value was saved from a string + let prompt = unsafe { std::str::from_utf8_unchecked(value) }; + if current_chunk.len() == current_chunk.capacity() { + chunks.push(std::mem::replace( + &mut current_chunk, + Vec::with_capacity(n_vectors_per_chunk), + )); + chunks_ids.push(std::mem::replace( + &mut current_chunk_ids, + Vec::with_capacity(n_vectors_per_chunk), + )); + }; + current_chunk.push(prompt.to_owned()); + current_chunk_ids.push(docid); + + if chunks.len() == chunks.capacity() { + let chunked_embeds = rt + .block_on( + embedder + .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), + ) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + + for (docid, embeddings) in chunks_ids + .iter() + .flat_map(|docids| docids.iter()) + .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + expected_dimension = Some(embeddings.dimension()); + } + chunks_ids.clear(); + } + } + + // send last chunk + if !chunks.is_empty() { + let chunked_embeds = rt + .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + for (docid, embeddings) in chunks_ids + .iter() + .flat_map(|docids| docids.iter()) + .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + expected_dimension = Some(embeddings.dimension()); + } + } + + if !current_chunk.is_empty() { + let embeds = rt + .block_on(embedder.embed(std::mem::take(&mut current_chunk))) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + + for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + expected_dimension = Some(embeddings.dimension()); + } + } + + Ok((writer_into_reader(state_writer)?, expected_dimension)) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 57f349894..40b0dcd61 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -9,9 +9,10 @@ mod extract_word_docids; mod extract_word_pair_proximity_docids; mod extract_word_position_docids; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::BufReader; +use std::sync::Arc; use crossbeam_channel::Sender; use log::debug; @@ -23,7 +24,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; -use self::extract_vector_points::extract_vector_points; +use self::extract_vector_points::{ + extract_embeddings, extract_vector_points, ExtractedVectorPoints, +}; use self::extract_word_docids::extract_word_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; use self::extract_word_position_docids::extract_word_position_docids; @@ -32,8 +35,10 @@ use super::helpers::{ MergeFn, MergeableReader, }; use super::{helpers, TypedChunk}; +use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; -use crate::{FieldId, Result}; +use crate::vector::Embedder; +use crate::{FieldId, FieldsIdsMap, Result}; /// Extract data for each databases from obkv documents in parallel. /// Send data in grenad file over provided Sender. @@ -47,13 +52,14 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vectors_field_id: Option, + field_id_map: FieldsIdsMap, stop_words: Option>, allowed_separators: Option<&[&str]>, dictionary: Option<&[&str]>, max_positions_per_attributes: Option, exact_attributes: HashSet, proximity_precision: ProximityPrecision, + embedders: HashMap, Arc)>, ) -> Result<()> { puffin::profile_function!(); @@ -64,7 +70,8 @@ pub(crate) fn data_from_obkv_documents( original_documents_chunk, indexer, lmdb_writer_sx.clone(), - vectors_field_id, + field_id_map.clone(), + embedders.clone(), ) }) .collect::>()?; @@ -276,24 +283,42 @@ fn send_original_documents_data( original_documents_chunk: Result>>, indexer: GrenadParameters, lmdb_writer_sx: Sender>, - vectors_field_id: Option, + field_id_map: FieldsIdsMap, + embedders: HashMap, Arc)>, ) -> Result<()> { let original_documents_chunk = original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; - if let Some(vectors_field_id) = vectors_field_id { - let documents_chunk_cloned = original_documents_chunk.clone(); - let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); - rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); - let _ = match result { - Ok(vector_points) => { - lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) - } - Err(error) => lmdb_writer_sx_cloned.send(Err(error)), - }; - }); - } + let documents_chunk_cloned = original_documents_chunk.clone(); + let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + rayon::spawn(move || { + let (embedder, prompt) = embedders.get("default").cloned().unzip(); + let result = + extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); + let _ = match result { + Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { + /// FIXME: support multiple embedders + let results = embedder.and_then(|embedder| { + match extract_embeddings(prompts, indexer, embedder.clone()) { + Ok(results) => Some(results), + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); + None + } + } + }); + let (embeddings, expected_dimension) = results.unzip(); + let expected_dimension = expected_dimension.flatten(); + lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { + remove_vectors, + embeddings, + expected_dimension, + manual_vectors, + })) + } + Err(error) => lmdb_writer_sx_cloned.send(Err(error)), + }; + }); // TODO: create a custom internal error lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index f825cad1c..76848b628 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -4,11 +4,12 @@ mod helpers; mod transform; mod typed_chunk; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::io::{Cursor, Read, Seek}; use std::iter::FromIterator; use std::num::NonZeroU32; use std::result::Result as StdResult; +use std::sync::Arc; use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; @@ -32,10 +33,12 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; +use crate::prompt::Prompt; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, }; +use crate::vector::Embedder; use crate::{CboRoaringBitmapCodec, Index, Result}; static MERGED_DATABASE_COUNT: usize = 7; @@ -78,6 +81,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { should_abort: FA, added_documents: u64, deleted_documents: u64, + embedders: HashMap, Arc)>, } #[derive(Default, Debug, Clone)] @@ -121,6 +125,7 @@ where index, added_documents: 0, deleted_documents: 0, + embedders: Default::default(), }) } @@ -167,6 +172,14 @@ where Ok((self, Ok(indexed_documents))) } + pub fn with_embedders( + mut self, + embedders: HashMap, Arc)>, + ) -> Self { + self.embedders = embedders; + self + } + /// Remove a batch of documents from the current builder. /// /// Returns the number of documents deleted from the builder. @@ -322,17 +335,18 @@ where // get filterable fields for facet databases let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; // get the fid of the `_geo.lat` and `_geo.lng` fields. - let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") { + let mut field_id_map = self.index.fields_ids_map(self.wtxn)?; + + // self.index.fields_ids_map($a)? ==>> field_id_map + let geo_fields_ids = match field_id_map.id("_geo") { Some(gfid) => { let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); // if `_geo` is faceted then we get the `lat` and `lng` if is_sortable || is_filterable { - let field_ids = self - .index - .fields_ids_map(self.wtxn)? + let field_ids = field_id_map .insert("_geo.lat") - .zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng")) + .zip(field_id_map.insert("_geo.lng")) .ok_or(UserError::AttributeLimitReached)?; Some(field_ids) } else { @@ -341,8 +355,6 @@ where } None => None, }; - // get the fid of the `_vectors` field. - let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); let stop_words = self.index.stop_words(self.wtxn)?; let separators = self.index.allowed_separators(self.wtxn)?; @@ -364,6 +376,8 @@ where self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; + let cloned_embedder = self.embedders.clone(); + // Run extraction pipeline in parallel. pool.install(|| { puffin::profile_scope!("extract_and_send_grenad_chunks"); @@ -387,13 +401,14 @@ where faceted_fields, primary_key_id, geo_fields_ids, - vectors_field_id, + field_id_map, stop_words, separators.as_deref(), dictionary.as_deref(), max_positions_per_attributes, exact_attributes, proximity_precision, + cloned_embedder, ) }); @@ -2505,7 +2520,7 @@ mod tests { .unwrap(); let rtxn = index.read_txn().unwrap(); - let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap(); + let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); assert_eq!(res.documents_ids.len(), 3); } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 49e36b87e..36d230d00 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -47,7 +47,12 @@ pub(crate) enum TypedChunk { FieldIdFacetIsNullDocids(grenad::Reader>), FieldIdFacetIsEmptyDocids(grenad::Reader>), GeoPoints(grenad::Reader>), - VectorPoints(grenad::Reader>), + VectorPoints { + remove_vectors: grenad::Reader>, + embeddings: Option>>, + expected_dimension: Option, + manual_vectors: grenad::Reader>, + }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), } @@ -100,8 +105,8 @@ impl TypedChunk { TypedChunk::GeoPoints(grenad) => { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } - TypedChunk::VectorPoints(grenad) => { - format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) + TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension.unwrap_or_default()) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -355,19 +360,64 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } - TypedChunk::VectorPoints(vector_points) => { - let mut vectors_set = HashSet::new(); + TypedChunk::VectorPoints { + remove_vectors, + manual_vectors, + embeddings, + expected_dimension, + } => { + if remove_vectors.is_empty() + && manual_vectors.is_empty() + && embeddings.as_ref().map_or(true, |e| e.is_empty()) + { + return Ok((RoaringBitmap::new(), is_merged_database)); + } + + let mut docid_vectors_map: HashMap>>> = + HashMap::new(); + // We extract and store the previous vectors if let Some(hnsw) = index.vector_hnsw(wtxn)? { for (pid, point) in hnsw.iter() { let pid_key = pid.into_inner(); let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + docid_vectors_map.entry(docid).or_default().insert(vector); } } - let mut cursor = vector_points.into_cursor()?; + // remove vectors for docids we want them removed + let mut cursor = remove_vectors.into_cursor()?; + while let Some((key, _)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + + docid_vectors_map.remove(&docid); + } + + // add generated embeddings + if let Some((embeddings, expected_dimension)) = embeddings.zip(expected_dimension) { + let mut cursor = embeddings.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + let data: Vec> = + pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + // it is a code error to have embeddings and not expected_dimension + let embeddings = + crate::vector::Embeddings::from_inner(data, expected_dimension) + // code error if we somehow got the wrong dimension + .unwrap(); + + let mut set = HashSet::new(); + for embedding in embeddings.iter() { + set.insert(embedding.to_vec()); + } + + docid_vectors_map.insert(docid, set); + } + } + + // perform the manual diff + let mut cursor = manual_vectors.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let (left, _index) = try_split_array_at(key).unwrap(); @@ -376,23 +426,30 @@ pub(crate) fn write_typed_chunk_into_index( let vector_deladd_obkv = KvReaderDelAdd::new(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { // convert the vector back to a Vec - let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - let key = (docid, vector); - if !vectors_set.remove(&key) { - error!("Unable to delete the vector: {:?}", key.1); - } + let vector: Vec> = + pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + docid_vectors_map.entry(docid).and_modify(|v| { + if !v.remove(&vector) { + error!("Unable to delete the vector: {:?}", vector); + } + }); } if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { // convert the vector back to a Vec let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + docid_vectors_map.entry(docid).and_modify(|v| { + v.insert(vector); + }); } } // Extract the most common vector dimension let expected_dimension_size = { let mut dims = HashMap::new(); - vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); + docid_vectors_map + .values() + .flat_map(|v| v.iter()) + .for_each(|v| *dims.entry(v.len()).or_insert(0) += 1); dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) }; @@ -400,7 +457,10 @@ pub(crate) fn write_typed_chunk_into_index( // prepare the vectors before inserting them in the HNSW. let mut points = Vec::new(); let mut docids = Vec::new(); - for (docid, vector) in vectors_set { + for (docid, vector) in docid_vectors_map + .into_iter() + .flat_map(|(docid, vectors)| std::iter::repeat(docid).zip(vectors)) + { if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { return Err(UserError::InvalidVectorDimensions { expected: expected_dimension_size.unwrap_or(vector.len()), diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 712e595e9..5e3683f32 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -3,7 +3,7 @@ use std::result::Result as StdResult; use charabia::{Normalize, Tokenizer, TokenizerBuilder}; use deserr::{DeserializeError, Deserr}; -use itertools::Itertools; +use itertools::{EitherOrBoth, Itertools}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use time::OffsetDateTime; @@ -15,6 +15,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; +use crate::vector::settings::{EmbeddingSettings, PromptSettings}; +use crate::vector::EmbeddingConfig; use crate::{FieldsIdsMap, Index, OrderBy, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -73,6 +75,13 @@ impl Setting { otherwise => otherwise, } } + + pub fn apply(&mut self, new: Self) { + if let Setting::NotSet = new { + return; + } + *self = new; + } } impl Serialize for Setting { @@ -129,6 +138,7 @@ pub struct Settings<'a, 't, 'i> { sort_facet_values_by: Setting>, pagination_max_total_hits: Setting, proximity_precision: Setting, + embedder_settings: Setting>>, } impl<'a, 't, 'i> Settings<'a, 't, 'i> { @@ -161,6 +171,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { sort_facet_values_by: Setting::NotSet, pagination_max_total_hits: Setting::NotSet, proximity_precision: Setting::NotSet, + embedder_settings: Setting::NotSet, indexer_config, } } @@ -343,6 +354,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { self.proximity_precision = Setting::Reset; } + pub fn set_embedder_settings(&mut self, value: BTreeMap>) { + self.embedder_settings = Setting::Set(value); + } + + pub fn reset_embedder_settings(&mut self) { + self.embedder_settings = Setting::Reset; + } + fn reindex( &mut self, progress_callback: &FP, @@ -890,6 +909,60 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Ok(changed) } + fn update_embedding_configs(&mut self) -> Result { + let update = match std::mem::take(&mut self.embedder_settings) { + Setting::Set(configs) => { + let mut changed = false; + let old_configs = self.index.embedding_configs(self.wtxn)?; + let old_configs: BTreeMap> = + old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect(); + + let mut new_configs = BTreeMap::new(); + for joined in old_configs + .into_iter() + .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) + { + match joined { + EitherOrBoth::Both((name, mut old), (_, new)) => { + old.apply(new); + let new = validate_prompt(&name, old)?; + changed = true; + new_configs.insert(name, new); + } + EitherOrBoth::Left((name, setting)) => { + new_configs.insert(name, setting); + } + EitherOrBoth::Right((name, setting)) => { + let setting = validate_prompt(&name, setting)?; + changed = true; + new_configs.insert(name, setting); + } + } + } + let new_configs: Vec<(String, EmbeddingConfig)> = new_configs + .into_iter() + .filter_map(|(name, setting)| match setting { + Setting::Set(value) => Some((name, value.into())), + Setting::Reset => None, + Setting::NotSet => Some((name, EmbeddingSettings::default().into())), + }) + .collect(); + if new_configs.is_empty() { + self.index.delete_embedding_configs(self.wtxn)?; + } else { + self.index.put_embedding_configs(self.wtxn, new_configs)?; + } + changed + } + Setting::Reset => { + self.index.delete_embedding_configs(self.wtxn)?; + true + } + Setting::NotSet => false, + }; + Ok(update) + } + pub fn execute(mut self, progress_callback: FP, should_abort: FA) -> Result<()> where FP: Fn(UpdateIndexingStep) + Sync, @@ -927,6 +1000,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let searchable_updated = self.update_searchable()?; let exact_attributes_updated = self.update_exact_attributes()?; let proximity_precision = self.update_proximity_precision()?; + // TODO: very rough approximation of the needs for reindexing where any change will result in + // a full reindexing. + // What can be done instead: + // 1. Only change the distance on a distance change + // 2. Only change the name -> embedder mapping on a name change + // 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage + let embedding_configs_updated = self.update_embedding_configs()?; if stop_words_updated || non_separator_tokens_updated @@ -937,6 +1017,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { || searchable_updated || exact_attributes_updated || proximity_precision + || embedding_configs_updated { self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; } @@ -945,6 +1026,34 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { } } +fn validate_prompt( + name: &str, + new: Setting, +) -> Result> { + match new { + Setting::Set(EmbeddingSettings { + embedder_options, + prompt: + Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), + }) => { + // validate + let template = crate::prompt::Prompt::new(template, None, None) + .map(|prompt| crate::prompt::PromptData::from(prompt).template) + .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; + + Ok(Setting::Set(EmbeddingSettings { + embedder_options, + prompt: Setting::Set(PromptSettings { + template: Setting::Set(template), + strategy, + fallback, + }), + })) + } + new => Ok(new), + } +} + #[cfg(test)] mod tests { use big_s::S; @@ -1763,6 +1872,7 @@ mod tests { sort_facet_values_by, pagination_max_total_hits, proximity_precision, + embedder_settings, } = settings; assert!(matches!(searchable_fields, Setting::NotSet)); assert!(matches!(displayed_fields, Setting::NotSet)); @@ -1785,6 +1895,7 @@ mod tests { assert!(matches!(sort_facet_values_by, Setting::NotSet)); assert!(matches!(pagination_max_total_hits, Setting::NotSet)); assert!(matches!(proximity_precision, Setting::NotSet)); + assert!(matches!(embedder_settings, Setting::NotSet)); }) .unwrap(); } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs new file mode 100644 index 000000000..1ae7a4678 --- /dev/null +++ b/milli/src/vector/error.rs @@ -0,0 +1,229 @@ +use std::path::PathBuf; + +use hf_hub::api::sync::ApiError; + +use crate::error::FaultSource; +use crate::vector::openai::OpenAiError; + +#[derive(Debug, thiserror::Error)] +#[error("Error while generating embeddings: {inner}")] +pub struct Error { + pub inner: Box, +} + +impl> From for Error { + fn from(value: I) -> Self { + Self { inner: Box::new(value.into()) } + } +} + +impl Error { + pub fn fault(&self) -> FaultSource { + match &*self.inner { + ErrorKind::NewEmbedderError(inner) => inner.fault, + ErrorKind::EmbedError(inner) => inner.fault, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ErrorKind { + #[error(transparent)] + NewEmbedderError(#[from] NewEmbedderError), + #[error(transparent)] + EmbedError(#[from] EmbedError), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct EmbedError { + pub kind: EmbedErrorKind, + pub fault: FaultSource, +} + +#[derive(Debug, thiserror::Error)] +pub enum EmbedErrorKind { + #[error("could not tokenize: {0}")] + Tokenize(Box), + #[error("unexpected tensor shape: {0}")] + TensorShape(candle_core::Error), + #[error("unexpected tensor value: {0}")] + TensorValue(candle_core::Error), + #[error("could not run model: {0}")] + ModelForward(candle_core::Error), + #[error("could not reach OpenAI: {0}")] + OpenAiNetwork(reqwest::Error), + #[error("unexpected response from OpenAI: {0}")] + OpenAiUnexpected(reqwest::Error), + #[error("could not authenticate against OpenAI: {0}")] + OpenAiAuth(OpenAiError), + #[error("sent too many requests to OpenAI: {0}")] + OpenAiTooManyRequests(OpenAiError), + #[error("received internal error from OpenAI: {0}")] + OpenAiInternalServerError(OpenAiError), + #[error("sent too many tokens in a request to OpenAI: {0}")] + OpenAiTooManyTokens(OpenAiError), + #[error("received unhandled HTTP status code {0} from OpenAI")] + OpenAiUnhandledStatusCode(u16), +} + +impl EmbedError { + pub fn tokenize(inner: Box) -> Self { + Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime } + } + + pub fn tensor_shape(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug } + } + + pub fn tensor_value(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug } + } + + pub fn model_forward(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } + } + + pub fn openai_network(inner: reqwest::Error) -> Self { + Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } + } + + pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } + } + + pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } + } + + pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } + } + + pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewEmbedderError { + pub kind: NewEmbedderErrorKind, + pub fault: FaultSource, +} + +impl NewEmbedderError { + pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError { + let open_config = OpenConfig { filename: config_filename, inner }; + + Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime } + } + + pub fn deserialize_config( + config: String, + config_filename: PathBuf, + inner: serde_json::Error, + ) -> NewEmbedderError { + let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; + Self { + kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), + fault: FaultSource::Runtime, + } + } + + pub fn open_tokenizer( + tokenizer_filename: PathBuf, + inner: Box, + ) -> NewEmbedderError { + let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner }; + Self { + kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer), + fault: FaultSource::Runtime, + } + } + + pub fn new_api_fail(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug } + } + + pub fn api_get(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided } + } + + pub fn pytorch_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn safetensor_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn load_model(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } + } + + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { + Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } + } + + pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { + Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open config at {filename:?}: {inner}")] +pub struct OpenConfig { + pub filename: PathBuf, + pub inner: std::io::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] +pub struct DeserializeConfig { + pub config: String, + pub filename: PathBuf, + pub inner: serde_json::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open tokenizer at {filename}: {inner}")] +pub struct OpenTokenizer { + pub filename: PathBuf, + #[source] + pub inner: Box, +} + +#[derive(Debug, thiserror::Error)] +pub enum NewEmbedderErrorKind { + // hf + #[error(transparent)] + OpenConfig(OpenConfig), + #[error(transparent)] + DeserializeConfig(DeserializeConfig), + #[error(transparent)] + OpenTokenizer(OpenTokenizer), + #[error("could not build weights from Pytorch weights: {0}")] + PytorchWeight(candle_core::Error), + #[error("could not build weights from Safetensor weights: {0}")] + SafetensorWeight(candle_core::Error), + #[error("could not spawn HG_HUB API client: {0}")] + NewApiFail(ApiError), + #[error("fetching file from HG_HUB failed: {0}")] + ApiGet(ApiError), + #[error("loading model failed: {0}")] + LoadModel(candle_core::Error), + // openai + #[error("initializing web client for sending embedding requests failed: {0}")] + InitWebClient(reqwest::Error), + #[error("The API key passed to Authorization error was in an invalid format: {0}")] + InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), +} diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs new file mode 100644 index 000000000..81cdd4b34 --- /dev/null +++ b/milli/src/vector/hf.rs @@ -0,0 +1,192 @@ +use candle_core::Tensor; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself +use hf_hub::api::sync::Api; +use hf_hub::{Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +pub use super::error::{EmbedError, Error, NewEmbedderError}; +use super::{Embedding, Embeddings}; + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Deserialize, + serde::Serialize, + deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum WeightSource { + #[default] + Safetensors, + Pytorch, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub model: String, + pub revision: Option, + pub weight_source: WeightSource, + pub normalize_embeddings: bool, +} + +impl EmbedderOptions { + pub fn new() -> Self { + Self { + //model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), + model: "BAAI/bge-base-en-v1.5".to_string(), + //revision: Some("refs/pr/21".to_string()), + revision: None, + //weight_source: Default::default(), + weight_source: WeightSource::Pytorch, + normalize_embeddings: true, + } + } +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::new() + } +} + +/// Perform embedding of documents and queries +pub struct Embedder { + model: BertModel, + tokenizer: Tokenizer, + options: EmbedderOptions, +} + +impl std::fmt::Debug for Embedder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Embedder") + .field("model", &self.options.model) + .field("tokenizer", &self.tokenizer) + .field("options", &self.options) + .finish() + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + let device = candle_core::Device::Cpu; + let repo = match options.revision.clone() { + Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), + None => Repo::model(options.model.clone()), + }; + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; + let api = api.repo(repo); + let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; + let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; + let weights = match options.weight_source { + WeightSource::Pytorch => { + api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? + } + WeightSource::Safetensors => { + api.get("model.safetensors").map_err(NewEmbedderError::api_get)? + } + }; + (config, tokenizer, weights) + }; + + let config = std::fs::read_to_string(&config_filename) + .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; + let config: Config = serde_json::from_str(&config).map_err(|inner| { + NewEmbedderError::deserialize_config(config, config_filename, inner) + })?; + let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) + .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; + + let vb = match options.weight_source { + WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) + .map_err(NewEmbedderError::pytorch_weight)?, + WeightSource::Safetensors => unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) + .map_err(NewEmbedderError::safetensor_weight)? + }, + }; + + let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; + + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + + Ok(Self { model, tokenizer, options }) + } + + pub async fn embed( + &self, + mut texts: Vec, + ) -> std::result::Result>, EmbedError> { + let tokens = match texts.len() { + 1 => vec![self + .tokenizer + .encode(texts.pop().unwrap(), true) + .map_err(EmbedError::tokenize)?], + _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, + }; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) + }) + .collect::, EmbedError>>()?; + + let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + let embeddings = + self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape)?; + + let embeddings: Tensor = if self.options.normalize_embeddings { + normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? + } else { + embeddings + }; + + let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; + Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) + .await + } + + pub fn chunk_count_hint(&self) -> usize { + 1 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) + } +} + +fn normalize_l2(v: &Tensor) -> Result { + v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) +} diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs new file mode 100644 index 000000000..faaa7bf2a --- /dev/null +++ b/milli/src/vector/mod.rs @@ -0,0 +1,142 @@ +use self::error::{EmbedError, NewEmbedderError}; +use crate::prompt::PromptData; + +pub mod error; +pub mod hf; +pub mod openai; +pub mod settings; + +pub use self::error::Error; + +pub type Embedding = Vec; + +pub struct Embeddings { + data: Vec, + dimension: usize, +} + +impl Embeddings { + pub fn new(dimension: usize) -> Self { + Self { data: Default::default(), dimension } + } + + pub fn from_single_embedding(embedding: Vec) -> Self { + Self { dimension: embedding.len(), data: embedding } + } + + pub fn from_inner(data: Vec, dimension: usize) -> Result> { + let mut this = Self::new(dimension); + this.append(data)?; + Ok(this) + } + + pub fn dimension(&self) -> usize { + self.dimension + } + + pub fn into_inner(self) -> Vec { + self.data + } + + pub fn as_inner(&self) -> &[F] { + &self.data + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.data.as_slice().chunks_exact(self.dimension) + } + + pub fn push(&mut self, mut embedding: Vec) -> Result<(), Vec> { + if embedding.len() != self.dimension { + return Err(embedding); + } + self.data.append(&mut embedding); + Ok(()) + } + + pub fn append(&mut self, mut embeddings: Vec) -> Result<(), Vec> { + if embeddings.len() % self.dimension != 0 { + return Err(embeddings); + } + self.data.append(&mut embeddings); + Ok(()) + } +} + +#[derive(Debug)] +pub enum Embedder { + HuggingFace(hf::Embedder), + OpenAi(openai::Embedder), +} + +#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] +pub struct EmbeddingConfig { + pub embedder_options: EmbedderOptions, + pub prompt: PromptData, + // TODO: add metrics and anything needed +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub enum EmbedderOptions { + HuggingFace(hf::EmbedderOptions), + OpenAi(openai::EmbedderOptions), +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::HuggingFace(Default::default()) + } +} + +impl EmbedderOptions { + pub fn huggingface() -> Self { + Self::HuggingFace(hf::EmbedderOptions::new()) + } + + pub fn openai(api_key: String) -> Self { + Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + Ok(match options { + EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), + EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + }) + } + + pub async fn embed( + &self, + texts: Vec, + ) -> std::result::Result>, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed(texts).await, + Embedder::OpenAi(embedder) => embedder.embed(texts).await, + } + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + } + } + + pub fn chunk_count_hint(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), + Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), + } + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), + } + } +} diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs new file mode 100644 index 000000000..670dc8526 --- /dev/null +++ b/milli/src/vector/openai.rs @@ -0,0 +1,416 @@ +use std::fmt::Display; + +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; + +use super::error::{EmbedError, NewEmbedderError}; +use super::{Embedding, Embeddings}; + +#[derive(Debug)] +pub struct Embedder { + client: reqwest::Client, + tokenizer: tiktoken_rs::CoreBPE, + options: EmbedderOptions, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub api_key: String, + pub embedding_model: EmbeddingModel, +} + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Serialize, + serde::Deserialize, + deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum EmbeddingModel { + #[default] + TextEmbeddingAda002, +} + +impl EmbeddingModel { + pub fn max_token(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 8191, + } + } + + pub fn dimensions(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 1536, + } + } + + pub fn name(&self) -> &'static str { + match self { + EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", + } + } + + pub fn from_name(name: &'static str) -> Option { + match name { + "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), + _ => None, + } + } +} + +pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; + +impl EmbedderOptions { + pub fn with_default_model(api_key: String) -> Self { + Self { api_key, embedding_model: Default::default() } + } + + pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self { + Self { api_key, embedding_model } + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key)) + .map_err(NewEmbedderError::openai_invalid_api_key_format)?, + ); + headers.insert( + reqwest::header::CONTENT_TYPE, + reqwest::header::HeaderValue::from_static("application/json"), + ); + let client = reqwest::ClientBuilder::new() + .default_headers(headers) + .build() + .map_err(NewEmbedderError::openai_initialize_web_client)?; + + // looking at the code it is very unclear that this can actually fail. + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + + Ok(Self { options, client, tokenizer }) + } + + pub async fn embed(&self, texts: Vec) -> Result>, EmbedError> { + let mut tokenized = false; + + for attempt in 0..7 { + let result = if tokenized { + self.try_embed_tokenized(&texts).await + } else { + self.try_embed(&texts).await + }; + + let retry_duration = match result { + Ok(embeddings) => return Ok(embeddings), + Err(retry) => { + log::warn!("Failed: {}", retry.error); + tokenized |= retry.must_tokenize(); + retry.into_duration(attempt) + } + }?; + log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); + tokio::time::sleep(retry_duration).await; + } + + let result = if tokenized { + self.try_embed_tokenized(&texts).await + } else { + self.try_embed(&texts).await + }; + + result.map_err(Retry::into_error) + } + + async fn check_response(response: reqwest::Response) -> Result { + if !response.status().is_success() { + match response.status() { + StatusCode::UNAUTHORIZED => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::give_up(EmbedError::openai_auth_error( + error_response.error, + ))); + } + StatusCode::TOO_MANY_REQUESTS => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( + error_response.error, + ))); + } + StatusCode::INTERNAL_SERVER_ERROR => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::openai_internal_server_error( + error_response.error, + ))); + } + StatusCode::SERVICE_UNAVAILABLE => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::openai_internal_server_error( + error_response.error, + ))); + } + StatusCode::BAD_REQUEST => { + // Most probably, one text contained too many tokens + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt."); + + return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( + error_response.error, + ))); + } + code => { + return Err(Retry::give_up(EmbedError::openai_unhandled_status_code( + code.as_u16(), + ))); + } + } + } + Ok(response) + } + + async fn try_embed + serde::Serialize>( + &self, + texts: &[S], + ) -> Result>, Retry> { + for text in texts { + log::trace!("Received prompt: {}", text.as_ref()) + } + let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; + let response = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let response: OpenAiResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + log::trace!("response: {:?}", response.data); + + Ok(response + .data + .into_iter() + .map(|data| Embeddings::from_single_embedding(data.embedding)) + .collect()) + } + + async fn try_embed_tokenized(&self, text: &[String]) -> Result>, Retry> { + pub const OVERLAP_SIZE: usize = 200; + let mut all_embeddings = Vec::with_capacity(text.len()); + for text in text { + let max_token_count = self.options.embedding_model.max_token(); + let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let len = encoded.len(); + if len < max_token_count { + all_embeddings.append(&mut self.try_embed(&[text]).await?); + continue; + } + + let mut tokens = encoded.as_slice(); + let mut embeddings_for_prompt = + Embeddings::new(self.options.embedding_model.dimensions()); + while tokens.len() > max_token_count { + let window = &tokens[..max_token_count]; + embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); + + tokens = &tokens[max_token_count - OVERLAP_SIZE..]; + } + + // end of text + embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); + + all_embeddings.push(embeddings_for_prompt); + } + Ok(all_embeddings) + } + + async fn embed_tokens(&self, tokens: &[usize]) -> Result { + for attempt in 0..9 { + let duration = match self.try_embed_tokens(tokens).await { + Ok(embedding) => return Ok(embedding), + Err(retry) => retry.into_duration(attempt), + } + .map_err(Retry::retry_later)?; + + tokio::time::sleep(duration).await; + } + + self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) + } + + async fn try_embed_tokens(&self, tokens: &[usize]) -> Result { + let request = + OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; + let response = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let mut response: OpenAiResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) + .await + } + + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } +} + +// retrying in case of failure + +struct Retry { + error: EmbedError, + strategy: RetryStrategy, +} + +enum RetryStrategy { + GiveUp, + Retry, + RetryTokenized, + RetryAfterRateLimit, +} + +impl Retry { + fn give_up(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + fn retry_later(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + fn retry_tokenized(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryTokenized } + } + + fn rate_limited(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), + RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), + RetryStrategy::RetryAfterRateLimit => { + Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) + } + } + } + + fn must_tokenize(&self) -> bool { + matches!(self.strategy, RetryStrategy::RetryTokenized) + } + + fn into_error(self) -> EmbedError { + self.error + } +} + +// openai api structs + +#[derive(Debug, Serialize)] +struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { + model: &'a str, + input: &'a [S], +} + +#[derive(Debug, Serialize)] +struct OpenAiTokensRequest<'a> { + model: &'a str, + input: &'a [usize], +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiErrorResponse { + error: OpenAiError, +} + +#[derive(Debug, Deserialize)] +pub struct OpenAiError { + message: String, + // type: String, + code: Option, +} + +impl Display for OpenAiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.code { + Some(code) => write!(f, "{} ({})", self.message, code), + None => write!(f, "{}", self.message), + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAiEmbedding { + embedding: Embedding, + // object: String, + // index: usize, +} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs new file mode 100644 index 000000000..2c0cf7924 --- /dev/null +++ b/milli/src/vector/settings.rs @@ -0,0 +1,308 @@ +use deserr::Deserr; +use serde::{Deserialize, Serialize}; + +use crate::prompt::{PromptData, PromptFallbackStrategy}; +use crate::update::Setting; +use crate::vector::hf::WeightSource; +use crate::vector::EmbeddingConfig; + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct EmbeddingSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] + #[deserr(default, rename = "source")] + pub embedder_options: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub prompt: Setting, +} + +impl EmbeddingSettings { + pub fn apply(&mut self, new: Self) { + let EmbeddingSettings { embedder_options, prompt } = new; + self.embedder_options.apply(embedder_options); + self.prompt.apply(prompt); + } +} + +impl From for EmbeddingSettings { + fn from(value: EmbeddingConfig) -> Self { + Self { + embedder_options: Setting::Set(value.embedder_options.into()), + prompt: Setting::Set(value.prompt.into()), + } + } +} + +impl From for EmbeddingConfig { + fn from(value: EmbeddingSettings) -> Self { + let mut this = Self::default(); + let EmbeddingSettings { embedder_options, prompt } = value; + if let Some(embedder_options) = embedder_options.set() { + this.embedder_options = embedder_options.into(); + } + if let Some(prompt) = prompt.set() { + this.prompt = prompt.into(); + } + this + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct PromptSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub template: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub strategy: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub fallback: Setting, +} + +impl PromptSettings { + pub fn apply(&mut self, new: Self) { + let PromptSettings { template, strategy, fallback } = new; + self.template.apply(template); + self.strategy.apply(strategy); + self.fallback.apply(fallback); + } +} + +impl From for PromptSettings { + fn from(value: PromptData) -> Self { + Self { + template: Setting::Set(value.template), + strategy: Setting::Set(value.strategy), + fallback: Setting::Set(value.fallback), + } + } +} + +impl From for PromptData { + fn from(value: PromptSettings) -> Self { + let mut this = PromptData::default(); + let PromptSettings { template, strategy, fallback } = value; + if let Some(template) = template.set() { + this.template = template; + } + if let Some(strategy) = strategy.set() { + this.strategy = strategy; + } + if let Some(fallback) = fallback.set() { + this.fallback = fallback; + } + this + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +pub enum EmbedderSettings { + HuggingFace(Setting), + OpenAi(Setting), +} + +impl Deserr for EmbedderSettings +where + E: deserr::DeserializeError, +{ + fn deserialize_from_value( + value: deserr::Value, + location: deserr::ValuePointerRef, + ) -> Result { + match value { + deserr::Value::Map(map) => { + if deserr::Map::len(&map) != 1 { + return Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::Unexpected { + msg: format!( + "Expected a single field, got {} fields", + deserr::Map::len(&map) + ), + }, + location, + ))); + } + let mut it = deserr::Map::into_iter(map); + let (k, v) = it.next().unwrap(); + + match k.as_str() { + "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( + HfEmbedderSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + ))), + "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( + OpenAiEmbedderSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + ))), + other => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::UnknownKey { + key: other, + accepted: &["huggingFace", "openAi"], + }, + location, + ))), + } + } + _ => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: value, + accepted: &[deserr::ValueKind::Map], + }, + location, + ))), + } + } +} + +impl Default for EmbedderSettings { + fn default() -> Self { + Self::HuggingFace(Default::default()) + } +} + +impl From for EmbedderSettings { + fn from(value: crate::vector::EmbedderOptions) -> Self { + match value { + crate::vector::EmbedderOptions::HuggingFace(hf) => { + Self::HuggingFace(Setting::Set(hf.into())) + } + crate::vector::EmbedderOptions::OpenAi(openai) => { + Self::OpenAi(Setting::Set(openai.into())) + } + } + } +} + +impl From for crate::vector::EmbedderOptions { + fn from(value: EmbedderSettings) -> Self { + match value { + EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), + EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), + EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), + EmbedderSettings::OpenAi(_setting) => Self::OpenAi( + crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()), + ), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct HfEmbedderSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub model: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub revision: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub weight_source: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub normalize_embeddings: Setting, +} + +impl HfEmbedderSettings { + pub fn apply(&mut self, new: Self) { + let HfEmbedderSettings { + model, + revision, + weight_source, + normalize_embeddings: normalize_embedding, + } = new; + self.model.apply(model); + self.revision.apply(revision); + self.weight_source.apply(weight_source); + self.normalize_embeddings.apply(normalize_embedding); + } +} + +impl From for HfEmbedderSettings { + fn from(value: crate::vector::hf::EmbedderOptions) -> Self { + Self { + model: Setting::Set(value.model), + revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), + weight_source: Setting::Set(value.weight_source), + normalize_embeddings: Setting::Set(value.normalize_embeddings), + } + } +} + +impl From for crate::vector::hf::EmbedderOptions { + fn from(value: HfEmbedderSettings) -> Self { + let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; + let mut this = Self::default(); + if let Some(model) = model.set() { + this.model = model; + } + if let Some(revision) = revision.set() { + this.revision = Some(revision); + } + if let Some(weight_source) = weight_source.set() { + this.weight_source = weight_source; + } + if let Some(normalize_embeddings) = normalize_embeddings.set() { + this.normalize_embeddings = normalize_embeddings; + } + this + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct OpenAiEmbedderSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub api_key: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub embedding_model: Setting, +} + +impl OpenAiEmbedderSettings { + pub fn apply(&mut self, new: Self) { + let Self { api_key, embedding_model: embedding_mode } = new; + self.api_key.apply(api_key); + self.embedding_model.apply(embedding_mode); + } +} + +impl From for OpenAiEmbedderSettings { + fn from(value: crate::vector::openai::EmbedderOptions) -> Self { + Self { + api_key: Setting::Set(value.api_key), + embedding_model: Setting::Set(value.embedding_model), + } + } +} + +impl From for crate::vector::openai::EmbedderOptions { + fn from(value: OpenAiEmbedderSettings) -> Self { + let OpenAiEmbedderSettings { api_key, embedding_model } = value; + Self { + api_key: api_key.set().unwrap_or_else(infer_api_key), + embedding_model: embedding_model.set().unwrap_or_default(), + } + } +} + +fn infer_api_key() -> String { + /// FIXME: get key from instance options? + std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default() +}