Refactor the paths_of_cost algorithm

Support conditions that require certain nodes to be skipped
This commit is contained in:
Loïc Lecrenier 2023-03-30 12:11:11 +02:00
parent 01e24dd630
commit aa9592455c
2 changed files with 191 additions and 130 deletions

View File

@ -9,141 +9,202 @@ use crate::search::new::query_graph::QueryNode;
use crate::search::new::small_bitmap::SmallBitmap; use crate::search::new::small_bitmap::SmallBitmap;
use crate::Result; use crate::Result;
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> { type VisitFn<'f, G> = &'f mut dyn FnMut(
pub fn visit_paths_of_cost( &[Interned<<G as RankingRuleGraphTrait>::Condition>],
&mut self, &mut RankingRuleGraph<G>,
from: Interned<QueryNode>, &mut DeadEndsCache<<G as RankingRuleGraphTrait>::Condition>,
cost: u16, ) -> Result<ControlFlow<()>>;
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
dead_ends_cache: &mut DeadEndsCache<G::Condition>, struct VisitorContext<'a, G: RankingRuleGraphTrait> {
mut visit: impl FnMut( graph: &'a mut RankingRuleGraph<G>,
&[Interned<G::Condition>], all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
&mut Self, dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
&mut DeadEndsCache<G::Condition>, }
) -> Result<ControlFlow<()>>,
) -> Result<()> { struct VisitorState<G: RankingRuleGraphTrait> {
let _ = self.visit_paths_of_cost_rec( remaining_cost: u64,
from,
cost, path: Vec<Interned<G::Condition>>,
all_distances,
dead_ends_cache, visited_conditions: SmallBitmap<G::Condition>,
&mut visit, visited_nodes: SmallBitmap<QueryNode>,
&mut vec![],
&mut SmallBitmap::for_interned_values_in(&self.conditions_interner), forbidden_conditions: SmallBitmap<G::Condition>,
dead_ends_cache.forbidden.clone(), forbidden_conditions_to_nodes: SmallBitmap<QueryNode>,
)?; }
pub struct PathVisitor<'a, G: RankingRuleGraphTrait> {
state: VisitorState<G>,
ctx: VisitorContext<'a, G>,
}
impl<'a, G: RankingRuleGraphTrait> PathVisitor<'a, G> {
pub fn new(
cost: u64,
graph: &'a mut RankingRuleGraph<G>,
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
) -> Self {
Self {
state: VisitorState {
remaining_cost: cost,
path: vec![],
visited_conditions: SmallBitmap::for_interned_values_in(&graph.conditions_interner),
visited_nodes: SmallBitmap::for_interned_values_in(&graph.query_graph.nodes),
forbidden_conditions: SmallBitmap::for_interned_values_in(
&graph.conditions_interner,
),
forbidden_conditions_to_nodes: SmallBitmap::for_interned_values_in(
&graph.query_graph.nodes,
),
},
ctx: VisitorContext { graph, all_costs_from_node, dead_ends_cache },
}
}
pub fn visit_paths(mut self, visit: VisitFn<G>) -> Result<()> {
let _ =
self.state.visit_node(self.ctx.graph.query_graph.root_node, visit, &mut self.ctx)?;
Ok(()) Ok(())
} }
pub fn visit_paths_of_cost_rec( }
impl<G: RankingRuleGraphTrait> VisitorState<G> {
fn visit_node(
&mut self, &mut self,
from: Interned<QueryNode>, from_node: Interned<QueryNode>,
cost: u16, visit: VisitFn<G>,
all_distances: &MappedInterner<QueryNode, Vec<u16>>, ctx: &mut VisitorContext<G>,
dead_ends_cache: &mut DeadEndsCache<G::Condition>, ) -> Result<ControlFlow<(), bool>> {
visit: &mut impl FnMut(
&[Interned<G::Condition>],
&mut Self,
&mut DeadEndsCache<G::Condition>,
) -> Result<ControlFlow<()>>,
prev_conditions: &mut Vec<Interned<G::Condition>>,
cur_path: &mut SmallBitmap<G::Condition>,
mut forbidden_conditions: SmallBitmap<G::Condition>,
) -> Result<bool> {
let mut any_valid = false; let mut any_valid = false;
let edges = self.edges_of_node.get(from).clone(); let edges = ctx.graph.edges_of_node.get(from_node).clone();
'edges_loop: for edge_idx in edges.iter() { for edge_idx in edges.iter() {
let Some(edge) = self.edges_store.get(edge_idx).as_ref() else { continue }; let Some(edge) = ctx.graph.edges_store.get(edge_idx).clone() else { continue };
if cost < edge.cost as u16 {
if self.remaining_cost < edge.cost as u64 {
continue; continue;
} }
let next_any_valid = match edge.condition { self.remaining_cost -= edge.cost as u64;
None => { let cf = match edge.condition {
if edge.dest_node == self.query_graph.end_node { Some(condition) => self.visit_condition(
any_valid = true; condition,
let control_flow = visit(prev_conditions, self, dead_ends_cache)?; edge.dest_node,
match control_flow { &edge.nodes_to_skip,
ControlFlow::Continue(_) => {} visit,
ControlFlow::Break(_) => return Ok(true), ctx,
} )?,
true None => self.visit_no_condition(edge.dest_node, &edge.nodes_to_skip, visit, ctx)?,
} else {
self.visit_paths_of_cost_rec(
edge.dest_node,
cost - edge.cost as u16,
all_distances,
dead_ends_cache,
visit,
prev_conditions,
cur_path,
forbidden_conditions.clone(),
)?
}
}
Some(condition) => {
if forbidden_conditions.contains(condition)
|| all_distances
.get(edge.dest_node)
.iter()
.all(|next_cost| *next_cost != cost - edge.cost as u16)
{
continue;
}
cur_path.insert(condition);
prev_conditions.push(condition);
let mut new_forbidden_conditions = forbidden_conditions.clone();
if let Some(next_forbidden) =
dead_ends_cache.forbidden_conditions_after_prefix(prev_conditions)
{
new_forbidden_conditions.union(&next_forbidden);
}
let next_any_valid = if edge.dest_node == self.query_graph.end_node {
any_valid = true;
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
match control_flow {
ControlFlow::Continue(_) => {}
ControlFlow::Break(_) => return Ok(true),
}
true
} else {
self.visit_paths_of_cost_rec(
edge.dest_node,
cost - edge.cost as u16,
all_distances,
dead_ends_cache,
visit,
prev_conditions,
cur_path,
new_forbidden_conditions,
)?
};
cur_path.remove(condition);
prev_conditions.pop();
next_any_valid
}
}; };
any_valid |= next_any_valid; self.remaining_cost += edge.cost as u64;
let ControlFlow::Continue(next_any_valid) = cf else {
return Ok(ControlFlow::Break(()));
};
if next_any_valid { if next_any_valid {
forbidden_conditions = self.forbidden_conditions = ctx
dead_ends_cache.forbidden_conditions_for_all_prefixes_up_to(prev_conditions); .dead_ends_cache
if cur_path.intersects(&forbidden_conditions) { .forbidden_conditions_for_all_prefixes_up_to(self.path.iter().copied());
break 'edges_loop; if self.visited_conditions.intersects(&self.forbidden_conditions) {
break;
} }
} }
any_valid |= next_any_valid;
} }
Ok(any_valid) Ok(ControlFlow::Continue(any_valid))
} }
pub fn initialize_distances_with_necessary_edges(&self) -> MappedInterner<QueryNode, Vec<u16>> { fn visit_no_condition(
let mut distances_to_end = self.query_graph.nodes.map(|_| vec![]); &mut self,
dest_node: Interned<QueryNode>,
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
visit: VisitFn<G>,
ctx: &mut VisitorContext<G>,
) -> Result<ControlFlow<(), bool>> {
if ctx
.all_costs_from_node
.get(dest_node)
.iter()
.all(|next_cost| *next_cost != self.remaining_cost)
{
return Ok(ControlFlow::Continue(false));
}
if dest_node == ctx.graph.query_graph.end_node {
let control_flow = visit(&self.path, ctx.graph, ctx.dead_ends_cache)?;
match control_flow {
ControlFlow::Continue(_) => Ok(ControlFlow::Continue(true)),
ControlFlow::Break(_) => Ok(ControlFlow::Break(())),
}
} else {
let old_fbct = self.forbidden_conditions_to_nodes.clone();
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
let cf = self.visit_node(dest_node, visit, ctx)?;
self.forbidden_conditions_to_nodes = old_fbct;
Ok(cf)
}
}
fn visit_condition(
&mut self,
condition: Interned<G::Condition>,
dest_node: Interned<QueryNode>,
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
visit: VisitFn<G>,
ctx: &mut VisitorContext<G>,
) -> Result<ControlFlow<(), bool>> {
assert!(dest_node != ctx.graph.query_graph.end_node);
if self.forbidden_conditions_to_nodes.contains(dest_node)
|| edge_forbidden_nodes.intersects(&self.visited_nodes)
{
return Ok(ControlFlow::Continue(false));
}
if self.forbidden_conditions.contains(condition) {
return Ok(ControlFlow::Continue(false));
}
if ctx
.all_costs_from_node
.get(dest_node)
.iter()
.all(|next_cost| *next_cost != self.remaining_cost)
{
return Ok(ControlFlow::Continue(false));
}
self.path.push(condition);
self.visited_nodes.insert(dest_node);
self.visited_conditions.insert(condition);
let old_fc = self.forbidden_conditions.clone();
if let Some(next_forbidden) =
ctx.dead_ends_cache.forbidden_conditions_after_prefix(self.path.iter().copied())
{
self.forbidden_conditions.union(&next_forbidden);
}
let old_fctn = self.forbidden_conditions_to_nodes.clone();
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
let cf = self.visit_node(dest_node, visit, ctx)?;
self.forbidden_conditions_to_nodes = old_fctn;
self.forbidden_conditions = old_fc;
self.visited_conditions.remove(condition);
self.visited_nodes.remove(dest_node);
self.path.pop();
Ok(cf)
}
}
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
pub fn find_all_costs_to_end(&self) -> MappedInterner<QueryNode, Vec<u64>> {
let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]);
let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len()); let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len());
let mut node_stack = VecDeque::new(); let mut node_stack = VecDeque::new();
*distances_to_end.get_mut(self.query_graph.end_node) = vec![0]; *costs_to_end.get_mut(self.query_graph.end_node) = vec![0];
for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() { for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() {
node_stack.push_back(prev_node); node_stack.push_back(prev_node);
@ -151,22 +212,22 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
} }
while let Some(cur_node) = node_stack.pop_front() { while let Some(cur_node) = node_stack.pop_front() {
let mut self_distances = BTreeSet::<u16>::new(); let mut self_costs = BTreeSet::<u64>::new();
let cur_node_edges = &self.edges_of_node.get(cur_node); let cur_node_edges = &self.edges_of_node.get(cur_node);
for edge_idx in cur_node_edges.iter() { for edge_idx in cur_node_edges.iter() {
let edge = self.edges_store.get(edge_idx).as_ref().unwrap(); let edge = self.edges_store.get(edge_idx).as_ref().unwrap();
let succ_node = edge.dest_node; let succ_node = edge.dest_node;
let succ_distances = distances_to_end.get(succ_node); let succ_costs = costs_to_end.get(succ_node);
for succ_distance in succ_distances { for succ_distance in succ_costs {
self_distances.insert(edge.cost as u16 + succ_distance); self_costs.insert(edge.cost as u64 + succ_distance);
} }
} }
let distances_to_end_cur_node = distances_to_end.get_mut(cur_node); let costs_to_end_cur_node = costs_to_end.get_mut(cur_node);
for cost in self_distances.iter() { for cost in self_costs.iter() {
distances_to_end_cur_node.push(*cost); costs_to_end_cur_node.push(*cost);
} }
*distances_to_end.get_mut(cur_node) = self_distances.into_iter().collect(); *costs_to_end.get_mut(cur_node) = self_costs.into_iter().collect();
for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() {
if !enqueued.contains(prev_node) { if !enqueued.contains(prev_node) {
node_stack.push_back(prev_node); node_stack.push_back(prev_node);
@ -174,6 +235,6 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
} }
} }
} }
distances_to_end costs_to_end
} }
} }

