diff --git a/milli/src/search/new/query_graph.rs b/milli/src/search/new/query_graph.rs index 726a1460c..821c1a226 100644 --- a/milli/src/search/new/query_graph.rs +++ b/milli/src/search/new/query_graph.rs @@ -1,7 +1,8 @@ -use std::collections::HashSet; use std::fmt::Debug; +use std::{collections::HashSet, fmt}; use heed::RoTxn; +use roaring::RoaringBitmap; use super::{ db_cache::DatabaseCache, @@ -19,21 +20,31 @@ pub enum QueryNode { #[derive(Debug, Clone)] pub struct Edges { - pub incoming: HashSet, - pub outgoing: HashSet, + // TODO: use a tiny bitset instead + // something like a simple Vec where most queries will see a vector of one element + pub predecessors: RoaringBitmap, + pub successors: RoaringBitmap, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct NodeIndex(pub u32); +impl fmt::Display for NodeIndex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } } #[derive(Debug, Clone)] pub struct QueryGraph { - pub root_node: usize, - pub end_node: usize, + pub root_node: NodeIndex, + pub end_node: NodeIndex, pub nodes: Vec, pub edges: Vec, } fn _assert_sizes() { let _: [u8; 112] = [0; std::mem::size_of::()]; - let _: [u8; 96] = [0; std::mem::size_of::()]; + let _: [u8; 48] = [0; std::mem::size_of::()]; } impl Default for QueryGraph { @@ -41,32 +52,32 @@ impl Default for QueryGraph { fn default() -> Self { let nodes = vec![QueryNode::Start, QueryNode::End]; let edges = vec![ - Edges { incoming: HashSet::new(), outgoing: HashSet::new() }, - Edges { incoming: HashSet::new(), outgoing: HashSet::new() }, + Edges { predecessors: RoaringBitmap::new(), successors: RoaringBitmap::new() }, + Edges { predecessors: RoaringBitmap::new(), successors: RoaringBitmap::new() }, ]; - Self { root_node: 0, end_node: 1, nodes, edges } + Self { root_node: NodeIndex(0), end_node: NodeIndex(1), nodes, edges } } } impl QueryGraph { - fn connect_to_node(&mut self, from_nodes: &[usize], end_node: usize) { + fn connect_to_node(&mut self, from_nodes: &[NodeIndex], to_node: NodeIndex) { for &from_node in from_nodes { - self.edges[from_node].outgoing.insert(end_node); - self.edges[end_node].incoming.insert(from_node); + self.edges[from_node.0 as usize].successors.insert(to_node.0); + self.edges[to_node.0 as usize].predecessors.insert(from_node.0); } } - fn add_node(&mut self, from_nodes: &[usize], node: QueryNode) -> usize { - let new_node_idx = self.nodes.len(); + fn add_node(&mut self, from_nodes: &[NodeIndex], node: QueryNode) -> NodeIndex { + let new_node_idx = self.nodes.len() as u32; self.nodes.push(node); self.edges.push(Edges { - incoming: from_nodes.iter().copied().collect(), - outgoing: HashSet::new(), + predecessors: from_nodes.iter().map(|x| x.0).collect(), + successors: RoaringBitmap::new(), }); for from_node in from_nodes { - self.edges[*from_node].outgoing.insert(new_node_idx); + self.edges[from_node.0 as usize].successors.insert(new_node_idx); } - new_node_idx + NodeIndex(new_node_idx) } } @@ -88,7 +99,7 @@ impl QueryGraph { let word_set = index.words_fst(txn)?; let mut graph = QueryGraph::default(); - let (mut prev2, mut prev1, mut prev0): (Vec, Vec, Vec) = + let (mut prev2, mut prev1, mut prev0): (Vec, Vec, Vec) = (vec![], vec![], vec![graph.root_node]); // TODO: add all the word derivations found in the fst @@ -162,38 +173,41 @@ impl QueryGraph { Ok(graph) } - pub fn remove_nodes(&mut self, nodes: &[usize]) { + pub fn remove_nodes(&mut self, nodes: &[NodeIndex]) { for &node in nodes { - self.nodes[node] = QueryNode::Deleted; - let edges = self.edges[node].clone(); - for &pred in edges.incoming.iter() { - self.edges[pred].outgoing.remove(&node); + self.nodes[node.0 as usize] = QueryNode::Deleted; + let edges = self.edges[node.0 as usize].clone(); + for pred in edges.predecessors.iter() { + self.edges[pred as usize].successors.remove(node.0); } - for succ in edges.outgoing { - self.edges[succ].incoming.remove(&node); + for succ in edges.successors { + self.edges[succ as usize].predecessors.remove(node.0); } - self.edges[node] = Edges { incoming: HashSet::new(), outgoing: HashSet::new() }; + self.edges[node.0 as usize] = + Edges { predecessors: RoaringBitmap::new(), successors: RoaringBitmap::new() }; } } - pub fn remove_nodes_keep_edges(&mut self, nodes: &[usize]) { + pub fn remove_nodes_keep_edges(&mut self, nodes: &[NodeIndex]) { for &node in nodes { - self.nodes[node] = QueryNode::Deleted; - let edges = self.edges[node].clone(); - for &pred in edges.incoming.iter() { - self.edges[pred].outgoing.remove(&node); - self.edges[pred].outgoing.extend(edges.outgoing.iter()); + self.nodes[node.0 as usize] = QueryNode::Deleted; + let edges = self.edges[node.0 as usize].clone(); + for pred in edges.predecessors.iter() { + self.edges[pred as usize].successors.remove(node.0); + self.edges[pred as usize].successors |= &edges.successors; } - for succ in edges.outgoing { - self.edges[succ].incoming.remove(&node); - self.edges[succ].incoming.extend(edges.incoming.iter()); + for succ in edges.successors { + self.edges[succ as usize].predecessors.remove(node.0); + self.edges[succ as usize].predecessors |= &edges.predecessors; } - self.edges[node] = Edges { incoming: HashSet::new(), outgoing: HashSet::new() }; + self.edges[node.0 as usize] = + Edges { predecessors: RoaringBitmap::new(), successors: RoaringBitmap::new() }; } } pub fn remove_words_at_position(&mut self, position: i8) { let mut nodes_to_remove_keeping_edges = vec![]; let mut nodes_to_remove = vec![]; for (node_idx, node) in self.nodes.iter().enumerate() { + let node_idx = NodeIndex(node_idx as u32); let QueryNode::Term(LocatedQueryTerm { value: _, positions }) = node else { continue }; if positions.contains(&position) { nodes_to_remove_keeping_edges.push(node_idx) @@ -213,11 +227,11 @@ impl QueryGraph { let mut nodes_to_remove = vec![]; for (node_idx, node) in self.nodes.iter().enumerate() { if (!matches!(node, QueryNode::End | QueryNode::Deleted) - && self.edges[node_idx].outgoing.is_empty()) + && self.edges[node_idx].successors.is_empty()) || (!matches!(node, QueryNode::Start | QueryNode::Deleted) - && self.edges[node_idx].incoming.is_empty()) + && self.edges[node_idx].predecessors.is_empty()) { - nodes_to_remove.push(node_idx); + nodes_to_remove.push(NodeIndex(node_idx as u32)); } } if nodes_to_remove.is_empty() { @@ -301,14 +315,14 @@ node [shape = "record"] continue; } desc.push_str(&format!("{node} [label = {:?}]", &self.nodes[node],)); - if node == self.root_node { + if node == self.root_node.0 as usize { desc.push_str("[color = blue]"); - } else if node == self.end_node { + } else if node == self.end_node.0 as usize { desc.push_str("[color = red]"); } desc.push_str(";\n"); - for edge in self.edges[node].outgoing.iter() { + for edge in self.edges[node].successors.iter() { desc.push_str(&format!("{node} -> {edge};\n")); } // for edge in self.edges[node].incoming.iter() { diff --git a/milli/src/search/new/ranking_rule_graph/build.rs b/milli/src/search/new/ranking_rule_graph/build.rs index 605fe82d1..45dda3c1f 100644 --- a/milli/src/search/new/ranking_rule_graph/build.rs +++ b/milli/src/search/new/ranking_rule_graph/build.rs @@ -1,10 +1,11 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use heed::RoTxn; +use roaring::RoaringBitmap; use super::{Edge, RankingRuleGraph, RankingRuleGraphTrait}; use crate::new::db_cache::DatabaseCache; -use crate::new::QueryGraph; +use crate::new::{NodeIndex, QueryGraph}; use crate::{Index, Result}; impl RankingRuleGraph { @@ -14,29 +15,38 @@ impl RankingRuleGraph { db_cache: &mut DatabaseCache<'transaction>, query_graph: QueryGraph, ) -> Result { - let mut ranking_rule_graph = Self { query_graph, all_edges: vec![], node_edges: vec![] }; + let mut ranking_rule_graph = + Self { query_graph, all_edges: vec![], node_edges: vec![], successors: vec![] }; for (node_idx, node) in ranking_rule_graph.query_graph.nodes.iter().enumerate() { - ranking_rule_graph.node_edges.push(BTreeSet::new()); + ranking_rule_graph.node_edges.push(RoaringBitmap::new()); + ranking_rule_graph.successors.push(RoaringBitmap::new()); let new_edges = ranking_rule_graph.node_edges.last_mut().unwrap(); + let new_successors = ranking_rule_graph.successors.last_mut().unwrap(); let Some(from_node_data) = G::build_visit_from_node(index, txn, db_cache, node)? else { continue }; - for &successor_idx in ranking_rule_graph.query_graph.edges[node_idx].outgoing.iter() { - let to_node = &ranking_rule_graph.query_graph.nodes[successor_idx]; - let Some(edges) = G::build_visit_to_node(index, txn, db_cache, to_node, &from_node_data)? else { continue }; + for successor_idx in ranking_rule_graph.query_graph.edges[node_idx].successors.iter() { + let to_node = &ranking_rule_graph.query_graph.nodes[successor_idx as usize]; + let mut edges = + G::build_visit_to_node(index, txn, db_cache, to_node, &from_node_data)?; + if edges.is_empty() { + continue; + } + edges.sort_by_key(|e| e.0); for (cost, details) in edges { ranking_rule_graph.all_edges.push(Some(Edge { - from_node: node_idx, - to_node: successor_idx, + from_node: NodeIndex(node_idx as u32), + to_node: NodeIndex(successor_idx), cost, details, })); - new_edges.insert(ranking_rule_graph.all_edges.len() - 1); + new_edges.insert(ranking_rule_graph.all_edges.len() as u32 - 1); + new_successors.insert(successor_idx); } } } - ranking_rule_graph.simplify(); + // ranking_rule_graph.simplify(); Ok(ranking_rule_graph) } diff --git a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs index 3bd43fd6f..f1c1035a3 100644 --- a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs +++ b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs @@ -1,6 +1,9 @@ use std::collections::{BTreeMap, HashSet}; use itertools::Itertools; +use roaring::RoaringBitmap; + +use crate::new::NodeIndex; use super::{ empty_paths_cache::EmptyPathsCache, paths_map::PathsMap, Edge, EdgeIndex, RankingRuleGraph, @@ -14,18 +17,11 @@ pub struct Path { } struct DijkstraState { - unvisited: HashSet, // should be a small bitset - distances: Vec, // or binary heap (f64, usize) + unvisited: RoaringBitmap, // should be a small bitset? + distances: Vec, // or binary heap, or btreemap? (f64, usize) edges: Vec, edge_costs: Vec, - paths: Vec>, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct PathEdgeId { - pub from: usize, - pub to: usize, - pub id: Id, + paths: Vec>, } pub struct KCheapestPathsState { @@ -127,9 +123,10 @@ impl KCheapestPathsState { // for all the paths already found that share a common prefix with the root path // we delete the edge from the spur node to the next one for edge_index_to_remove in self.cheapest_paths.edge_indices_after_prefix(root_path) { - let was_removed = graph.node_edges[*spur_node].remove(&edge_index_to_remove.0); + let was_removed = + graph.node_edges[spur_node.0 as usize].remove(edge_index_to_remove.0 as u32); if was_removed { - tmp_removed_edges.push(edge_index_to_remove.0); + tmp_removed_edges.push(edge_index_to_remove.0 as u32); } } @@ -137,7 +134,7 @@ impl KCheapestPathsState { // we will combine it with the root path to get a potential kth cheapest path let spur_path = graph.cheapest_path_to_end(*spur_node); // restore the temporarily removed edges - graph.node_edges[*spur_node].extend(tmp_removed_edges); + graph.node_edges[spur_node.0 as usize].extend(tmp_removed_edges); let Some(spur_path) = spur_path else { continue; }; let total_cost = root_cost + spur_path.cost; @@ -182,68 +179,73 @@ impl KCheapestPathsState { } impl RankingRuleGraph { - fn cheapest_path_to_end(&self, from: usize) -> Option { + fn cheapest_path_to_end(&self, from: NodeIndex) -> Option { let mut dijkstra = DijkstraState { - unvisited: (0..self.query_graph.nodes.len()).collect(), + unvisited: (0..self.query_graph.nodes.len() as u32).collect(), distances: vec![u64::MAX; self.query_graph.nodes.len()], edges: vec![EdgeIndex(usize::MAX); self.query_graph.nodes.len()], edge_costs: vec![u8::MAX; self.query_graph.nodes.len()], paths: vec![None; self.query_graph.nodes.len()], }; - dijkstra.distances[from] = 0; + dijkstra.distances[from.0 as usize] = 0; - // TODO: could use a binary heap here to store the distances - while let Some(&cur_node) = - dijkstra.unvisited.iter().min_by_key(|&&n| dijkstra.distances[n]) + // TODO: could use a binary heap here to store the distances, or a btreemap + while let Some(cur_node) = + dijkstra.unvisited.iter().min_by_key(|&n| dijkstra.distances[n as usize]) { - let cur_node_dist = dijkstra.distances[cur_node]; + let cur_node_dist = dijkstra.distances[cur_node as usize]; if cur_node_dist == u64::MAX { return None; } - if cur_node == self.query_graph.end_node { + if cur_node == self.query_graph.end_node.0 { break; } - let succ_cur_node: HashSet<_> = self.node_edges[cur_node] - .iter() - .map(|e| self.all_edges[*e].as_ref().unwrap().to_node) - .collect(); + // this is expensive, but shouldn't + // ideally I could quickly get a bitmap of all a node's successors + // then take the intersection with unvisited + let succ_cur_node: &RoaringBitmap = &self.successors[cur_node as usize]; + // .iter() + // .map(|e| self.all_edges[e as usize].as_ref().unwrap().to_node.0) + // .collect(); // TODO: this intersection may be slow but shouldn't be, // can use a bitmap intersection instead - let unvisited_succ_cur_node = succ_cur_node.intersection(&dijkstra.unvisited); - for &succ in unvisited_succ_cur_node { - let Some((cheapest_edge, cheapest_edge_cost)) = self.cheapest_edge(cur_node, succ) else { + let unvisited_succ_cur_node = succ_cur_node & &dijkstra.unvisited; + for succ in unvisited_succ_cur_node { + // cheapest_edge() is also potentially too expensive + let Some((cheapest_edge, cheapest_edge_cost)) = self.cheapest_edge(NodeIndex(cur_node), NodeIndex(succ)) else { continue }; // println!("cur node dist {cur_node_dist}"); - let old_dist_succ = &mut dijkstra.distances[succ]; + let old_dist_succ = &mut dijkstra.distances[succ as usize]; let new_potential_distance = cur_node_dist + cheapest_edge_cost as u64; if new_potential_distance < *old_dist_succ { *old_dist_succ = new_potential_distance; - dijkstra.edges[succ] = cheapest_edge; - dijkstra.edge_costs[succ] = cheapest_edge_cost; - dijkstra.paths[succ] = Some(cur_node); + dijkstra.edges[succ as usize] = cheapest_edge; + dijkstra.edge_costs[succ as usize] = cheapest_edge_cost; + dijkstra.paths[succ as usize] = Some(NodeIndex(cur_node)); } } - dijkstra.unvisited.remove(&cur_node); + dijkstra.unvisited.remove(cur_node); } let mut cur = self.query_graph.end_node; // let mut edge_costs = vec![]; // let mut distances = vec![]; let mut path_edges = vec![]; - while let Some(n) = dijkstra.paths[cur] { - path_edges.push(dijkstra.edges[cur]); + while let Some(n) = dijkstra.paths[cur.0 as usize] { + path_edges.push(dijkstra.edges[cur.0 as usize]); cur = n; } path_edges.reverse(); - Some(Path { edges: path_edges, cost: dijkstra.distances[self.query_graph.end_node] }) + Some(Path { + edges: path_edges, + cost: dijkstra.distances[self.query_graph.end_node.0 as usize], + }) } - // TODO: this implementation is VERY fragile, as we assume that the edges are ordered by cost - // already. Change it. - pub fn cheapest_edge(&self, cur_node: usize, succ: usize) -> Option<(EdgeIndex, u8)> { + pub fn cheapest_edge(&self, cur_node: NodeIndex, succ: NodeIndex) -> Option<(EdgeIndex, u8)> { self.visit_edges(cur_node, succ, |edge_idx, edge| { std::ops::ControlFlow::Break((edge_idx, edge.cost)) }) diff --git a/milli/src/search/new/ranking_rule_graph/edge_docids_cache.rs b/milli/src/search/new/ranking_rule_graph/edge_docids_cache.rs index 301810847..0c9768f04 100644 --- a/milli/src/search/new/ranking_rule_graph/edge_docids_cache.rs +++ b/milli/src/search/new/ranking_rule_graph/edge_docids_cache.rs @@ -9,6 +9,12 @@ use crate::new::db_cache::DatabaseCache; use crate::new::BitmapOrAllRef; use crate::{Index, Result}; +// TODO: the cache should have a G::EdgeDetails as key +// but then it means that we should have a quick way of +// computing their hash and comparing them +// which can be done... +// by using a pointer (real, Rc, bumpalo, or in a vector)??? + pub struct EdgeDocidsCache { pub cache: HashMap, diff --git a/milli/src/search/new/ranking_rule_graph/mod.rs b/milli/src/search/new/ranking_rule_graph/mod.rs index 12f397df3..f7a312240 100644 --- a/milli/src/search/new/ranking_rule_graph/mod.rs +++ b/milli/src/search/new/ranking_rule_graph/mod.rs @@ -13,7 +13,7 @@ use heed::RoTxn; use roaring::RoaringBitmap; use super::db_cache::DatabaseCache; -use super::{QueryGraph, QueryNode}; +use super::{NodeIndex, QueryGraph, QueryNode}; use crate::{Index, Result}; #[derive(Debug, Clone)] @@ -24,8 +24,8 @@ pub enum EdgeDetails { #[derive(Debug, Clone)] pub struct Edge { - from_node: usize, - to_node: usize, + from_node: NodeIndex, + to_node: NodeIndex, cost: u8, details: EdgeDetails, } @@ -38,22 +38,20 @@ pub struct EdgePointer<'graph, E> { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct EdgeIndex(pub usize); -// { -// // TODO: they could all be u16 instead -// // There may be a way to store all the edge indices in a u32 as well, -// // if the edges are in a vector -// // then we can store sets of edges in a bitmap efficiently -// pub from: usize, -// pub to: usize, -// pub edge_idx: usize, -// } pub trait RankingRuleGraphTrait { + /// The details of an edge connecting two query nodes. These details + /// should be sufficient to compute the edge's cost and associated document ids + /// in [`compute_docids`](RankingRuleGraphTrait). type EdgeDetails: Sized; + type BuildVisitedFromNode; - fn edge_details_dot_label(edge: &Self::EdgeDetails) -> String; + /// Return the label of the given edge details, to be used when visualising + /// the ranking rule graph using GraphViz. + fn graphviz_edge_details_label(edge: &Self::EdgeDetails) -> String; + /// Compute the document ids associated with the given edge. fn compute_docids<'transaction>( index: &Index, txn: &'transaction RoTxn, @@ -61,6 +59,10 @@ pub trait RankingRuleGraphTrait { edge_details: &Self::EdgeDetails, ) -> Result; + /// Prepare to build the edges outgoing from `from_node`. + /// + /// This call is followed by zero, one or more calls to [`build_visit_to_node`](RankingRuleGraphTrait::build_visit_to_node), + /// which builds the actual edges. fn build_visit_from_node<'transaction>( index: &Index, txn: &'transaction RoTxn, @@ -68,39 +70,59 @@ pub trait RankingRuleGraphTrait { from_node: &QueryNode, ) -> Result>; + /// Return the cost and details of the edges going from the previously visited node + /// (with [`build_visit_from_node`](RankingRuleGraphTrait::build_visit_from_node)) to `to_node`. fn build_visit_to_node<'from_data, 'transaction: 'from_data>( index: &Index, txn: &'transaction RoTxn, db_cache: &mut DatabaseCache<'transaction>, to_node: &QueryNode, from_node_data: &'from_data Self::BuildVisitedFromNode, - ) -> Result)>>>; + ) -> Result)>>; } pub struct RankingRuleGraph { pub query_graph: QueryGraph, // pub edges: Vec>>>, pub all_edges: Vec>>, - pub node_edges: Vec>, + + pub node_edges: Vec, + + pub successors: Vec, + // to get the edges between two nodes: + // 1. get node_outgoing_edges[from] + // 2. get node_incoming_edges[to] + // 3. take intersection betweem the two + + // TODO: node edges could be different I guess + // something like: + // pub node_edges: Vec + // where each index is the result of: + // the successor index in the top 16 bits, the edge index in the bottom 16 bits + + // TODO: + // node_successors? + // pub removed_edges: HashSet, // pub tmp_removed_edges: HashSet, } impl RankingRuleGraph { - // NOTE: returns the edge even if it was removed pub fn get_edge(&self, edge_index: EdgeIndex) -> &Option> { &self.all_edges[edge_index.0] } + + // Visit all edges between the two given nodes in order of increasing cost. pub fn visit_edges<'graph, O>( &'graph self, - from: usize, - to: usize, + from: NodeIndex, + to: NodeIndex, mut visit: impl FnMut(EdgeIndex, &'graph Edge) -> ControlFlow, ) -> Option { - let from_edges = &self.node_edges[from]; - for &edge_idx in from_edges { - let edge = self.all_edges[edge_idx].as_ref().unwrap(); + let from_edges = &self.node_edges[from.0 as usize]; + for edge_idx in from_edges { + let edge = self.all_edges[edge_idx as usize].as_ref().unwrap(); if edge.to_node == to { - let cf = visit(EdgeIndex(edge_idx), edge); + let cf = visit(EdgeIndex(edge_idx as usize), edge); match cf { ControlFlow::Continue(_) => continue, ControlFlow::Break(o) => return Some(o), @@ -113,54 +135,61 @@ impl RankingRuleGraph { fn remove_edge(&mut self, edge_index: EdgeIndex) { let edge_opt = &mut self.all_edges[edge_index.0]; - let Some(Edge { from_node, to_node, cost, details }) = &edge_opt else { return }; - - let node_edges = &mut self.node_edges[*from_node]; - node_edges.remove(&edge_index.0); - + let Some(edge) = &edge_opt else { return }; + let (from_node, to_node) = (edge.from_node, edge.to_node); *edge_opt = None; - } - pub fn remove_nodes(&mut self, nodes: &[usize]) { - for &node in nodes { - let edge_indices = &mut self.node_edges[node]; - for edge_index in edge_indices.iter() { - self.all_edges[*edge_index] = None; - } - edge_indices.clear(); - let preds = &self.query_graph.edges[node].incoming; - for pred in preds { - let edge_indices = &mut self.node_edges[*pred]; - for edge_index in edge_indices.iter() { - let edge_opt = &mut self.all_edges[*edge_index]; - let Some(edge) = edge_opt else { continue; }; - if edge.to_node == node { - *edge_opt = None; - } - } - panic!("remove nodes is incorrect at the moment"); - edge_indices.clear(); - } - } - self.query_graph.remove_nodes(nodes); - } - pub fn simplify(&mut self) { - loop { - let mut nodes_to_remove = vec![]; - for (node_idx, node) in self.query_graph.nodes.iter().enumerate() { - if !matches!(node, QueryNode::End | QueryNode::Deleted) - && self.node_edges[node_idx].is_empty() - { - nodes_to_remove.push(node_idx); - } - } - if nodes_to_remove.is_empty() { - break; - } else { - self.remove_nodes(&nodes_to_remove); - } + let from_node_edges = &mut self.node_edges[from_node.0 as usize]; + from_node_edges.remove(edge_index.0 as u32); + + let mut new_successors_from_node = RoaringBitmap::new(); + for edge in from_node_edges.iter() { + let Edge { to_node, .. } = &self.all_edges[edge as usize].as_ref().unwrap(); + new_successors_from_node.insert(to_node.0); } + self.successors[from_node.0 as usize] = new_successors_from_node; } + // pub fn remove_nodes(&mut self, nodes: &[usize]) { + // for &node in nodes { + // let edge_indices = &mut self.node_edges[node]; + // for edge_index in edge_indices.iter() { + // self.all_edges[*edge_index] = None; + // } + // edge_indices.clear(); + + // let preds = &self.query_graph.edges[node].incoming; + // for pred in preds { + // let edge_indices = &mut self.node_edges[*pred]; + // for edge_index in edge_indices.iter() { + // let edge_opt = &mut self.all_edges[*edge_index]; + // let Some(edge) = edge_opt else { continue; }; + // if edge.to_node == node { + // *edge_opt = None; + // } + // } + // panic!("remove nodes is incorrect at the moment"); + // edge_indices.clear(); + // } + // } + // self.query_graph.remove_nodes(nodes); + // } + // pub fn simplify(&mut self) { + // loop { + // let mut nodes_to_remove = vec![]; + // for (node_idx, node) in self.query_graph.nodes.iter().enumerate() { + // if !matches!(node, QueryNode::End | QueryNode::Deleted) + // && self.node_edges[node_idx].is_empty() + // { + // nodes_to_remove.push(node_idx); + // } + // } + // if nodes_to_remove.is_empty() { + // break; + // } else { + // self.remove_nodes(&nodes_to_remove); + // } + // } + // } // fn is_removed_edge(&self, edge: EdgeIndex) -> bool { // self.removed_edges.contains(&edge) || self.tmp_removed_edges.contains(&edge) // } @@ -174,9 +203,9 @@ impl RankingRuleGraph { continue; } desc.push_str(&format!("{node_idx} [label = {:?}]", node)); - if node_idx == self.query_graph.root_node { + if node_idx == self.query_graph.root_node.0 as usize { desc.push_str("[color = blue]"); - } else if node_idx == self.query_graph.end_node { + } else if node_idx == self.query_graph.end_node.0 as usize { desc.push_str("[color = red]"); } desc.push_str(";\n"); @@ -195,7 +224,7 @@ impl RankingRuleGraph { desc.push_str(&format!( "{from_node} -> {to_node} [label = \"cost {cost} {edge_label}\"];\n", cost = edge.cost, - edge_label = G::edge_details_dot_label(details) + edge_label = G::graphviz_edge_details_label(details) )); } } diff --git a/milli/src/search/new/ranking_rule_graph/paths_map.rs b/milli/src/search/new/ranking_rule_graph/paths_map.rs index 589a1a52f..b1e4bb451 100644 --- a/milli/src/search/new/ranking_rule_graph/paths_map.rs +++ b/milli/src/search/new/ranking_rule_graph/paths_map.rs @@ -235,9 +235,9 @@ impl RankingRuleGraph { continue; } desc.push_str(&format!("{node_idx} [label = {:?}]", node)); - if node_idx == self.query_graph.root_node { + if node_idx == self.query_graph.root_node.0 as usize { desc.push_str("[color = blue]"); - } else if node_idx == self.query_graph.end_node { + } else if node_idx == self.query_graph.end_node.0 as usize { desc.push_str("[color = red]"); } desc.push_str(";\n"); @@ -262,7 +262,7 @@ impl RankingRuleGraph { desc.push_str(&format!( "{from_node} -> {to_node} [label = \"cost {cost} {edge_label}\", color = {color}];\n", cost = edge.cost, - edge_label = G::edge_details_dot_label(details), + edge_label = G::graphviz_edge_details_label(details), )); } } diff --git a/milli/src/search/new/ranking_rule_graph/proximity/build.rs b/milli/src/search/new/ranking_rule_graph/proximity/build.rs index 07ec3bb5e..7149f8bf6 100644 --- a/milli/src/search/new/ranking_rule_graph/proximity/build.rs +++ b/milli/src/search/new/ranking_rule_graph/proximity/build.rs @@ -51,11 +51,11 @@ pub fn visit_to_node<'transaction, 'from_data>( db_cache: &mut DatabaseCache<'transaction>, to_node: &QueryNode, from_node_data: &'from_data (WordDerivations, i8), -) -> Result)>>> { +) -> Result)>> { let (derivations1, pos1) = from_node_data; let term2 = match &to_node { - QueryNode::End => return Ok(Some(vec![(0, EdgeDetails::Unconditional)])), - QueryNode::Deleted | QueryNode::Start => return Ok(None), + QueryNode::End => return Ok(vec![(0, EdgeDetails::Unconditional)]), + QueryNode::Deleted | QueryNode::Start => return Ok(vec![]), QueryNode::Term(term) => term, }; let LocatedQueryTerm { value: value2, positions: pos2 } = term2; @@ -86,7 +86,7 @@ pub fn visit_to_node<'transaction, 'from_data>( // We want to effectively ignore this pair of terms // Unconditionally walk through the edge without computing the docids // But also what should the cost be? - return Ok(Some(vec![(0, EdgeDetails::Unconditional)])); + return Ok(vec![(0, EdgeDetails::Unconditional)]); } let updb1 = derivations1.use_prefix_db; @@ -161,5 +161,5 @@ pub fn visit_to_node<'transaction, 'from_data>( }) .collect::>(); new_edges.push((8 + (ngram_len2 - 1) as u8, EdgeDetails::Unconditional)); - Ok(Some(new_edges)) + Ok(new_edges) } diff --git a/milli/src/search/new/ranking_rule_graph/proximity/mod.rs b/milli/src/search/new/ranking_rule_graph/proximity/mod.rs index 199a5eb4a..e4905ead9 100644 --- a/milli/src/search/new/ranking_rule_graph/proximity/mod.rs +++ b/milli/src/search/new/ranking_rule_graph/proximity/mod.rs @@ -26,7 +26,7 @@ impl RankingRuleGraphTrait for ProximityGraph { type EdgeDetails = ProximityEdge; type BuildVisitedFromNode = (WordDerivations, i8); - fn edge_details_dot_label(edge: &Self::EdgeDetails) -> String { + fn graphviz_edge_details_label(edge: &Self::EdgeDetails) -> String { let ProximityEdge { pairs, proximity } = edge; format!(", prox {proximity}, {} pairs", pairs.len()) } @@ -55,7 +55,7 @@ impl RankingRuleGraphTrait for ProximityGraph { db_cache: &mut DatabaseCache<'transaction>, to_node: &QueryNode, from_node_data: &'from_data Self::BuildVisitedFromNode, - ) -> Result)>>> { + ) -> Result)>> { build::visit_to_node(index, txn, db_cache, to_node, from_node_data) } } diff --git a/milli/src/search/new/ranking_rules.rs b/milli/src/search/new/ranking_rules.rs index ce883ad6a..b980c1dc4 100644 --- a/milli/src/search/new/ranking_rules.rs +++ b/milli/src/search/new/ranking_rules.rs @@ -36,15 +36,17 @@ impl<'transaction, Query> RankingRuleOutputIter<'transaction, Query> } pub trait RankingRuleQueryTrait: Sized + Clone + 'static {} + #[derive(Clone)] pub struct PlaceholderQuery; impl RankingRuleQueryTrait for PlaceholderQuery {} impl RankingRuleQueryTrait for QueryGraph {} pub trait RankingRule<'transaction, Query: RankingRuleQueryTrait> { - // TODO: add an update_candidates function to deal with distinct - // attributes? - + /// Prepare the ranking rule such that it can start iterating over its + /// buckets using [`next_bucket`](RankingRule::next_bucket). + /// + /// The given universe is the universe that will be given to [`next_bucket`](RankingRule::next_bucket). fn start_iteration( &mut self, index: &Index, @@ -54,6 +56,13 @@ pub trait RankingRule<'transaction, Query: RankingRuleQueryTrait> { query: &Query, ) -> Result<()>; + /// Return the next bucket of this ranking rule. + /// + /// The returned candidates MUST be a subset of the given universe. + /// + /// The universe given as argument is either: + /// - a subset of the universe given to the previous call to [`next_bucket`](RankingRule::next_bucket); OR + /// - the universe given to [`start_iteration`](RankingRule::start_iteration) fn next_bucket( &mut self, index: &Index, @@ -62,6 +71,8 @@ pub trait RankingRule<'transaction, Query: RankingRuleQueryTrait> { universe: &RoaringBitmap, ) -> Result>>; + /// Finish iterating over the buckets, which yields control to the parent ranking rule + /// The next call to this ranking rule, if any, will be [`start_iteration`](RankingRule::start_iteration). fn end_iteration( &mut self, index: &Index, @@ -72,7 +83,7 @@ pub trait RankingRule<'transaction, Query: RankingRuleQueryTrait> { #[derive(Debug)] pub struct RankingRuleOutput { - /// The query tree that must be used by the child ranking rule to fetch candidates. + /// The query corresponding to the current bucket for the child ranking rule pub query: Q, /// The allowed candidates for the child ranking rule pub candidates: RoaringBitmap, @@ -151,7 +162,6 @@ pub fn execute_search<'transaction>( let ranking_rules_len = ranking_rules.len(); ranking_rules[0].start_iteration(index, txn, db_cache, universe, query_graph)?; - // TODO: parent_candidates could be used only during debugging? let mut candidates = vec![RoaringBitmap::default(); ranking_rules_len]; candidates[0] = universe.clone(); @@ -296,43 +306,43 @@ mod tests { let primary_key = index.primary_key(&txn).unwrap().unwrap(); let primary_key = index.fields_ids_map(&txn).unwrap().id(primary_key).unwrap(); - // loop { - // let start = Instant::now(); + loop { + let start = Instant::now(); - // let mut db_cache = DatabaseCache::default(); + let mut db_cache = DatabaseCache::default(); - // let query_graph = make_query_graph( - // &index, - // &txn, - // &mut db_cache, - // "released from prison by the government", - // ) - // .unwrap(); - // // println!("{}", query_graph.graphviz()); + let query_graph = make_query_graph( + &index, + &txn, + &mut db_cache, + "released from prison by the government", + ) + .unwrap(); + // println!("{}", query_graph.graphviz()); - // // TODO: filters + maybe distinct attributes? - // let universe = get_start_universe( - // &index, - // &txn, - // &mut db_cache, - // &query_graph, - // TermsMatchingStrategy::Last, - // ) - // .unwrap(); - // // println!("universe: {universe:?}"); + // TODO: filters + maybe distinct attributes? + let universe = get_start_universe( + &index, + &txn, + &mut db_cache, + &query_graph, + TermsMatchingStrategy::Last, + ) + .unwrap(); + // println!("universe: {universe:?}"); - // let results = execute_search( - // &index, - // &txn, - // &mut db_cache, - // &universe, - // &query_graph, /* 0, 20 */ - // ) - // .unwrap(); + let results = execute_search( + &index, + &txn, + &mut db_cache, + &universe, + &query_graph, /* 0, 20 */ + ) + .unwrap(); - // let elapsed = start.elapsed(); - // println!("{}us: {results:?}", elapsed.as_micros()); - // } + let elapsed = start.elapsed(); + println!("{}us: {results:?}", elapsed.as_micros()); + } let start = Instant::now(); let mut db_cache = DatabaseCache::default(); @@ -388,7 +398,7 @@ mod tests { let mut s = Search::new(&txn, &index); s.query("released from prison by the government"); s.terms_matching_strategy(TermsMatchingStrategy::Last); - // s.criterion_implementation_strategy(crate::CriterionImplementationStrategy::OnlySetBased); + s.criterion_implementation_strategy(crate::CriterionImplementationStrategy::OnlySetBased); let docs = s.execute().unwrap(); let elapsed = start.elapsed(); @@ -431,7 +441,7 @@ mod tests { builder.execute(|_| (), || false).unwrap(); } - // #[test] + #[test] fn _index_movies() { let mut options = EnvOpenOptions::new(); options.map_size(100 * 1024 * 1024 * 1024); // 100 GB @@ -446,20 +456,14 @@ mod tests { let config = IndexerConfig::default(); let mut builder = Settings::new(&mut wtxn, &index, &config); - builder.set_primary_key(primary_key.to_owned()); - let searchable_fields = searchable_fields.iter().map(|s| s.to_string()).collect(); builder.set_searchable_fields(searchable_fields); - let filterable_fields = filterable_fields.iter().map(|s| s.to_string()).collect(); builder.set_filterable_fields(filterable_fields); - - builder.set_criteria(vec![Criterion::Words]); - - // let sortable_fields = sortable_fields.iter().map(|s| s.to_string()).collect(); - // builder.set_sortable_fields(sortable_fields); - + builder.set_min_word_len_one_typo(5); + builder.set_min_word_len_two_typos(100); + builder.set_criteria(vec![Criterion::Words, Criterion::Proximity]); builder.execute(|_| (), || false).unwrap(); let config = IndexerConfig::default(); diff --git a/milli/src/search/new/resolve_query_graph.rs b/milli/src/search/new/resolve_query_graph.rs index 748524492..8bc56bb23 100644 --- a/milli/src/search/new/resolve_query_graph.rs +++ b/milli/src/search/new/resolve_query_graph.rs @@ -4,11 +4,12 @@ use std::collections::{HashMap, HashSet, VecDeque}; use super::db_cache::DatabaseCache; use super::query_term::{QueryTerm, WordDerivations}; -use super::QueryGraph; +use super::{NodeIndex, QueryGraph}; use crate::{Index, Result, RoaringBitmapCodec}; // TODO: manual performance metrics: access to DB, bitmap deserializations/operations, etc. +// TODO: reuse NodeDocidsCache in between calls to resolve_query_graph #[derive(Default)] pub struct NodeDocIdsCache { pub cache: HashMap, @@ -26,7 +27,7 @@ pub fn resolve_query_graph<'transaction>( // resolve_query_graph_rec(index, txn, q, q.root_node, &mut docids, &mut cache)?; - let mut nodes_resolved = HashSet::new(); + let mut nodes_resolved = RoaringBitmap::new(); // TODO: should be given as an argument and kept between invocations of resolve query graph let mut nodes_docids = vec![RoaringBitmap::new(); q.nodes.len()]; @@ -34,16 +35,16 @@ pub fn resolve_query_graph<'transaction>( next_nodes_to_visit.push_front(q.root_node); while let Some(node) = next_nodes_to_visit.pop_front() { - let predecessors = &q.edges[node].incoming; + let predecessors = &q.edges[node.0 as usize].predecessors; if !predecessors.is_subset(&nodes_resolved) { next_nodes_to_visit.push_back(node); continue; } // Take union of all predecessors - let predecessors_iter = predecessors.iter().map(|p| &nodes_docids[*p]); + let predecessors_iter = predecessors.iter().map(|p| &nodes_docids[p as usize]); let predecessors_docids = MultiOps::union(predecessors_iter); - let n = &q.nodes[node]; + let n = &q.nodes[node.0 as usize]; // println!("resolving {node} {n:?}, predecessors: {predecessors:?}, their docids: {predecessors_docids:?}"); let node_docids = match n { super::QueryNode::Term(located_term) => { @@ -95,18 +96,18 @@ pub fn resolve_query_graph<'transaction>( return Ok(predecessors_docids); } }; - nodes_resolved.insert(node); - nodes_docids[node] = node_docids; + nodes_resolved.insert(node.0); + nodes_docids[node.0 as usize] = node_docids; - for &succ in q.edges[node].outgoing.iter() { - if !next_nodes_to_visit.contains(&succ) && !nodes_resolved.contains(&succ) { - next_nodes_to_visit.push_back(succ); + for succ in q.edges[node.0 as usize].successors.iter() { + if !next_nodes_to_visit.contains(&NodeIndex(succ)) && !nodes_resolved.contains(succ) { + next_nodes_to_visit.push_back(NodeIndex(succ)); } } // This is currently slow but could easily be implemented very efficiently - for &prec in q.edges[node].incoming.iter() { - if q.edges[prec].outgoing.is_subset(&nodes_resolved) { - nodes_docids[prec].clear(); + for prec in q.edges[node.0 as usize].predecessors.iter() { + if q.edges[prec as usize].successors.is_subset(&nodes_resolved) { + nodes_docids[prec as usize].clear(); } } // println!("cached docids: {nodes_docids:?}");