diff --git a/Cargo.toml b/Cargo.toml index eac65404d..79707af9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,6 @@ criterion = "0.3" [features] default = [] -intersect-to-csv = [] [[bench]] name = "search" diff --git a/src/lib.rs b/src/lib.rs index c24c77563..ead9ee278 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,68 +123,42 @@ impl Index { let mut documents = Vec::new(); - let mut debug_intersects = HashMap::new(); + // Returns the union of the same position for all the derived words. + let unions_word_pos = |word: usize, pos: u32| { + let mut union_docids = RoaringBitmap::new(); + for (word, attrs) in &words[word] { + if attrs.contains(pos) { + let mut key = word.clone(); + key.extend_from_slice(&pos.to_be_bytes()); + if let Some(attrs) = self.postings_ids.get(rtxn, &key).unwrap() { + let right = RoaringBitmap::deserialize_from_slice(attrs).unwrap(); + union_docids.union_with(&right); + } + } + } + + union_docids + }; + + let mut union_cache = HashMap::new(); let mut intersect_cache = HashMap::new(); - let mut lunion_docids = RoaringBitmap::default(); - let mut runion_docids = RoaringBitmap::default(); + // Returns `true` if there is documents in common between the two words and positions given. let contains_documents = |(lword, lpos): (usize, u32), (rword, rpos): (usize, u32)| { let proximity = best_proximity::positions_proximity(lpos, rpos); + if proximity == 0 { return false } + // We retrieve or compute the intersection between the two given words and positions. *intersect_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { - let (nb_words, nb_docs_intersect, lnblookups, lnbbitmaps, rnblookups, rnbbitmaps) = - debug_intersects.entry((lword, lpos, rword, rpos, proximity)).or_default(); + // We retrieve or compute the unions for the two words and positions. + union_cache.entry((lword, lpos)).or_insert_with(|| unions_word_pos(lword, lpos)); + union_cache.entry((rword, rpos)).or_insert_with(|| unions_word_pos(rword, rpos)); - let left = &words[lword]; - let right = &words[rword]; + // TODO is there a way to avoid this double gets? + let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); + let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); - *nb_words = left.len() + right.len(); - - let mut l_lookups = 0; - let mut l_bitmaps = 0; - let mut r_lookups = 0; - let mut r_bitmaps = 0; - - // This for the left word - lunion_docids.clear(); - for (word, attrs) in left { - if attrs.contains(lpos) { - l_lookups += 1; - let mut key = word.clone(); - key.extend_from_slice(&lpos.to_be_bytes()); - if let Some(attrs) = self.postings_ids.get(rtxn, &key).unwrap() { - l_bitmaps += 1; - let right = RoaringBitmap::deserialize_from_slice(attrs).unwrap(); - lunion_docids.union_with(&right); - } - } - } - - // This for the right word - runion_docids.clear(); - for (word, attrs) in right { - if attrs.contains(rpos) { - r_lookups += 1; - let mut key = word.clone(); - key.extend_from_slice(&rpos.to_be_bytes()); - if let Some(attrs) = self.postings_ids.get(rtxn, &key).unwrap() { - r_bitmaps += 1; - let right = RoaringBitmap::deserialize_from_slice(attrs).unwrap(); - runion_docids.union_with(&right); - } - } - } - - let intersect_docids = &mut lunion_docids; - intersect_docids.intersect_with(&runion_docids); - - *lnblookups = l_lookups; - *lnbbitmaps = l_bitmaps; - *rnblookups = r_lookups; - *rnbbitmaps = r_bitmaps; - *nb_docs_intersect += intersect_docids.len(); - - !intersect_docids.is_empty() + !lunion_docids.is_disjoint(&runion_docids) }) }; @@ -262,65 +236,7 @@ impl Index { } } - if cfg!(feature = "intersect-to-csv") { - debug_intersects_to_csv(debug_intersects); - } - eprintln!("{} candidates", documents.iter().map(RoaringBitmap::len).sum::()); Ok(documents.iter().flatten().take(20).collect()) } } - -fn debug_intersects_to_csv(intersects: HashMap<(usize, u32, usize, u32, u32), (usize, u64, usize, usize, usize, usize)>) { - let mut wrt = csv::Writer::from_path("intersects-stats.csv").unwrap(); - wrt.write_record(&[ - "proximity", - "lword", - "lpos", - "rword", - "rpos", - "nb_derived_words", - "nb_docs_intersect", - "lnblookups", - "lnbbitmaps", - "rnblookups", - "rnbbitmaps", - ]).unwrap(); - - for ((lword, lpos, rword, rpos, proximity), vals) in intersects { - let ( - nb_derived_words, - nb_docs_intersect, - lnblookups, - lnbbitmaps, - rnblookups, - rnbbitmaps, - ) = vals; - - let proximity = proximity.to_string(); - let lword = lword.to_string(); - let lpos = lpos.to_string(); - let rword = rword.to_string(); - let rpos = rpos.to_string(); - let nb_derived_words = nb_derived_words.to_string(); - let nb_docs_intersect = nb_docs_intersect.to_string(); - let lnblookups = lnblookups.to_string(); - let lnbbitmaps = lnbbitmaps.to_string(); - let rnblookups = rnblookups.to_string(); - let rnbbitmaps = rnbbitmaps.to_string(); - - wrt.write_record(&[ - &proximity, - &lword, - &lpos, - &rword, - &rpos, - &nb_derived_words, - &nb_docs_intersect, - &lnblookups, - &lnbbitmaps, - &rnblookups, - &rnbbitmaps, - ]).unwrap(); - } -}