From 65e49b7092475d11afc97152395190cdd3e954e9 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 10:05:06 +0100 Subject: [PATCH] Remove stuff, add distribution shift (WIP) --- Cargo.lock | 219 ++++-------------- meilisearch/src/search.rs | 45 +--- milli/Cargo.toml | 16 +- milli/src/distance.rs | 41 ---- milli/src/index.rs | 4 - milli/src/lib.rs | 2 - milli/src/search/new/mod.rs | 13 +- milli/src/search/new/vector_sort.rs | 16 +- .../src/update/index_documents/typed_chunk.rs | 4 +- milli/src/vector/mod.rs | 44 ++++ 10 files changed, 126 insertions(+), 278 deletions(-) delete mode 100644 milli/src/distance.rs diff --git a/Cargo.lock b/Cargo.lock index fba78b3b6..3c2f38840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,7 +56,7 @@ dependencies = [ "flate2", "futures-core", "h2", - "http", + "http 0.2.9", "httparse", "httpdate", "itoa", @@ -90,7 +90,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66ff4d247d2b160861fa2866457e85706833527840e4133f8f49aa423a38799" dependencies = [ "bytestring", - "http", + "http 0.2.9", "regex", "serde", "tracing", @@ -189,7 +189,7 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "http", + "http 0.2.9", "itoa", "language-tags", "log", @@ -1407,7 +1407,7 @@ dependencies = [ "anyhow", "big_s", "flate2", - "http", + "http 0.2.9", "log", "maplit", "meili-snap", @@ -1702,21 +1702,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.2.0" @@ -2047,7 +2032,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 1.9.3", "slab", "tokio", @@ -2171,13 +2156,12 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hf-hub" version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" +source = "git+https://github.com/dureuill/hf-hub.git?branch=rust_tls#88d4f11cb9fa079f2912bacb96f5080b16825ce8" dependencies = [ "dirs", + "http 1.0.0", "indicatif", "log", - "native-tls", "rand", "serde", "serde_json", @@ -2205,6 +2189,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -2212,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -2245,7 +2240,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -2265,7 +2260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" dependencies = [ "futures-util", - "http", + "http 0.2.9", "hyper", "rustls 0.21.6", "tokio", @@ -2868,21 +2863,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "instant-distance" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a" -dependencies = [ - "num_cpus", - "ordered-float 3.7.0", - "parking_lot", - "rand", - "rayon", - "serde", - "serde-big-array", -] - [[package]] name = "io-lifetimes" version = "1.0.11" @@ -3531,7 +3511,7 @@ dependencies = [ "futures", "futures-util", "hex", - "http", + "http 0.2.9", "index-scheduler", "indexmap 2.0.0", "insta", @@ -3718,7 +3698,6 @@ dependencies = [ "hf-hub", "indexmap 2.0.0", "insta", - "instant-distance", "itertools 0.11.0", "json-depth-checker", "levenshtein_automata", @@ -3730,7 +3709,6 @@ dependencies = [ "meili-snap", "memmap2 0.7.1", "mimalloc", - "nolife", "obkv", "once_cell", "ordered-float 3.7.0", @@ -3829,35 +3807,11 @@ dependencies = [ "syn 2.0.28", ] -[[package]] -name = "native-tls" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nelson" version = "0.1.0" source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a" -[[package]] -name = "nolife" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc" - [[package]] name = "nom" version = "7.1.3" @@ -3994,50 +3948,6 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" -[[package]] -name = "openssl" -version = "0.10.59" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a257ad03cd8fb16ad4172fedf8094451e1af1c4b70097636ef2eac9a5f0cc33" -dependencies = [ - "bitflags 2.4.1", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", -] - -[[package]] -name = "openssl-probe" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" - -[[package]] -name = "openssl-sys" -version = "0.9.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "option-ext" version = "0.2.0" @@ -4655,7 +4565,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-rustls", @@ -4802,7 +4712,7 @@ checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" dependencies = [ "log", "ring", - "rustls-webpki 0.101.3", + "rustls-webpki", "sct", ] @@ -4815,16 +4725,6 @@ dependencies = [ "base64 0.21.5", ] -[[package]] -name = "rustls-webpki" -version = "0.100.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "rustls-webpki" version = "0.101.3" @@ -4866,15 +4766,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" -dependencies = [ - "windows-sys 0.48.0", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -4891,29 +4782,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "security-framework" -version = "2.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "segment" version = "0.2.2" @@ -4949,15 +4817,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-big-array" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" -dependencies = [ - "serde", -] - [[package]] name = "serde-cs" version = "0.2.4" @@ -5151,6 +5010,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.5.2" @@ -5713,21 +5583,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "ureq" -version = "2.7.1" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97" dependencies = [ "base64 0.21.5", "flate2", "log", - "native-tls", "once_cell", "rustls 0.21.6", - "rustls-webpki 0.100.2", + "rustls-webpki", "serde", "serde_json", + "socks", "url", - "webpki-roots 0.23.1", + "webpki-roots 0.25.3", ] [[package]] @@ -5958,15 +5828,6 @@ dependencies = [ "webpki", ] -[[package]] -name = "webpki-roots" -version = "0.23.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" -dependencies = [ - "rustls-webpki 0.100.2", -] - [[package]] name = "webpki-roots" version = "0.25.3" diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 235b745a9..9136157f9 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -14,18 +14,14 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; -use meilisearch_types::milli::{ - dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, - VectorQuery, -}; +use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, + SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; -use ordered_float::OrderedFloat; use regex::Regex; use serde::Serialize; use serde_json::{json, Value}; @@ -550,13 +546,8 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - let semantic_score = /*match query.vector.as_ref() { - Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { - Some(vectors) => compute_semantic_score(vector, vectors)?, - None => None, - }, - None => None, - };*/ None; + /// FIXME: remove this or set to value from the score details + let semantic_score = None; let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -689,18 +680,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result> { - let vectors = serde_json::from_value(vectors) - .map(VectorOrArrayOfVectors::into_array_of_vectors) - .map_err(InternalError::SerdeJson)?; - Ok(vectors - .into_iter() - .flatten() - .map(|v| OrderedFloat(dot_product_similarity(query, &v))) - .max() - .map(OrderedFloat::into_inner)) -} - fn compute_formatted_options( attr_to_highlight: &HashSet, attr_to_crop: &[String], @@ -828,22 +807,6 @@ fn make_document( Ok(document) } -/// Extract the JSON value under the field name specified -/// but doesn't support nested objects. -fn extract_field( - field_name: &str, - field_ids_map: &FieldsIdsMap, - obkv: obkv::KvReaderU16, -) -> Result, MeilisearchHttpError> { - match field_ids_map.id(field_name) { - Some(fid) => match obkv.get(fid) { - Some(value) => Ok(serde_json::from_slice(value).map(Some)?), - None => Ok(None), - }, - None => Ok(None), - } -} - fn format_fields<'a>( document: &Document, field_ids_map: &FieldsIdsMap, diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 0aee03b2f..b977d64f1 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -36,7 +36,6 @@ heed = { version = "0.20.0-alpha.9", default-features = false, features = [ "read-txn-no-tls", ] } indexmap = { version = "2.0.0", features = ["serde"] } -instant-distance = { version = "0.6.1", features = ["with-serde"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.7.1" @@ -79,10 +78,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", version = "0. candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } -hf-hub = "0.3.2" +hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ + "online", +] } tokio = { version = "1.34.0", features = ["rt"] } futures = "0.3.29" -nolife = { version = "0.3.1" } reqwest = { version = "0.11.16", features = [ "rustls-tls", "json", @@ -102,7 +102,15 @@ meili-snap = { path = "../meili-snap" } rand = { version = "0.8.5", features = ["small_rng"] } [features] -all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] +all-tokenizations = [ + "charabia/chinese", + "charabia/hebrew", + "charabia/japanese", + "charabia/thai", + "charabia/korean", + "charabia/greek", + "charabia/khmer", +] # Use POSIX semaphores instead of SysV semaphores in LMDB # For more information on this feature, see heed's Cargo.toml diff --git a/milli/src/distance.rs b/milli/src/distance.rs deleted file mode 100644 index e9e17e647..000000000 --- a/milli/src/distance.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::ops; - -use instant_distance::Point; -use serde::{Deserialize, Serialize}; - -use crate::normalize_vector; - -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct NDotProductPoint(Vec); - -impl NDotProductPoint { - pub fn new(point: Vec) -> Self { - NDotProductPoint(normalize_vector(point)) - } - - pub fn into_inner(self) -> Vec { - self.0 - } -} - -impl ops::Deref for NDotProductPoint { - type Target = [f32]; - - fn deref(&self) -> &Self::Target { - self.0.as_slice() - } -} - -impl Point for NDotProductPoint { - fn distance(&self, other: &Self) -> f32 { - let dist = 1.0 - dot_product_similarity(&self.0, &other.0); - debug_assert!(!dist.is_nan()); - dist - } -} - -/// Returns the dot product similarity score that will between 0.0 and 1.0 -/// if both vectors are normalized. The higher the more similar the vectors are. -pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b).map(|(a, b)| a * b).sum() -} diff --git a/milli/src/index.rs b/milli/src/index.rs index c5e190d38..05babf410 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -10,7 +10,6 @@ use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::distance::NDotProductPoint; use crate::documents::PrimaryKey; use crate::error::{InternalError, UserError}; use crate::fields_ids_map::FieldsIdsMap; @@ -30,9 +29,6 @@ use crate::{ BEU32, BEU64, }; -/// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = instant_distance::Hnsw; - pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; diff --git a/milli/src/lib.rs b/milli/src/lib.rs index b865747e0..ce37fe375 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,7 +10,6 @@ pub mod documents; mod asc_desc; mod criterion; -pub mod distance; mod error; mod external_documents_ids; pub mod facet; @@ -33,7 +32,6 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; -pub use distance::dot_product_similarity; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index ad5c59f99..bba6cf119 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -50,6 +50,7 @@ use self::vector_sort::VectorSort; use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; +use crate::vector::DistributionShift; use crate::{ AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, }; @@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>( geo_strategy: geo_sort::Strategy, limit_plus_offset: usize, target: &[f32], + distribution_shift: Option, ) -> Result>> { // query graph search @@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>( target.to_vec(), vector_candidates, limit_plus_offset, + distribution_shift, )?; ranking_rules.push(Box::new(vector_sort)); vector = true; @@ -515,8 +518,14 @@ pub fn execute_vector_search( /// FIXME: input universe = universe & documents_with_vectors // for now if we're computing embeddings for ALL documents, we can assume that this is just universe - let ranking_rules = - get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?; + let ranking_rules = get_ranking_rules_for_vector( + ctx, + sort_criteria, + geo_strategy, + from + length, + vector, + None, + )?; let mut placeholder_search_logger = logger::DefaultSearchLogger; let placeholder_search_logger: &mut dyn SearchLogger = diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 9bf13c631..2d7cdbe39 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -5,6 +5,7 @@ use roaring::RoaringBitmap; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; use crate::score_details::{self, ScoreDetails}; +use crate::vector::DistributionShift; use crate::{DocumentId, Result, SearchContext, SearchLogger}; pub struct VectorSort { @@ -13,6 +14,7 @@ pub struct VectorSort { vector_candidates: RoaringBitmap, cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec)>, limit: usize, + distribution_shift: Option, } impl VectorSort { @@ -21,6 +23,7 @@ impl VectorSort { target: Vec, vector_candidates: RoaringBitmap, limit: usize, + distribution_shift: Option, ) -> Result { Ok(Self { query: None, @@ -28,6 +31,7 @@ impl VectorSort { vector_candidates, cached_sorted_docids: Default::default(), limit, + distribution_shift, }) } @@ -52,7 +56,7 @@ impl VectorSort { for reader in readers.iter() { let nns_by_vector = reader.nns_by_vector( ctx.txn, - &target, + target, self.limit, None, Some(&self.vector_candidates), @@ -66,6 +70,7 @@ impl VectorSort { } results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance)); self.cached_sorted_docids = results.into_iter(); + Ok(()) } } @@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { })); } - while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() { + for (docid, distance, vector) in self.cached_sorted_docids.by_ref() { if self.vector_candidates.contains(docid) { + let score = 1.0 - distance; + let score = self + .distribution_shift + .map(|distribution| distribution.shift(score)) + .unwrap_or(score); return Ok(Some(RankingRuleOutput { query, candidates: RoaringBitmap::from_iter([docid]), score: ScoreDetails::Vector(score_details::Vector { target_vector: self.target.clone(), - value_similarity: Some((vector, 1.0 - distance)), + value_similarity: Some((vector, score)), }), })); } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 84b17dca9..da99ed685 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -415,7 +415,7 @@ pub(crate) fn write_typed_chunk_into_index( let mut deleted_index = None; for (index, writer) in writers.iter().enumerate() { - let Some(candidate) = writer.item_vector(&wtxn, docid)? else { + let Some(candidate) = writer.item_vector(wtxn, docid)? else { // uses invariant: vectors are packed in the first writers. break; }; @@ -429,7 +429,7 @@ pub(crate) fn write_typed_chunk_into_index( if let Some(deleted_index) = deleted_index { let mut last_index_with_a_vector = None; for (index, writer) in writers.iter().enumerate().skip(deleted_index) { - let Some(candidate) = writer.item_vector(&wtxn, docid)? else { + let Some(candidate) = writer.item_vector(wtxn, docid)? else { break; }; last_index_with_a_vector = Some((index, candidate)); diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index faaa7bf2a..91640b8fb 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -140,3 +140,47 @@ impl Embedder { } } } + +#[derive(Debug, Clone, Copy)] +pub struct DistributionShift { + pub current_mean: f32, + pub current_sigma: f32, +} + +impl DistributionShift { + /// `None` if sigma <= 0. + pub fn new(mean: f32, sigma: f32) -> Option { + if sigma <= 0.0 { + None + } else { + Some(Self { current_mean: mean, current_sigma: sigma }) + } + } + + pub fn shift(&self, score: f32) -> f32 { + // + // We're somewhat abusively mapping the distribution of distances to a gaussian. + // The parameters we're given is the mean and sigma of the native result distribution. + // We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4. + + let target_mean = 0.5; + let target_sigma = 0.4; + + // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. + let factor = target_sigma / self.current_sigma; + // a*mu1 + b = mu2 => b = mu2 - a*mu1 + let offset = target_mean - (factor * self.current_mean); + + let mut score = factor * score + offset; + + // clamp the final score in the ]0, 1] interval. + if score <= 0.0 { + score = f32::EPSILON; + } + if score > 1.0 { + score = 1.0; + } + + score + } +}