diff --git a/milli/src/search/new/db_cache.rs b/milli/src/search/new/db_cache.rs index d7ef031bb..ce846009a 100644 --- a/milli/src/search/new/db_cache.rs +++ b/milli/src/search/new/db_cache.rs @@ -154,10 +154,11 @@ impl<'ctx> SearchContext<'ctx> { /// Retrieve or insert the given value in the `word_docids` database. fn get_db_word_docids(&mut self, word: Interned) -> Result> { - match &self.restricted_tolerant_fids { + match &self.restricted_fids { Some(restricted_fids) => { let interned = self.word_interner.get(word).as_str(); - let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect(); + let keys: Vec<_> = + restricted_fids.tolerant.iter().map(|fid| (interned, *fid)).collect(); DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>( self.txn, @@ -182,10 +183,11 @@ impl<'ctx> SearchContext<'ctx> { &mut self, word: Interned, ) -> Result> { - match &self.restricted_exact_fids { + match &self.restricted_fids { Some(restricted_fids) => { let interned = self.word_interner.get(word).as_str(); - let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect(); + let keys: Vec<_> = + restricted_fids.exact.iter().map(|fid| (interned, *fid)).collect(); DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>( self.txn, @@ -231,10 +233,11 @@ impl<'ctx> SearchContext<'ctx> { &mut self, prefix: Interned, ) -> Result> { - match &self.restricted_tolerant_fids { + match &self.restricted_fids { Some(restricted_fids) => { let interned = self.word_interner.get(prefix).as_str(); - let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect(); + let keys: Vec<_> = + restricted_fids.tolerant.iter().map(|fid| (interned, *fid)).collect(); DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>( self.txn, @@ -259,10 +262,11 @@ impl<'ctx> SearchContext<'ctx> { &mut self, prefix: Interned, ) -> Result> { - match &self.restricted_exact_fids { + match &self.restricted_fids { Some(restricted_fids) => { let interned = self.word_interner.get(prefix).as_str(); - let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect(); + let keys: Vec<_> = + restricted_fids.exact.iter().map(|fid| (interned, *fid)).collect(); DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>( self.txn, @@ -364,9 +368,7 @@ impl<'ctx> SearchContext<'ctx> { fid: u16, ) -> Result> { // if the requested fid isn't in the restricted list, return None. - if self.restricted_tolerant_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) - && self.restricted_exact_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) - { + if self.restricted_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) { return Ok(None); } @@ -385,9 +387,7 @@ impl<'ctx> SearchContext<'ctx> { fid: u16, ) -> Result> { // if the requested fid isn't in the restricted list, return None. - if self.restricted_tolerant_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) - && self.restricted_exact_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) - { + if self.restricted_fids.as_ref().map_or(false, |fids| !fids.contains(&fid)) { return Ok(None); } diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 56c55d031..ba29dbd1f 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -51,7 +51,8 @@ use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; use crate::{ - AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, + AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, + BEU32, }; /// A structure used throughout the execution of a search query. @@ -63,8 +64,7 @@ pub struct SearchContext<'ctx> { pub phrase_interner: DedupInterner, pub term_interner: Interner, pub phrase_docids: PhraseDocIdsCache, - pub restricted_tolerant_fids: Option>, - pub restricted_exact_fids: Option>, + pub restricted_fids: Option, } impl<'ctx> SearchContext<'ctx> { @@ -77,8 +77,7 @@ impl<'ctx> SearchContext<'ctx> { phrase_interner: <_>::default(), term_interner: <_>::default(), phrase_docids: <_>::default(), - restricted_tolerant_fids: None, - restricted_exact_fids: None, + restricted_fids: None, } } @@ -87,8 +86,7 @@ impl<'ctx> SearchContext<'ctx> { let searchable_names = self.index.searchable_fields(self.txn)?; let exact_attributes_ids = self.index.exact_attributes_ids(self.txn)?; - let mut restricted_exact_fids = Vec::new(); - let mut restricted_tolerant_fids = Vec::new(); + let mut restricted_fids = RestrictedFids::default(); let mut contains_wildcard = false; for field_name in searchable_attributes { if field_name == "*" { @@ -128,14 +126,13 @@ impl<'ctx> SearchContext<'ctx> { }; if exact_attributes_ids.contains(&fid) { - restricted_exact_fids.push(fid); + restricted_fids.exact.push(fid); } else { - restricted_tolerant_fids.push(fid); + restricted_fids.tolerant.push(fid); }; } - self.restricted_exact_fids = (!contains_wildcard).then_some(restricted_exact_fids); - self.restricted_tolerant_fids = (!contains_wildcard).then_some(restricted_tolerant_fids); + self.restricted_fids = (!contains_wildcard).then_some(restricted_fids); Ok(()) } @@ -156,6 +153,18 @@ impl Word { } } +#[derive(Debug, Clone, Default)] +pub struct RestrictedFids { + pub tolerant: Vec, + pub exact: Vec, +} + +impl RestrictedFids { + pub fn contains(&self, fid: &FieldId) -> bool { + self.tolerant.contains(fid) || self.exact.contains(fid) + } +} + /// Apply the [`TermsMatchingStrategy`] to the query graph and resolve it. fn resolve_maximally_reduced_query_graph( ctx: &mut SearchContext,