diff --git a/src/bin/infos.rs b/src/bin/infos.rs index a4513cce2..996e724d7 100644 --- a/src/bin/infos.rs +++ b/src/bin/infos.rs @@ -41,6 +41,15 @@ enum Command { /// The maximum number of frequencies to return. #[structopt(default_value = "10")] limit: usize, + }, + + /// Outputs a CSV with the frequencies of the specified words. + /// + /// Read the documentation of the `most-common-words` command + /// for more information about the CSV headers. + WordsFrequencies { + /// The words you want to retrieve frequencies of. + words: Vec, } } @@ -64,6 +73,7 @@ fn main() -> anyhow::Result<()> { match opt.command { Command::MostCommonWords { limit } => most_common_words(&index, &rtxn, limit), + Command::WordsFrequencies { words } => words_frequencies(&index, &rtxn, words), } } @@ -83,7 +93,7 @@ fn most_common_words(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyhow: match prev.as_mut() { Some((prev_word, freq, docids)) if prev_word == word => { - *freq += docids.len(); + *freq += postings.len(); docids.union_with(&postings); }, Some((prev_word, freq, docids)) => { @@ -110,3 +120,30 @@ fn most_common_words(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyhow: Ok(wtr.flush()?) } + +fn words_frequencies(index: &Index, rtxn: &heed::RoTxn, words: Vec) -> anyhow::Result<()> { + use roaring::RoaringBitmap; + + let stdout = io::stdout(); + let mut wtr = csv::Writer::from_writer(stdout.lock()); + wtr.write_record(&["word", "document_frequency", "frequency"])?; + + for word in words { + let mut document_frequency = RoaringBitmap::new(); + let mut frequency = 0; + for result in index.word_position_docids.prefix_iter(rtxn, word.as_bytes())? { + let (bytes, postings) = result?; + let (w, _position) = bytes.split_at(bytes.len() - 4); + + // if the word is not exactly the word we requested then it means + // we found a word that *starts with* the requested word and we must stop. + if word.as_bytes() != w { break } + + document_frequency.union_with(&postings); + frequency += postings.len(); + } + wtr.write_record(&[word, document_frequency.len().to_string(), frequency.to_string()])?; + } + + Ok(wtr.flush()?) +}