diff --git a/src/search/mod.rs b/src/search/mod.rs index 8d190cf87..a7c83e79b 100644 --- a/src/search/mod.rs +++ b/src/search/mod.rs @@ -156,18 +156,16 @@ impl<'a> Search<'a> { field_id: FieldId, facet_type: FacetType, ascending: bool, - documents_ids: RoaringBitmap, + mut documents_ids: RoaringBitmap, limit: usize, ) -> anyhow::Result> { - let mut limit_tmp = limit; - let mut output = Vec::new(); - match facet_type { + let mut output: Vec<_> = match facet_type { FacetType::Float => { if documents_ids.len() <= 1000 { let db = self.index.field_id_docid_facet_values.remap_key_type::(); let mut docids_values = Vec::with_capacity(documents_ids.len() as usize); - for docid in documents_ids { + for docid in documents_ids.iter() { let left = (field_id, docid, f64::MIN); let right = (field_id, docid, f64::MAX); let mut iter = db.range(self.rtxn, &(left..=right))?; @@ -179,9 +177,9 @@ impl<'a> Search<'a> { docids_values.sort_unstable_by_key(|(_, value)| *value); let iter = docids_values.into_iter().map(|(id, _)| id); if ascending { - Ok(iter.take(limit).collect()) + iter.take(limit).collect() } else { - Ok(iter.rev().take(limit).collect()) + iter.rev().take(limit).collect() } } else { let facet_fn = if ascending { @@ -189,20 +187,22 @@ impl<'a> Search<'a> { } else { FacetIter::::new_reverse }; - for result in facet_fn(self.rtxn, self.index, field_id, documents_ids)? { + let mut limit_tmp = limit; + let mut output = Vec::new(); + for result in facet_fn(self.rtxn, self.index, field_id, documents_ids.clone())? { let (_val, docids) = result?; limit_tmp = limit_tmp.saturating_sub(docids.len() as usize); output.push(docids); if limit_tmp == 0 { break } } - Ok(output.into_iter().flatten().take(limit).collect()) + output.into_iter().flatten().take(limit).collect() } }, FacetType::Integer => { if documents_ids.len() <= 1000 { let db = self.index.field_id_docid_facet_values.remap_key_type::(); let mut docids_values = Vec::with_capacity(documents_ids.len() as usize); - for docid in documents_ids { + for docid in documents_ids.iter() { let left = (field_id, docid, i64::MIN); let right = (field_id, docid, i64::MAX); let mut iter = db.range(self.rtxn, &(left..=right))?; @@ -214,9 +214,9 @@ impl<'a> Search<'a> { docids_values.sort_unstable_by_key(|(_, value)| *value); let iter = docids_values.into_iter().map(|(id, _)| id); if ascending { - Ok(iter.take(limit).collect()) + iter.take(limit).collect() } else { - Ok(iter.rev().take(limit).collect()) + iter.rev().take(limit).collect() } } else { let facet_fn = if ascending { @@ -224,17 +224,30 @@ impl<'a> Search<'a> { } else { FacetIter::::new_reverse }; - for result in facet_fn(self.rtxn, self.index, field_id, documents_ids)? { + let mut limit_tmp = limit; + let mut output = Vec::new(); + for result in facet_fn(self.rtxn, self.index, field_id, documents_ids.clone())? { let (_val, docids) = result?; limit_tmp = limit_tmp.saturating_sub(docids.len() as usize); output.push(docids); if limit_tmp == 0 { break } } - Ok(output.into_iter().flatten().take(limit).collect()) + output.into_iter().flatten().take(limit).collect() } }, FacetType::String => bail!("criteria facet type must be a number"), + }; + + // if there isn't enough documents to return we try to complete that list + // with documents that are maybe not faceted under this field and therefore + // not returned by the previous facet iteration. + if output.len() < limit { + output.iter().for_each(|n| { documents_ids.remove(*n); }); + let remaining = documents_ids.iter().take(limit - output.len()); + output.extend(remaining); } + + Ok(output) } pub fn execute(&self) -> anyhow::Result {