diff --git a/Cargo.lock b/Cargo.lock index 904d1c225..ccf79f9a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2595,6 +2595,7 @@ dependencies = [ "num_cpus", "obkv", "once_cell", + "ordered-float", "parking_lot", "permissive-json-pointer", "pin-project-lite", diff --git a/meilisearch/Cargo.toml b/meilisearch/Cargo.toml index 8fcd69591..d90dd24dd 100644 --- a/meilisearch/Cargo.toml +++ b/meilisearch/Cargo.toml @@ -48,6 +48,7 @@ mime = "0.3.17" num_cpus = "1.15.0" obkv = "0.2.0" once_cell = "1.17.1" +ordered-float = "3.7.0" parking_lot = "0.12.1" permissive-json-pointer = { path = "../permissive-json-pointer" } pin-project-lite = "0.2.9" diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index a85c0a437..c0d707657 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -10,6 +10,7 @@ use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::dot_product_similarity; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -18,6 +19,7 @@ use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; +use ordered_float::OrderedFloat; use regex::Regex; use serde::Serialize; use serde_json::{json, Value}; @@ -457,6 +459,10 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } + if let Some(vector) = query.vector.as_ref() { + insert_semantic_similarity(&vector, &mut document); + } + let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); let ranking_score_details = @@ -542,6 +548,22 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } +fn insert_semantic_similarity(query: &[f32], document: &mut Document) { + if let Some(value) = document.get("_vectors") { + let vectors: Vec> = match serde_json::from_value(value.clone()) { + Ok(Either::Left(vector)) => vec![vector], + Ok(Either::Right(vectors)) => vectors, + Err(_) => return, + }; + let similarity = vectors + .into_iter() + .map(|v| OrderedFloat(dot_product_similarity(query, &v))) + .max() + .map(OrderedFloat::into_inner); + document.insert("_semanticSimilarity".to_string(), json!(similarity)); + } +} + fn compute_formatted_options( attr_to_highlight: &HashSet, attr_to_crop: &[String], diff --git a/milli/src/distance.rs b/milli/src/distance.rs index c26a745a4..1b91b4654 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -12,13 +12,18 @@ impl Metric> for DotProduct { // // Following . fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum(); - let dist = 1.0 - dist; + let dist = 1.0 - dot_product_similarity(a, b); debug_assert!(!dist.is_nan()); dist.to_bits() } } +/// Returns the dot product similarity score that will between 0.0 and 1.0 +/// if both vectors are normalized. The higher the more similar the vectors are. +pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| a * b).sum() +} + #[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] pub struct Euclidean; @@ -26,9 +31,14 @@ impl Metric> for Euclidean { type Unit = u32; fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum(); - let dist = squared.sqrt(); + let dist = euclidean_squared_distance(a, b).sqrt(); debug_assert!(!dist.is_nan()); dist.to_bits() } } + +/// Return the squared euclidean distance between both vectors that will +/// between 0.0 and +inf. The smaller the nearer the vectors are. +pub fn euclidean_squared_distance(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum() +} diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 04c81039a..c93bf88ff 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -31,6 +31,7 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; +pub use distance::{dot_product_similarity, euclidean_squared_distance}; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType;