diff --git a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs index c09f6e5e0..443ab0ec4 100644 --- a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs +++ b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs @@ -9,141 +9,202 @@ use crate::search::new::query_graph::QueryNode; use crate::search::new::small_bitmap::SmallBitmap; use crate::Result; -impl RankingRuleGraph { - pub fn visit_paths_of_cost( - &mut self, - from: Interned, - cost: u16, - all_distances: &MappedInterner>, - dead_ends_cache: &mut DeadEndsCache, - mut visit: impl FnMut( - &[Interned], - &mut Self, - &mut DeadEndsCache, - ) -> Result>, - ) -> Result<()> { - let _ = self.visit_paths_of_cost_rec( - from, - cost, - all_distances, - dead_ends_cache, - &mut visit, - &mut vec![], - &mut SmallBitmap::for_interned_values_in(&self.conditions_interner), - dead_ends_cache.forbidden.clone(), - )?; +type VisitFn<'f, G> = &'f mut dyn FnMut( + &[Interned<::Condition>], + &mut RankingRuleGraph, + &mut DeadEndsCache<::Condition>, +) -> Result>; + +struct VisitorContext<'a, G: RankingRuleGraphTrait> { + graph: &'a mut RankingRuleGraph, + all_costs_from_node: &'a MappedInterner>, + dead_ends_cache: &'a mut DeadEndsCache, +} + +struct VisitorState { + remaining_cost: u64, + + path: Vec>, + + visited_conditions: SmallBitmap, + visited_nodes: SmallBitmap, + + forbidden_conditions: SmallBitmap, + forbidden_conditions_to_nodes: SmallBitmap, +} + +pub struct PathVisitor<'a, G: RankingRuleGraphTrait> { + state: VisitorState, + ctx: VisitorContext<'a, G>, +} +impl<'a, G: RankingRuleGraphTrait> PathVisitor<'a, G> { + pub fn new( + cost: u64, + graph: &'a mut RankingRuleGraph, + all_costs_from_node: &'a MappedInterner>, + dead_ends_cache: &'a mut DeadEndsCache, + ) -> Self { + Self { + state: VisitorState { + remaining_cost: cost, + path: vec![], + visited_conditions: SmallBitmap::for_interned_values_in(&graph.conditions_interner), + visited_nodes: SmallBitmap::for_interned_values_in(&graph.query_graph.nodes), + forbidden_conditions: SmallBitmap::for_interned_values_in( + &graph.conditions_interner, + ), + forbidden_conditions_to_nodes: SmallBitmap::for_interned_values_in( + &graph.query_graph.nodes, + ), + }, + ctx: VisitorContext { graph, all_costs_from_node, dead_ends_cache }, + } + } + + pub fn visit_paths(mut self, visit: VisitFn) -> Result<()> { + let _ = + self.state.visit_node(self.ctx.graph.query_graph.root_node, visit, &mut self.ctx)?; Ok(()) } - pub fn visit_paths_of_cost_rec( +} + +impl VisitorState { + fn visit_node( &mut self, - from: Interned, - cost: u16, - all_distances: &MappedInterner>, - dead_ends_cache: &mut DeadEndsCache, - visit: &mut impl FnMut( - &[Interned], - &mut Self, - &mut DeadEndsCache, - ) -> Result>, - prev_conditions: &mut Vec>, - cur_path: &mut SmallBitmap, - mut forbidden_conditions: SmallBitmap, - ) -> Result { + from_node: Interned, + visit: VisitFn, + ctx: &mut VisitorContext, + ) -> Result> { let mut any_valid = false; - let edges = self.edges_of_node.get(from).clone(); - 'edges_loop: for edge_idx in edges.iter() { - let Some(edge) = self.edges_store.get(edge_idx).as_ref() else { continue }; - if cost < edge.cost as u16 { + let edges = ctx.graph.edges_of_node.get(from_node).clone(); + for edge_idx in edges.iter() { + let Some(edge) = ctx.graph.edges_store.get(edge_idx).clone() else { continue }; + + if self.remaining_cost < edge.cost as u64 { continue; } - let next_any_valid = match edge.condition { - None => { - if edge.dest_node == self.query_graph.end_node { - any_valid = true; - let control_flow = visit(prev_conditions, self, dead_ends_cache)?; - match control_flow { - ControlFlow::Continue(_) => {} - ControlFlow::Break(_) => return Ok(true), - } - true - } else { - self.visit_paths_of_cost_rec( - edge.dest_node, - cost - edge.cost as u16, - all_distances, - dead_ends_cache, - visit, - prev_conditions, - cur_path, - forbidden_conditions.clone(), - )? - } - } - Some(condition) => { - if forbidden_conditions.contains(condition) - || all_distances - .get(edge.dest_node) - .iter() - .all(|next_cost| *next_cost != cost - edge.cost as u16) - { - continue; - } - cur_path.insert(condition); - prev_conditions.push(condition); - let mut new_forbidden_conditions = forbidden_conditions.clone(); - if let Some(next_forbidden) = - dead_ends_cache.forbidden_conditions_after_prefix(prev_conditions) - { - new_forbidden_conditions.union(&next_forbidden); - } - - let next_any_valid = if edge.dest_node == self.query_graph.end_node { - any_valid = true; - let control_flow = visit(prev_conditions, self, dead_ends_cache)?; - match control_flow { - ControlFlow::Continue(_) => {} - ControlFlow::Break(_) => return Ok(true), - } - true - } else { - self.visit_paths_of_cost_rec( - edge.dest_node, - cost - edge.cost as u16, - all_distances, - dead_ends_cache, - visit, - prev_conditions, - cur_path, - new_forbidden_conditions, - )? - }; - cur_path.remove(condition); - prev_conditions.pop(); - next_any_valid - } + self.remaining_cost -= edge.cost as u64; + let cf = match edge.condition { + Some(condition) => self.visit_condition( + condition, + edge.dest_node, + &edge.nodes_to_skip, + visit, + ctx, + )?, + None => self.visit_no_condition(edge.dest_node, &edge.nodes_to_skip, visit, ctx)?, }; - any_valid |= next_any_valid; + self.remaining_cost += edge.cost as u64; + let ControlFlow::Continue(next_any_valid) = cf else { + return Ok(ControlFlow::Break(())); + }; if next_any_valid { - forbidden_conditions = - dead_ends_cache.forbidden_conditions_for_all_prefixes_up_to(prev_conditions); - if cur_path.intersects(&forbidden_conditions) { - break 'edges_loop; + self.forbidden_conditions = ctx + .dead_ends_cache + .forbidden_conditions_for_all_prefixes_up_to(self.path.iter().copied()); + if self.visited_conditions.intersects(&self.forbidden_conditions) { + break; } } + any_valid |= next_any_valid; } - Ok(any_valid) + Ok(ControlFlow::Continue(any_valid)) } - pub fn initialize_distances_with_necessary_edges(&self) -> MappedInterner> { - let mut distances_to_end = self.query_graph.nodes.map(|_| vec![]); + fn visit_no_condition( + &mut self, + dest_node: Interned, + edge_forbidden_nodes: &SmallBitmap, + visit: VisitFn, + ctx: &mut VisitorContext, + ) -> Result> { + if ctx + .all_costs_from_node + .get(dest_node) + .iter() + .all(|next_cost| *next_cost != self.remaining_cost) + { + return Ok(ControlFlow::Continue(false)); + } + if dest_node == ctx.graph.query_graph.end_node { + let control_flow = visit(&self.path, ctx.graph, ctx.dead_ends_cache)?; + match control_flow { + ControlFlow::Continue(_) => Ok(ControlFlow::Continue(true)), + ControlFlow::Break(_) => Ok(ControlFlow::Break(())), + } + } else { + let old_fbct = self.forbidden_conditions_to_nodes.clone(); + self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes); + let cf = self.visit_node(dest_node, visit, ctx)?; + self.forbidden_conditions_to_nodes = old_fbct; + Ok(cf) + } + } + fn visit_condition( + &mut self, + condition: Interned, + dest_node: Interned, + edge_forbidden_nodes: &SmallBitmap, + visit: VisitFn, + ctx: &mut VisitorContext, + ) -> Result> { + assert!(dest_node != ctx.graph.query_graph.end_node); + + if self.forbidden_conditions_to_nodes.contains(dest_node) + || edge_forbidden_nodes.intersects(&self.visited_nodes) + { + return Ok(ControlFlow::Continue(false)); + } + if self.forbidden_conditions.contains(condition) { + return Ok(ControlFlow::Continue(false)); + } + + if ctx + .all_costs_from_node + .get(dest_node) + .iter() + .all(|next_cost| *next_cost != self.remaining_cost) + { + return Ok(ControlFlow::Continue(false)); + } + + self.path.push(condition); + self.visited_nodes.insert(dest_node); + self.visited_conditions.insert(condition); + + let old_fc = self.forbidden_conditions.clone(); + if let Some(next_forbidden) = + ctx.dead_ends_cache.forbidden_conditions_after_prefix(self.path.iter().copied()) + { + self.forbidden_conditions.union(&next_forbidden); + } + let old_fctn = self.forbidden_conditions_to_nodes.clone(); + self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes); + + let cf = self.visit_node(dest_node, visit, ctx)?; + + self.forbidden_conditions_to_nodes = old_fctn; + self.forbidden_conditions = old_fc; + + self.visited_conditions.remove(condition); + self.visited_nodes.remove(dest_node); + self.path.pop(); + + Ok(cf) + } +} + +impl RankingRuleGraph { + pub fn find_all_costs_to_end(&self) -> MappedInterner> { + let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]); let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len()); let mut node_stack = VecDeque::new(); - *distances_to_end.get_mut(self.query_graph.end_node) = vec![0]; + *costs_to_end.get_mut(self.query_graph.end_node) = vec![0]; for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() { node_stack.push_back(prev_node); @@ -151,22 +212,22 @@ impl RankingRuleGraph { } while let Some(cur_node) = node_stack.pop_front() { - let mut self_distances = BTreeSet::::new(); + let mut self_costs = BTreeSet::::new(); let cur_node_edges = &self.edges_of_node.get(cur_node); for edge_idx in cur_node_edges.iter() { let edge = self.edges_store.get(edge_idx).as_ref().unwrap(); let succ_node = edge.dest_node; - let succ_distances = distances_to_end.get(succ_node); - for succ_distance in succ_distances { - self_distances.insert(edge.cost as u16 + succ_distance); + let succ_costs = costs_to_end.get(succ_node); + for succ_distance in succ_costs { + self_costs.insert(edge.cost as u64 + succ_distance); } } - let distances_to_end_cur_node = distances_to_end.get_mut(cur_node); - for cost in self_distances.iter() { - distances_to_end_cur_node.push(*cost); + let costs_to_end_cur_node = costs_to_end.get_mut(cur_node); + for cost in self_costs.iter() { + costs_to_end_cur_node.push(*cost); } - *distances_to_end.get_mut(cur_node) = self_distances.into_iter().collect(); + *costs_to_end.get_mut(cur_node) = self_costs.into_iter().collect(); for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { if !enqueued.contains(prev_node) { node_stack.push_back(prev_node); @@ -174,6 +235,6 @@ impl RankingRuleGraph { } } } - distances_to_end + costs_to_end } } diff --git a/milli/src/search/new/ranking_rule_graph/dead_ends_cache.rs b/milli/src/search/new/ranking_rule_graph/dead_ends_cache.rs index d25c69c23..f3bb25d56 100644 --- a/milli/src/search/new/ranking_rule_graph/dead_ends_cache.rs +++ b/milli/src/search/new/ranking_rule_graph/dead_ends_cache.rs @@ -36,12 +36,12 @@ impl DeadEndsCache { } pub fn forbidden_conditions_for_all_prefixes_up_to( &mut self, - prefix: &[Interned], + prefix: impl Iterator>, ) -> SmallBitmap { let mut forbidden = self.forbidden.clone(); let mut cursor = self; - for c in prefix.iter() { - if let Some(next) = cursor.advance(*c) { + for c in prefix { + if let Some(next) = cursor.advance(c) { cursor = next; forbidden.union(&cursor.forbidden); } else { @@ -52,11 +52,11 @@ impl DeadEndsCache { } pub fn forbidden_conditions_after_prefix( &mut self, - prefix: &[Interned], + prefix: impl Iterator>, ) -> Option> { let mut cursor = self; - for c in prefix.iter() { - if let Some(next) = cursor.advance(*c) { + for c in prefix { + if let Some(next) = cursor.advance(c) { cursor = next; } else { return None;