diff --git a/http-ui/src/main.rs b/http-ui/src/main.rs index b730344f2..d05b69f2c 100644 --- a/http-ui/src/main.rs +++ b/http-ui/src/main.rs @@ -28,7 +28,7 @@ use warp::{Filter, http::Response}; use milli::tokenizer::{simple_tokenizer, TokenType}; use milli::update::UpdateIndexingStep::*; use milli::update::{UpdateBuilder, IndexDocumentsMethod, UpdateFormat}; -use milli::{obkv_to_json, Index, UpdateStore, SearchResult}; +use milli::{obkv_to_json, Index, UpdateStore, SearchResult, FacetCondition}; static GLOBAL_THREAD_POOL: OnceCell = OnceCell::new(); @@ -550,9 +550,12 @@ async fn main() -> anyhow::Result<()> { .body(include_str!("../public/logo-black.svg")) ); - #[derive(Deserialize)] + #[derive(Debug, Deserialize)] + #[serde(deny_unknown_fields)] + #[serde(rename_all = "camelCase")] struct QueryBody { query: Option, + facet_condition: Option, } let disable_highlighting = opt.disable_highlighting; @@ -569,6 +572,10 @@ async fn main() -> anyhow::Result<()> { if let Some(query) = query.query { search.query(query); } + if let Some(condition) = query.facet_condition { + let condition = FacetCondition::from_str(&rtxn, &index, &condition).unwrap(); + search.facet_condition(condition); + } let SearchResult { found_words, documents_ids } = search.execute().unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 12a24a59c..ff578dd4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ pub use self::criterion::{Criterion, default_criteria}; pub use self::external_documents_ids::ExternalDocumentsIds; pub use self::fields_ids_map::FieldsIdsMap; pub use self::index::Index; -pub use self::search::{Search, SearchResult}; +pub use self::search::{Search, FacetCondition, SearchResult}; pub use self::heed_codec::{ RoaringBitmapCodec, BEU32StrCodec, StrStrU8Codec, ObkvCodec, BoRoaringBitmapCodec, CboRoaringBitmapCodec, diff --git a/src/search.rs b/src/search.rs index ae2b5d127..17f25edfc 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,6 +1,8 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; +use std::fmt; +use anyhow::{bail, ensure, Context}; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::DFA; use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; @@ -8,8 +10,10 @@ use log::debug; use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; -use crate::query_tokens::{QueryTokens, QueryToken}; +use crate::facet::FacetType; +use crate::heed_codec::{CboRoaringBitmapCodec, facet::FacetValueI64Codec}; use crate::mdfs::Mdfs; +use crate::query_tokens::{QueryTokens, QueryToken}; use crate::{Index, DocumentId}; // Building these factories is not free. @@ -17,8 +21,91 @@ static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); +// TODO support also floats +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum FacetOperator { + GreaterThan(i64), + GreaterThanOrEqual(i64), + LowerThan(i64), + LowerThanOrEqual(i64), + Equal(i64), + Between(i64, i64), +} + +// TODO also support ANDs, ORs, NOTs. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum FacetCondition { + Operator(u8, FacetOperator), +} + +impl FacetCondition { + pub fn from_str( + rtxn: &heed::RoTxn, + index: &Index, + string: &str, + ) -> anyhow::Result> + { + use FacetCondition::*; + use FacetOperator::*; + + let fields_ids_map = index.fields_ids_map(rtxn)?; + let faceted_fields = index.faceted_fields(rtxn)?; + + // TODO use a better parsing technic + let mut iter = string.split_whitespace(); + + let field_name = match iter.next() { + Some(field_name) => field_name, + None => return Ok(None), + }; + + let field_id = fields_ids_map.id(&field_name).with_context(|| format!("field {} not found", field_name))?; + let field_type = faceted_fields.get(&field_id).with_context(|| format!("field {} is not faceted", field_name))?; + + ensure!(*field_type == FacetType::Integer, "Only conditions on integer facets"); + + match iter.next() { + Some(">") => { + let param = iter.next().context("missing parameter")?; + let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; + Ok(Some(Operator(field_id, GreaterThan(value)))) + }, + Some(">=") => { + let param = iter.next().context("missing parameter")?; + let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; + Ok(Some(Operator(field_id, GreaterThanOrEqual(value)))) + }, + Some("<") => { + let param = iter.next().context("missing parameter")?; + let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; + Ok(Some(Operator(field_id, LowerThan(value)))) + }, + Some("<=") => { + let param = iter.next().context("missing parameter")?; + let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; + Ok(Some(Operator(field_id, LowerThanOrEqual(value)))) + }, + Some("=") => { + let param = iter.next().context("missing parameter")?; + let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; + Ok(Some(Operator(field_id, Equal(value)))) + }, + Some(otherwise) => { + // BETWEEN or X TO Y (both inclusive) + let left_param = otherwise.parse().with_context(|| format!("invalid first TO parameter ({:?})", otherwise))?; + ensure!(iter.next().map_or(false, |s| s.eq_ignore_ascii_case("to")), "TO keyword missing or invalid"); + let next = iter.next().context("missing second TO parameter")?; + let right_param = next.parse().with_context(|| format!("invalid second TO parameter ({:?})", next))?; + Ok(Some(Operator(field_id, Between(left_param, right_param)))) + }, + None => bail!("missing facet filter first parameter"), + } + } +} + pub struct Search<'a> { query: Option, + facet_condition: Option, offset: usize, limit: usize, rtxn: &'a heed::RoTxn<'a>, @@ -27,7 +114,7 @@ pub struct Search<'a> { impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { - Search { query: None, offset: 0, limit: 20, rtxn, index } + Search { query: None, facet_condition: None, offset: 0, limit: 20, rtxn, index } } pub fn query(&mut self, query: impl Into) -> &mut Search<'a> { @@ -45,6 +132,11 @@ impl<'a> Search<'a> { self } + pub fn facet_condition(&mut self, condition: FacetCondition) -> &mut Search<'a> { + self.facet_condition = Some(condition); + self + } + /// Extracts the query words from the query string and returns the DFAs accordingly. /// TODO introduce settings for the number of typos regarding the words lengths. fn generate_query_dfas(query: &str) -> Vec<(String, bool, DFA)> { @@ -135,22 +227,66 @@ impl<'a> Search<'a> { pub fn execute(&self) -> anyhow::Result { let limit = self.limit; - let fst = self.index.words_fst(self.rtxn)?; // Construct the DFAs related to the query words. - let dfas = match self.query.as_deref().map(Self::generate_query_dfas) { - Some(dfas) if !dfas.is_empty() => dfas, - _ => { + let derived_words = match self.query.as_deref().map(Self::generate_query_dfas) { + Some(dfas) if !dfas.is_empty() => Some(self.fetch_words_docids(&fst, dfas)?), + _otherwise => None, + }; + + // We create the original candidates with the facet conditions results. + let facet_candidates = match self.facet_condition { + Some(FacetCondition::Operator(fid, operator)) => { + use std::ops::Bound::{Included, Excluded}; + use FacetOperator::*; + // Make sure we always bound the ranges with the field id, as the facets + // values are all in the same database and prefixed by the field id. + let range = match operator { + GreaterThan(val) => (Excluded((fid, val)), Included((fid, i64::MAX))), + GreaterThanOrEqual(val) => (Included((fid, val)), Included((fid, i64::MAX))), + LowerThan(val) => (Included((fid, i64::MIN)), Excluded((fid, val))), + LowerThanOrEqual(val) => (Included((fid, i64::MIN)), Included((fid, val))), + Equal(val) => (Included((fid, val)), Included((fid, val))), + Between(left, right) => (Included((fid, left)), Included((fid, right))), + }; + + let mut candidates = RoaringBitmap::new(); + + let db = self.index.facet_field_id_value_docids; + let db = db.remap_types::(); + for result in db.range(self.rtxn, &range)? { + let ((_fid, _value), docids) = result?; + candidates.union_with(&docids); + } + + Some(candidates) + }, + None => None, + }; + + let (candidates, derived_words) = match (facet_candidates, derived_words) { + (Some(mut facet_candidates), Some(derived_words)) => { + let words_candidates = Self::compute_candidates(&derived_words); + facet_candidates.intersect_with(&words_candidates); + (facet_candidates, derived_words) + }, + (None, Some(derived_words)) => { + (Self::compute_candidates(&derived_words), derived_words) + }, + (Some(facet_candidates), None) => { + // If the query is not set or results in no DFAs but + // there is some facet conditions we return a placeholder. + let documents_ids = facet_candidates.iter().take(limit).collect(); + return Ok(SearchResult { documents_ids, ..Default::default() }) + }, + (None, None) => { // If the query is not set or results in no DFAs we return a placeholder. let documents_ids = self.index.documents_ids(self.rtxn)?.iter().take(limit).collect(); return Ok(SearchResult { documents_ids, ..Default::default() }) }, }; - let derived_words = self.fetch_words_docids(&fst, dfas)?; - let candidates = Self::compute_candidates(&derived_words); - debug!("candidates: {:?}", candidates); // The mana depth first search is a revised DFS that explore @@ -175,6 +311,17 @@ impl<'a> Search<'a> { } } +impl fmt::Debug for Search<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Search") + .field("query", &self.query) + .field("facet_condition", &self.facet_condition) + .field("offset", &self.offset) + .field("limit", &self.limit) + .finish() + } +} + #[derive(Default)] pub struct SearchResult { pub found_words: HashSet,