diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 99126f60e..6f14977ec 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -57,8 +57,9 @@ pub use self::heed_codec::{ }; pub use self::index::Index; pub use self::search::{ - FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, Search, - SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, + FacetDistribution, FacetSearchResult, Filter, FormatOptions, MatchBounds, MatcherBuilder, + MatchingWords, Search, SearchForFacetValue, SearchResult, TermsMatchingStrategy, + DEFAULT_VALUES_PER_FACET, }; pub type Result = std::result::Result; diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index dc25c0f23..1e648d241 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -1,5 +1,7 @@ use std::fmt; +use fst::automaton::{Complement, Intersection, StartsWith, Str, Union}; +use fst::Streamer; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; @@ -7,9 +9,11 @@ use roaring::bitmap::RoaringBitmap; pub use self::facet::{FacetDistribution, Filter, DEFAULT_VALUES_PER_FACET}; pub use self::new::matches::{FormatOptions, MatchBounds, Matcher, MatcherBuilder, MatchingWords}; use self::new::PartialSearchResult; +use crate::error::UserError; +use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::{ - execute_search, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, SearchContext, + execute_search, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, SearchContext, BEU16, }; // Building these factories is not free. @@ -234,6 +238,103 @@ pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA { } } +pub struct SearchForFacetValue<'a> { + query: Option, + facet: String, + search_query: Search<'a>, +} + +impl<'a> SearchForFacetValue<'a> { + fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValue<'a> { + SearchForFacetValue { query: None, facet, search_query } + } + + fn query(&mut self, query: impl Into) -> &mut Self { + self.query = Some(query.into()); + self + } + + fn execute(&self) -> Result> { + let index = self.search_query.index; + let rtxn = self.search_query.rtxn; + + let sortable_fields = index.sortable_fields(rtxn)?; + if !sortable_fields.contains(&self.facet) { + // TODO create a new type of error + return Err(UserError::InvalidSortableAttribute { + field: self.facet.clone(), + valid_fields: sortable_fields.into_iter().collect(), + })?; + } + + let fields_ids_map = index.fields_ids_map(rtxn)?; + let (field_id, fst) = match fields_ids_map.id(&self.facet) { + Some(fid) => { + match self.search_query.index.facet_id_string_fst.get(rtxn, &BEU16::new(fid))? { + Some(fst) => (fid, fst), + None => todo!("return an error, is the user trying to search in numbers?"), + } + } + None => todo!("return an internal error bug"), + }; + + let search_candidates = self.search_query.execute()?.candidates; + + match self.query.as_ref() { + Some(query) => { + let is_prefix = true; + let starts = StartsWith(Str::new(get_first(query))); + let first = Intersection(build_dfa(query, 1, is_prefix), Complement(&starts)); + let second_dfa = build_dfa(query, 2, is_prefix); + let second = Intersection(&second_dfa, &starts); + let automaton = Union(first, &second); + + let mut stream = fst.search(automaton).into_stream(); + let mut result = vec![]; + while let Some(facet_value) = stream.next() { + let value = std::str::from_utf8(facet_value)?; + let key = FacetGroupKey { field_id, level: 0, left_bound: value }; + let docids = match index.facet_id_string_docids.get(rtxn, &key)? { + Some(FacetGroupValue { bitmap, .. }) => bitmap, + None => todo!("return an internal error"), + }; + let count = search_candidates.intersection_len(&docids); + if count != 0 { + result.push(FacetSearchResult { value: value.to_string(), count }); + } + } + + Ok(result) + } + None => { + let mut stream = fst.stream(); + let mut result = vec![]; + while let Some(facet_value) = stream.next() { + let value = std::str::from_utf8(facet_value)?; + let key = FacetGroupKey { field_id, level: 0, left_bound: value }; + let docids = match index.facet_id_string_docids.get(rtxn, &key)? { + Some(FacetGroupValue { bitmap, .. }) => bitmap, + None => todo!("return an internal error"), + }; + let count = search_candidates.intersection_len(&docids); + if count != 0 { + result.push(FacetSearchResult { value: value.to_string(), count }); + } + } + + Ok(result) + } + } + } +} + +pub struct FacetSearchResult { + /// The original facet value + pub value: String, + /// The number of documents associated to this facet + pub count: u64, +} + #[cfg(test)] mod test { #[allow(unused_imports)] diff --git a/milli/src/update/facets.rs b/milli/src/update/facets.rs deleted file mode 100644 index 8b1378917..000000000 --- a/milli/src/update/facets.rs +++ /dev/null @@ -1 +0,0 @@ -