Refactor the paths_of_cost algorithm

Support conditions that require certain nodes to be skipped
This commit is contained in:
Loïc Lecrenier 2023-03-30 12:11:11 +02:00
parent 01e24dd630
commit aa9592455c
2 changed files with 191 additions and 130 deletions

View File

@ -9,141 +9,202 @@ use crate::search::new::query_graph::QueryNode;
use crate::search::new::small_bitmap::SmallBitmap;
use crate::Result;
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
pub fn visit_paths_of_cost(
&mut self,
from: Interned<QueryNode>,
cost: u16,
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
mut visit: impl FnMut(
&[Interned<G::Condition>],
&mut Self,
&mut DeadEndsCache<G::Condition>,
) -> Result<ControlFlow<()>>,
) -> Result<()> {
let _ = self.visit_paths_of_cost_rec(
from,
cost,
all_distances,
dead_ends_cache,
&mut visit,
&mut vec![],
&mut SmallBitmap::for_interned_values_in(&self.conditions_interner),
dead_ends_cache.forbidden.clone(),
)?;
type VisitFn<'f, G> = &'f mut dyn FnMut(
&[Interned<<G as RankingRuleGraphTrait>::Condition>],
&mut RankingRuleGraph<G>,
&mut DeadEndsCache<<G as RankingRuleGraphTrait>::Condition>,
) -> Result<ControlFlow<()>>;
struct VisitorContext<'a, G: RankingRuleGraphTrait> {
graph: &'a mut RankingRuleGraph<G>,
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
}
struct VisitorState<G: RankingRuleGraphTrait> {
remaining_cost: u64,
path: Vec<Interned<G::Condition>>,
visited_conditions: SmallBitmap<G::Condition>,
visited_nodes: SmallBitmap<QueryNode>,
forbidden_conditions: SmallBitmap<G::Condition>,
forbidden_conditions_to_nodes: SmallBitmap<QueryNode>,
}
pub struct PathVisitor<'a, G: RankingRuleGraphTrait> {
state: VisitorState<G>,
ctx: VisitorContext<'a, G>,
}
impl<'a, G: RankingRuleGraphTrait> PathVisitor<'a, G> {
pub fn new(
cost: u64,
graph: &'a mut RankingRuleGraph<G>,
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
) -> Self {
Self {
state: VisitorState {
remaining_cost: cost,
path: vec![],
visited_conditions: SmallBitmap::for_interned_values_in(&graph.conditions_interner),
visited_nodes: SmallBitmap::for_interned_values_in(&graph.query_graph.nodes),
forbidden_conditions: SmallBitmap::for_interned_values_in(
&graph.conditions_interner,
),
forbidden_conditions_to_nodes: SmallBitmap::for_interned_values_in(
&graph.query_graph.nodes,
),
},
ctx: VisitorContext { graph, all_costs_from_node, dead_ends_cache },
}
}
pub fn visit_paths(mut self, visit: VisitFn<G>) -> Result<()> {
let _ =
self.state.visit_node(self.ctx.graph.query_graph.root_node, visit, &mut self.ctx)?;
Ok(())
}
pub fn visit_paths_of_cost_rec(
}
impl<G: RankingRuleGraphTrait> VisitorState<G> {
fn visit_node(
&mut self,
from: Interned<QueryNode>,
cost: u16,
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
visit: &mut impl FnMut(
&[Interned<G::Condition>],
&mut Self,
&mut DeadEndsCache<G::Condition>,
) -> Result<ControlFlow<()>>,
prev_conditions: &mut Vec<Interned<G::Condition>>,
cur_path: &mut SmallBitmap<G::Condition>,
mut forbidden_conditions: SmallBitmap<G::Condition>,
) -> Result<bool> {
from_node: Interned<QueryNode>,
visit: VisitFn<G>,
ctx: &mut VisitorContext<G>,
) -> Result<ControlFlow<(), bool>> {
let mut any_valid = false;
let edges = self.edges_of_node.get(from).clone();
'edges_loop: for edge_idx in edges.iter() {
let Some(edge) = self.edges_store.get(edge_idx).as_ref() else { continue };
if cost < edge.cost as u16 {
let edges = ctx.graph.edges_of_node.get(from_node).clone();
for edge_idx in edges.iter() {
let Some(edge) = ctx.graph.edges_store.get(edge_idx).clone() else { continue };
if self.remaining_cost < edge.cost as u64 {
continue;
}
let next_any_valid = match edge.condition {
None => {
if edge.dest_node == self.query_graph.end_node {
any_valid = true;
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
match control_flow {
ControlFlow::Continue(_) => {}
ControlFlow::Break(_) => return Ok(true),
}
true
} else {
self.visit_paths_of_cost_rec(
self.remaining_cost -= edge.cost as u64;
let cf = match edge.condition {
Some(condition) => self.visit_condition(
condition,
edge.dest_node,
cost - edge.cost as u16,
all_distances,
dead_ends_cache,
&edge.nodes_to_skip,
visit,
prev_conditions,
cur_path,
forbidden_conditions.clone(),
)?
}
}
Some(condition) => {
if forbidden_conditions.contains(condition)
|| all_distances
.get(edge.dest_node)
.iter()
.all(|next_cost| *next_cost != cost - edge.cost as u16)
{
continue;
}
cur_path.insert(condition);
prev_conditions.push(condition);
let mut new_forbidden_conditions = forbidden_conditions.clone();
if let Some(next_forbidden) =
dead_ends_cache.forbidden_conditions_after_prefix(prev_conditions)
{
new_forbidden_conditions.union(&next_forbidden);
}
let next_any_valid = if edge.dest_node == self.query_graph.end_node {
any_valid = true;
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
match control_flow {
ControlFlow::Continue(_) => {}
ControlFlow::Break(_) => return Ok(true),
}
true
} else {
self.visit_paths_of_cost_rec(
edge.dest_node,
cost - edge.cost as u16,
all_distances,
dead_ends_cache,
visit,
prev_conditions,
cur_path,
new_forbidden_conditions,
)?
ctx,
)?,
None => self.visit_no_condition(edge.dest_node, &edge.nodes_to_skip, visit, ctx)?,
};
cur_path.remove(condition);
prev_conditions.pop();
next_any_valid
}
};
any_valid |= next_any_valid;
self.remaining_cost += edge.cost as u64;
let ControlFlow::Continue(next_any_valid) = cf else {
return Ok(ControlFlow::Break(()));
};
if next_any_valid {
forbidden_conditions =
dead_ends_cache.forbidden_conditions_for_all_prefixes_up_to(prev_conditions);
if cur_path.intersects(&forbidden_conditions) {
break 'edges_loop;
self.forbidden_conditions = ctx
.dead_ends_cache
.forbidden_conditions_for_all_prefixes_up_to(self.path.iter().copied());
if self.visited_conditions.intersects(&self.forbidden_conditions) {
break;
}
}
any_valid |= next_any_valid;
}
Ok(ControlFlow::Continue(any_valid))
}
fn visit_no_condition(
&mut self,
dest_node: Interned<QueryNode>,
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
visit: VisitFn<G>,
ctx: &mut VisitorContext<G>,
) -> Result<ControlFlow<(), bool>> {
if ctx
.all_costs_from_node
.get(dest_node)
.iter()
.all(|next_cost| *next_cost != self.remaining_cost)
{
return Ok(ControlFlow::Continue(false));
}
if dest_node == ctx.graph.query_graph.end_node {
let control_flow = visit(&self.path, ctx.graph, ctx.dead_ends_cache)?;
match control_flow {
ControlFlow::Continue(_) => Ok(ControlFlow::Continue(true)),
ControlFlow::Break(_) => Ok(ControlFlow::Break(())),
}
} else {
let old_fbct = self.forbidden_conditions_to_nodes.clone();
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
let cf = self.visit_node(dest_node, visit, ctx)?;
self.forbidden_conditions_to_nodes = old_fbct;
Ok(cf)
}
}
fn visit_condition(
&mut self,
condition: Interned<G::Condition>,
dest_node: Interned<QueryNode>,
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
visit: VisitFn<G>,
ctx: &mut VisitorContext<G>,
) -> Result<ControlFlow<(), bool>> {
assert!(dest_node != ctx.graph.query_graph.end_node);
if self.forbidden_conditions_to_nodes.contains(dest_node)
|| edge_forbidden_nodes.intersects(&self.visited_nodes)
{
return Ok(ControlFlow::Continue(false));
}
if self.forbidden_conditions.contains(condition) {
return Ok(ControlFlow::Continue(false));
}
if ctx
.all_costs_from_node
.get(dest_node)
.iter()
.all(|next_cost| *next_cost != self.remaining_cost)
{
return Ok(ControlFlow::Continue(false));
}
self.path.push(condition);
self.visited_nodes.insert(dest_node);
self.visited_conditions.insert(condition);
let old_fc = self.forbidden_conditions.clone();
if let Some(next_forbidden) =
ctx.dead_ends_cache.forbidden_conditions_after_prefix(self.path.iter().copied())
{
self.forbidden_conditions.union(&next_forbidden);
}
let old_fctn = self.forbidden_conditions_to_nodes.clone();
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
let cf = self.visit_node(dest_node, visit, ctx)?;
self.forbidden_conditions_to_nodes = old_fctn;
self.forbidden_conditions = old_fc;
self.visited_conditions.remove(condition);
self.visited_nodes.remove(dest_node);
self.path.pop();
Ok(cf)
}
}
Ok(any_valid)
}
pub fn initialize_distances_with_necessary_edges(&self) -> MappedInterner<QueryNode, Vec<u16>> {
let mut distances_to_end = self.query_graph.nodes.map(|_| vec![]);
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
pub fn find_all_costs_to_end(&self) -> MappedInterner<QueryNode, Vec<u64>> {
let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]);
let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len());
let mut node_stack = VecDeque::new();
*distances_to_end.get_mut(self.query_graph.end_node) = vec![0];
*costs_to_end.get_mut(self.query_graph.end_node) = vec![0];
for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() {
node_stack.push_back(prev_node);
@ -151,22 +212,22 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
}
while let Some(cur_node) = node_stack.pop_front() {
let mut self_distances = BTreeSet::<u16>::new();
let mut self_costs = BTreeSet::<u64>::new();
let cur_node_edges = &self.edges_of_node.get(cur_node);
for edge_idx in cur_node_edges.iter() {
let edge = self.edges_store.get(edge_idx).as_ref().unwrap();
let succ_node = edge.dest_node;
let succ_distances = distances_to_end.get(succ_node);
for succ_distance in succ_distances {
self_distances.insert(edge.cost as u16 + succ_distance);
let succ_costs = costs_to_end.get(succ_node);
for succ_distance in succ_costs {
self_costs.insert(edge.cost as u64 + succ_distance);
}
}
let distances_to_end_cur_node = distances_to_end.get_mut(cur_node);
for cost in self_distances.iter() {
distances_to_end_cur_node.push(*cost);
let costs_to_end_cur_node = costs_to_end.get_mut(cur_node);
for cost in self_costs.iter() {
costs_to_end_cur_node.push(*cost);
}
*distances_to_end.get_mut(cur_node) = self_distances.into_iter().collect();
*costs_to_end.get_mut(cur_node) = self_costs.into_iter().collect();
for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() {
if !enqueued.contains(prev_node) {
node_stack.push_back(prev_node);
@ -174,6 +235,6 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
}
}
}
distances_to_end
costs_to_end
}
}

View File

@ -36,12 +36,12 @@ impl<T> DeadEndsCache<T> {
}
pub fn forbidden_conditions_for_all_prefixes_up_to(
&mut self,
prefix: &[Interned<T>],
prefix: impl Iterator<Item = Interned<T>>,
) -> SmallBitmap<T> {
let mut forbidden = self.forbidden.clone();
let mut cursor = self;
for c in prefix.iter() {
if let Some(next) = cursor.advance(*c) {
for c in prefix {
if let Some(next) = cursor.advance(c) {
cursor = next;
forbidden.union(&cursor.forbidden);
} else {
@ -52,11 +52,11 @@ impl<T> DeadEndsCache<T> {
}
pub fn forbidden_conditions_after_prefix(
&mut self,
prefix: &[Interned<T>],
prefix: impl Iterator<Item = Interned<T>>,
) -> Option<SmallBitmap<T>> {
let mut cursor = self;
for c in prefix.iter() {
if let Some(next) = cursor.advance(*c) {
for c in prefix {
if let Some(next) = cursor.advance(c) {
cursor = next;
} else {
return None;