diff --git a/src/search/facet/mod.rs b/src/search/facet/mod.rs index 72e6426d0..6b2d90c8e 100644 --- a/src/search/facet/mod.rs +++ b/src/search/facet/mod.rs @@ -1,4 +1,149 @@ +use std::fmt::Debug; +use std::ops::Bound::{self, Included, Excluded, Unbounded}; + +use heed::types::DecodeIgnore; +use heed::{BytesEncode, BytesDecode}; +use heed::{Database, RoRange, LazyDecode}; +use num_traits::Bounded; +use roaring::RoaringBitmap; + +use crate::heed_codec::CboRoaringBitmapCodec; +use crate::{Index, FieldId}; + +pub use self::facet_condition::{FacetCondition, FacetNumberOperator, FacetStringOperator}; + mod facet_condition; mod parser; -pub use self::facet_condition::{FacetCondition, FacetNumberOperator, FacetStringOperator}; +struct FacetRange<'t, T: 't, KC> { + iter: RoRange<'t, KC, LazyDecode>, + end: Bound, +} + +impl<'t, T: 't, KC> FacetRange<'t, T, KC> +where + KC: for<'a> BytesEncode<'a, EItem = (FieldId, u8, T, T)>, + T: PartialOrd + Copy + Bounded, +{ + fn new( + rtxn: &'t heed::RoTxn, + db: Database, + field_id: FieldId, + level: u8, + left: Bound, + right: Bound, + ) -> heed::Result> + { + let left_bound = match left { + Included(left) => Included((field_id, level, left, T::min_value())), + Excluded(left) => Excluded((field_id, level, left, T::min_value())), + Unbounded => Included((field_id, level, T::min_value(), T::min_value())), + }; + let right_bound = Included((field_id, level, T::max_value(), T::max_value())); + let iter = db.lazily_decode_data().range(rtxn, &(left_bound, right_bound))?; + Ok(FacetRange { iter, end: right }) + } +} + +impl<'t, T, KC> Iterator for FacetRange<'t, T, KC> +where + KC: for<'a> BytesEncode<'a, EItem = (FieldId, u8, T, T)>, + KC: BytesDecode<'t, DItem = (FieldId, u8, T, T)>, + T: PartialOrd + Copy, +{ + type Item = heed::Result<((FieldId, u8, T, T), RoaringBitmap)>; + + fn next(&mut self) -> Option { + match self.iter.next() { + Some(Ok(((fid, level, left, right), docids))) => { + let must_be_returned = match self.end { + Included(end) => right <= end, + Excluded(end) => right < end, + Unbounded => true, + }; + if must_be_returned { + match docids.decode() { + Ok(docids) => Some(Ok(((fid, level, left, right), docids))), + Err(e) => Some(Err(e)), + } + } else { + None + } + }, + Some(Err(e)) => Some(Err(e)), + None => None, + } + } +} + +pub struct FacetIter<'t, T: 't, KC> { + rtxn: &'t heed::RoTxn<'t>, + db: Database, + field_id: FieldId, + documents_ids: RoaringBitmap, + level_iters: Vec>, +} + +impl<'t, T, KC> FacetIter<'t, T, KC> +where + KC: for<'a> BytesEncode<'a, EItem = (FieldId, u8, T, T)>, + T: PartialOrd + Copy + Bounded, +{ + pub fn new( + rtxn: &'t heed::RoTxn, + index: &'t Index, + field_id: FieldId, + documents_ids: RoaringBitmap, + ) -> heed::Result> + { + let db = index.facet_field_id_value_docids.remap_key_type::(); + let level_0_iter = FacetRange::new(rtxn, db, field_id, 0, Unbounded, Unbounded)?; + Ok(FacetIter { rtxn, db, field_id, documents_ids, level_iters: vec![level_0_iter] }) + } +} + +impl<'t, T: 't, KC> Iterator for FacetIter<'t, T, KC> +where + KC: heed::BytesDecode<'t, DItem = (FieldId, u8, T, T)>, + KC: for<'x> heed::BytesEncode<'x, EItem = (FieldId, u8, T, T)>, + T: PartialOrd + Copy + Bounded, +{ + type Item = heed::Result<(T, RoaringBitmap)>; + + fn next(&mut self) -> Option { + loop { + let last = self.level_iters.last_mut()?; + for result in last { + match result { + Ok(((_fid, level, left, right), mut docids)) => { + if level == 0 { + docids.intersect_with(&self.documents_ids); + if !docids.is_empty() { + self.documents_ids.difference_with(&docids); + return Some(Ok((left, docids))); + } + } else if !docids.is_disjoint(&self.documents_ids) { + let result = FacetRange::new( + self.rtxn, + self.db, + self.field_id, + level - 1, + Included(left), + Included(right), + ); + match result { + Ok(iter) => { + self.level_iters.push(iter); + break; + }, + Err(e) => return Some(Err(e)), + } + } + }, + Err(e) => return Some(Err(e)), + } + } + self.level_iters.pop(); + } + } +} diff --git a/src/search/mod.rs b/src/search/mod.rs index b56f5a345..078cf2dab 100644 --- a/src/search/mod.rs +++ b/src/search/mod.rs @@ -16,8 +16,7 @@ use crate::mdfs::Mdfs; use crate::query_tokens::{QueryTokens, QueryToken}; use crate::{Index, FieldId, DocumentId, Criterion}; -pub use self::facet::{FacetCondition, FacetNumberOperator, FacetStringOperator, Order}; -pub use self::facet::facet_number_recurse; +pub use self::facet::{FacetCondition, FacetNumberOperator, FacetStringOperator}; // Building these factories is not free. static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); @@ -157,6 +156,7 @@ impl<'a> Search<'a> { limit: usize, ) -> anyhow::Result> { + let mut limit_tmp = limit; let mut output = Vec::new(); match facet_type { FacetType::Float => { @@ -167,9 +167,10 @@ impl<'a> Search<'a> { order, documents_ids, |_val, docids| { + limit_tmp = limit_tmp.saturating_sub(docids.len() as usize); + debug!("Facet ordered iteration find {:?}", docids); output.push(docids); - // Returns `true` if we must continue iterating - output.iter().map(|ids| ids.len()).sum::() < limit as u64 + limit_tmp != 0 // Returns `true` if we must continue iterating } )?; }, @@ -181,9 +182,10 @@ impl<'a> Search<'a> { order, documents_ids, |_val, docids| { + limit_tmp = limit_tmp.saturating_sub(docids.len() as usize); + debug!("Facet ordered iteration find {:?}", docids); output.push(docids); - // Returns `true` if we must continue iterating - output.iter().map(|ids| ids.len()).sum::() < limit as u64 + limit_tmp != 0 // Returns `true` if we must continue iterating } )?; },