diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 81bcb6aaa..f1fb341a2 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -295,6 +295,10 @@ pub fn perform_search( let mut search = index.search(&rtxn); + if let Some(ref vector) = query.vector { + search.vector(vector.clone()); + } + if let Some(ref query) = query.q { search.query(query); } diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 87c9a004d..82de56434 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -52,6 +52,7 @@ fn main() -> Result<(), Box> { let docs = execute_search( &mut ctx, &(!query.trim().is_empty()).then(|| query.trim().to_owned()), + &None, TermsMatchingStrategy::Last, milli::score_details::ScoringStrategy::Skip, false, diff --git a/milli/src/dot_product.rs b/milli/src/dot_product.rs index 2f5f1e474..86dd2f1d4 100644 --- a/milli/src/dot_product.rs +++ b/milli/src/dot_product.rs @@ -7,9 +7,13 @@ pub struct DotProduct; impl Metric> for DotProduct { type Unit = u32; + // TODO explain me this function, I don't understand why f32.to_bits is ordered. + // I tried to do this and it wasn't OK + // // Following . fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum(); + let dist = 1.0 - dist; debug_assert!(!dist.is_nan()); dist.to_bits() } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 3c972d9b0..970c0b7ab 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -23,6 +23,7 @@ pub mod new; pub struct Search<'a> { query: Option, + vector: Option>, // this should be linked to the String in the query filter: Option>, offset: usize, @@ -41,6 +42,7 @@ impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { Search { query: None, + vector: None, filter: None, offset: 0, limit: 20, @@ -60,6 +62,11 @@ impl<'a> Search<'a> { self } + pub fn vector(&mut self, vector: impl Into>) -> &mut Search<'a> { + self.vector = Some(vector.into()); + self + } + pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { self.offset = offset; self @@ -114,6 +121,7 @@ impl<'a> Search<'a> { execute_search( &mut ctx, &self.query, + &self.vector, self.terms_matching_strategy, self.scoring_strategy, self.exhaustive_number_hits, @@ -141,6 +149,7 @@ impl fmt::Debug for Search<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let Search { query, + vector: _, filter, offset, limit, @@ -155,6 +164,7 @@ impl fmt::Debug for Search<'_> { } = self; f.debug_struct("Search") .field("query", query) + .field("vector", &"[...]") .field("filter", filter) .field("offset", offset) .field("limit", limit) diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index f33d595e5..ce28e16c1 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -509,6 +509,7 @@ mod tests { let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, &Some(query.to_string()), + &None, crate::TermsMatchingStrategy::default(), crate::score_details::ScoringStrategy::Skip, false, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 8df764f29..948a2fa21 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,6 +28,7 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; +use hnsw::Searcher; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -39,6 +40,7 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; +use space::Neighbor; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; @@ -46,7 +48,9 @@ use self::graph_based_ranking_rule::Words; use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; -use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError}; +use crate::{ + AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, +}; /// A structure used throughout the execution of a search query. pub struct SearchContext<'ctx> { @@ -350,6 +354,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( pub fn execute_search( ctx: &mut SearchContext, query: &Option, + vector: &Option>, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, exhaustive_number_hits: bool, @@ -442,6 +447,34 @@ pub fn execute_search( let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?; + let docids = match vector { + Some(vector) => { + // return the nearest documents that are also part of the candidates. + let mut searcher = Searcher::new(); + let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); + let ef = hnsw.len().min(100); + let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + + let mut docids = Vec::new(); + for Neighbor { index, distance } in neighbors.iter() { + let index = BEU32::new(*index as u32); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); + dbg!(distance, f32::from_bits(*distance)); + if universe.contains(docid) { + docids.push(docid); + if docids.len() == length { + break; + } + } + } + + docids + } + // return the search docids if the vector field is not specified + None => docids, + }; + // The candidates is the universe unless the exhaustive number of hits // is requested and a distinct attribute is set. if exhaustive_number_hits {