From bd9aba4d7733af4a1ce2d3ac341b9a69421faf4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Lecrenier?= Date: Thu, 13 Apr 2023 10:46:09 +0200 Subject: [PATCH] Add "position" part of the attribute ranking rule --- milli/src/search/new/db_cache.rs | 77 +++++++++++++++ .../search/new/graph_based_ranking_rule.rs | 14 ++- milli/src/search/new/logger/visual.rs | 45 ++++++--- milli/src/search/new/mod.rs | 5 +- .../{attribute => fid}/mod.rs | 12 +-- .../src/search/new/ranking_rule_graph/mod.rs | 7 +- .../new/ranking_rule_graph/position/mod.rs | 93 +++++++++++++++++++ milli/src/search/new/resolve_query_graph.rs | 35 +++++++ .../tests/{attribute.rs => attribute_fid.rs} | 2 +- .../search/new/tests/attribute_position.rs | 52 +++++++++++ milli/src/search/new/tests/mod.rs | 3 +- 11 files changed, 314 insertions(+), 31 deletions(-) rename milli/src/search/new/ranking_rule_graph/{attribute => fid}/mod.rs (90%) create mode 100644 milli/src/search/new/ranking_rule_graph/position/mod.rs rename milli/src/search/new/tests/{attribute.rs => attribute_fid.rs} (99%) create mode 100644 milli/src/search/new/tests/attribute_position.rs diff --git a/milli/src/search/new/db_cache.rs b/milli/src/search/new/db_cache.rs index cf5332700..cf851a313 100644 --- a/milli/src/search/new/db_cache.rs +++ b/milli/src/search/new/db_cache.rs @@ -34,6 +34,10 @@ pub struct DatabaseCache<'ctx> { pub words_fst: Option>>, pub word_position_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, + pub word_prefix_position_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, + pub word_positions: FxHashMap, Vec>, + pub word_prefix_positions: FxHashMap, Vec>, + pub word_fid_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, pub word_prefix_fid_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, pub word_fids: FxHashMap, Vec>, @@ -356,4 +360,77 @@ impl<'ctx> SearchContext<'ctx> { }; Ok(fids) } + + pub fn get_db_word_prefix_position_docids( + &mut self, + word_prefix: Interned, + position: u16, + ) -> Result> { + DatabaseCache::get_value( + self.txn, + (word_prefix, position), + &(self.word_interner.get(word_prefix).as_str(), position), + &mut self.db_cache.word_prefix_position_docids, + self.index.word_prefix_position_docids.remap_data_type::(), + )? + .map(|bytes| CboRoaringBitmapCodec::bytes_decode(bytes).ok_or(heed::Error::Decoding.into())) + .transpose() + } + + pub fn get_db_word_positions(&mut self, word: Interned) -> Result> { + let positions = match self.db_cache.word_positions.entry(word) { + Entry::Occupied(positions) => positions.get().clone(), + Entry::Vacant(entry) => { + let mut key = self.word_interner.get(word).as_bytes().to_owned(); + key.push(0); + let mut positions = vec![]; + let remap_key_type = self + .index + .word_position_docids + .remap_types::() + .prefix_iter(self.txn, &key)? + .remap_key_type::(); + for result in remap_key_type { + let ((_, position), value) = result?; + // filling other caches to avoid searching for them again + self.db_cache.word_position_docids.insert((word, position), Some(value)); + positions.push(position); + } + entry.insert(positions.clone()); + positions + } + }; + Ok(positions) + } + + pub fn get_db_word_prefix_positions( + &mut self, + word_prefix: Interned, + ) -> Result> { + let positions = match self.db_cache.word_prefix_positions.entry(word_prefix) { + Entry::Occupied(positions) => positions.get().clone(), + Entry::Vacant(entry) => { + let mut key = self.word_interner.get(word_prefix).as_bytes().to_owned(); + key.push(0); + let mut positions = vec![]; + let remap_key_type = self + .index + .word_prefix_position_docids + .remap_types::() + .prefix_iter(self.txn, &key)? + .remap_key_type::(); + for result in remap_key_type { + let ((_, position), value) = result?; + // filling other caches to avoid searching for them again + self.db_cache + .word_prefix_position_docids + .insert((word_prefix, position), Some(value)); + positions.push(position); + } + entry.insert(positions.clone()); + positions + } + }; + Ok(positions) + } } diff --git a/milli/src/search/new/graph_based_ranking_rule.rs b/milli/src/search/new/graph_based_ranking_rule.rs index 3ee16ed50..0d22b5b1e 100644 --- a/milli/src/search/new/graph_based_ranking_rule.rs +++ b/milli/src/search/new/graph_based_ranking_rule.rs @@ -44,7 +44,7 @@ use super::interner::{Interned, MappedInterner}; use super::logger::SearchLogger; use super::query_graph::QueryNode; use super::ranking_rule_graph::{ - AttributeGraph, ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, ProximityGraph, + ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, FidGraph, PositionGraph, ProximityGraph, RankingRuleGraph, RankingRuleGraphTrait, TypoGraph, }; use super::small_bitmap::SmallBitmap; @@ -59,10 +59,16 @@ impl GraphBasedRankingRule { Self::new_with_id("proximity".to_owned(), terms_matching_strategy) } } -pub type Attribute = GraphBasedRankingRule; -impl GraphBasedRankingRule { +pub type Fid = GraphBasedRankingRule; +impl GraphBasedRankingRule { pub fn new(terms_matching_strategy: Option) -> Self { - Self::new_with_id("attribute".to_owned(), terms_matching_strategy) + Self::new_with_id("fid".to_owned(), terms_matching_strategy) + } +} +pub type Position = GraphBasedRankingRule; +impl GraphBasedRankingRule { + pub fn new(terms_matching_strategy: Option) -> Self { + Self::new_with_id("position".to_owned(), terms_matching_strategy) } } pub type Typo = GraphBasedRankingRule; diff --git a/milli/src/search/new/logger/visual.rs b/milli/src/search/new/logger/visual.rs index 7834f7e46..1cbe007d3 100644 --- a/milli/src/search/new/logger/visual.rs +++ b/milli/src/search/new/logger/visual.rs @@ -11,8 +11,8 @@ use crate::search::new::interner::Interned; use crate::search::new::query_graph::QueryNodeData; use crate::search::new::query_term::LocatedQueryTermSubset; use crate::search::new::ranking_rule_graph::{ - AttributeCondition, AttributeGraph, Edge, ProximityCondition, ProximityGraph, RankingRuleGraph, - RankingRuleGraphTrait, TypoCondition, TypoGraph, + Edge, FidCondition, FidGraph, PositionCondition, PositionGraph, ProximityCondition, + ProximityGraph, RankingRuleGraph, RankingRuleGraphTrait, TypoCondition, TypoGraph, }; use crate::search::new::ranking_rules::BoxRankingRule; use crate::search::new::{QueryGraph, QueryNode, RankingRule, SearchContext, SearchLogger}; @@ -29,15 +29,18 @@ pub enum SearchEvents { ProximityPaths { paths: Vec>> }, TypoGraph { graph: RankingRuleGraph }, TypoPaths { paths: Vec>> }, - AttributeGraph { graph: RankingRuleGraph }, - AttributePaths { paths: Vec>> }, + FidGraph { graph: RankingRuleGraph }, + FidPaths { paths: Vec>> }, + PositionGraph { graph: RankingRuleGraph }, + PositionPaths { paths: Vec>> }, } enum Location { Words, Typo, Proximity, - Attribute, + Fid, + Position, Other, } @@ -84,7 +87,8 @@ impl SearchLogger for VisualSearchLogger { "words" => Location::Words, "typo" => Location::Typo, "proximity" => Location::Proximity, - "attribute" => Location::Attribute, + "fid" => Location::Fid, + "position" => Location::Position, _ => Location::Other, }); } @@ -156,13 +160,20 @@ impl SearchLogger for VisualSearchLogger { self.events.push(SearchEvents::ProximityPaths { paths: paths.clone() }); } } - Location::Attribute => { - if let Some(graph) = state.downcast_ref::>() { - self.events.push(SearchEvents::AttributeGraph { graph: graph.clone() }); + Location::Fid => { + if let Some(graph) = state.downcast_ref::>() { + self.events.push(SearchEvents::FidGraph { graph: graph.clone() }); } - if let Some(paths) = state.downcast_ref::>>>() - { - self.events.push(SearchEvents::AttributePaths { paths: paths.clone() }); + if let Some(paths) = state.downcast_ref::>>>() { + self.events.push(SearchEvents::FidPaths { paths: paths.clone() }); + } + } + Location::Position => { + if let Some(graph) = state.downcast_ref::>() { + self.events.push(SearchEvents::PositionGraph { graph: graph.clone() }); + } + if let Some(paths) = state.downcast_ref::>>>() { + self.events.push(SearchEvents::PositionPaths { paths: paths.clone() }); } } Location::Other => {} @@ -327,9 +338,13 @@ impl<'ctx> DetailedLoggerFinish<'ctx> { SearchEvents::TypoPaths { paths } => { self.write_rr_graph_paths::(paths)?; } - SearchEvents::AttributeGraph { graph } => self.write_rr_graph(&graph)?, - SearchEvents::AttributePaths { paths } => { - self.write_rr_graph_paths::(paths)?; + SearchEvents::FidGraph { graph } => self.write_rr_graph(&graph)?, + SearchEvents::FidPaths { paths } => { + self.write_rr_graph_paths::(paths)?; + } + SearchEvents::PositionGraph { graph } => self.write_rr_graph(&graph)?, + SearchEvents::PositionPaths { paths } => { + self.write_rr_graph_paths::(paths)?; } } Ok(()) diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 16eccb393..b691e00e3 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,7 +28,7 @@ use std::collections::HashSet; use bucket_sort::bucket_sort; use charabia::TokenizerBuilder; use db_cache::DatabaseCache; -use graph_based_ranking_rule::{Attribute, Proximity, Typo}; +use graph_based_ranking_rule::{Fid, Position, Proximity, Typo}; use heed::RoTxn; use interner::DedupInterner; pub use logger::visual::VisualSearchLogger; @@ -223,7 +223,8 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( continue; } attribute = true; - ranking_rules.push(Box::new(Attribute::new(None))); + ranking_rules.push(Box::new(Fid::new(None))); + ranking_rules.push(Box::new(Position::new(None))); } crate::Criterion::Sort => { if sort { diff --git a/milli/src/search/new/ranking_rule_graph/attribute/mod.rs b/milli/src/search/new/ranking_rule_graph/fid/mod.rs similarity index 90% rename from milli/src/search/new/ranking_rule_graph/attribute/mod.rs rename to milli/src/search/new/ranking_rule_graph/fid/mod.rs index a2981c604..0f2cceaec 100644 --- a/milli/src/search/new/ranking_rule_graph/attribute/mod.rs +++ b/milli/src/search/new/ranking_rule_graph/fid/mod.rs @@ -9,22 +9,22 @@ use crate::search::new::SearchContext; use crate::Result; #[derive(Clone, PartialEq, Eq, Hash)] -pub struct AttributeCondition { +pub struct FidCondition { term: LocatedQueryTermSubset, fid: u16, } -pub enum AttributeGraph {} +pub enum FidGraph {} -impl RankingRuleGraphTrait for AttributeGraph { - type Condition = AttributeCondition; +impl RankingRuleGraphTrait for FidGraph { + type Condition = FidCondition; fn resolve_condition( ctx: &mut SearchContext, condition: &Self::Condition, universe: &RoaringBitmap, ) -> Result { - let AttributeCondition { term, .. } = condition; + let FidCondition { term, .. } = condition; // maybe compute_query_term_subset_docids_within_field_id should accept a universe as argument let mut docids = compute_query_term_subset_docids_within_field_id( ctx, @@ -73,7 +73,7 @@ impl RankingRuleGraphTrait for AttributeGraph { // the term subsets associated to each field ids fetched. edges.push(( fid as u32 * term.term_ids.len() as u32, // TODO improve the fid score i.e. fid^10. - conditions_interner.insert(AttributeCondition { + conditions_interner.insert(FidCondition { term: term.clone(), // TODO remove this ugly clone fid, }), diff --git a/milli/src/search/new/ranking_rule_graph/mod.rs b/milli/src/search/new/ranking_rule_graph/mod.rs index fe31029b4..db65afdd7 100644 --- a/milli/src/search/new/ranking_rule_graph/mod.rs +++ b/milli/src/search/new/ranking_rule_graph/mod.rs @@ -11,9 +11,11 @@ mod condition_docids_cache; mod dead_ends_cache; /// Implementation of the `attribute` ranking rule -mod attribute; +mod fid; /// Implementation of the `exactness` ranking rule mod exactness; +/// Implementation of the `position` ranking rule +mod position; /// Implementation of the `proximity` ranking rule mod proximity; /// Implementation of the `typo` ranking rule @@ -21,11 +23,12 @@ mod typo; use std::hash::Hash; -pub use attribute::{AttributeCondition, AttributeGraph}; +pub use fid::{FidCondition, FidGraph}; pub use cheapest_paths::PathVisitor; pub use condition_docids_cache::ConditionDocIdsCache; pub use dead_ends_cache::DeadEndsCache; pub use exactness::{ExactnessCondition, ExactnessGraph}; +pub use position::{PositionCondition, PositionGraph}; pub use proximity::{ProximityCondition, ProximityGraph}; use roaring::RoaringBitmap; pub use typo::{TypoCondition, TypoGraph}; diff --git a/milli/src/search/new/ranking_rule_graph/position/mod.rs b/milli/src/search/new/ranking_rule_graph/position/mod.rs new file mode 100644 index 000000000..81d013141 --- /dev/null +++ b/milli/src/search/new/ranking_rule_graph/position/mod.rs @@ -0,0 +1,93 @@ +use fxhash::FxHashSet; +use roaring::RoaringBitmap; + +use super::{ComputedCondition, RankingRuleGraphTrait}; +use crate::search::new::interner::{DedupInterner, Interned}; +use crate::search::new::query_term::LocatedQueryTermSubset; +use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_position; +use crate::search::new::SearchContext; +use crate::Result; + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct PositionCondition { + term: LocatedQueryTermSubset, + position: u16, +} + +pub enum PositionGraph {} + +impl RankingRuleGraphTrait for PositionGraph { + type Condition = PositionCondition; + + fn resolve_condition( + ctx: &mut SearchContext, + condition: &Self::Condition, + universe: &RoaringBitmap, + ) -> Result { + let PositionCondition { term, .. } = condition; + // maybe compute_query_term_subset_docids_within_position_id should accept a universe as argument + let mut docids = compute_query_term_subset_docids_within_position( + ctx, + &term.term_subset, + condition.position, + )?; + docids &= universe; + + Ok(ComputedCondition { + docids, + universe_len: universe.len(), + start_term_subset: None, + end_term_subset: term.clone(), + }) + } + + fn build_edges( + ctx: &mut SearchContext, + conditions_interner: &mut DedupInterner, + _from: Option<&LocatedQueryTermSubset>, + to_term: &LocatedQueryTermSubset, + ) -> Result)>> { + let term = to_term; + + let mut all_positions = FxHashSet::default(); + for word in term.term_subset.all_single_words_except_prefix_db(ctx)? { + let positions = ctx.get_db_word_positions(word.interned())?; + all_positions.extend(positions); + } + + for phrase in term.term_subset.all_phrases(ctx)? { + for &word in phrase.words(ctx).iter().flatten() { + let positions = ctx.get_db_word_positions(word)?; + all_positions.extend(positions); + } + } + + if let Some(word_prefix) = term.term_subset.use_prefix_db(ctx) { + let positions = ctx.get_db_word_prefix_positions(word_prefix.interned())?; + all_positions.extend(positions); + } + + let mut edges = vec![]; + for position in all_positions { + let cost = { + let mut cost = 0; + for i in 0..term.term_ids.len() { + cost += position as u32 + i as u32; + } + cost + }; + + // TODO: We can improve performances and relevancy by storing + // the term subsets associated to each position fetched. + edges.push(( + cost, + conditions_interner.insert(PositionCondition { + term: term.clone(), // TODO remove this ugly clone + position, + }), + )); + } + + Ok(edges) + } +} diff --git a/milli/src/search/new/resolve_query_graph.rs b/milli/src/search/new/resolve_query_graph.rs index a125caa39..b8eb623bb 100644 --- a/milli/src/search/new/resolve_query_graph.rs +++ b/milli/src/search/new/resolve_query_graph.rs @@ -87,6 +87,41 @@ pub fn compute_query_term_subset_docids_within_field_id( Ok(docids) } +pub fn compute_query_term_subset_docids_within_position( + ctx: &mut SearchContext, + term: &QueryTermSubset, + position: u16, +) -> Result { + // TODO Use the roaring::MultiOps trait + + let mut docids = RoaringBitmap::new(); + for word in term.all_single_words_except_prefix_db(ctx)? { + if let Some(word_position_docids) = + ctx.get_db_word_position_docids(word.interned(), position)? + { + docids |= word_position_docids; + } + } + + for phrase in term.all_phrases(ctx)? { + for &word in phrase.words(ctx).iter().flatten() { + if let Some(word_position_docids) = ctx.get_db_word_position_docids(word, position)? { + docids |= word_position_docids; + } + } + } + + if let Some(word_prefix) = term.use_prefix_db(ctx) { + if let Some(word_position_docids) = + ctx.get_db_word_prefix_position_docids(word_prefix.interned(), position)? + { + docids |= word_position_docids; + } + } + + Ok(docids) +} + pub fn compute_query_graph_docids( ctx: &mut SearchContext, q: &QueryGraph, diff --git a/milli/src/search/new/tests/attribute.rs b/milli/src/search/new/tests/attribute_fid.rs similarity index 99% rename from milli/src/search/new/tests/attribute.rs rename to milli/src/search/new/tests/attribute_fid.rs index b248f7953..ec7b7a69e 100644 --- a/milli/src/search/new/tests/attribute.rs +++ b/milli/src/search/new/tests/attribute_fid.rs @@ -95,7 +95,7 @@ fn create_index() -> TempIndex { } #[test] -fn test_attributes_simple() { +fn test_attribute_fid_simple() { let index = create_index(); let txn = index.read_txn().unwrap(); diff --git a/milli/src/search/new/tests/attribute_position.rs b/milli/src/search/new/tests/attribute_position.rs new file mode 100644 index 000000000..0eafedb97 --- /dev/null +++ b/milli/src/search/new/tests/attribute_position.rs @@ -0,0 +1,52 @@ +use crate::{index::tests::TempIndex, Criterion, Search, SearchResult, TermsMatchingStrategy}; + +fn create_index() -> TempIndex { + let index = TempIndex::new(); + + index + .update_settings(|s| { + s.set_primary_key("id".to_owned()); + s.set_searchable_fields(vec!["text".to_owned()]); + s.set_criteria(vec![Criterion::Attribute]); + }) + .unwrap(); + + index + .add_documents(documents!([ + { + "id": 0, + "text": "do you know about the quick and talented brown fox", + }, + { + "id": 1, + "text": "do you know about the quick brown fox", + }, + { + "id": 2, + "text": "the quick and talented brown fox", + }, + { + "id": 3, + "text": "fox brown quick the", + }, + { + "id": 4, + "text": "the quick brown fox", + }, + ])) + .unwrap(); + index +} + +#[test] +fn test_attribute_fid_simple() { + let index = create_index(); + + let txn = index.read_txn().unwrap(); + + let mut s = Search::new(&txn, &index); + s.terms_matching_strategy(TermsMatchingStrategy::All); + s.query("the quick brown fox"); + let SearchResult { documents_ids, .. } = s.execute().unwrap(); + insta::assert_snapshot!(format!("{documents_ids:?}"), @"[3, 4, 2, 1, 0]"); +} diff --git a/milli/src/search/new/tests/mod.rs b/milli/src/search/new/tests/mod.rs index 9d6d9e159..31b37933d 100644 --- a/milli/src/search/new/tests/mod.rs +++ b/milli/src/search/new/tests/mod.rs @@ -1,4 +1,5 @@ -pub mod attribute; +pub mod attribute_fid; +pub mod attribute_position; pub mod distinct; #[cfg(feature = "default")] pub mod language;