diff --git a/milli/src/search/new/db_cache.rs b/milli/src/search/new/db_cache.rs index 6193f4c58..ad843b045 100644 --- a/milli/src/search/new/db_cache.rs +++ b/milli/src/search/new/db_cache.rs @@ -9,6 +9,7 @@ use roaring::RoaringBitmap; use super::interner::Interned; use super::Word; +use crate::heed_codec::StrBEU16Codec; use crate::{ CboRoaringBitmapCodec, CboRoaringBitmapLenCodec, Result, RoaringBitmapCodec, SearchContext, }; @@ -34,6 +35,9 @@ pub struct DatabaseCache<'ctx> { pub words_fst: Option>>, pub word_position_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, pub word_fid_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, + pub word_prefix_fid_docids: FxHashMap<(Interned, u16), Option<&'ctx [u8]>>, + pub word_fids: FxHashMap, Vec>, + pub word_prefix_fids: FxHashMap, Vec>, } impl<'ctx> DatabaseCache<'ctx> { fn get_value<'v, K1, KC>( @@ -284,4 +288,70 @@ impl<'ctx> SearchContext<'ctx> { .map(|bytes| CboRoaringBitmapCodec::bytes_decode(bytes).ok_or(heed::Error::Decoding.into())) .transpose() } + + pub fn get_db_word_prefix_fid_docids( + &mut self, + word_prefix: Interned, + fid: u16, + ) -> Result> { + DatabaseCache::get_value( + self.txn, + (word_prefix, fid), + &(self.word_interner.get(word_prefix).as_str(), fid), + &mut self.db_cache.word_prefix_fid_docids, + self.index.word_prefix_fid_docids.remap_data_type::(), + )? + .map(|bytes| CboRoaringBitmapCodec::bytes_decode(bytes).ok_or(heed::Error::Decoding.into())) + .transpose() + } + + pub fn get_db_word_fids(&mut self, word: Interned) -> Result> { + let fids = match self.db_cache.word_fids.entry(word) { + Entry::Occupied(fids) => fids.get().clone(), + Entry::Vacant(entry) => { + let key = self.word_interner.get(word).as_bytes(); + let mut fids = vec![]; + let remap_key_type = self + .index + .word_fid_docids + .remap_types::() + .prefix_iter(self.txn, key)? + .remap_key_type::(); + for result in remap_key_type { + let ((_, fid), value) = result?; + // filling other caches to avoid searching for them again + self.db_cache.word_fid_docids.insert((word, fid), Some(value)); + fids.push(fid); + } + entry.insert(fids.clone()); + fids + } + }; + Ok(fids) + } + + pub fn get_db_word_prefix_fids(&mut self, word_prefix: Interned) -> Result> { + let fids = match self.db_cache.word_prefix_fids.entry(word_prefix) { + Entry::Occupied(fids) => fids.get().clone(), + Entry::Vacant(entry) => { + let key = self.word_interner.get(word_prefix).as_bytes(); + let mut fids = vec![]; + let remap_key_type = self + .index + .word_prefix_fid_docids + .remap_types::() + .prefix_iter(self.txn, key)? + .remap_key_type::(); + for result in remap_key_type { + let ((_, fid), value) = result?; + // filling other caches to avoid searching for them again + self.db_cache.word_prefix_fid_docids.insert((word_prefix, fid), Some(value)); + fids.push(fid); + } + entry.insert(fids.clone()); + fids + } + }; + Ok(fids) + } } diff --git a/milli/src/search/new/graph_based_ranking_rule.rs b/milli/src/search/new/graph_based_ranking_rule.rs index 41a96dd9e..3ee16ed50 100644 --- a/milli/src/search/new/graph_based_ranking_rule.rs +++ b/milli/src/search/new/graph_based_ranking_rule.rs @@ -44,8 +44,8 @@ use super::interner::{Interned, MappedInterner}; use super::logger::SearchLogger; use super::query_graph::QueryNode; use super::ranking_rule_graph::{ - ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, ProximityGraph, RankingRuleGraph, - RankingRuleGraphTrait, TypoGraph, + AttributeGraph, ConditionDocIdsCache, DeadEndsCache, ExactnessGraph, ProximityGraph, + RankingRuleGraph, RankingRuleGraphTrait, TypoGraph, }; use super::small_bitmap::SmallBitmap; use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext}; @@ -59,6 +59,12 @@ impl GraphBasedRankingRule { Self::new_with_id("proximity".to_owned(), terms_matching_strategy) } } +pub type Attribute = GraphBasedRankingRule; +impl GraphBasedRankingRule { + pub fn new(terms_matching_strategy: Option) -> Self { + Self::new_with_id("attribute".to_owned(), terms_matching_strategy) + } +} pub type Typo = GraphBasedRankingRule; impl GraphBasedRankingRule { pub fn new(terms_matching_strategy: Option) -> Self { diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 9f8d8699f..16eccb393 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,7 +28,7 @@ use std::collections::HashSet; use bucket_sort::bucket_sort; use charabia::TokenizerBuilder; use db_cache::DatabaseCache; -use graph_based_ranking_rule::{Proximity, Typo}; +use graph_based_ranking_rule::{Attribute, Proximity, Typo}; use heed::RoTxn; use interner::DedupInterner; pub use logger::visual::VisualSearchLogger; @@ -174,7 +174,7 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( let mut typo = false; let mut proximity = false; let mut sort = false; - let attribute = false; + let mut attribute = false; let mut exactness = false; let mut asc = HashSet::new(); let mut desc = HashSet::new(); @@ -222,8 +222,8 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( if attribute { continue; } - // todo!(); - // attribute = false; + attribute = true; + ranking_rules.push(Box::new(Attribute::new(None))); } crate::Criterion::Sort => { if sort { diff --git a/milli/src/search/new/query_term/phrase.rs b/milli/src/search/new/query_term/phrase.rs index 2ea8e0d39..033c5cf12 100644 --- a/milli/src/search/new/query_term/phrase.rs +++ b/milli/src/search/new/query_term/phrase.rs @@ -13,4 +13,8 @@ impl Interned { let p = ctx.phrase_interner.get(self); p.words.iter().flatten().map(|w| ctx.word_interner.get(*w)).join(" ") } + pub fn words(self, ctx: &SearchContext) -> Vec>> { + let p = ctx.phrase_interner.get(self); + p.words.clone() + } } diff --git a/milli/src/search/new/ranking_rule_graph/attribute/mod.rs b/milli/src/search/new/ranking_rule_graph/attribute/mod.rs new file mode 100644 index 000000000..a2981c604 --- /dev/null +++ b/milli/src/search/new/ranking_rule_graph/attribute/mod.rs @@ -0,0 +1,85 @@ +use fxhash::FxHashSet; +use roaring::RoaringBitmap; + +use super::{ComputedCondition, RankingRuleGraphTrait}; +use crate::search::new::interner::{DedupInterner, Interned}; +use crate::search::new::query_term::LocatedQueryTermSubset; +use crate::search::new::resolve_query_graph::compute_query_term_subset_docids_within_field_id; +use crate::search::new::SearchContext; +use crate::Result; + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct AttributeCondition { + term: LocatedQueryTermSubset, + fid: u16, +} + +pub enum AttributeGraph {} + +impl RankingRuleGraphTrait for AttributeGraph { + type Condition = AttributeCondition; + + fn resolve_condition( + ctx: &mut SearchContext, + condition: &Self::Condition, + universe: &RoaringBitmap, + ) -> Result { + let AttributeCondition { term, .. } = condition; + // maybe compute_query_term_subset_docids_within_field_id should accept a universe as argument + let mut docids = compute_query_term_subset_docids_within_field_id( + ctx, + &term.term_subset, + condition.fid, + )?; + docids &= universe; + + Ok(ComputedCondition { + docids, + universe_len: universe.len(), + start_term_subset: None, + end_term_subset: term.clone(), + }) + } + + fn build_edges( + ctx: &mut SearchContext, + conditions_interner: &mut DedupInterner, + _from: Option<&LocatedQueryTermSubset>, + to_term: &LocatedQueryTermSubset, + ) -> Result)>> { + let term = to_term; + + let mut all_fields = FxHashSet::default(); + for word in term.term_subset.all_single_words_except_prefix_db(ctx)? { + let fields = ctx.get_db_word_fids(word.interned())?; + all_fields.extend(fields); + } + + for phrase in term.term_subset.all_phrases(ctx)? { + for &word in phrase.words(ctx).iter().flatten() { + let fields = ctx.get_db_word_fids(word)?; + all_fields.extend(fields); + } + } + + if let Some(word_prefix) = term.term_subset.use_prefix_db(ctx) { + let fields = ctx.get_db_word_prefix_fids(word_prefix.interned())?; + all_fields.extend(fields); + } + + let mut edges = vec![]; + for fid in all_fields { + // TODO: We can improve performances and relevancy by storing + // the term subsets associated to each field ids fetched. + edges.push(( + fid as u32 * term.term_ids.len() as u32, // TODO improve the fid score i.e. fid^10. + conditions_interner.insert(AttributeCondition { + term: term.clone(), // TODO remove this ugly clone + fid, + }), + )); + } + + Ok(edges) + } +} diff --git a/milli/src/search/new/ranking_rule_graph/mod.rs b/milli/src/search/new/ranking_rule_graph/mod.rs index 6a9bfff93..fe31029b4 100644 --- a/milli/src/search/new/ranking_rule_graph/mod.rs +++ b/milli/src/search/new/ranking_rule_graph/mod.rs @@ -10,6 +10,8 @@ mod cheapest_paths; mod condition_docids_cache; mod dead_ends_cache; +/// Implementation of the `attribute` ranking rule +mod attribute; /// Implementation of the `exactness` ranking rule mod exactness; /// Implementation of the `proximity` ranking rule @@ -19,6 +21,7 @@ mod typo; use std::hash::Hash; +pub use attribute::{AttributeCondition, AttributeGraph}; pub use cheapest_paths::PathVisitor; pub use condition_docids_cache::ConditionDocIdsCache; pub use dead_ends_cache::DeadEndsCache; diff --git a/milli/src/search/new/resolve_query_graph.rs b/milli/src/search/new/resolve_query_graph.rs index f4938ca12..a125caa39 100644 --- a/milli/src/search/new/resolve_query_graph.rs +++ b/milli/src/search/new/resolve_query_graph.rs @@ -33,6 +33,8 @@ pub fn compute_query_term_subset_docids( ctx: &mut SearchContext, term: &QueryTermSubset, ) -> Result { + // TODO Use the roaring::MultiOps trait + let mut docids = RoaringBitmap::new(); for word in term.all_single_words_except_prefix_db(ctx)? { if let Some(word_docids) = ctx.word_docids(word)? { @@ -52,6 +54,39 @@ pub fn compute_query_term_subset_docids( Ok(docids) } +pub fn compute_query_term_subset_docids_within_field_id( + ctx: &mut SearchContext, + term: &QueryTermSubset, + fid: u16, +) -> Result { + // TODO Use the roaring::MultiOps trait + + let mut docids = RoaringBitmap::new(); + for word in term.all_single_words_except_prefix_db(ctx)? { + if let Some(word_fid_docids) = ctx.get_db_word_fid_docids(word.interned(), fid)? { + docids |= word_fid_docids; + } + } + + for phrase in term.all_phrases(ctx)? { + for &word in phrase.words(ctx).iter().flatten() { + if let Some(word_fid_docids) = ctx.get_db_word_fid_docids(word, fid)? { + docids |= word_fid_docids; + } + } + } + + if let Some(word_prefix) = term.use_prefix_db(ctx) { + if let Some(word_fid_docids) = + ctx.get_db_word_prefix_fid_docids(word_prefix.interned(), fid)? + { + docids |= word_fid_docids; + } + } + + Ok(docids) +} + pub fn compute_query_graph_docids( ctx: &mut SearchContext, q: &QueryGraph, diff --git a/milli/src/search/new/tests/attribute.rs b/milli/src/search/new/tests/attribute.rs new file mode 100644 index 000000000..f9b29881b --- /dev/null +++ b/milli/src/search/new/tests/attribute.rs @@ -0,0 +1,58 @@ +use std::collections::HashMap; + +use crate::{ + index::tests::TempIndex, search::new::tests::collect_field_values, Criterion, Search, + SearchResult, TermsMatchingStrategy, +}; + +fn create_index() -> TempIndex { + let index = TempIndex::new(); + + index + .update_settings(|s| { + s.set_primary_key("id".to_owned()); + s.set_searchable_fields(vec![ + "title".to_owned(), + "description".to_owned(), + "plot".to_owned(), + ]); + s.set_criteria(vec![Criterion::Attribute]); + }) + .unwrap(); + + index + .add_documents(documents!([ + { + "id": 0, + "title": "the quick brown fox jumps over the lazy dog", + "description": "Pack my box with five dozen liquor jugs", + "plot": "How vexingly quick daft zebras jump", + }, + { + "id": 1, + "title": "Pack my box with five dozen liquor jugs", + "description": "the quick brown foxes jump over the lazy dog", + "plot": "How vexingly quick daft zebras jump", + }, + { + "id": 2, + "title": "How vexingly quick daft zebras jump", + "description": "Pack my box with five dozen liquor jugs", + "plot": "the quick brown fox jumps over the lazy dog", + } + ])) + .unwrap(); + index +} + +#[test] +fn test_attributes_are_ranked_correctly() { + let index = create_index(); + let txn = index.read_txn().unwrap(); + + let mut s = Search::new(&txn, &index); + s.terms_matching_strategy(TermsMatchingStrategy::All); + s.query("the quick brown fox"); + let SearchResult { documents_ids, .. } = s.execute().unwrap(); + insta::assert_snapshot!(format!("{documents_ids:?}"), @"[0, 1, 2]"); +} diff --git a/milli/src/search/new/tests/mod.rs b/milli/src/search/new/tests/mod.rs index 898276858..9d6d9e159 100644 --- a/milli/src/search/new/tests/mod.rs +++ b/milli/src/search/new/tests/mod.rs @@ -1,3 +1,4 @@ +pub mod attribute; pub mod distinct; #[cfg(feature = "default")] pub mod language;