diff --git a/milli/src/search/new/logger/detailed.rs b/milli/src/search/new/logger/detailed.rs index 81571c14a..a85d20ccc 100644 --- a/milli/src/search/new/logger/detailed.rs +++ b/milli/src/search/new/logger/detailed.rs @@ -68,7 +68,7 @@ impl SearchLogger for DetailedSearchLogger { fn initial_universe(&mut self, universe: &RoaringBitmap) { self.initial_universe = Some(universe.clone()); } - fn ranking_rules(&mut self, rr: &[Box>]) { + fn ranking_rules(&mut self, rr: &[&mut dyn RankingRule]) { self.ranking_rules_ids = Some(rr.iter().map(|rr| rr.id()).collect()); } diff --git a/milli/src/search/new/logger/mod.rs b/milli/src/search/new/logger/mod.rs index 6b1f95152..3b828f7cb 100644 --- a/milli/src/search/new/logger/mod.rs +++ b/milli/src/search/new/logger/mod.rs @@ -17,7 +17,7 @@ impl SearchLogger for DefaultSearchLogger { fn initial_universe(&mut self, _universe: &RoaringBitmap) {} - fn ranking_rules(&mut self, _rr: &[Box>]) {} + fn ranking_rules(&mut self, _rr: &[&mut dyn RankingRule]) {} fn start_iteration_ranking_rule<'transaction>( &mut self, _ranking_rule_idx: usize, @@ -67,7 +67,7 @@ pub trait SearchLogger { fn initial_query(&mut self, query: &Q); fn initial_universe(&mut self, universe: &RoaringBitmap); - fn ranking_rules(&mut self, rr: &[Box>]); + fn ranking_rules(&mut self, rr: &[&mut dyn RankingRule]); fn start_iteration_ranking_rule<'transaction>( &mut self, diff --git a/milli/src/search/new/ranking_rules.rs b/milli/src/search/new/ranking_rules.rs index e78bdff0c..f3f71ab4b 100644 --- a/milli/src/search/new/ranking_rules.rs +++ b/milli/src/search/new/ranking_rules.rs @@ -8,7 +8,7 @@ use super::QueryGraph; use crate::new::graph_based_ranking_rule::GraphBasedRankingRule; use crate::new::ranking_rule_graph::proximity::ProximityGraph; use crate::new::words::Words; -// use crate::search::new::sort::Sort; +use crate::search::new::sort::Sort; use crate::{Filter, Index, Result, TermsMatchingStrategy}; pub trait RankingRuleOutputIter<'transaction, Query> { @@ -122,12 +122,12 @@ pub fn execute_search<'transaction>( length: usize, logger: &mut dyn SearchLogger, ) -> Result> { - let words = Words::new(TermsMatchingStrategy::Last); - // let sort = Sort::new(index, txn, "sort1".to_owned(), true)?; - let proximity = GraphBasedRankingRule::::new("proximity".to_owned()); + let words = &mut Words::new(TermsMatchingStrategy::Last); + let sort = &mut Sort::new(index, txn, "release_date".to_owned(), true)?; + let proximity = &mut GraphBasedRankingRule::::new("proximity".to_owned()); // TODO: ranking rules given as argument - let mut ranking_rules: Vec>> = - vec![Box::new(words), Box::new(proximity) /* Box::new(sort) */]; + let mut ranking_rules: Vec<&mut dyn RankingRule<'transaction, QueryGraph>> = + vec![words, proximity, sort]; logger.ranking_rules(&ranking_rules); @@ -142,7 +142,7 @@ pub fn execute_search<'transaction>( } let ranking_rules_len = ranking_rules.len(); - logger.start_iteration_ranking_rule(0, ranking_rules[0].as_ref(), query_graph, &universe); + logger.start_iteration_ranking_rule(0, ranking_rules[0], query_graph, &universe); ranking_rules[0].start_iteration(index, txn, db_cache, logger, &universe, query_graph)?; let mut candidates = vec![RoaringBitmap::default(); ranking_rules_len]; @@ -152,9 +152,10 @@ pub fn execute_search<'transaction>( macro_rules! back { () => { + assert!(candidates[cur_ranking_rule_index].is_empty()); logger.end_iteration_ranking_rule( cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), + ranking_rules[cur_ranking_rule_index], &candidates[cur_ranking_rule_index], ); candidates[cur_ranking_rule_index].clear(); @@ -182,7 +183,7 @@ pub fn execute_search<'transaction>( if cur_offset + (candidates.len() as usize) < from { logger.skip_bucket_ranking_rule( cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), + ranking_rules[cur_ranking_rule_index], &candidates, ); } else { @@ -191,7 +192,7 @@ pub fn execute_search<'transaction>( all_candidates.split_at(from - cur_offset); logger.skip_bucket_ranking_rule( cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), + ranking_rules[cur_ranking_rule_index], &skipped_candidates.into_iter().collect(), ); let candidates = candidates @@ -216,6 +217,7 @@ pub fn execute_search<'transaction>( // The universe for this bucket is zero or one element, so we don't need to sort // anything, just extend the results and go back to the parent ranking rule. if candidates[cur_ranking_rule_index].len() <= 1 { + candidates[cur_ranking_rule_index].clear(); maybe_add_to_results!(&candidates[cur_ranking_rule_index]); back!(); continue; @@ -223,7 +225,7 @@ pub fn execute_search<'transaction>( logger.next_bucket_ranking_rule( cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), + ranking_rules[cur_ranking_rule_index], &candidates[cur_ranking_rule_index], ); @@ -232,6 +234,7 @@ pub fn execute_search<'transaction>( continue; }; + assert!(candidates[cur_ranking_rule_index].is_superset(&next_bucket.candidates)); candidates[cur_ranking_rule_index] -= &next_bucket.candidates; if cur_ranking_rule_index == ranking_rules_len - 1 @@ -246,7 +249,7 @@ pub fn execute_search<'transaction>( candidates[cur_ranking_rule_index] = next_bucket.candidates.clone(); logger.start_iteration_ranking_rule( cur_ranking_rule_index, - ranking_rules[cur_ranking_rule_index].as_ref(), + ranking_rules[cur_ranking_rule_index], &next_bucket.query, &candidates[cur_ranking_rule_index], ); diff --git a/milli/src/search/new/sort.rs b/milli/src/search/new/sort.rs index 29d244383..9ef01bd95 100644 --- a/milli/src/search/new/sort.rs +++ b/milli/src/search/new/sort.rs @@ -16,19 +16,16 @@ use crate::{ Result, }; -// TODO: The implementation of Sort is not correct: -// (1) it should not return documents it has already returned (does the current implementation have the same bug?) -// (2) at the end, it should return all the remaining documents (this could be ensured at the trait level?) - pub struct Sort<'transaction, Query> { field_name: String, field_id: Option, is_ascending: bool, + original_query: Option, iter: Option>, } impl<'transaction, Query> Sort<'transaction, Query> { pub fn new( - index: &'transaction Index, + index: &Index, rtxn: &'transaction heed::RoTxn, field_name: String, is_ascending: bool, @@ -36,7 +33,7 @@ impl<'transaction, Query> Sort<'transaction, Query> { let fields_ids_map = index.fields_ids_map(rtxn)?; let field_id = fields_ids_map.id(&field_name); - Ok(Self { field_name, field_id, is_ascending, iter: None }) + Ok(Self { field_name, field_id, is_ascending, original_query: None, iter: None }) } } @@ -87,6 +84,7 @@ impl<'transaction, Query: RankingRuleQueryTrait> RankingRule<'transaction, Query } None => RankingRuleOutputIterWrapper::new(Box::new(std::iter::empty())), }; + self.original_query = Some(parent_query_graph.clone()); self.iter = Some(iter); Ok(()) } @@ -97,11 +95,17 @@ impl<'transaction, Query: RankingRuleQueryTrait> RankingRule<'transaction, Query _txn: &'transaction RoTxn, _db_cache: &mut DatabaseCache<'transaction>, _logger: &mut dyn SearchLogger, - _universe: &RoaringBitmap, + universe: &RoaringBitmap, ) -> Result>> { let iter = self.iter.as_mut().unwrap(); // TODO: we should make use of the universe in the function below - iter.next_bucket() + if let Some(mut bucket) = iter.next_bucket()? { + bucket.candidates &= universe; + Ok(Some(bucket)) + } else { + let query = self.original_query.as_ref().unwrap().clone(); + Ok(Some(RankingRuleOutput { query, candidates: universe.clone() })) + } } fn end_iteration( @@ -111,6 +115,7 @@ impl<'transaction, Query: RankingRuleQueryTrait> RankingRule<'transaction, Query _db_cache: &mut DatabaseCache<'transaction>, _logger: &mut dyn SearchLogger, ) { + self.original_query = None; self.iter = None; } }