From fb8fa071694fbca8625675128eff5e7b89d5399b Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Tue, 13 Jun 2023 17:37:35 +0200 Subject: [PATCH] Restrict field ids in search context --- milli/src/search/new/db_cache.rs | 74 +++++++++++++++++++++++++------- milli/src/update/mod.rs | 4 +- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/milli/src/search/new/db_cache.rs b/milli/src/search/new/db_cache.rs index 2b2cd4d79..6c1b0f7f8 100644 --- a/milli/src/search/new/db_cache.rs +++ b/milli/src/search/new/db_cache.rs @@ -10,7 +10,7 @@ use roaring::RoaringBitmap; use super::interner::Interned; use super::Word; use crate::heed_codec::{BytesDecodeOwned, StrBEU16Codec}; -use crate::update::MergeFn; +use crate::update::{merge_cbo_roaring_bitmaps, MergeFn}; use crate::{ CboRoaringBitmapCodec, CboRoaringBitmapLenCodec, Result, RoaringBitmapCodec, SearchContext, }; @@ -79,7 +79,7 @@ impl<'ctx> DatabaseCache<'ctx> { fn get_value_from_keys<'v, K1, KC, DC>( txn: &'ctx RoTxn, cache_key: K1, - db_keys: &[&'v KC::EItem], + db_keys: &'v [KC::EItem], cache: &mut FxHashMap>>, db: Database, merger: MergeFn, @@ -88,6 +88,7 @@ impl<'ctx> DatabaseCache<'ctx> { K1: Copy + Eq + Hash, KC: BytesEncode<'v>, DC: BytesDecodeOwned, + KC::EItem: Sized, { match cache.entry(cache_key) { Entry::Occupied(_) => {} @@ -125,6 +126,7 @@ impl<'ctx> DatabaseCache<'ctx> { } } } + impl<'ctx> SearchContext<'ctx> { pub fn get_words_fst(&mut self) -> Result>> { if let Some(fst) = self.db_cache.words_fst.clone() { @@ -158,13 +160,28 @@ impl<'ctx> SearchContext<'ctx> { /// Retrieve or insert the given value in the `word_docids` database. fn get_db_word_docids(&mut self, word: Interned) -> Result> { - DatabaseCache::get_value::<_, _, RoaringBitmapCodec>( - self.txn, - word, - self.word_interner.get(word).as_str(), - &mut self.db_cache.word_docids, - self.index.word_docids.remap_data_type::(), - ) + match &self.restricted_fids { + Some(restricted_fids) => { + let interned = self.word_interner.get(word).as_str(); + let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect(); + + DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>( + self.txn, + word, + &keys[..], + &mut self.db_cache.word_docids, + self.index.word_fid_docids.remap_data_type::(), + merge_cbo_roaring_bitmaps, + ) + } + None => DatabaseCache::get_value::<_, _, RoaringBitmapCodec>( + self.txn, + word, + self.word_interner.get(word).as_str(), + &mut self.db_cache.word_docids, + self.index.word_docids.remap_data_type::(), + ), + } } fn get_db_exact_word_docids( @@ -205,13 +222,28 @@ impl<'ctx> SearchContext<'ctx> { &mut self, prefix: Interned, ) -> Result> { - DatabaseCache::get_value::<_, _, RoaringBitmapCodec>( - self.txn, - prefix, - self.word_interner.get(prefix).as_str(), - &mut self.db_cache.word_prefix_docids, - self.index.word_prefix_docids.remap_data_type::(), - ) + match &self.restricted_fids { + Some(restricted_fids) => { + let interned = self.word_interner.get(prefix).as_str(); + let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect(); + + DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>( + self.txn, + prefix, + &keys[..], + &mut self.db_cache.word_prefix_docids, + self.index.word_prefix_fid_docids.remap_data_type::(), + merge_cbo_roaring_bitmaps, + ) + } + None => DatabaseCache::get_value::<_, _, RoaringBitmapCodec>( + self.txn, + prefix, + self.word_interner.get(prefix).as_str(), + &mut self.db_cache.word_prefix_docids, + self.index.word_prefix_docids.remap_data_type::(), + ), + } } fn get_db_exact_word_prefix_docids( @@ -307,6 +339,11 @@ impl<'ctx> SearchContext<'ctx> { word: Interned, fid: u16, ) -> Result> { + // if the requested fid isn't in the restricted list, return None. + if self.restricted_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) { + return Ok(None); + } + DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>( self.txn, (word, fid), @@ -321,6 +358,11 @@ impl<'ctx> SearchContext<'ctx> { word_prefix: Interned, fid: u16, ) -> Result> { + // if the requested fid isn't in the restricted list, return None. + if self.restricted_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) { + return Ok(None); + } + DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>( self.txn, (word_prefix, fid), diff --git a/milli/src/update/mod.rs b/milli/src/update/mod.rs index 011a2eb60..32584825b 100644 --- a/milli/src/update/mod.rs +++ b/milli/src/update/mod.rs @@ -4,8 +4,8 @@ pub use self::delete_documents::{DeleteDocuments, DeletionStrategy, DocumentDele pub use self::facet::bulk::FacetsUpdateBulk; pub use self::facet::incremental::FacetsUpdateIncrementalInner; pub use self::index_documents::{ - merge_roaring_bitmaps, DocumentAdditionResult, DocumentId, IndexDocuments, - IndexDocumentsConfig, IndexDocumentsMethod, MergeFn, + merge_cbo_roaring_bitmaps, merge_roaring_bitmaps, DocumentAdditionResult, DocumentId, + IndexDocuments, IndexDocumentsConfig, IndexDocumentsMethod, MergeFn, }; pub use self::indexer_config::IndexerConfig; pub use self::prefix_word_pairs::{