From 362eb0de86e860612b6776b712ed41057f2df504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Lecrenier?= Date: Mon, 27 Feb 2023 16:45:07 +0100 Subject: [PATCH] Add support for filters --- milli/src/search/new/ranking_rules.rs | 117 +++++++++++--------------- 1 file changed, 48 insertions(+), 69 deletions(-) diff --git a/milli/src/search/new/ranking_rules.rs b/milli/src/search/new/ranking_rules.rs index ed51d3345..c7c9d5c97 100644 --- a/milli/src/search/new/ranking_rules.rs +++ b/milli/src/search/new/ranking_rules.rs @@ -11,7 +11,7 @@ use crate::new::graph_based_ranking_rule::GraphBasedRankingRule; use crate::new::ranking_rule_graph::proximity::ProximityGraph; use crate::new::words::Words; // use crate::search::new::sort::Sort; -use crate::{Index, Result, TermsMatchingStrategy}; +use crate::{Filter, Index, Result, TermsMatchingStrategy}; pub trait RankingRuleOutputIter<'transaction, Query> { fn next_bucket(&mut self) -> Result>>; @@ -111,16 +111,18 @@ pub fn get_start_universe<'transaction>( Ok(universe) } +// TODO: can make it generic over the query type (either query graph or placeholder) fairly easily +#[allow(clippy::too_many_arguments)] pub fn execute_search<'transaction>( index: &Index, txn: &'transaction heed::RoTxn, // TODO: ranking rules parameter db_cache: &mut DatabaseCache<'transaction>, - universe: &RoaringBitmap, query_graph: &QueryGraph, - logger: &mut dyn SearchLogger, + filters: Option, from: usize, length: usize, + logger: &mut dyn SearchLogger, ) -> Result> { let words = Words::new(TermsMatchingStrategy::Last); // let sort = Sort::new(index, txn, "sort1".to_owned(), true)?; @@ -131,9 +133,19 @@ pub fn execute_search<'transaction>( logger.ranking_rules(&ranking_rules); + let universe = if let Some(filters) = filters { + filters.evaluate(txn, index)? + } else { + index.documents_ids(txn)? + }; + + if universe.len() < from as u64 { + return Ok(vec![]); + } + let ranking_rules_len = ranking_rules.len(); - logger.start_iteration_ranking_rule(0, ranking_rules[0].as_ref(), query_graph, universe); - ranking_rules[0].start_iteration(index, txn, db_cache, logger, universe, query_graph)?; + logger.start_iteration_ranking_rule(0, ranking_rules[0].as_ref(), query_graph, &universe); + ranking_rules[0].start_iteration(index, txn, db_cache, logger, &universe, query_graph)?; let mut candidates = vec![RoaringBitmap::default(); ranking_rules_len]; candidates[0] = universe.clone(); @@ -160,23 +172,21 @@ pub fn execute_search<'transaction>( let mut results = vec![]; let mut cur_offset = 0usize; - macro_rules! add_to_results { + // Add the candidates to the results. Take the `from`, `limit`, and `cur_offset` into account. + macro_rules! maybe_add_to_results { ($candidates:expr) => { let candidates = $candidates; let len = candidates.len(); + // if the candidates are empty, there is nothing to do; if !candidates.is_empty() { - println!("cur_offset: {}, candidates_len: {}", cur_offset, candidates.len()); if cur_offset < from { - println!(" cur_offset < from"); if cur_offset + (candidates.len() as usize) < from { - println!(" cur_offset + candidates_len < from"); logger.skip_bucket_ranking_rule( cur_ranking_rule_index, ranking_rules[cur_ranking_rule_index].as_ref(), &candidates, ); } else { - println!(" cur_offset + candidates_len >= from"); let all_candidates = candidates.iter().collect::>(); let (skipped_candidates, candidates) = all_candidates.split_at(from - cur_offset); @@ -203,13 +213,12 @@ pub fn execute_search<'transaction>( cur_offset += len as usize; }; } - // TODO: skip buckets when we want to start from an offset while results.len() < length { // The universe for this bucket is zero or one element, so we don't need to sort // anything, just extend the results and go back to the parent ranking rule. if candidates[cur_ranking_rule_index].len() <= 1 { - add_to_results!(&candidates[cur_ranking_rule_index]); + maybe_add_to_results!(&candidates[cur_ranking_rule_index]); back!(); continue; } @@ -227,41 +236,30 @@ pub fn execute_search<'transaction>( candidates[cur_ranking_rule_index] -= &next_bucket.candidates; - if next_bucket.candidates.len() <= 1 { - // Only zero or one candidate, no need to sort through the child ranking rule. - add_to_results!(next_bucket.candidates); + if cur_ranking_rule_index == ranking_rules_len - 1 + || next_bucket.candidates.len() <= 1 + || cur_offset + (next_bucket.candidates.len() as usize) < from + { + maybe_add_to_results!(&next_bucket.candidates); continue; - } else { - // many candidates, give to next ranking rule, if any - if cur_ranking_rule_index == ranking_rules_len - 1 { - add_to_results!(next_bucket.candidates); - } else if cur_offset + (next_bucket.candidates.len() as usize) < from { - cur_offset += next_bucket.candidates.len() as usize; - logger.skip_bucket_ranking_rule( - cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), - &next_bucket.candidates, - ); - continue; - } else { - cur_ranking_rule_index += 1; - candidates[cur_ranking_rule_index] = next_bucket.candidates.clone(); - logger.start_iteration_ranking_rule( - cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), - &next_bucket.query, - &candidates[cur_ranking_rule_index], - ); - ranking_rules[cur_ranking_rule_index].start_iteration( - index, - txn, - db_cache, - logger, - &next_bucket.candidates, - &next_bucket.query, - )?; - } } + + cur_ranking_rule_index += 1; + candidates[cur_ranking_rule_index] = next_bucket.candidates.clone(); + logger.start_iteration_ranking_rule( + cur_ranking_rule_index, + ranking_rules[cur_ranking_rule_index].as_ref(), + &next_bucket.query, + &candidates[cur_ranking_rule_index], + ); + ranking_rules[cur_ranking_rule_index].start_iteration( + index, + txn, + db_cache, + logger, + &next_bucket.candidates, + &next_bucket.query, + )?; } Ok(results) @@ -325,28 +323,9 @@ mod tests { println!("{}", query_graph.graphviz()); logger.initial_query(&query_graph); - // TODO: filters + maybe distinct attributes? - let universe = get_start_universe( - &index, - &txn, - &mut db_cache, - &query_graph, - TermsMatchingStrategy::Last, - ) - .unwrap(); - println!("universe: {universe:?}"); - - let results = execute_search( - &index, - &txn, - &mut db_cache, - &universe, - &query_graph, - &mut logger, - 0, - 20, - ) - .unwrap(); + let results = + execute_search(&index, &txn, &mut db_cache, &query_graph, None, 0, 20, &mut logger) + .unwrap(); println!("{results:?}") } @@ -389,11 +368,11 @@ mod tests { &index, &txn, &mut db_cache, - &universe, &query_graph, - &mut logger, //&mut DefaultSearchLogger, + None, 500, 100, + &mut logger, //&mut DefaultSearchLogger, ) .unwrap();