From e4035ff3ec8882efe6bf3fb844f2ea0375a56e45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Lecrenier?= Date: Mon, 8 May 2023 11:52:43 +0200 Subject: [PATCH] Implement `words` as a graph-based ranking rule and fix some bugs --- .../search/new/graph_based_ranking_rule.rs | 8 +- milli/src/search/new/logger/visual.rs | 26 +++-- milli/src/search/new/mod.rs | 11 +- .../new/ranking_rule_graph/cheapest_paths.rs | 105 ++++++++++++------ .../src/search/new/ranking_rule_graph/mod.rs | 3 + .../new/ranking_rule_graph/words/mod.rs | 49 ++++++++ milli/src/search/new/tests/distinct.rs | 3 +- milli/src/search/new/words.rs | 87 --------------- 8 files changed, 154 insertions(+), 138 deletions(-) create mode 100644 milli/src/search/new/ranking_rule_graph/words/mod.rs delete mode 100644 milli/src/search/new/words.rs diff --git a/milli/src/search/new/graph_based_ranking_rule.rs b/milli/src/search/new/graph_based_ranking_rule.rs index d8f6836e7..dd25ddd4a 100644 --- a/milli/src/search/new/graph_based_ranking_rule.rs +++ b/milli/src/search/new/graph_based_ranking_rule.rs @@ -46,7 +46,7 @@ use super::logger::SearchLogger; use super::query_graph::QueryNode; use super::ranking_rule_graph::{ ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, FidGraph, PositionGraph, ProximityGraph, - RankingRuleGraph, RankingRuleGraphTrait, TypoGraph, + RankingRuleGraph, RankingRuleGraphTrait, TypoGraph, WordsGraph, }; use super::small_bitmap::SmallBitmap; use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext}; @@ -54,6 +54,12 @@ use crate::search::new::query_term::LocatedQueryTermSubset; use crate::search::new::ranking_rule_graph::PathVisitor; use crate::{Result, TermsMatchingStrategy}; +pub type Words = GraphBasedRankingRule; +impl GraphBasedRankingRule { + pub fn new(terms_matching_strategy: TermsMatchingStrategy) -> Self { + Self::new_with_id("words".to_owned(), Some(terms_matching_strategy)) + } +} pub type Proximity = GraphBasedRankingRule; impl GraphBasedRankingRule { pub fn new(terms_matching_strategy: Option) -> Self { diff --git a/milli/src/search/new/logger/visual.rs b/milli/src/search/new/logger/visual.rs index 1cbe007d3..e3f2b7c59 100644 --- a/milli/src/search/new/logger/visual.rs +++ b/milli/src/search/new/logger/visual.rs @@ -13,6 +13,7 @@ use crate::search::new::query_term::LocatedQueryTermSubset; use crate::search::new::ranking_rule_graph::{ Edge, FidCondition, FidGraph, PositionCondition, PositionGraph, ProximityCondition, ProximityGraph, RankingRuleGraph, RankingRuleGraphTrait, TypoCondition, TypoGraph, + WordsCondition, WordsGraph, }; use crate::search::new::ranking_rules::BoxRankingRule; use crate::search::new::{QueryGraph, QueryNode, RankingRule, SearchContext, SearchLogger}; @@ -24,11 +25,12 @@ pub enum SearchEvents { RankingRuleSkipBucket { ranking_rule_idx: usize, bucket_len: u64 }, RankingRuleEndIteration { ranking_rule_idx: usize, universe_len: u64 }, ExtendResults { new: Vec }, - WordsGraph { query_graph: QueryGraph }, ProximityGraph { graph: RankingRuleGraph }, ProximityPaths { paths: Vec>> }, TypoGraph { graph: RankingRuleGraph }, TypoPaths { paths: Vec>> }, + WordsGraph { graph: RankingRuleGraph }, + WordsPaths { paths: Vec>> }, FidGraph { graph: RankingRuleGraph }, FidPaths { paths: Vec>> }, PositionGraph { graph: RankingRuleGraph }, @@ -139,8 +141,11 @@ impl SearchLogger for VisualSearchLogger { let Some(location) = self.location.last() else { return }; match location { Location::Words => { - if let Some(query_graph) = state.downcast_ref::() { - self.events.push(SearchEvents::WordsGraph { query_graph: query_graph.clone() }); + if let Some(graph) = state.downcast_ref::>() { + self.events.push(SearchEvents::WordsGraph { graph: graph.clone() }); + } + if let Some(paths) = state.downcast_ref::>>>() { + self.events.push(SearchEvents::WordsPaths { paths: paths.clone() }); } } Location::Typo => { @@ -329,7 +334,6 @@ impl<'ctx> DetailedLoggerFinish<'ctx> { SearchEvents::ExtendResults { new } => { self.write_extend_results(new)?; } - SearchEvents::WordsGraph { query_graph } => self.write_words_graph(query_graph)?, SearchEvents::ProximityGraph { graph } => self.write_rr_graph(&graph)?, SearchEvents::ProximityPaths { paths } => { self.write_rr_graph_paths::(paths)?; @@ -338,6 +342,10 @@ impl<'ctx> DetailedLoggerFinish<'ctx> { SearchEvents::TypoPaths { paths } => { self.write_rr_graph_paths::(paths)?; } + SearchEvents::WordsGraph { graph } => self.write_rr_graph(&graph)?, + SearchEvents::WordsPaths { paths } => { + self.write_rr_graph_paths::(paths)?; + } SearchEvents::FidGraph { graph } => self.write_rr_graph(&graph)?, SearchEvents::FidPaths { paths } => { self.write_rr_graph_paths::(paths)?; @@ -482,13 +490,13 @@ fill: \"#B6E2D3\" } Ok(()) } - fn write_words_graph(&mut self, qg: QueryGraph) -> Result<()> { - self.make_new_file_for_internal_state_if_needed()?; + // fn write_words_graph(&mut self, qg: QueryGraph) -> Result<()> { + // self.make_new_file_for_internal_state_if_needed()?; - self.write_query_graph(&qg)?; + // self.write_query_graph(&qg)?; - Ok(()) - } + // Ok(()) + // } fn write_rr_graph( &mut self, graph: &RankingRuleGraph, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index cbc085b12..a28f42f35 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -15,11 +15,7 @@ mod resolve_query_graph; mod small_bitmap; mod exact_attribute; -// TODO: documentation + comments -// implementation is currently an adaptation of the previous implementation to fit with the new model mod sort; -// TODO: documentation + comments -mod words; #[cfg(test)] mod tests; @@ -43,10 +39,10 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; -use words::Words; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; +use self::graph_based_ranking_rule::Words; use self::interner::Interned; use crate::search::new::distinct::apply_distinct_rule; use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError}; @@ -202,6 +198,11 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( let mut sorted_fields = HashSet::new(); let mut geo_sorted = false; + // Don't add the `words` ranking rule if the term matching strategy is `All` + if matches!(terms_matching_strategy, TermsMatchingStrategy::All) { + words = true; + } + let mut ranking_rules: Vec> = vec![]; let settings_ranking_rules = ctx.index.criteria(ctx.txn)?; for rr in settings_ranking_rules { diff --git a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs index c065cc706..30caf0017 100644 --- a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs +++ b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs @@ -205,18 +205,12 @@ impl VisitorState { impl RankingRuleGraph { pub fn find_all_costs_to_end(&self) -> MappedInterner> { let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]); - let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len()); - let mut node_stack = VecDeque::new(); - - *costs_to_end.get_mut(self.query_graph.end_node) = vec![0]; - - for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() { - node_stack.push_back(prev_node); - enqueued.insert(prev_node); - } - - while let Some(cur_node) = node_stack.pop_front() { + self.traverse_breadth_first_backward(self.query_graph.end_node, |cur_node| { + if cur_node == self.query_graph.end_node { + *costs_to_end.get_mut(self.query_graph.end_node) = vec![0]; + return true; + } let mut self_costs = Vec::::new(); let cur_node_edges = &self.edges_of_node.get(cur_node); @@ -232,13 +226,8 @@ impl RankingRuleGraph { self_costs.dedup(); *costs_to_end.get_mut(cur_node) = self_costs; - for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { - if !enqueued.contains(prev_node) { - node_stack.push_back(prev_node); - enqueued.insert(prev_node); - } - } - } + true + }); costs_to_end } @@ -247,17 +236,9 @@ impl RankingRuleGraph { node_with_removed_outgoing_conditions: Interned, costs: &mut MappedInterner>, ) { - let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len()); - let mut node_stack = VecDeque::new(); - - enqueued.insert(node_with_removed_outgoing_conditions); - node_stack.push_back(node_with_removed_outgoing_conditions); - - 'main_loop: while let Some(cur_node) = node_stack.pop_front() { + self.traverse_breadth_first_backward(node_with_removed_outgoing_conditions, |cur_node| { let mut costs_to_remove = FxHashSet::default(); - for c in costs.get(cur_node) { - costs_to_remove.insert(*c); - } + costs_to_remove.extend(costs.get(cur_node).iter().copied()); let cur_node_edges = &self.edges_of_node.get(cur_node); for edge_idx in cur_node_edges.iter() { @@ -265,23 +246,79 @@ impl RankingRuleGraph { for cost in costs.get(edge.dest_node).iter() { costs_to_remove.remove(&(*cost + edge.cost as u64)); if costs_to_remove.is_empty() { - continue 'main_loop; + return false; } } } if costs_to_remove.is_empty() { - continue 'main_loop; + return false; } let mut new_costs = BTreeSet::from_iter(costs.get(cur_node).iter().copied()); for c in costs_to_remove { new_costs.remove(&c); } *costs.get_mut(cur_node) = new_costs.into_iter().collect(); + true + }); + } - for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { - if !enqueued.contains(prev_node) { - node_stack.push_back(prev_node); - enqueued.insert(prev_node); + /// Traverse the graph backwards from the given node such that every time + /// a node is visited, we are guaranteed that all its successors either: + /// 1. have already been visited; OR + /// 2. were not reachable from the given node + pub fn traverse_breadth_first_backward( + &self, + from: Interned, + mut visit: impl FnMut(Interned) -> bool, + ) { + let mut reachable = SmallBitmap::for_interned_values_in(&self.query_graph.nodes); + { + // go backward to get the set of all reachable nodes from the given node + // the nodes that are not reachable will be set as `visited` + let mut stack = VecDeque::new(); + let mut enqueued = SmallBitmap::for_interned_values_in(&self.query_graph.nodes); + enqueued.insert(from); + stack.push_back(from); + while let Some(n) = stack.pop_front() { + if reachable.contains(n) { + continue; + } + reachable.insert(n); + for prev_node in self.query_graph.nodes.get(n).predecessors.iter() { + if !enqueued.contains(prev_node) && !reachable.contains(prev_node) { + stack.push_back(prev_node); + enqueued.insert(prev_node); + } + } + } + }; + let mut unreachable_or_visited = + SmallBitmap::for_interned_values_in(&self.query_graph.nodes); + for (n, _) in self.query_graph.nodes.iter() { + if !reachable.contains(n) { + unreachable_or_visited.insert(n); + } + } + + let mut enqueued = SmallBitmap::for_interned_values_in(&self.query_graph.nodes); + let mut stack = VecDeque::new(); + + enqueued.insert(from); + stack.push_back(from); + + while let Some(cur_node) = stack.pop_front() { + if !self.query_graph.nodes.get(cur_node).successors.is_subset(&unreachable_or_visited) { + stack.push_back(cur_node); + continue; + } + unreachable_or_visited.insert(cur_node); + if visit(cur_node) { + for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { + if !enqueued.contains(prev_node) && !unreachable_or_visited.contains(prev_node) + { + stack.push_back(prev_node); + enqueued.insert(prev_node); + } } } } diff --git a/milli/src/search/new/ranking_rule_graph/mod.rs b/milli/src/search/new/ranking_rule_graph/mod.rs index f60c481de..8de455822 100644 --- a/milli/src/search/new/ranking_rule_graph/mod.rs +++ b/milli/src/search/new/ranking_rule_graph/mod.rs @@ -20,6 +20,8 @@ mod position; mod proximity; /// Implementation of the `typo` ranking rule mod typo; +/// Implementation of the `words` ranking rule +mod words; use std::collections::BTreeSet; use std::hash::Hash; @@ -33,6 +35,7 @@ pub use position::{PositionCondition, PositionGraph}; pub use proximity::{ProximityCondition, ProximityGraph}; use roaring::RoaringBitmap; pub use typo::{TypoCondition, TypoGraph}; +pub use words::{WordsCondition, WordsGraph}; use super::interner::{DedupInterner, FixedSizeInterner, Interned, MappedInterner}; use super::query_term::LocatedQueryTermSubset; diff --git a/milli/src/search/new/ranking_rule_graph/words/mod.rs b/milli/src/search/new/ranking_rule_graph/words/mod.rs new file mode 100644 index 000000000..0a0cc112b --- /dev/null +++ b/milli/src/search/new/ranking_rule_graph/words/mod.rs @@ -0,0 +1,49 @@ +use roaring::RoaringBitmap; + +use super::{ComputedCondition, RankingRuleGraphTrait}; +use crate::search::new::interner::{DedupInterner, Interned}; +use crate::search::new::query_term::LocatedQueryTermSubset; +use crate::search::new::resolve_query_graph::compute_query_term_subset_docids; +use crate::search::new::SearchContext; +use crate::Result; + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct WordsCondition { + term: LocatedQueryTermSubset, +} + +pub enum WordsGraph {} + +impl RankingRuleGraphTrait for WordsGraph { + type Condition = WordsCondition; + + fn resolve_condition( + ctx: &mut SearchContext, + condition: &Self::Condition, + universe: &RoaringBitmap, + ) -> Result { + let WordsCondition { term, .. } = condition; + // maybe compute_query_term_subset_docids should accept a universe as argument + let mut docids = compute_query_term_subset_docids(ctx, &term.term_subset)?; + docids &= universe; + + Ok(ComputedCondition { + docids, + universe_len: universe.len(), + start_term_subset: None, + end_term_subset: term.clone(), + }) + } + + fn build_edges( + _ctx: &mut SearchContext, + conditions_interner: &mut DedupInterner, + _from: Option<&LocatedQueryTermSubset>, + to_term: &LocatedQueryTermSubset, + ) -> Result)>> { + Ok(vec![( + to_term.term_ids.len() as u32, + conditions_interner.insert(WordsCondition { term: to_term.clone() }), + )]) + } +} diff --git a/milli/src/search/new/tests/distinct.rs b/milli/src/search/new/tests/distinct.rs index 2c147d514..ec835ba85 100644 --- a/milli/src/search/new/tests/distinct.rs +++ b/milli/src/search/new/tests/distinct.rs @@ -11,11 +11,10 @@ It doesn't test properly: - distinct attributes with arrays (because we know it's incorrect as well) */ -use std::collections::HashSet; - use big_s::S; use heed::RoTxn; use maplit::hashset; +use std::collections::HashSet; use super::collect_field_values; use crate::index::tests::TempIndex; diff --git a/milli/src/search/new/words.rs b/milli/src/search/new/words.rs deleted file mode 100644 index 72b7b5916..000000000 --- a/milli/src/search/new/words.rs +++ /dev/null @@ -1,87 +0,0 @@ -use roaring::RoaringBitmap; - -use super::logger::SearchLogger; -use super::query_graph::QueryNode; -use super::resolve_query_graph::compute_query_graph_docids; -use super::small_bitmap::SmallBitmap; -use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext}; -use crate::{Result, TermsMatchingStrategy}; - -pub struct Words { - exhausted: bool, // TODO: remove - query_graph: Option, - nodes_to_remove: Vec>, - terms_matching_strategy: TermsMatchingStrategy, -} -impl Words { - pub fn new(terms_matching_strategy: TermsMatchingStrategy) -> Self { - Self { - exhausted: true, - query_graph: None, - nodes_to_remove: vec![], - terms_matching_strategy, - } - } -} - -impl<'ctx> RankingRule<'ctx, QueryGraph> for Words { - fn id(&self) -> String { - "words".to_owned() - } - fn start_iteration( - &mut self, - ctx: &mut SearchContext<'ctx>, - _logger: &mut dyn SearchLogger, - _universe: &RoaringBitmap, - parent_query_graph: &QueryGraph, - ) -> Result<()> { - self.exhausted = false; - self.query_graph = Some(parent_query_graph.clone()); - self.nodes_to_remove = match self.terms_matching_strategy { - TermsMatchingStrategy::Last => { - let mut ns = parent_query_graph.removal_order_for_terms_matching_strategy_last(ctx); - ns.reverse(); - ns - } - TermsMatchingStrategy::All => { - vec![] - } - }; - Ok(()) - } - - fn next_bucket( - &mut self, - ctx: &mut SearchContext<'ctx>, - logger: &mut dyn SearchLogger, - universe: &RoaringBitmap, - ) -> Result>> { - if self.exhausted { - return Ok(None); - } - let Some(query_graph) = &mut self.query_graph else { panic!() }; - logger.log_internal_state(query_graph); - - let this_bucket = compute_query_graph_docids(ctx, query_graph, universe)?; - - let child_query_graph = query_graph.clone(); - - if self.nodes_to_remove.is_empty() { - self.exhausted = true; - } else { - let nodes_to_remove = self.nodes_to_remove.pop().unwrap(); - query_graph.remove_nodes_keep_edges(&nodes_to_remove.iter().collect::>()); - } - Ok(Some(RankingRuleOutput { query: child_query_graph, candidates: this_bucket })) - } - - fn end_iteration( - &mut self, - _ctx: &mut SearchContext<'ctx>, - _logger: &mut dyn SearchLogger, - ) { - self.exhausted = true; - self.nodes_to_remove = vec![]; - self.query_graph = None; - } -}