diff --git a/Cargo.lock b/Cargo.lock index a8d659d1a..dc7fb41bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,15 +85,6 @@ dependencies = [ "warp", ] -[[package]] -name = "astar-iter" -version = "0.1.0" -source = "git+https://github.com/Kerollmops/astar-iter#87cb97a11c701f1a6025b72b673a8bfd0ca249a5" -dependencies = [ - "indexmap", - "num-traits", -] - [[package]] name = "atty" version = "0.2.11" @@ -990,7 +981,6 @@ dependencies = [ "arc-cache", "askama", "askama_warp", - "astar-iter", "bitpacking", "bstr", "byteorder", diff --git a/Cargo.toml b/Cargo.toml index 8077e3b35..a9f685ea0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ default-run = "indexer" [dependencies] anyhow = "1.0.28" arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" } -astar-iter = { git = "https://github.com/Kerollmops/astar-iter" } bitpacking = "0.8.2" bstr = "0.2.13" byteorder = "1.3.4" diff --git a/src/bin/indexer.rs b/src/bin/indexer.rs index ee3880a26..0bec93964 100644 --- a/src/bin/indexer.rs +++ b/src/bin/indexer.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io::{self, Read, Write}; @@ -13,6 +14,7 @@ use cow_utils::CowUtils; use csv::StringRecord; use flate2::read::GzDecoder; use fst::IntoStreamer; +use heed::BytesDecode; use heed::BytesEncode; use heed::EnvOpenOptions; use heed::types::*; @@ -25,7 +27,7 @@ use structopt::StructOpt; use milli::heed_codec::CsvStringRecordCodec; use milli::tokenizer::{simple_tokenizer, only_words}; -use milli::{SmallVec32, Index, DocumentId, Position, Attribute, BEU32}; +use milli::{SmallVec32, Index, DocumentId, BEU32, StrBEU32Codec}; const LMDB_MAX_KEY_LENGTH: usize = 511; const ONE_MILLION: usize = 1_000_000; @@ -37,10 +39,8 @@ const HEADERS_KEY: &[u8] = b"\0headers"; const DOCUMENTS_IDS_KEY: &[u8] = b"\x04documents-ids"; const WORDS_FST_KEY: &[u8] = b"\x06words-fst"; const DOCUMENTS_IDS_BYTE: u8 = 4; -const WORD_ATTRIBUTE_DOCIDS_BYTE: u8 = 3; -const WORD_FOUR_POSITIONS_DOCIDS_BYTE: u8 = 5; -const WORD_POSITION_DOCIDS_BYTE: u8 = 2; -const WORD_POSITIONS_BYTE: u8 = 1; +const WORD_DOCIDS_BYTE: u8 = 2; +const WORD_DOCID_POSITIONS_BYTE: u8 = 1; #[cfg(target_os = "linux")] #[global_allocator] @@ -125,10 +125,7 @@ fn lmdb_key_valid_size(key: &[u8]) -> bool { type MergeFn = fn(&[u8], &[Vec]) -> Result, ()>; struct Store { - word_positions: ArcCache, RoaringBitmap>, - word_position_docids: ArcCache<(SmallVec32, Position), RoaringBitmap>, - word_four_positions_docids: ArcCache<(SmallVec32, Position), RoaringBitmap>, - word_attribute_docids: ArcCache<(SmallVec32, Attribute), RoaringBitmap>, + word_docids: ArcCache, RoaringBitmap>, documents_ids: RoaringBitmap, sorter: Sorter, documents_sorter: Sorter, @@ -162,10 +159,7 @@ impl Store { } Store { - word_positions: ArcCache::new(arc_cache_size), - word_position_docids: ArcCache::new(arc_cache_size), - word_four_positions_docids: ArcCache::new(arc_cache_size), - word_attribute_docids: ArcCache::new(arc_cache_size), + word_docids: ArcCache::new(arc_cache_size), documents_ids: RoaringBitmap::new(), sorter: builder.build(), documents_sorter: documents_builder.build(), @@ -173,65 +167,48 @@ impl Store { } // Save the documents ids under the position and word we have seen it. - pub fn insert_word_position_docid(&mut self, word: &str, position: Position, id: DocumentId) -> anyhow::Result<()> { + pub fn insert_word_docid(&mut self, word: &str, id: DocumentId) -> anyhow::Result<()> { let word_vec = SmallVec32::from(word.as_bytes()); let ids = RoaringBitmap::from_iter(Some(id)); - let (_, lrus) = self.word_position_docids.insert((word_vec, position), ids, |old, new| old.union_with(&new)); - Self::write_word_position_docids(&mut self.sorter, lrus)?; - self.insert_word_position(word, position)?; - self.insert_word_four_positions_docid(word, position, id)?; - self.insert_word_attribute_docid(word, position / MAX_POSITION as u32, id)?; + let (_, lrus) = self.word_docids.insert(word_vec, ids, |old, new| old.union_with(&new)); + Self::write_word_docids(&mut self.sorter, lrus)?; Ok(()) } - pub fn insert_word_four_positions_docid(&mut self, word: &str, position: Position, id: DocumentId) -> anyhow::Result<()> { - let position = position - position % 4; - let word_vec = SmallVec32::from(word.as_bytes()); - let ids = RoaringBitmap::from_iter(Some(id)); - let (_, lrus) = self.word_four_positions_docids.insert((word_vec, position), ids, |old, new| old.union_with(&new)); - Self::write_word_four_positions_docids(&mut self.sorter, lrus) - } - - // Save the positions where this word has been seen. - pub fn insert_word_position(&mut self, word: &str, position: Position) -> anyhow::Result<()> { - let word = SmallVec32::from(word.as_bytes()); - let position = RoaringBitmap::from_iter(Some(position)); - let (_, lrus) = self.word_positions.insert(word, position, |old, new| old.union_with(&new)); - Self::write_word_positions(&mut self.sorter, lrus) - } - - // Save the documents ids under the attribute and word we have seen it. - fn insert_word_attribute_docid(&mut self, word: &str, attribute: Attribute, id: DocumentId) -> anyhow::Result<()> { - let word = SmallVec32::from(word.as_bytes()); - let ids = RoaringBitmap::from_iter(Some(id)); - let (_, lrus) = self.word_attribute_docids.insert((word, attribute), ids, |old, new| old.union_with(&new)); - Self::write_word_attribute_docids(&mut self.sorter, lrus) - } - pub fn write_headers(&mut self, headers: &StringRecord) -> anyhow::Result<()> { let headers = CsvStringRecordCodec::bytes_encode(headers) .with_context(|| format!("could not encode csv record"))?; Ok(self.sorter.insert(HEADERS_KEY, headers)?) } - pub fn write_document(&mut self, id: DocumentId, record: &StringRecord) -> anyhow::Result<()> { + pub fn write_document( + &mut self, + id: DocumentId, + iter: impl IntoIterator, + record: &StringRecord, + ) -> anyhow::Result<()> + { let record = CsvStringRecordCodec::bytes_encode(record) .with_context(|| format!("could not encode csv record"))?; self.documents_ids.insert(id); - Ok(self.documents_sorter.insert(id.to_be_bytes(), record)?) + self.documents_sorter.insert(id.to_be_bytes(), record)?; + Self::write_docid_word_positions(&mut self.sorter, id, iter)?; + Ok(()) } - fn write_word_positions(sorter: &mut Sorter, iter: I) -> anyhow::Result<()> - where I: IntoIterator, RoaringBitmap)> + fn write_docid_word_positions(sorter: &mut Sorter, id: DocumentId, iter: I) -> anyhow::Result<()> + where I: IntoIterator { - // postings ids keys are all prefixed - let mut key = vec![WORD_POSITIONS_BYTE]; + // postings positions ids keys are all prefixed + let mut key = vec![WORD_DOCID_POSITIONS_BYTE]; let mut buffer = Vec::new(); for (word, positions) in iter { key.truncate(1); - key.extend_from_slice(&word); - // We serialize the positions into a buffer + key.extend_from_slice(word.as_bytes()); + // We prefix the words by the document id. + key.extend_from_slice(&id.to_be_bytes()); + // We serialize the document ids into a buffer buffer.clear(); buffer.reserve(positions.serialized_size()); positions.serialize_into(&mut buffer)?; @@ -244,68 +221,16 @@ impl Store { Ok(()) } - fn write_word_position_docids(sorter: &mut Sorter, iter: I) -> anyhow::Result<()> - where I: IntoIterator, Position), RoaringBitmap)> + fn write_word_docids(sorter: &mut Sorter, iter: I) -> anyhow::Result<()> + where I: IntoIterator, RoaringBitmap)> { // postings positions ids keys are all prefixed - let mut key = vec![WORD_POSITION_DOCIDS_BYTE]; + let mut key = vec![WORD_DOCIDS_BYTE]; let mut buffer = Vec::new(); - for ((word, pos), ids) in iter { + for (word, ids) in iter { key.truncate(1); key.extend_from_slice(&word); - // we postfix the word by the positions it appears in - key.extend_from_slice(&pos.to_be_bytes()); - // We serialize the document ids into a buffer - buffer.clear(); - buffer.reserve(ids.serialized_size()); - ids.serialize_into(&mut buffer)?; - // that we write under the generated key into MTBL - if lmdb_key_valid_size(&key) { - sorter.insert(&key, &buffer)?; - } - } - - Ok(()) - } - - fn write_word_four_positions_docids(sorter: &mut Sorter, iter: I) -> anyhow::Result<()> - where I: IntoIterator, Position), RoaringBitmap)> - { - // postings positions ids keys are all prefixed - let mut key = vec![WORD_FOUR_POSITIONS_DOCIDS_BYTE]; - let mut buffer = Vec::new(); - - for ((word, pos), ids) in iter { - key.truncate(1); - key.extend_from_slice(&word); - // we postfix the word by the positions it appears in - key.extend_from_slice(&pos.to_be_bytes()); - // We serialize the document ids into a buffer - buffer.clear(); - buffer.reserve(ids.serialized_size()); - ids.serialize_into(&mut buffer)?; - // that we write under the generated key into MTBL - if lmdb_key_valid_size(&key) { - sorter.insert(&key, &buffer)?; - } - } - - Ok(()) - } - - fn write_word_attribute_docids(sorter: &mut Sorter, iter: I) -> anyhow::Result<()> - where I: IntoIterator, Attribute), RoaringBitmap)> - { - // postings attributes keys are all prefixed - let mut key = vec![WORD_ATTRIBUTE_DOCIDS_BYTE]; - let mut buffer = Vec::new(); - - for ((word, attr), ids) in iter { - key.truncate(1); - key.extend_from_slice(&word); - // we postfix the word by the positions it appears in - key.extend_from_slice(&attr.to_be_bytes()); // We serialize the document ids into a buffer buffer.clear(); buffer.reserve(ids.serialized_size()); @@ -327,10 +252,7 @@ impl Store { } pub fn finish(mut self) -> anyhow::Result<(Reader, Reader)> { - Self::write_word_positions(&mut self.sorter, self.word_positions)?; - Self::write_word_position_docids(&mut self.sorter, self.word_position_docids)?; - Self::write_word_four_positions_docids(&mut self.sorter, self.word_four_positions_docids)?; - Self::write_word_attribute_docids(&mut self.sorter, self.word_attribute_docids)?; + Self::write_word_docids(&mut self.sorter, self.word_docids)?; Self::write_documents_ids(&mut self.sorter, self.documents_ids)?; let mut wtr = tempfile::tempfile().map(Writer::new)?; @@ -339,7 +261,8 @@ impl Store { let mut iter = self.sorter.into_iter()?; while let Some(result) = iter.next() { let (key, val) = result?; - if let Some((&1, word)) = key.split_first() { + if let Some((&1, bytes)) = key.split_first() { + let (word, _docid) = StrBEU32Codec::bytes_decode(bytes).unwrap(); // This is a lexicographically ordered word position // we use the key to construct the words fst. builder.insert(word)?; @@ -389,12 +312,7 @@ fn merge(key: &[u8], values: &[Vec]) -> Result, ()> { Ok(values[0].to_vec()) }, key => match key[0] { - DOCUMENTS_IDS_BYTE - | WORD_POSITIONS_BYTE - | WORD_POSITION_DOCIDS_BYTE - | WORD_FOUR_POSITIONS_DOCIDS_BYTE - | WORD_ATTRIBUTE_DOCIDS_BYTE => - { + DOCUMENTS_IDS_BYTE | WORD_DOCIDS_BYTE | WORD_DOCID_POSITIONS_BYTE => { let (head, tail) = values.split_first().unwrap(); let mut head = RoaringBitmap::deserialize_from(head.as_slice()).unwrap(); @@ -427,24 +345,14 @@ fn lmdb_writer(wtxn: &mut heed::RwTxn, index: &Index, key: &[u8], val: &[u8]) -> // Write the documents ids list index.main.put::<_, Str, ByteSlice>(wtxn, "documents-ids", val)?; } - else if key.starts_with(&[WORD_POSITIONS_BYTE]) { + else if key.starts_with(&[WORD_DOCIDS_BYTE]) { // Write the postings lists - index.word_positions.as_polymorph() + index.word_docids.as_polymorph() .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; } - else if key.starts_with(&[WORD_POSITION_DOCIDS_BYTE]) { + else if key.starts_with(&[WORD_DOCID_POSITIONS_BYTE]) { // Write the postings lists - index.word_position_docids.as_polymorph() - .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; - } - else if key.starts_with(&[WORD_FOUR_POSITIONS_DOCIDS_BYTE]) { - // Write the postings lists - index.word_four_positions_docids.as_polymorph() - .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; - } - else if key.starts_with(&[WORD_ATTRIBUTE_DOCIDS_BYTE]) { - // Write the attribute postings lists - index.word_attribute_docids.as_polymorph() + index.word_docid_positions.as_polymorph() .put::<_, ByteSlice, ByteSlice>(wtxn, &key[1..], val)?; } @@ -499,6 +407,7 @@ fn index_csv( let mut before = Instant::now(); let mut document_id: usize = 0; let mut document = csv::StringRecord::new(); + let mut word_positions = HashMap::new(); while rdr.read_record(&mut document)? { // We skip documents that must not be indexed by this thread. @@ -512,14 +421,15 @@ fn index_csv( let document_id = DocumentId::try_from(document_id).context("generated id is too big")?; for (attr, content) in document.iter().enumerate().take(MAX_ATTRIBUTES) { for (pos, (_, token)) in simple_tokenizer(&content).filter(only_words).enumerate().take(MAX_POSITION) { - let word = token.cow_to_lowercase(); + let word = token.to_lowercase(); let position = (attr * MAX_POSITION + pos) as u32; - store.insert_word_position_docid(&word, position, document_id)?; + store.insert_word_docid(&word, document_id)?; + word_positions.entry(word).or_insert_with(RoaringBitmap::new).insert(position); } } // We write the document in the database. - store.write_document(document_id, &document)?; + store.write_document(document_id, word_positions.drain(), &document)?; } // Compute the document id of the the next document. diff --git a/src/heed_codec/str_beu32_codec.rs b/src/heed_codec/str_beu32_codec.rs index 95836ec4e..23c52c09c 100644 --- a/src/heed_codec/str_beu32_codec.rs +++ b/src/heed_codec/str_beu32_codec.rs @@ -8,8 +8,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU32Codec { type DItem = (&'a str, u32); fn bytes_decode(bytes: &'a [u8]) -> Option { - let str_len = bytes.len().checked_sub(4)?; - let (str_bytes, n_bytes) = bytes.split_at(str_len); + let (str_bytes, n_bytes) = bytes.split_at(bytes.len() - 4); let s = str::from_utf8(str_bytes).ok()?; let n = n_bytes.try_into().map(u32::from_be_bytes).ok()?; Some((s, n)) diff --git a/src/lib.rs b/src/lib.rs index 5ab8c5769..8e1d174dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ mod criterion; -mod node; mod query_tokens; mod search; pub mod heed_codec; @@ -16,7 +15,7 @@ use heed::{PolyDatabase, Database}; pub use self::search::{Search, SearchResult}; pub use self::criterion::{Criterion, default_criteria}; -use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec, CsvStringRecordCodec}; +pub use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec, CsvStringRecordCodec}; pub type FastMap4 = HashMap>; pub type FastMap8 = HashMap>; @@ -36,14 +35,10 @@ const DOCUMENTS_IDS_KEY: &str = "documents-ids"; pub struct Index { /// Contains many different types (e.g. the documents CSV headers). pub main: PolyDatabase, - /// A word and all the positions where it appears in the whole dataset. - pub word_positions: Database, - /// Maps a word at a position (u32) and all the documents ids where the given word appears. - pub word_position_docids: Database, - /// Maps a word and a range of 4 positions, i.e. 0..4, 4..8, 12..16. - pub word_four_positions_docids: Database, - /// Maps a word and an attribute (u32) to all the documents ids where the given word appears. - pub word_attribute_docids: Database, + /// A word and all the documents ids containing the word. + pub word_docids: Database, + /// Maps a word and a document id (u32) to all the positions where the given word appears. + pub word_docid_positions: Database, /// Maps the document id to the document as a CSV line. pub documents: Database, ByteSlice>, } @@ -52,10 +47,8 @@ impl Index { pub fn new(env: &heed::Env) -> anyhow::Result { Ok(Index { main: env.create_poly_database(None)?, - word_positions: env.create_database(Some("word-positions"))?, - word_position_docids: env.create_database(Some("word-position-docids"))?, - word_four_positions_docids: env.create_database(Some("word-four-positions-docids"))?, - word_attribute_docids: env.create_database(Some("word-attribute-docids"))?, + word_docids: env.create_database(Some("word-docids"))?, + word_docid_positions: env.create_database(Some("word-docid-positions"))?, documents: env.create_database(Some("documents"))?, }) } diff --git a/src/node.rs b/src/node.rs deleted file mode 100644 index cbe6cbc59..000000000 --- a/src/node.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::cmp; -use roaring::RoaringBitmap; - -const ONE_ATTRIBUTE: u32 = 1000; -const MAX_DISTANCE: u32 = 8; - -fn index_proximity(lhs: u32, rhs: u32) -> u32 { - if lhs <= rhs { - cmp::min(rhs - lhs, MAX_DISTANCE) - } else { - cmp::min((lhs - rhs) + 1, MAX_DISTANCE) - } -} - -pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 { - let (lhs_attr, lhs_index) = extract_position(lhs); - let (rhs_attr, rhs_index) = extract_position(rhs); - if lhs_attr != rhs_attr { MAX_DISTANCE } - else { index_proximity(lhs_index, rhs_index) } -} - -// Returns the attribute and index parts. -pub fn extract_position(position: u32) -> (u32, u32) { - (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) -} - -// Returns the group of four positions in which this position reside (i.e. 0, 4, 12). -pub fn group_of_four(position: u32) -> u32 { - position - position % 4 -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Node { - // Is this node is the first node. - Uninit, - Init { - // The layer where this node located. - layer: usize, - // The position where this node is located. - position: u32, - // The parent position from the above layer. - parent_position: u32, - }, -} - -impl Node { - // TODO we must skip the successors that have already been seen - // TODO we must skip the successors that doesn't return any documents - // this way we are able to skip entire paths - pub fn successors(&self, positions: &[RoaringBitmap], contains_documents: &mut F) -> Vec<(Node, u32)> - where F: FnMut((usize, u32), (usize, u32)) -> bool, - { - match self { - Node::Uninit => { - positions[0].iter().map(|position| { - (Node::Init { layer: 0, position, parent_position: 0 }, 0) - }).collect() - }, - // We reached the highest layer - n @ Node::Init { .. } if n.is_complete(positions) => vec![], - Node::Init { layer, position, .. } => { - positions[layer + 1].iter().filter_map(|p| { - let proximity = positions_proximity(*position, p); - let node = Node::Init { - layer: layer + 1, - position: p, - parent_position: *position, - }; - // We do not produce the nodes we have already seen in previous iterations loops. - if node.is_reachable(contains_documents) { - Some((node, proximity)) - } else { - None - } - }).collect() - } - } - } - - pub fn is_complete(&self, positions: &[RoaringBitmap]) -> bool { - match self { - Node::Uninit => false, - Node::Init { layer, .. } => *layer == positions.len() - 1, - } - } - - pub fn position(&self) -> Option { - match self { - Node::Uninit => None, - Node::Init { position, .. } => Some(*position), - } - } - - pub fn is_reachable(&self, contains_documents: &mut F) -> bool - where F: FnMut((usize, u32), (usize, u32)) -> bool, - { - match self { - Node::Uninit => true, - Node::Init { layer, position, parent_position, .. } => { - match layer.checked_sub(1) { - Some(parent_layer) => { - (contains_documents)((parent_layer, *parent_position), (*layer, *position)) - }, - None => true, - } - }, - } - } -} diff --git a/src/search.rs b/src/search.rs index 2e53ca5ed..7f0831d1c 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,8 +1,5 @@ -use std::cell::RefCell; use std::collections::{HashMap, HashSet}; -use std::rc::Rc; -use astar_iter::AstarBagIter; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::DFA; use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; @@ -10,7 +7,6 @@ use log::debug; use once_cell::sync::Lazy; use roaring::RoaringBitmap; -use crate::node::{self, Node}; use crate::query_tokens::{QueryTokens, QueryToken}; use crate::{Index, DocumentId, Position, Attribute}; @@ -86,69 +82,52 @@ impl<'a> Search<'a> { .collect() } - /// Fetch the words from the given FST related to the given DFAs along with the associated - /// positions and the unions of those positions where the words found appears in the documents. - fn fetch_words_positions( + /// Fetch the words from the given FST related to the + /// given DFAs along with the associated documents ids. + fn fetch_words_docids( rtxn: &heed::RoTxn, index: &Index, fst: &fst::Set<&[u8]>, dfas: Vec<(String, bool, DFA)>, - ) -> anyhow::Result<(Vec>, Vec)> + ) -> anyhow::Result, RoaringBitmap)>> { // A Vec storing all the derived words from the original query words, associated - // with the distance from the original word and the positions it appears at. - // The index the derived words appears in the Vec corresponds to the original query - // word position. - let mut derived_words = Vec::>::with_capacity(dfas.len()); - // A Vec storing the unions of all of each of the derived words positions. The index - // the union appears in the Vec corresponds to the original query word position. - let mut union_positions = Vec::::with_capacity(dfas.len()); + // with the distance from the original word and the docids where the words appears. + let mut derived_words = Vec::<(HashMap::, RoaringBitmap)>::with_capacity(dfas.len()); for (_word, _is_prefix, dfa) in dfas { - let mut acc_derived_words = Vec::new(); - let mut acc_union_positions = RoaringBitmap::new(); + let mut acc_derived_words = HashMap::new(); + let mut unions_docids = RoaringBitmap::new(); let mut stream = fst.search_with_state(&dfa).into_stream(); while let Some((word, state)) = stream.next() { let word = std::str::from_utf8(word)?; - let positions = index.word_positions.get(rtxn, word)?.unwrap(); + let docids = index.word_docids.get(rtxn, word)?.unwrap(); let distance = dfa.distance(state); - acc_union_positions.union_with(&positions); - acc_derived_words.push((word.to_string(), distance.to_u8(), positions)); + unions_docids.union_with(&docids); + acc_derived_words.insert(word.to_string(), (distance.to_u8(), docids)); } - derived_words.push(acc_derived_words); - union_positions.push(acc_union_positions); + derived_words.push((acc_derived_words, unions_docids)); } - Ok((derived_words, union_positions)) + Ok(derived_words) } /// Returns the set of docids that contains all of the query words. fn compute_candidates( rtxn: &heed::RoTxn, index: &Index, - derived_words: &[Vec<(String, u8, RoaringBitmap)>], + derived_words: &[(HashMap, RoaringBitmap)], ) -> anyhow::Result { // we do a union between all the docids of each of the derived words, // we got N unions (the number of original query words), we then intersect them. - // TODO we must store the words documents ids to avoid these unions. let mut candidates = RoaringBitmap::new(); - let number_of_attributes = index.number_of_attributes(rtxn)?.map_or(0, |n| n as u32); - - for (i, derived_words) in derived_words.iter().enumerate() { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, _positions) in derived_words { - for attr in 0..number_of_attributes { - if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attr))? { - union_docids.union_with(&docids); - } - } - } + for (i, (_, union_docids)) in derived_words.iter().enumerate() { if i == 0 { - candidates = union_docids; + candidates = union_docids.clone(); } else { candidates.intersect_with(&union_docids); } @@ -157,161 +136,6 @@ impl<'a> Search<'a> { Ok(candidates) } - /// Returns the union of the same position for all the given words. - fn union_word_position( - rtxn: &heed::RoTxn, - index: &Index, - words: &[(String, u8, RoaringBitmap)], - position: Position, - ) -> anyhow::Result - { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, positions) in words { - if positions.contains(position) { - if let Some(docids) = index.word_position_docids.get(rtxn, &(word, position))? { - union_docids.union_with(&docids); - } - } - } - Ok(union_docids) - } - - /// Returns the union of the same gorup of four positions for all the given words. - fn union_word_four_positions( - rtxn: &heed::RoTxn, - index: &Index, - words: &[(String, u8, RoaringBitmap)], - group: Position, - ) -> anyhow::Result - { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, _positions) in words { - // TODO would be better to check if the group exist - if let Some(docids) = index.word_four_positions_docids.get(rtxn, &(word, group))? { - union_docids.union_with(&docids); - } - } - Ok(union_docids) - } - - /// Returns the union of the same attribute for all the given words. - fn union_word_attribute( - rtxn: &heed::RoTxn, - index: &Index, - words: &[(String, u8, RoaringBitmap)], - attribute: Attribute, - ) -> anyhow::Result - { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, _positions) in words { - if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attribute))? { - union_docids.union_with(&docids); - } - } - Ok(union_docids) - } - - // Returns `true` if there is documents in common between the two words and positions given. - fn contains_documents( - rtxn: &heed::RoTxn, - index: &Index, - (lword, lpos): (usize, u32), - (rword, rpos): (usize, u32), - candidates: &RoaringBitmap, - derived_words: &[Vec<(String, u8, RoaringBitmap)>], - union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, - non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, - group_four_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, - group_four_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, - attribute_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, - attribute_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, - ) -> bool - { - if lpos == rpos { return false } - - // TODO move this function to a better place. - let (lattr, _) = node::extract_position(lpos); - let (rattr, _) = node::extract_position(rpos); - - if lattr == rattr { - // TODO move this function to a better place. - let lgroup = node::group_of_four(lpos); - let rgroup = node::group_of_four(rpos); - - // We can't compute a disjunction on a group of four positions if those - // two positions are in the same group, we must go down to the position. - if lgroup == rgroup { - // We retrieve or compute the intersection between the two given words and positions. - *non_disjoint_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - union_cache.entry((lword, lpos)).or_insert_with(|| { - let words = &derived_words[lword]; - Self::union_word_position(rtxn, index, words, lpos).unwrap() - }); - union_cache.entry((rword, rpos)).or_insert_with(|| { - let words = &derived_words[rword]; - Self::union_word_position(rtxn, index, words, rpos).unwrap() - }); - - // TODO is there a way to avoid this double gets? - let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); - let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } - - !lunion_docids.is_disjoint(&runion_docids) - }) - } else { - // We retrieve or compute the intersection between the two given words and positions. - *group_four_non_disjoint_cache.entry(((lword, lgroup), (rword, rgroup))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and group of four positions. - group_four_union_cache.entry((lword, lgroup)).or_insert_with(|| { - let words = &derived_words[lword]; - Self::union_word_four_positions(rtxn, index, words, lgroup).unwrap() - }); - group_four_union_cache.entry((rword, rgroup)).or_insert_with(|| { - let words = &derived_words[rword]; - Self::union_word_four_positions(rtxn, index, words, rgroup).unwrap() - }); - - // TODO is there a way to avoid this double gets? - let lunion_group_docids = group_four_union_cache.get(&(lword, lgroup)).unwrap(); - let runion_group_docids = group_four_union_cache.get(&(rword, rgroup)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_group_docids.is_disjoint(candidates) { return false } - if runion_group_docids.is_disjoint(candidates) { return false } - - !lunion_group_docids.is_disjoint(&runion_group_docids) - }) - } - } else { - *attribute_non_disjoint_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - attribute_union_cache.entry((lword, lattr)).or_insert_with(|| { - let words = &derived_words[lword]; - Self::union_word_attribute(rtxn, index, words, lattr).unwrap() - }); - attribute_union_cache.entry((rword, rattr)).or_insert_with(|| { - let words = &derived_words[rword]; - Self::union_word_attribute(rtxn, index, words, rattr).unwrap() - }); - - // TODO is there a way to avoid this double gets? - let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); - let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } - - !lunion_docids.is_disjoint(&runion_docids) - }) - } - } - pub fn execute(&self) -> anyhow::Result { let rtxn = self.rtxn; let index = self.index; @@ -333,111 +157,14 @@ impl<'a> Search<'a> { return Ok(Default::default()); } - let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?; + let derived_words = Self::fetch_words_docids(rtxn, index, &fst, dfas)?; let candidates = Self::compute_candidates(rtxn, index, &derived_words)?; debug!("candidates: {:?}", candidates); - let union_cache = HashMap::new(); - let mut non_disjoint_cache = HashMap::new(); + let documents = vec![candidates]; - let mut group_four_union_cache = HashMap::new(); - let mut group_four_non_disjoint_cache = HashMap::new(); - - let mut attribute_union_cache = HashMap::new(); - let mut attribute_non_disjoint_cache = HashMap::new(); - - let candidates = Rc::new(RefCell::new(candidates)); - let union_cache = Rc::new(RefCell::new(union_cache)); - - let candidates_cloned = candidates.clone(); - let union_cache_cloned = union_cache.clone(); - let mut contains_documents = |left, right| { - Self::contains_documents( - rtxn, index, - left, right, - &candidates_cloned.borrow(), - &derived_words, - &mut union_cache_cloned.borrow_mut(), - &mut non_disjoint_cache, - &mut group_four_union_cache, - &mut group_four_non_disjoint_cache, - &mut attribute_union_cache, - &mut attribute_non_disjoint_cache, - ) - }; - - let astar_iter = AstarBagIter::new( - Node::Uninit, // start - |n| n.successors(&union_positions, &mut contains_documents), // successors - |_| 0, // heuristic - |n| n.is_complete(&union_positions), // success - ); - - let mut documents = Vec::new(); - for (paths, proximity) in astar_iter { - let mut union_cache = union_cache.borrow_mut(); - let mut candidates = candidates.borrow_mut(); - - let mut positions: Vec> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect(); - positions.sort_unstable(); - - debug!("Found {} positions with a proximity of {}", positions.len(), proximity); - - let mut same_proximity_union = RoaringBitmap::default(); - for positions in positions { - // Precompute the potentially missing unions - positions.iter().enumerate().for_each(|(word, pos)| { - union_cache.entry((word, *pos)).or_insert_with(|| { - let words = &&derived_words[word]; - Self::union_word_position(rtxn, index, words, *pos).unwrap() - }); - }); - - // Retrieve the unions along with the popularity of it. - let mut to_intersect = Vec::new(); - for (word, pos) in positions.into_iter().enumerate() { - let docids = union_cache.get(&(word, pos)).unwrap(); - to_intersect.push((docids.len(), docids)); - } - - // Sort the unions by popularity to help reduce - // the number of documents as soon as possible. - to_intersect.sort_unstable_by_key(|(l, _)| *l); - - // Intersect all the unions in the inverse popularity order. - let mut intersect_docids = RoaringBitmap::new(); - for (i, (_, union_docids)) in to_intersect.into_iter().enumerate() { - if i == 0 { - intersect_docids = union_docids.clone(); - } else { - intersect_docids.intersect_with(union_docids); - } - } - - same_proximity_union.union_with(&intersect_docids); - } - - // We achieve to find valid documents ids so we remove them from the candidates list. - candidates.difference_with(&same_proximity_union); - - // We remove documents we have already been seen in previous - // fetches from this set of documents we just fetched. - for previous_documents in &documents { - same_proximity_union.difference_with(previous_documents); - } - - if !same_proximity_union.is_empty() { - documents.push(same_proximity_union); - } - - // We found enough documents we can stop here. - if documents.iter().map(RoaringBitmap::len).sum::() >= limit as u64 { - break; - } - } - - let found_words = derived_words.into_iter().flatten().map(|(w, _, _)| w).collect(); + let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect(); let documents_ids = documents.iter().flatten().take(limit).collect(); Ok(SearchResult { found_words, documents_ids })