From 3f7a500f3b108f2e2b34688e9b6f16ff1bc5b2cb Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Wed, 25 Sep 2024 14:15:18 +0200 Subject: [PATCH] Build prefix fst --- milli/src/update/new/merger.rs | 146 +----------------- milli/src/update/new/mod.rs | 1 + milli/src/update/new/word_fst_builder.rs | 187 +++++++++++++++++++++++ 3 files changed, 192 insertions(+), 142 deletions(-) create mode 100644 milli/src/update/new/word_fst_builder.rs diff --git a/milli/src/update/new/merger.rs b/milli/src/update/new/merger.rs index ca6b213c1..7e1a80888 100644 --- a/milli/src/update/new/merger.rs +++ b/milli/src/update/new/merger.rs @@ -1,20 +1,18 @@ use std::fs::File; -use std::io::{self, BufWriter}; +use std::io::{self}; use bincode::ErrorKind; -use fst::{Set, SetBuilder, Streamer}; use grenad::Merger; use heed::types::Bytes; use heed::{Database, RoTxn}; -use memmap2::Mmap; use roaring::RoaringBitmap; -use tempfile::tempfile; use super::channel::*; use super::extract::FacetKind; use super::{Deletion, DocumentChange, Insertion, KvReaderDelAdd, KvReaderFieldId, Update}; use crate::update::del_add::DelAdd; use crate::update::new::channel::MergerOperation; +use crate::update::new::word_fst_builder::WordFstBuilder; use crate::update::MergeDeladdCboRoaringBitmaps; use crate::{CboRoaringBitmapCodec, Error, GeoPoint, GlobalFieldsIdsMap, Index, Result}; @@ -82,8 +80,8 @@ pub fn merge_grenad_entries( tracing::trace_span!(target: "indexing::documents::merge", "words_fst"); let _entered = span.enter(); - let mmap = word_fst_builder.build()?; - sender.main().write_words_fst(mmap).unwrap(); + let (word_fst_mmap, prefix_fst_mmap) = word_fst_builder.build()?; + sender.main().write_words_fst(word_fst_mmap).unwrap(); } } MergerOperation::WordFidDocidsMerger(merger) => { @@ -190,142 +188,6 @@ pub fn merge_grenad_entries( Ok(()) } -struct WordFstBuilder<'a> { - stream: fst::set::Stream<'a>, - word_fst_builder: SetBuilder>, - prefix_fst_builders: Vec>>, - max_prefix_length: usize, - last_word: Vec, -} - -impl<'a> WordFstBuilder<'a> { - pub fn new( - words_fst: &'a Set>, - max_prefix_length: usize, - ) -> Result { - let mut prefix_fst_builders = Vec::new(); - for _ in 0..max_prefix_length { - prefix_fst_builders.push(SetBuilder::new(BufWriter::new(tempfile()?))?); - } - - Ok(Self { - stream: words_fst.stream(), - word_fst_builder: SetBuilder::new(BufWriter::new(tempfile()?))?, - prefix_fst_builders, - max_prefix_length, - last_word: Vec::new(), - }) - } - - pub fn register_word(&mut self, deladd: DelAdd, key: &[u8]) -> Result<()> { - match deladd { - DelAdd::Addition => self.add_word(key), - DelAdd::Deletion => self.del_word(key), - } - } - - pub fn add_word(&mut self, word: &[u8]) -> Result<()> { - if !self.last_word.is_empty() { - let next = self.last_word.as_slice(); - match next.cmp(word) { - std::cmp::Ordering::Less => { - // We need to insert the last word from the current fst - self.word_fst_builder.insert(next)?; - self.last_word.clear(); - } - std::cmp::Ordering::Equal => { - // We insert the word and drop the last word - self.word_fst_builder.insert(next)?; - self.last_word.clear(); - return Ok(()); - } - std::cmp::Ordering::Greater => { - // We insert the word and keep the last word - self.word_fst_builder.insert(word)?; - - return Ok(()); - } - } - } - - while let Some(next) = self.stream.next() { - match next.cmp(word) { - std::cmp::Ordering::Less => { - // We need to insert the last word from the current fst - self.word_fst_builder.insert(next)?; - } - std::cmp::Ordering::Equal => { - // We insert the word - self.word_fst_builder.insert(next)?; - - return Ok(()); - } - std::cmp::Ordering::Greater => { - // We insert the word and keep the last word - self.word_fst_builder.insert(word)?; - self.last_word.clear(); - self.last_word.extend_from_slice(next); - - return Ok(()); - } - } - } - - Ok(()) - } - - pub fn del_word(&mut self, word: &[u8]) -> Result<()> { - if !self.last_word.is_empty() { - let next = self.last_word.as_slice(); - match next.cmp(word) { - std::cmp::Ordering::Less => { - // We insert the word from the current fst because the next word to delete is greater - self.word_fst_builder.insert(next)?; - self.last_word.clear(); - } - std::cmp::Ordering::Equal => { - // We delete the word by not inserting it in the new fst and drop the last word - self.last_word.clear(); - return Ok(()); - } - std::cmp::Ordering::Greater => { - // keep the current word until the next word to delete is greater or equal - return Ok(()); - } - } - } - - while let Some(next) = self.stream.next() { - match next.cmp(word) { - std::cmp::Ordering::Less => { - // We insert the word from the current fst because the next word to delete is greater - self.word_fst_builder.insert(next)?; - } - std::cmp::Ordering::Equal => { - // We delete the word by not inserting it in the new fst and drop the last word - return Ok(()); - } - std::cmp::Ordering::Greater => { - // keep the current word until the next word to delete is greater or equal - self.last_word.clear(); - self.last_word.extend_from_slice(next); - - return Ok(()); - } - } - } - - Ok(()) - } - - pub fn build(mut self) -> Result { - let words_fst_file = self.word_fst_builder.into_inner()?.into_inner().unwrap(); - let words_fst_mmap = unsafe { Mmap::map(&words_fst_file)? }; - - Ok(words_fst_mmap) - } -} - pub struct GeoExtractor { rtree: Option>, } diff --git a/milli/src/update/new/mod.rs b/milli/src/update/new/mod.rs index 6389a53c4..dedd89497 100644 --- a/milli/src/update/new/mod.rs +++ b/milli/src/update/new/mod.rs @@ -12,6 +12,7 @@ pub mod indexer; mod items_pool; mod merger; mod top_level_map; +mod word_fst_builder; /// TODO move them elsewhere pub type StdResult = std::result::Result; diff --git a/milli/src/update/new/word_fst_builder.rs b/milli/src/update/new/word_fst_builder.rs new file mode 100644 index 000000000..227a81d9d --- /dev/null +++ b/milli/src/update/new/word_fst_builder.rs @@ -0,0 +1,187 @@ +use std::{fs::File, io::BufWriter}; + +use fst::{Set, SetBuilder, Streamer}; +use memmap2::Mmap; +use tempfile::tempfile; + +use crate::{update::del_add::DelAdd, Result, SmallString32}; + +pub struct WordFstBuilder<'a> { + stream: Option>, + word_fst_builder: SetBuilder>, + /// TODO: Replace the full memory allocation + prefix_fst_builders: Vec>>, + max_prefix_length: usize, + last_word: Option>, + current_prefix: Vec, + current_prefix_count: Vec, + prefix_count_threshold: u64, +} + +impl<'a> WordFstBuilder<'a> { + pub fn new( + words_fst: &'a Set>, + max_prefix_length: usize, + ) -> Result { + let mut prefix_fst_builders = Vec::new(); + for _ in 0..max_prefix_length { + prefix_fst_builders.push(SetBuilder::memory()); + } + + Ok(Self { + stream: Some(words_fst.stream()), + word_fst_builder: SetBuilder::new(BufWriter::new(tempfile()?))?, + prefix_fst_builders, + max_prefix_length, + last_word: None, + current_prefix: vec![SmallString32::new(); max_prefix_length], + current_prefix_count: vec![0; max_prefix_length], + prefix_count_threshold: 100, + }) + } + + pub fn register_word(&mut self, deladd: DelAdd, right: &[u8]) -> Result<()> { + if let Some(left) = self.last_word.take() { + let (left_inserted, right_inserted) = + self.compare_and_insert(deladd, left.as_slice(), right)?; + + // left was not inserted, so we keep it for the next iteration + if !left_inserted { + self.last_word = Some(left); + } + + // right was inserted, so we can stop + if right_inserted { + return Ok(()); + } + } + + if let Some(mut stream) = self.stream.take() { + while let Some(left) = stream.next() { + let (left_inserted, right_inserted) = + self.compare_and_insert(deladd, left, right)?; + + // left was not inserted, so we keep it for the next iteration + if !left_inserted { + self.last_word = Some(left.to_vec()); + } + + // right was inserted, so we can stop + if right_inserted { + break; + } + } + + self.stream = Some(stream); + } + + Ok(()) + } + + pub fn compare_and_insert( + &mut self, + deladd: DelAdd, + left: &[u8], + right: &[u8], + ) -> Result<(bool, bool)> { + let mut left_inserted = false; + let mut right_inserted = false; + match left.cmp(right) { + std::cmp::Ordering::Less => { + // We need to insert the last word from the current fst + self.insert_word(left)?; + + left_inserted = true; + } + std::cmp::Ordering::Equal => { + // Addition: We insert the word + // Deletion: We delete the word by not inserting it + if deladd == DelAdd::Addition { + self.insert_word(right)?; + } + + left_inserted = true; + right_inserted = true; + } + std::cmp::Ordering::Greater => { + // Addition: We insert the word and keep the last word + // Deletion: We keep the current word until the left word to delete is greater or equal + if deladd == DelAdd::Addition { + self.insert_word(right)?; + } + + right_inserted = true; + } + } + + Ok((left_inserted, right_inserted)) + } + + fn insert_word(&mut self, bytes: &[u8]) -> Result<()> { + self.word_fst_builder.insert(bytes)?; + + for n in 0..self.max_prefix_length { + let current_prefix = &mut self.current_prefix[n]; + let current_prefix_count = &mut self.current_prefix_count[n]; + let builder = &mut self.prefix_fst_builders[n]; + + // We try to get the first n bytes out of this string but we only want + // to split at valid characters bounds. If we try to split in the middle of + // a character we ignore this word and go to the next one. + let word = std::str::from_utf8(bytes)?; + let prefix = match word.get(..=n) { + Some(prefix) => prefix, + None => continue, + }; + + // This is the first iteration of the loop, + // or the current word doesn't starts with the current prefix. + if *current_prefix_count == 0 || prefix != current_prefix.as_str() { + *current_prefix = SmallString32::from(prefix); + *current_prefix_count = 0; + } + + *current_prefix_count += 1; + + // There is enough words corresponding to this prefix to add it to the cache. + /// TODO: (LEGACY) Replace this by `==` to avoid inserting several times the same prefix? + if *current_prefix_count >= self.prefix_count_threshold { + builder.insert(prefix)?; + } + } + + Ok(()) + } + + fn drain_stream(&mut self) -> Result<()> { + if let Some(mut stream) = self.stream.take() { + while let Some(current) = stream.next() { + self.insert_word(current)?; + } + } + + Ok(()) + } + + pub fn build(mut self) -> Result<(Mmap, Mmap)> { + self.drain_stream()?; + + /// TODO: ugly unwrap + let words_fst_file = self.word_fst_builder.into_inner()?.into_inner().unwrap(); + let words_fst_mmap = unsafe { Mmap::map(&words_fst_file)? }; + + // We merge all of the previously computed prefixes into on final set. + let mut prefix_fsts = Vec::new(); + for builder in self.prefix_fst_builders { + prefix_fsts.push(builder.into_set()); + } + let op = fst::set::OpBuilder::from_iter(prefix_fsts.iter()); + let mut builder = SetBuilder::new(BufWriter::new(tempfile()?))?; + builder.extend_stream(op.r#union())?; + /// TODO: ugly unwrap + let prefix_fst_file = builder.into_inner()?.into_inner().unwrap(); + let prefix_fst_mmap = unsafe { Mmap::map(&prefix_fst_file)? }; + + Ok((words_fst_mmap, prefix_fst_mmap)) + } +}