View File

@ -36,12 +36,12 @@ impl<T> DeadEndsCache<T> {
} }
pub fn forbidden_conditions_for_all_prefixes_up_to( pub fn forbidden_conditions_for_all_prefixes_up_to(
&mut self, &mut self,
prefix: &[Interned<T>], prefix: impl Iterator<Item = Interned<T>>,
) -> SmallBitmap<T> { ) -> SmallBitmap<T> {
let mut forbidden = self.forbidden.clone(); let mut forbidden = self.forbidden.clone();
let mut cursor = self; let mut cursor = self;
for c in prefix.iter() { for c in prefix {
if let Some(next) = cursor.advance(*c) { if let Some(next) = cursor.advance(c) {
cursor = next; cursor = next;
forbidden.union(&cursor.forbidden); forbidden.union(&cursor.forbidden);
} else { } else {
@ -52,11 +52,11 @@ impl<T> DeadEndsCache<T> {
} }
pub fn forbidden_conditions_after_prefix( pub fn forbidden_conditions_after_prefix(
&mut self, &mut self,
prefix: &[Interned<T>], prefix: impl Iterator<Item = Interned<T>>,
) -> Option<SmallBitmap<T>> { ) -> Option<SmallBitmap<T>> {
let mut cursor = self; let mut cursor = self;
for c in prefix.iter() { for c in prefix {
if let Some(next) = cursor.advance(*c) { if let Some(next) = cursor.advance(c) {
cursor = next; cursor = next;
} else { } else {
return None; return None;