diff --git a/src/lib.rs b/src/lib.rs index 55d2f583e..91fc9ae42 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ mod criterion; mod query_tokens; mod search; pub mod heed_codec; +pub mod proximity; pub mod tokenizer; use std::collections::HashMap; diff --git a/src/proximity.rs b/src/proximity.rs new file mode 100644 index 000000000..0186eb3d0 --- /dev/null +++ b/src/proximity.rs @@ -0,0 +1,28 @@ +use std::cmp; +use crate::{Attribute, Position}; + +const ONE_ATTRIBUTE: u32 = 1000; +const MAX_DISTANCE: u32 = 8; + +pub fn index_proximity(lhs: u32, rhs: u32) -> u32 { + if lhs <= rhs { + cmp::min(rhs - lhs, MAX_DISTANCE) + } else { + cmp::min((lhs - rhs) + 1, MAX_DISTANCE) + } +} + +pub fn positions_proximity(lhs: Position, rhs: Position) -> u32 { + let (lhs_attr, lhs_index) = extract_position(lhs); + let (rhs_attr, rhs_index) = extract_position(rhs); + if lhs_attr != rhs_attr { MAX_DISTANCE } + else { index_proximity(lhs_index, rhs_index) } +} + +pub fn extract_position(position: Position) -> (Attribute, Position) { + (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) +} + +pub fn path_proximity(path: &[Position]) -> u32 { + path.windows(2).map(|w| positions_proximity(w[0], w[1])).sum::() +} diff --git a/src/search.rs b/src/search.rs index b392ad0b6..46e5f5280 100644 --- a/src/search.rs +++ b/src/search.rs @@ -10,6 +10,7 @@ use roaring::bitmap::{IntoIter, RoaringBitmap}; use near_proximity::near_proximity; +use crate::proximity::path_proximity; use crate::query_tokens::{QueryTokens, QueryToken}; use crate::{Index, DocumentId, Position}; @@ -194,33 +195,6 @@ impl<'a> Search<'a> { let mut documents = Vec::new(); - // TODO move this function elsewhere - fn compute_proximity(path: &[Position]) -> u32 { - const ONE_ATTRIBUTE: u32 = 1000; - const MAX_DISTANCE: u32 = 8; - - fn index_proximity(lhs: u32, rhs: u32) -> u32 { - if lhs <= rhs { - cmp::min(rhs - lhs, MAX_DISTANCE) - } else { - cmp::min((lhs - rhs) + 1, MAX_DISTANCE) - } - } - - fn positions_proximity(lhs: u32, rhs: u32) -> u32 { - let (lhs_attr, lhs_index) = extract_position(lhs); - let (rhs_attr, rhs_index) = extract_position(rhs); - if lhs_attr != rhs_attr { MAX_DISTANCE } - else { index_proximity(lhs_index, rhs_index) } - } - - fn extract_position(position: u32) -> (u32, u32) { - (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) - } - - path.windows(2).map(|w| positions_proximity(w[0], w[1])).sum::() - } - // If there only is one word, no need to compute the best proximities. if derived_words.len() == 1 { let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect(); @@ -231,7 +205,7 @@ impl<'a> Search<'a> { let mut paths = Vec::new(); for candidate in candidates { let keywords = Self::fecth_keywords(rtxn, index, &derived_words, candidate)?; - near_proximity(keywords, &mut paths, compute_proximity); + near_proximity(keywords, &mut paths, path_proximity); if let Some((prox, _path)) = paths.first() { documents.push((*prox, candidate)); }