diff --git a/src/best_proximity.rs b/src/best_proximity.rs index 6dda4d1fb..6f822ee6d 100644 --- a/src/best_proximity.rs +++ b/src/best_proximity.rs @@ -23,7 +23,7 @@ pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 { } // Returns the attribute and index parts. -fn extract_position(position: u32) -> (u32, u32) { +pub fn extract_position(position: u32) -> (u32, u32) { (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) } @@ -66,7 +66,7 @@ impl Node { parent_position: *position, }; // We do not produce the nodes we have already seen in previous iterations loops. - if proximity > 7 || (node.is_complete(positions) && acc_proximity + proximity < best_proximity) { + if node.is_complete(positions) && acc_proximity + proximity < best_proximity { None } else { Some((node, proximity)) @@ -138,7 +138,7 @@ impl BestProximity { { let before = Instant::now(); - if self.best_proximity == self.positions.len() as u32 * (MAX_DISTANCE - 1) { + if self.best_proximity == self.positions.len() as u32 * MAX_DISTANCE { return None; } @@ -177,6 +177,11 @@ impl BestProximity { mod tests { use super::*; + fn sort(mut val: (u32, Vec)) -> (u32, Vec) { + val.1.sort_unstable(); + val + } + #[test] fn same_attribute() { let positions = vec![ @@ -190,7 +195,7 @@ mod tests { assert_eq!(iter.next(f), Some((1+2, vec![vec![0, 1, 3]]))); // 3 assert_eq!(iter.next(f), Some((2+2, vec![vec![2, 1, 3]]))); // 4 assert_eq!(iter.next(f), Some((3+2, vec![vec![3, 1, 3]]))); // 5 - assert_eq!(iter.next(f), Some((1+5, vec![vec![0, 1, 6], vec![4, 1, 3]]))); // 6 + assert_eq!(iter.next(f).map(sort), Some((1+5, vec![vec![0, 1, 6], vec![4, 1, 3]]))); // 6 assert_eq!(iter.next(f), Some((2+5, vec![vec![2, 1, 6]]))); // 7 assert_eq!(iter.next(f), Some((3+5, vec![vec![3, 1, 6]]))); // 8 assert_eq!(iter.next(f), Some((4+5, vec![vec![4, 1, 6]]))); // 9 diff --git a/src/lib.rs b/src/lib.rs index c9ac19de1..6af75e875 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,30 +198,66 @@ impl Index { union_docids }; + // Returns the union of the same attribute for all the derived words. + let unions_word_attr = |word: usize, attr: u32| { + let mut union_docids = RoaringBitmap::new(); + for (word, _) in &words[word] { + let mut key = word.clone(); + key.extend_from_slice(&attr.to_be_bytes()); + if let Some(right) = self.word_attribute_docids.get(rtxn, &key).unwrap() { + union_docids.union_with(&right); + } + } + union_docids + }; + let mut union_cache = HashMap::new(); let mut intersect_cache = HashMap::new(); + + let mut attribute_union_cache = HashMap::new(); + let mut attribute_intersect_cache = HashMap::new(); + // Returns `true` if there is documents in common between the two words and positions given. let mut contains_documents = |(lword, lpos), (rword, rpos), union_cache: &mut HashMap<_, _>, candidates: &RoaringBitmap| { - let proximity = best_proximity::positions_proximity(lpos, rpos); + if lpos == rpos { return false } - if proximity == 0 { return false } + let (lattr, _) = best_proximity::extract_position(lpos); + let (rattr, _) = best_proximity::extract_position(rpos); - // We retrieve or compute the intersection between the two given words and positions. - *intersect_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - union_cache.entry((lword, lpos)).or_insert_with(|| unions_word_pos(lword, lpos)); - union_cache.entry((rword, rpos)).or_insert_with(|| unions_word_pos(rword, rpos)); + if lattr == rattr { + // We retrieve or compute the intersection between the two given words and positions. + *intersect_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { + // We retrieve or compute the unions for the two words and positions. + union_cache.entry((lword, lpos)).or_insert_with(|| unions_word_pos(lword, lpos)); + union_cache.entry((rword, rpos)).or_insert_with(|| unions_word_pos(rword, rpos)); - // TODO is there a way to avoid this double gets? - let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); - let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); + // TODO is there a way to avoid this double gets? + let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); + let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } + // We first check that the docids of these unions are part of the candidates. + if lunion_docids.is_disjoint(candidates) { return false } + if runion_docids.is_disjoint(candidates) { return false } - !lunion_docids.is_disjoint(&runion_docids) - }) + !lunion_docids.is_disjoint(&runion_docids) + }) + } else { + *attribute_intersect_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { + // We retrieve or compute the unions for the two words and positions. + attribute_union_cache.entry((lword, lattr)).or_insert_with(|| unions_word_attr(lword, lattr)); + attribute_union_cache.entry((rword, rattr)).or_insert_with(|| unions_word_attr(rword, rattr)); + + // TODO is there a way to avoid this double gets? + let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); + let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); + + // We first check that the docids of these unions are part of the candidates. + if lunion_docids.is_disjoint(candidates) { return false } + if runion_docids.is_disjoint(candidates) { return false } + + !lunion_docids.is_disjoint(&runion_docids) + }) + } }; let mut documents = Vec::new();