mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-03-03 04:14:15 +08:00
200 lines
6.9 KiB
Rust
200 lines
6.9 KiB
Rust
|
use roaring::RoaringBitmap;
|
||
|
use std::ops::ControlFlow;
|
||
|
|
||
|
use crate::heed_codec::facet::new::{FacetGroupValueCodec, FacetKey, FacetKeyCodec, MyByteSlice};
|
||
|
|
||
|
use super::{get_first_facet_value, get_highest_level};
|
||
|
|
||
|
pub fn iterate_over_facet_distribution<'t, CB>(
|
||
|
rtxn: &'t heed::RoTxn<'t>,
|
||
|
db: &'t heed::Database<FacetKeyCodec<MyByteSlice>, FacetGroupValueCodec>,
|
||
|
field_id: u16,
|
||
|
candidates: &RoaringBitmap,
|
||
|
callback: CB,
|
||
|
) where
|
||
|
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
|
||
|
{
|
||
|
let mut fd = FacetDistribution { rtxn, db, field_id, callback };
|
||
|
let highest_level =
|
||
|
get_highest_level(rtxn, &db.remap_key_type::<FacetKeyCodec<MyByteSlice>>(), field_id);
|
||
|
|
||
|
if let Some(first_bound) = get_first_facet_value::<MyByteSlice>(rtxn, db, field_id) {
|
||
|
fd.iterate(candidates, highest_level, first_bound, usize::MAX);
|
||
|
return;
|
||
|
} else {
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
struct FacetDistribution<'t, CB>
|
||
|
where
|
||
|
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
|
||
|
{
|
||
|
rtxn: &'t heed::RoTxn<'t>,
|
||
|
db: &'t heed::Database<FacetKeyCodec<MyByteSlice>, FacetGroupValueCodec>,
|
||
|
field_id: u16,
|
||
|
callback: CB,
|
||
|
}
|
||
|
|
||
|
impl<'t, CB> FacetDistribution<'t, CB>
|
||
|
where
|
||
|
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
|
||
|
{
|
||
|
fn iterate_level_0(
|
||
|
&mut self,
|
||
|
candidates: &RoaringBitmap,
|
||
|
starting_bound: &'t [u8],
|
||
|
group_size: usize,
|
||
|
) -> ControlFlow<()> {
|
||
|
let starting_key =
|
||
|
FacetKey { field_id: self.field_id, level: 0, left_bound: starting_bound };
|
||
|
let iter = self.db.range(self.rtxn, &(starting_key..)).unwrap().take(group_size);
|
||
|
for el in iter {
|
||
|
let (key, value) = el.unwrap();
|
||
|
// The range is unbounded on the right and the group size for the highest level is MAX,
|
||
|
// so we need to check that we are not iterating over the next field id
|
||
|
if key.field_id != self.field_id {
|
||
|
return ControlFlow::Break(());
|
||
|
}
|
||
|
let docids_in_common = value.bitmap.intersection_len(candidates);
|
||
|
if docids_in_common > 0 {
|
||
|
match (self.callback)(key.left_bound, docids_in_common) {
|
||
|
ControlFlow::Continue(_) => {}
|
||
|
ControlFlow::Break(_) => return ControlFlow::Break(()),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return ControlFlow::Continue(());
|
||
|
}
|
||
|
fn iterate(
|
||
|
&mut self,
|
||
|
candidates: &RoaringBitmap,
|
||
|
level: u8,
|
||
|
starting_bound: &'t [u8],
|
||
|
group_size: usize,
|
||
|
) -> ControlFlow<()> {
|
||
|
if level == 0 {
|
||
|
return self.iterate_level_0(candidates, starting_bound, group_size);
|
||
|
}
|
||
|
let starting_key = FacetKey { field_id: self.field_id, level, left_bound: starting_bound };
|
||
|
let iter = self.db.range(&self.rtxn, &(&starting_key..)).unwrap().take(group_size);
|
||
|
|
||
|
for el in iter {
|
||
|
let (key, value) = el.unwrap();
|
||
|
// The range is unbounded on the right and the group size for the highest level is MAX,
|
||
|
// so we need to check that we are not iterating over the next field id
|
||
|
if key.field_id != self.field_id {
|
||
|
return ControlFlow::Break(());
|
||
|
}
|
||
|
let docids_in_common = value.bitmap & candidates;
|
||
|
if docids_in_common.len() > 0 {
|
||
|
let cf =
|
||
|
self.iterate(&docids_in_common, level - 1, key.left_bound, value.size as usize);
|
||
|
match cf {
|
||
|
ControlFlow::Continue(_) => {}
|
||
|
ControlFlow::Break(_) => return ControlFlow::Break(()),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ControlFlow::Continue(());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use crate::{codec::U16Codec, Index};
|
||
|
use heed::BytesDecode;
|
||
|
use roaring::RoaringBitmap;
|
||
|
use std::ops::ControlFlow;
|
||
|
|
||
|
use super::iterate_over_facet_distribution;
|
||
|
|
||
|
fn get_simple_index() -> Index<U16Codec> {
|
||
|
let index = Index::<U16Codec>::new(4, 8);
|
||
|
let mut txn = index.env.write_txn().unwrap();
|
||
|
for i in 0..256u16 {
|
||
|
let mut bitmap = RoaringBitmap::new();
|
||
|
bitmap.insert(i as u32);
|
||
|
index.insert(&mut txn, 0, &i, &bitmap);
|
||
|
}
|
||
|
txn.commit().unwrap();
|
||
|
index
|
||
|
}
|
||
|
fn get_random_looking_index() -> Index<U16Codec> {
|
||
|
let index = Index::<U16Codec>::new(4, 8);
|
||
|
let mut txn = index.env.write_txn().unwrap();
|
||
|
|
||
|
let rng = fastrand::Rng::with_seed(0);
|
||
|
let keys = std::iter::from_fn(|| Some(rng.u32(..256))).take(128).collect::<Vec<u32>>();
|
||
|
|
||
|
for (_i, key) in keys.into_iter().enumerate() {
|
||
|
let mut bitmap = RoaringBitmap::new();
|
||
|
bitmap.insert(key);
|
||
|
bitmap.insert(key + 100);
|
||
|
index.insert(&mut txn, 0, &(key as u16), &bitmap);
|
||
|
}
|
||
|
txn.commit().unwrap();
|
||
|
index
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn random_looking_index_snap() {
|
||
|
let index = get_random_looking_index();
|
||
|
insta::assert_display_snapshot!(index)
|
||
|
}
|
||
|
#[test]
|
||
|
fn filter_distribution_all() {
|
||
|
let indexes = [get_simple_index(), get_random_looking_index()];
|
||
|
for (i, index) in indexes.into_iter().enumerate() {
|
||
|
let txn = index.env.read_txn().unwrap();
|
||
|
let candidates = (0..=255).into_iter().collect::<RoaringBitmap>();
|
||
|
let mut results = String::new();
|
||
|
iterate_over_facet_distribution(
|
||
|
&txn,
|
||
|
&index.db.content,
|
||
|
0,
|
||
|
&candidates,
|
||
|
|facet, count| {
|
||
|
let facet = U16Codec::bytes_decode(facet).unwrap();
|
||
|
results.push_str(&format!("{facet}: {count}\n"));
|
||
|
ControlFlow::Continue(())
|
||
|
},
|
||
|
);
|
||
|
insta::assert_snapshot!(format!("filter_distribution_{i}_all"), results);
|
||
|
|
||
|
txn.commit().unwrap();
|
||
|
}
|
||
|
}
|
||
|
#[test]
|
||
|
fn filter_distribution_all_stop_early() {
|
||
|
let indexes = [get_simple_index(), get_random_looking_index()];
|
||
|
for (i, index) in indexes.into_iter().enumerate() {
|
||
|
let txn = index.env.read_txn().unwrap();
|
||
|
let candidates = (0..=255).into_iter().collect::<RoaringBitmap>();
|
||
|
let mut results = String::new();
|
||
|
let mut nbr_facets = 0;
|
||
|
iterate_over_facet_distribution(
|
||
|
&txn,
|
||
|
&index.db.content,
|
||
|
0,
|
||
|
&candidates,
|
||
|
|facet, count| {
|
||
|
let facet = U16Codec::bytes_decode(facet).unwrap();
|
||
|
if nbr_facets == 100 {
|
||
|
return ControlFlow::Break(());
|
||
|
} else {
|
||
|
nbr_facets += 1;
|
||
|
results.push_str(&format!("{facet}: {count}\n"));
|
||
|
|
||
|
ControlFlow::Continue(())
|
||
|
}
|
||
|
},
|
||
|
);
|
||
|
insta::assert_snapshot!(format!("filter_distribution_{i}_all_stop_early"), results);
|
||
|
|
||
|
txn.commit().unwrap();
|
||
|
}
|
||
|
}
|
||
|
}
|