From 34f11e33808d5e7413bf0752b5dde028a9447a04 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Thu, 5 Sep 2024 10:30:39 +0200 Subject: [PATCH] Implement word count and word pair proximity extractors --- .../extract_fid_word_count_docids.rs | 135 +++++++++++++ .../extract_word_pair_proximity_docids.rs | 182 ++++++++++++++++++ .../src/update/new/extract/searchable/mod.rs | 4 + .../extract/searchable/tokenize_document.rs | 17 +- 4 files changed, 331 insertions(+), 7 deletions(-) create mode 100644 milli/src/update/new/extract/searchable/extract_fid_word_count_docids.rs create mode 100644 milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs diff --git a/milli/src/update/new/extract/searchable/extract_fid_word_count_docids.rs b/milli/src/update/new/extract/searchable/extract_fid_word_count_docids.rs new file mode 100644 index 000000000..08160155e --- /dev/null +++ b/milli/src/update/new/extract/searchable/extract_fid_word_count_docids.rs @@ -0,0 +1,135 @@ +use std::{borrow::Cow, collections::HashMap}; + +use heed::RoTxn; + +use super::{tokenize_document::DocumentTokenizer, SearchableExtractor}; +use crate::{ + update::{ + new::{extract::cache::CboCachedSorter, DocumentChange}, + MergeDeladdCboRoaringBitmaps, + }, + FieldId, GlobalFieldsIdsMap, Index, Result, +}; + +const MAX_COUNTED_WORDS: usize = 30; + +pub struct FidWordCountDocidsExtractor; +impl SearchableExtractor for FidWordCountDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(rtxn: &'a RoTxn, index: &'a Index) -> Result> { + Ok(vec![]) + } + + /// This case is unreachable because extract_document_change has been reimplemented to not call this function. + fn build_key<'a>(_field_id: FieldId, _position: u16, _word: &'a str) -> Cow<'a, [u8]> { + unreachable!() + } + + // This method is reimplemented to count the number of words in the document in each field + // and to store the docids of the documents that have a number of words in a given field equal to or under than MAX_COUNTED_WORDS. + fn extract_document_change( + rtxn: &RoTxn, + index: &Index, + document_tokenizer: &DocumentTokenizer, + fields_ids_map: &mut GlobalFieldsIdsMap, + cached_sorter: &mut CboCachedSorter, + document_change: DocumentChange, + ) -> Result<()> { + let mut key_buffer = Vec::new(); + match document_change { + DocumentChange::Deletion(inner) => { + let mut fid_word_count = HashMap::new(); + let mut token_fn = |fid: FieldId, pos: u16, word: &str| { + fid_word_count.entry(fid).and_modify(|count| *count += 1).or_insert(1); + Ok(()) + }; + document_tokenizer.tokenize_document( + inner.current(rtxn, index)?.unwrap(), + fields_ids_map, + &mut token_fn, + )?; + + // The docids of the documents that have a number of words in a given field equal to or under than MAX_COUNTED_WORDS are deleted. + for (fid, count) in fid_word_count.iter() { + if *count <= MAX_COUNTED_WORDS { + let key = build_key(*fid, *count as u8, &mut key_buffer); + /// TODO manage the error + cached_sorter.insert_del_u32(key, inner.docid()).unwrap(); + } + } + } + DocumentChange::Update(inner) => { + let mut fid_word_count = HashMap::new(); + let mut token_fn = |fid: FieldId, pos: u16, word: &str| { + fid_word_count + .entry(fid) + .and_modify(|(current_count, new_count)| *current_count += 1) + .or_insert((1, 0)); + Ok(()) + }; + document_tokenizer.tokenize_document( + inner.current(rtxn, index)?.unwrap(), + fields_ids_map, + &mut token_fn, + )?; + + let mut token_fn = |fid: FieldId, pos: u16, word: &str| { + fid_word_count + .entry(fid) + .and_modify(|(current_count, new_count)| *new_count += 1) + .or_insert((0, 1)); + Ok(()) + }; + document_tokenizer.tokenize_document(inner.new(), fields_ids_map, &mut token_fn)?; + + // Only the fields that have a change in the number of words are updated. + for (fid, (current_count, new_count)) in fid_word_count.iter() { + if *current_count != *new_count { + if *current_count <= MAX_COUNTED_WORDS { + let key = build_key(*fid, *current_count as u8, &mut key_buffer); + /// TODO manage the error + cached_sorter.insert_del_u32(key, inner.docid()).unwrap(); + } + if *new_count <= MAX_COUNTED_WORDS { + let key = build_key(*fid, *new_count as u8, &mut key_buffer); + /// TODO manage the error + cached_sorter.insert_add_u32(key, inner.docid()).unwrap(); + } + } + } + } + DocumentChange::Insertion(inner) => { + let mut fid_word_count = HashMap::new(); + let mut token_fn = |fid: FieldId, pos: u16, word: &str| { + fid_word_count.entry(fid).and_modify(|count| *count += 1).or_insert(1); + Ok(()) + }; + document_tokenizer.tokenize_document(inner.new(), fields_ids_map, &mut token_fn)?; + + // The docids of the documents that have a number of words in a given field equal to or under than MAX_COUNTED_WORDS are stored. + for (fid, count) in fid_word_count.iter() { + if *count <= MAX_COUNTED_WORDS { + let key = build_key(*fid, *count as u8, &mut key_buffer); + /// TODO manage the error + cached_sorter.insert_add_u32(key, inner.docid()).unwrap(); + } + } + } + } + + Ok(()) + } +} + +fn build_key(fid: FieldId, count: u8, key_buffer: &mut Vec) -> &[u8] { + key_buffer.clear(); + key_buffer.extend_from_slice(&fid.to_be_bytes()); + key_buffer.push(count); + key_buffer.as_slice() +} diff --git a/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs b/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs new file mode 100644 index 000000000..e170a6486 --- /dev/null +++ b/milli/src/update/new/extract/searchable/extract_word_pair_proximity_docids.rs @@ -0,0 +1,182 @@ +use std::{ + borrow::Cow, + collections::{BTreeMap, VecDeque}, +}; + +use heed::RoTxn; +use itertools::merge_join_by; +use obkv::KvReader; + +use super::{tokenize_document::DocumentTokenizer, SearchableExtractor}; +use crate::{ + proximity::{index_proximity, MAX_DISTANCE}, + update::{ + new::{extract::cache::CboCachedSorter, DocumentChange}, + MergeDeladdCboRoaringBitmaps, + }, + FieldId, GlobalFieldsIdsMap, Index, Result, +}; + +pub struct WordPairProximityDocidsExtractor; +impl SearchableExtractor for WordPairProximityDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(rtxn: &'a RoTxn, index: &'a Index) -> Result> { + Ok(vec![]) + } + + /// This case is unreachable because extract_document_change has been reimplemented to not call this function. + fn build_key<'a>(_field_id: FieldId, _position: u16, _word: &'a str) -> Cow<'a, [u8]> { + unreachable!() + } + + // This method is reimplemented to count the number of words in the document in each field + // and to store the docids of the documents that have a number of words in a given field equal to or under than MAX_COUNTED_WORDS. + fn extract_document_change( + rtxn: &RoTxn, + index: &Index, + document_tokenizer: &DocumentTokenizer, + fields_ids_map: &mut GlobalFieldsIdsMap, + cached_sorter: &mut CboCachedSorter, + document_change: DocumentChange, + ) -> Result<()> { + /// TODO: mutualize those buffers + let mut key_buffer = Vec::new(); + let mut add_word_pair_proximity = BTreeMap::new(); + let mut del_word_pair_proximity = BTreeMap::new(); + let mut word_positions: VecDeque<(String, u16)> = + VecDeque::with_capacity(MAX_DISTANCE as usize); + + let docid = document_change.docid(); + match document_change { + DocumentChange::Deletion(inner) => { + let document = inner.current(rtxn, index)?.unwrap(); + process_document_tokens( + document, + document_tokenizer, + fields_ids_map, + &mut word_positions, + &mut del_word_pair_proximity, + )?; + } + DocumentChange::Update(inner) => { + let document = inner.current(rtxn, index)?.unwrap(); + process_document_tokens( + &document, + document_tokenizer, + fields_ids_map, + &mut word_positions, + &mut del_word_pair_proximity, + )?; + let document = inner.new(); + process_document_tokens( + document, + document_tokenizer, + fields_ids_map, + &mut word_positions, + &mut add_word_pair_proximity, + )?; + } + DocumentChange::Insertion(inner) => { + let document = inner.new(); + process_document_tokens( + document, + document_tokenizer, + fields_ids_map, + &mut word_positions, + &mut add_word_pair_proximity, + )?; + } + } + + use itertools::EitherOrBoth::*; + for eob in + merge_join_by(del_word_pair_proximity.iter(), add_word_pair_proximity.iter(), |d, a| { + d.cmp(a) + }) + { + match eob { + Left(((w1, w2), prox)) => { + let key = build_key(*prox, w1, w2, &mut key_buffer); + cached_sorter.insert_del_u32(key, docid).unwrap(); + } + Right(((w1, w2), prox)) => { + let key = build_key(*prox, w1, w2, &mut key_buffer); + cached_sorter.insert_add_u32(key, docid).unwrap(); + } + Both(((w1, w2), del_prox), (_, add_prox)) => { + if del_prox != add_prox { + let key = build_key(*del_prox, w1, w2, &mut key_buffer); + cached_sorter.insert_del_u32(key, docid).unwrap(); + let key = build_key(*add_prox, w1, w2, &mut key_buffer); + cached_sorter.insert_add_u32(key, docid).unwrap(); + } + } + }; + } + + Ok(()) + } +} + +fn build_key<'a>(prox: u8, w1: &str, w2: &str, key_buffer: &'a mut Vec) -> &'a [u8] { + key_buffer.clear(); + key_buffer.push(prox); + key_buffer.extend_from_slice(w1.as_bytes()); + key_buffer.push(0); + key_buffer.extend_from_slice(w2.as_bytes()); + key_buffer.as_slice() +} + +fn word_positions_into_word_pair_proximity( + word_positions: &mut VecDeque<(String, u16)>, + word_pair_proximity: &mut BTreeMap<(String, String), u8>, +) -> Result<()> { + let (head_word, head_position) = word_positions.pop_front().unwrap(); + for (word, position) in word_positions.iter() { + let prox = index_proximity(head_position as u32, *position as u32) as u8; + if prox > 0 && prox < MAX_DISTANCE as u8 { + word_pair_proximity + .entry((head_word.clone(), word.clone())) + .and_modify(|p| { + *p = std::cmp::min(*p, prox); + }) + .or_insert(prox); + } + } + Ok(()) +} + +fn process_document_tokens( + document: &KvReader, + document_tokenizer: &DocumentTokenizer, + fields_ids_map: &mut GlobalFieldsIdsMap, + word_positions: &mut VecDeque<(String, u16)>, + word_pair_proximity: &mut BTreeMap<(String, String), u8>, +) -> Result<()> { + let mut token_fn = |fid: FieldId, pos: u16, word: &str| { + // drain the proximity window until the head word is considered close to the word we are inserting. + while word_positions + .front() + .map_or(false, |(_w, p)| index_proximity(*p as u32, pos as u32) >= MAX_DISTANCE) + { + word_positions_into_word_pair_proximity(word_positions, word_pair_proximity)?; + } + + // insert the new word. + word_positions.push_back((word.to_string(), pos)); + Ok(()) + }; + document_tokenizer.tokenize_document(document, fields_ids_map, &mut token_fn)?; + + while !word_positions.is_empty() { + word_positions_into_word_pair_proximity(word_positions, word_pair_proximity)?; + } + + Ok(()) +} diff --git a/milli/src/update/new/extract/searchable/mod.rs b/milli/src/update/new/extract/searchable/mod.rs index 078d06150..ba4731d73 100644 --- a/milli/src/update/new/extract/searchable/mod.rs +++ b/milli/src/update/new/extract/searchable/mod.rs @@ -1,13 +1,17 @@ +mod extract_fid_word_count_docids; mod extract_word_docids; +mod extract_word_pair_proximity_docids; mod tokenize_document; use std::borrow::Cow; use std::fs::File; +pub use extract_fid_word_count_docids::FidWordCountDocidsExtractor; pub use extract_word_docids::{ ExactWordDocidsExtractor, WordDocidsExtractor, WordFidDocidsExtractor, WordPositionDocidsExtractor, }; +pub use extract_word_pair_proximity_docids::WordPairProximityDocidsExtractor; use grenad::Merger; use heed::RoTxn; use rayon::iter::{IntoParallelIterator, ParallelIterator}; diff --git a/milli/src/update/new/extract/searchable/tokenize_document.rs b/milli/src/update/new/extract/searchable/tokenize_document.rs index 1d19354db..7e23c9301 100644 --- a/milli/src/update/new/extract/searchable/tokenize_document.rs +++ b/milli/src/update/new/extract/searchable/tokenize_document.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use charabia::{SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder}; use serde_json::Value; +use crate::proximity::MAX_DISTANCE; use crate::update::new::extract::perm_json_p::{ seek_leaf_values_in_array, seek_leaf_values_in_object, select_field, }; @@ -43,8 +44,10 @@ impl<'a> DocumentTokenizer<'a> { return Err(UserError::AttributeLimitReached.into()); }; - let position = - field_position.entry(field_id).and_modify(|counter| *counter += 8).or_insert(0); + let position = field_position + .entry(field_id) + .and_modify(|counter| *counter += MAX_DISTANCE) + .or_insert(0); if *position as u32 >= self.max_positions_per_attributes { return Ok(()); } @@ -116,19 +119,19 @@ impl<'a> DocumentTokenizer<'a> { } /// take an iterator on tokens and compute their relative position depending on separator kinds -/// if it's an `Hard` separator we add an additional relative proximity of 8 between words, +/// if it's an `Hard` separator we add an additional relative proximity of MAX_DISTANCE between words, /// else we keep the standard proximity of 1 between words. fn process_tokens<'a>( - start_offset: usize, + start_offset: u32, tokens: impl Iterator>, -) -> impl Iterator)> { +) -> impl Iterator)> { tokens .skip_while(|token| token.is_separator()) .scan((start_offset, None), |(offset, prev_kind), mut token| { match token.kind { TokenKind::Word | TokenKind::StopWord if !token.lemma().is_empty() => { *offset += match *prev_kind { - Some(TokenKind::Separator(SeparatorKind::Hard)) => 8, + Some(TokenKind::Separator(SeparatorKind::Hard)) => MAX_DISTANCE, Some(_) => 1, None => 0, }; @@ -246,7 +249,7 @@ mod test { ]: "doggo", [ 2, - 8, + MAX_DISTANCE, ]: "doggo", [ 2,