diff --git a/Cargo.lock b/Cargo.lock index 884bc19d9..70128cfa9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,27 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +[[package]] +name = "block-buffer" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0940dc441f31689269e10ac70eb1002a3a1d3ad1390e030043662eb7fe4688b" +dependencies = [ + "block-padding", + "byte-tools", + "byteorder", + "generic-array", +] + +[[package]] +name = "block-padding" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa79dedbb091f449f1f39e53edf88d5dbe95f895dae6135a8d7b881fb5af73f5" +dependencies = [ + "byte-tools", +] + [[package]] name = "bstr" version = "0.2.13" @@ -63,6 +84,12 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e8c087f005730276d1096a652e92a8bacee2e2472bcc9715a74d2bec38b5820" +[[package]] +name = "byte-tools" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3b5ca7a04898ad4bcd41c90c5285445ff5b791899bb1b0abdd2a2aa791211d7" + [[package]] name = "byteorder" version = "1.3.4" @@ -285,12 +312,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "digest" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" +dependencies = [ + "generic-array", +] + [[package]] name = "either" version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" +[[package]] +name = "fake-simd" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" + [[package]] name = "flate2" version = "1.0.17" @@ -324,6 +366,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "generic-array" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c68f0274ae0e023facc3c97b2e00f076be70e254bc851d972503b328db79b2ec" +dependencies = [ + "typenum", +] + [[package]] name = "getrandom" version = "0.1.14" @@ -621,6 +672,8 @@ dependencies = [ "obkv", "once_cell", "ordered-float", + "pest 2.1.3 (git+https://github.com/pest-parser/pest.git?rev=51fd1d49f1041f7839975664ef71fe15c7dcaf67)", + "pest_derive", "rayon", "ringtail", "roaring", @@ -717,6 +770,12 @@ version = "11.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a170cebd8021a008ea92e4db85a72f80b35df514ec664b296fdcbb654eac0b2c" +[[package]] +name = "opaque-debug" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" + [[package]] name = "ordered-float" version = "2.0.0" @@ -742,6 +801,57 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "pest" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53" +dependencies = [ + "ucd-trie", +] + +[[package]] +name = "pest" +version = "2.1.3" +source = "git+https://github.com/pest-parser/pest.git?rev=51fd1d49f1041f7839975664ef71fe15c7dcaf67#51fd1d49f1041f7839975664ef71fe15c7dcaf67" +dependencies = [ + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "833d1ae558dc601e9a60366421196a8d94bc0ac980476d0b67e1d0988d72b2d0" +dependencies = [ + "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99b8db626e31e5b81787b9783425769681b347011cc59471e33ea46d2ea0cf55" +dependencies = [ + "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54be6e404f5317079812fc8f9f5279de376d8856929e21c184ecf6bbd692a11d" +dependencies = [ + "maplit", + "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "sha-1", +] + [[package]] name = "pkg-config" version = "0.3.19" @@ -1026,6 +1136,18 @@ dependencies = [ "serde", ] +[[package]] +name = "sha-1" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d94d0bede923b3cea61f3f1ff57ff8cdfd77b400fb8f9998949e0cf04163df" +dependencies = [ + "block-buffer", + "digest", + "fake-simd", + "opaque-debug", +] + [[package]] name = "slice-group-by" version = "0.2.6" @@ -1234,6 +1356,18 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +[[package]] +name = "typenum" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "373c8a200f9e67a0c95e62a4f52fbf80c23b4381c05a17845531982fa99e6b33" + +[[package]] +name = "ucd-trie" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" + [[package]] name = "unicode-bidi" version = "0.3.4" diff --git a/Cargo.toml b/Cargo.toml index b77cf4a44..37c83b4f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,10 @@ structopt = { version = "0.3.14", default-features = false, features = ["wrap_he tempfile = "3.1.0" uuid = { version = "0.8.1", features = ["v4"] } +# facet filter parser +pest = { git = "https://github.com/pest-parser/pest.git", rev = "51fd1d49f1041f7839975664ef71fe15c7dcaf67" } +pest_derive = "2.1.0" + # documents words self-join itertools = "0.9.0" diff --git a/http-ui/Cargo.lock b/http-ui/Cargo.lock index 162ca96b2..b15700ce5 100644 --- a/http-ui/Cargo.lock +++ b/http-ui/Cargo.lock @@ -654,9 +654,9 @@ dependencies = [ [[package]] name = "heed" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d2740ccbbfb2a6e6ff0c43e0fc14981ed668fb45be5a4e7b2bc03fc8cca3d3e" +checksum = "cddc0d0d20adfc803b3e57c2d84447e134cad636202e68e275c65e3cbe63c616" dependencies = [ "byteorder", "heed-traits", @@ -934,6 +934,12 @@ dependencies = [ "cfg-if 0.1.10", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matches" version = "0.1.8" @@ -987,9 +993,12 @@ dependencies = [ "log", "memmap", "near-proximity", + "num-traits", "obkv", "once_cell", "ordered-float", + "pest 2.1.3 (git+https://github.com/pest-parser/pest.git?rev=51fd1d49f1041f7839975664ef71fe15c7dcaf67)", + "pest_derive", "rayon", "ringtail", "roaring", @@ -1231,6 +1240,57 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "pest" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f4872ae94d7b90ae48754df22fd42ad52ce740b8f370b03da4835417403e53" +dependencies = [ + "ucd-trie", +] + +[[package]] +name = "pest" +version = "2.1.3" +source = "git+https://github.com/pest-parser/pest.git?rev=51fd1d49f1041f7839975664ef71fe15c7dcaf67#51fd1d49f1041f7839975664ef71fe15c7dcaf67" +dependencies = [ + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "833d1ae558dc601e9a60366421196a8d94bc0ac980476d0b67e1d0988d72b2d0" +dependencies = [ + "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99b8db626e31e5b81787b9783425769681b347011cc59471e33ea46d2ea0cf55" +dependencies = [ + "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "pest_meta", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pest_meta" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54be6e404f5317079812fc8f9f5279de376d8856929e21c184ecf6bbd692a11d" +dependencies = [ + "maplit", + "pest 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "sha-1 0.8.2", +] + [[package]] name = "pin-project" version = "0.4.27" @@ -2024,6 +2084,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "373c8a200f9e67a0c95e62a4f52fbf80c23b4381c05a17845531982fa99e6b33" +[[package]] +name = "ucd-trie" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c" + [[package]] name = "unicase" version = "2.6.0" diff --git a/http-ui/src/main.rs b/http-ui/src/main.rs index e03261641..ca1ddcd45 100644 --- a/http-ui/src/main.rs +++ b/http-ui/src/main.rs @@ -614,7 +614,8 @@ async fn main() -> anyhow::Result<()> { search.query(query); } if let Some(condition) = query.facet_condition { - if let Some(condition) = FacetCondition::from_str(&rtxn, &index, &condition).unwrap() { + if !condition.trim().is_empty() { + let condition = FacetCondition::from_str(&rtxn, &index, &condition).unwrap(); search.facet_condition(condition); } } diff --git a/src/lib.rs b/src/lib.rs index ff578dd4b..320077b86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#[macro_use] extern crate pest_derive; + mod criterion; mod external_documents_ids; mod fields_ids_map; diff --git a/src/search/facet.rs b/src/search/facet.rs deleted file mode 100644 index 08daf5fbc..000000000 --- a/src/search/facet.rs +++ /dev/null @@ -1,307 +0,0 @@ -use std::error::Error as StdError; -use std::fmt::Debug; -use std::ops::Bound::{self, Unbounded, Included, Excluded}; -use std::str::FromStr; - -use anyhow::{bail, ensure, Context}; -use heed::types::{ByteSlice, DecodeIgnore}; -use itertools::Itertools; -use log::debug; -use num_traits::Bounded; -use roaring::RoaringBitmap; - -use crate::facet::FacetType; -use crate::heed_codec::facet::FacetValueStringCodec; -use crate::heed_codec::facet::{FacetLevelValueI64Codec, FacetLevelValueF64Codec}; -use crate::{Index, CboRoaringBitmapCodec}; - -use self::FacetCondition::*; -use self::FacetNumberOperator::*; - -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum FacetNumberOperator { - GreaterThan(T), - GreaterThanOrEqual(T), - LowerThan(T), - LowerThanOrEqual(T), - Equal(T), - Between(T, T), -} - -#[derive(Debug, Clone, PartialEq)] -pub enum FacetStringOperator { - Equal(String), -} - -// TODO also support ANDs, ORs, NOTs. -#[derive(Debug, Clone, PartialEq)] -pub enum FacetCondition { - OperatorI64(u8, FacetNumberOperator), - OperatorF64(u8, FacetNumberOperator), - OperatorString(u8, FacetStringOperator), -} - -impl FacetCondition { - pub fn from_str( - rtxn: &heed::RoTxn, - index: &Index, - string: &str, - ) -> anyhow::Result> - { - let fields_ids_map = index.fields_ids_map(rtxn)?; - let faceted_fields = index.faceted_fields(rtxn)?; - - // TODO use a better parsing technic - let mut iter = string.split_whitespace(); - - let field_name = match iter.next() { - Some(field_name) => field_name, - None => return Ok(None), - }; - - let field_id = fields_ids_map.id(&field_name).with_context(|| format!("field {} not found", field_name))?; - let field_type = faceted_fields.get(&field_id).with_context(|| format!("field {} is not faceted", field_name))?; - - match field_type { - FacetType::Integer => Self::parse_number_condition(iter).map(|op| Some(OperatorI64(field_id, op))), - FacetType::Float => Self::parse_number_condition(iter).map(|op| Some(OperatorF64(field_id, op))), - FacetType::String => Self::parse_string_condition(iter).map(|op| Some(OperatorString(field_id, op))), - } - } - - fn parse_string_condition<'a>( - mut iter: impl Iterator, - ) -> anyhow::Result - { - match iter.next() { - Some("=") | Some(":") => { - match iter.next() { - Some(q @ "\"") | Some(q @ "\'") => { - let string: String = iter.take_while(|&c| c != q).intersperse(" ").collect(); - Ok(FacetStringOperator::Equal(string.to_lowercase())) - }, - Some(param) => Ok(FacetStringOperator::Equal(param.to_lowercase())), - None => bail!("missing parameter"), - } - }, - _ => bail!("invalid facet string operator"), - } - } - - fn parse_number_condition<'a, T: FromStr>( - mut iter: impl Iterator, - ) -> anyhow::Result> - where T::Err: Send + Sync + StdError + 'static, - { - match iter.next() { - Some(">") => { - let param = iter.next().context("missing parameter")?; - let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(GreaterThan(value)) - }, - Some(">=") => { - let param = iter.next().context("missing parameter")?; - let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(GreaterThanOrEqual(value)) - }, - Some("<") => { - let param = iter.next().context("missing parameter")?; - let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(LowerThan(value)) - }, - Some("<=") => { - let param = iter.next().context("missing parameter")?; - let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(LowerThanOrEqual(value)) - }, - Some("=") => { - let param = iter.next().context("missing parameter")?; - let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(Equal(value)) - }, - Some(otherwise) => { - // BETWEEN or X TO Y (both inclusive) - let left_param = otherwise.parse().with_context(|| format!("invalid first TO parameter ({:?})", otherwise))?; - ensure!(iter.next().map_or(false, |s| s.eq_ignore_ascii_case("to")), "TO keyword missing or invalid"); - let next = iter.next().context("missing second TO parameter")?; - let right_param = next.parse().with_context(|| format!("invalid second TO parameter ({:?})", next))?; - Ok(Between(left_param, right_param)) - }, - None => bail!("missing facet filter first parameter"), - } - } - - /// Aggregates the documents ids that are part of the specified range automatically - /// going deeper through the levels. - fn explore_facet_levels<'t, T: 't, KC>( - rtxn: &'t heed::RoTxn, - db: heed::Database, - field_id: u8, - level: u8, - left: Bound, - right: Bound, - output: &mut RoaringBitmap, - ) -> anyhow::Result<()> - where - T: Copy + PartialEq + PartialOrd + Bounded + Debug, - KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, - KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, - { - match (left, right) { - // If the request is an exact value we must go directly to the deepest level. - (Included(l), Included(r)) if l == r && level > 0 => { - return Self::explore_facet_levels::(rtxn, db, field_id, 0, left, right, output); - }, - // lower TO upper when lower > upper must return no result - (Included(l), Included(r)) if l > r => return Ok(()), - (Included(l), Excluded(r)) if l >= r => return Ok(()), - (Excluded(l), Excluded(r)) if l >= r => return Ok(()), - (Excluded(l), Included(r)) if l >= r => return Ok(()), - (_, _) => (), - } - - let mut left_found = None; - let mut right_found = None; - - // We must create a custom iterator to be able to iterate over the - // requested range as the range iterator cannot express some conditions. - let left_bound = match left { - Included(left) => Included((field_id, level, left, T::min_value())), - Excluded(left) => Excluded((field_id, level, left, T::min_value())), - Unbounded => Unbounded, - }; - let right_bound = Included((field_id, level, T::max_value(), T::max_value())); - // We also make sure that we don't decode the data before we are sure we must return it. - let iter = db - .remap_key_type::() - .lazily_decode_data() - .range(rtxn, &(left_bound, right_bound))? - .take_while(|r| r.as_ref().map_or(true, |((.., r), _)| { - match right { - Included(right) => *r <= right, - Excluded(right) => *r < right, - Unbounded => true, - } - })) - .map(|r| r.and_then(|(key, lazy)| lazy.decode().map(|data| (key, data)))); - - debug!("Iterating between {:?} and {:?} (level {})", left, right, level); - - for (i, result) in iter.enumerate() { - let ((_fid, level, l, r), docids) = result?; - debug!("{:?} to {:?} (level {}) found {} documents", l, r, level, docids.len()); - output.union_with(&docids); - // We save the leftest and rightest bounds we actually found at this level. - if i == 0 { left_found = Some(l); } - right_found = Some(r); - } - - // Can we go deeper? - let deeper_level = match level.checked_sub(1) { - Some(level) => level, - None => return Ok(()), - }; - - // We must refine the left and right bounds of this range by retrieving the - // missing part in a deeper level. - match left_found.zip(right_found) { - Some((left_found, right_found)) => { - // If the bound is satisfied we avoid calling this function again. - if !matches!(left, Included(l) if l == left_found) { - let sub_right = Excluded(left_found); - debug!("calling left with {:?} to {:?} (level {})", left, sub_right, deeper_level); - Self::explore_facet_levels::(rtxn, db, field_id, deeper_level, left, sub_right, output)?; - } - if !matches!(right, Included(r) if r == right_found) { - let sub_left = Excluded(right_found); - debug!("calling right with {:?} to {:?} (level {})", sub_left, right, deeper_level); - Self::explore_facet_levels::(rtxn, db, field_id, deeper_level, sub_left, right, output)?; - } - }, - None => { - // If we found nothing at this level it means that we must find - // the same bounds but at a deeper, more precise level. - Self::explore_facet_levels::(rtxn, db, field_id, deeper_level, left, right, output)?; - }, - } - - Ok(()) - } - - fn evaluate_number_operator<'t, T: 't, KC>( - rtxn: &'t heed::RoTxn, - db: heed::Database, - field_id: u8, - operator: FacetNumberOperator, - ) -> anyhow::Result - where - T: Copy + PartialEq + PartialOrd + Bounded + Debug, - KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, - KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, - { - // Make sure we always bound the ranges with the field id and the level, - // as the facets values are all in the same database and prefixed by the - // field id and the level. - let (left, right) = match operator { - GreaterThan(val) => (Excluded(val), Included(T::max_value())), - GreaterThanOrEqual(val) => (Included(val), Included(T::max_value())), - LowerThan(val) => (Included(T::min_value()), Excluded(val)), - LowerThanOrEqual(val) => (Included(T::min_value()), Included(val)), - Equal(val) => (Included(val), Included(val)), - Between(left, right) => (Included(left), Included(right)), - }; - - // Ask for the biggest value that can exist for this specific field, if it exists - // that's fine if it don't, the value just before will be returned instead. - let biggest_level = db - .remap_types::() - .get_lower_than_or_equal_to(rtxn, &(field_id, u8::MAX, T::max_value(), T::max_value()))? - .and_then(|((id, level, _, _), _)| if id == field_id { Some(level) } else { None }); - - match biggest_level { - Some(level) => { - let mut output = RoaringBitmap::new(); - Self::explore_facet_levels::(rtxn, db, field_id, level, left, right, &mut output)?; - Ok(output) - }, - None => Ok(RoaringBitmap::new()), - } - } - - fn evaluate_string_operator( - rtxn: &heed::RoTxn, - db: heed::Database, - field_id: u8, - operator: &FacetStringOperator, - ) -> anyhow::Result - { - match operator { - FacetStringOperator::Equal(string) => { - match db.get(rtxn, &(field_id, string))? { - Some(docids) => Ok(docids), - None => Ok(RoaringBitmap::new()) - } - } - } - } - - pub fn evaluate( - &self, - rtxn: &heed::RoTxn, - db: heed::Database, - ) -> anyhow::Result - { - match self { - OperatorI64(fid, op) => { - Self::evaluate_number_operator::(rtxn, db, *fid, *op) - }, - OperatorF64(fid, op) => { - Self::evaluate_number_operator::(rtxn, db, *fid, *op) - }, - OperatorString(fid, op) => { - let db = db.remap_key_type::(); - Self::evaluate_string_operator(rtxn, db, *fid, op) - }, - } - } -} diff --git a/src/search/facet/grammar.pest b/src/search/facet/grammar.pest new file mode 100644 index 000000000..2096517d3 --- /dev/null +++ b/src/search/facet/grammar.pest @@ -0,0 +1,29 @@ +key = _{quoted | word} +value = _{quoted | word} +quoted = _{ (PUSH("'") | PUSH("\"")) ~ string ~ POP } +string = {char*} +word = ${(LETTER | NUMBER | "_" | "-" | ".")+} + +char = _{ !(PEEK | "\\") ~ ANY + | "\\" ~ (PEEK | "\\" | "/" | "b" | "f" | "n" | "r" | "t") + | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4})} + +condition = _{between | eq | greater | less | geq | leq | neq} +between = {key ~ value ~ "TO" ~ value} +geq = {key ~ ">=" ~ value} +leq = {key ~ "<=" ~ value} +neq = {key ~ "!=" ~ value} +eq = {key ~ "=" ~ value} +greater = {key ~ ">" ~ value} +less = {key ~ "<" ~ value} + +prgm = {SOI ~ expr ~ EOI} +expr = _{ ( term ~ (operation ~ term)* ) } +term = { ("(" ~ expr ~ ")") | condition | not } +operation = _{ and | or } +and = {"AND"} +or = {"OR"} + +not = {"NOT" ~ term} + +WHITESPACE = _{ " " } diff --git a/src/search/facet/mod.rs b/src/search/facet/mod.rs new file mode 100644 index 000000000..b1d527337 --- /dev/null +++ b/src/search/facet/mod.rs @@ -0,0 +1,476 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::Bound::{self, Unbounded, Included, Excluded}; + +use heed::types::{ByteSlice, DecodeIgnore}; +use log::debug; +use num_traits::Bounded; +use parser::{PREC_CLIMBER, FilterParser}; +use pest::error::{Error as PestError, ErrorVariant}; +use pest::iterators::{Pair, Pairs}; +use pest::Parser; +use roaring::RoaringBitmap; + +use crate::facet::FacetType; +use crate::heed_codec::facet::FacetValueStringCodec; +use crate::heed_codec::facet::{FacetLevelValueI64Codec, FacetLevelValueF64Codec}; +use crate::{Index, FieldsIdsMap, CboRoaringBitmapCodec}; + +use self::FacetCondition::*; +use self::FacetNumberOperator::*; +use self::parser::Rule; + +mod parser; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum FacetNumberOperator { + GreaterThan(T), + GreaterThanOrEqual(T), + LowerThan(T), + LowerThanOrEqual(T), + Equal(T), + Between(T, T), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum FacetStringOperator { + Equal(String), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum FacetCondition { + OperatorI64(u8, FacetNumberOperator), + OperatorF64(u8, FacetNumberOperator), + OperatorString(u8, FacetStringOperator), + Or(Box, Box), + And(Box, Box), + Not(Box), +} + +fn get_field_id_facet_type<'a>( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + items: &mut Pairs<'a, Rule>, +) -> Result<(u8, FacetType), PestError> +{ + // lexing ensures that we at least have a key + let key = items.next().unwrap(); + let field_id = fields_ids_map + .id(key.as_str()) + .ok_or_else(|| { + PestError::new_from_span( + ErrorVariant::CustomError { + message: format!( + "attribute `{}` not found, available attributes are: {}", + key.as_str(), + fields_ids_map.iter().map(|(_, n)| n).collect::>().join(", ") + ), + }, + key.as_span(), + ) + })?; + + let facet_type = faceted_fields + .get(&field_id) + .copied() + .ok_or_else(|| { + PestError::new_from_span( + ErrorVariant::CustomError { + message: format!( + "attribute `{}` is not faceted, available faceted attributes are: {}", + key.as_str(), + faceted_fields.keys().flat_map(|id| fields_ids_map.name(*id)).collect::>().join(", ") + ), + }, + key.as_span(), + ) + })?; + + Ok((field_id, facet_type)) +} + +impl FacetCondition { + pub fn from_str( + rtxn: &heed::RoTxn, + index: &Index, + expression: &str, + ) -> anyhow::Result + { + let fields_ids_map = index.fields_ids_map(rtxn)?; + let faceted_fields = index.faceted_fields(rtxn)?; + let lexed = FilterParser::parse(Rule::prgm, expression)?; + FacetCondition::from_pairs(&fields_ids_map, &faceted_fields, lexed) + } + + fn from_pairs( + fim: &FieldsIdsMap, + ff: &HashMap, + expression: Pairs, + ) -> anyhow::Result + { + PREC_CLIMBER.climb( + expression, + |pair: Pair| match pair.as_rule() { + Rule::between => Ok(FacetCondition::between(fim, ff, pair)?), + Rule::eq => Ok(FacetCondition::equal(fim, ff, pair)?), + Rule::neq => Ok(Not(Box::new(FacetCondition::equal(fim, ff, pair)?))), + Rule::greater => Ok(FacetCondition::greater_than(fim, ff, pair)?), + Rule::geq => Ok(FacetCondition::greater_than_or_equal(fim, ff, pair)?), + Rule::less => Ok(FacetCondition::lower_than(fim, ff, pair)?), + Rule::leq => Ok(FacetCondition::lower_than_or_equal(fim, ff, pair)?), + Rule::prgm => Self::from_pairs(fim, ff, pair.into_inner()), + Rule::term => Self::from_pairs(fim, ff, pair.into_inner()), + Rule::not => Ok(Not(Box::new(Self::from_pairs(fim, ff, pair.into_inner())?))), + _ => unreachable!(), + }, + |lhs: anyhow::Result, op: Pair, rhs: anyhow::Result| { + match op.as_rule() { + Rule::or => Ok(Or(Box::new(lhs?), Box::new(rhs?))), + Rule::and => Ok(And(Box::new(lhs?), Box::new(rhs?))), + _ => unreachable!(), + } + }, + ) + } + + fn between( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + item: Pair, + ) -> anyhow::Result + { + let item_span = item.as_span(); + let mut items = item.into_inner(); + let (fid, ftype) = get_field_id_facet_type(fields_ids_map, faceted_fields, &mut items)?; + let lvalue = items.next().unwrap(); + let rvalue = items.next().unwrap(); + match ftype { + FacetType::Integer => { + let lvalue = lvalue.as_str().parse()?; + let rvalue = rvalue.as_str().parse()?; + Ok(OperatorI64(fid, Between(lvalue, rvalue))) + }, + FacetType::Float => { + let lvalue = lvalue.as_str().parse()?; + let rvalue = rvalue.as_str().parse()?; + Ok(OperatorF64(fid, Between(lvalue, rvalue))) + }, + FacetType::String => { + Err(PestError::::new_from_span( + ErrorVariant::CustomError { + message: format!("invalid operator on a faceted string"), + }, + item_span, + ).into()) + }, + } + } + + fn equal( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + item: Pair, + ) -> anyhow::Result + { + let mut items = item.into_inner(); + let (fid, ftype) = get_field_id_facet_type(fields_ids_map, faceted_fields, &mut items)?; + let value = items.next().unwrap(); + match ftype { + FacetType::Integer => Ok(OperatorI64(fid, Equal(value.as_str().parse()?))), + FacetType::Float => Ok(OperatorF64(fid, Equal(value.as_str().parse()?))), + FacetType::String => { + Ok(OperatorString(fid, FacetStringOperator::Equal(value.as_str().to_string()))) + }, + } + } + + fn greater_than( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + item: Pair, + ) -> anyhow::Result + { + let item_span = item.as_span(); + let mut items = item.into_inner(); + let (fid, ftype) = get_field_id_facet_type(fields_ids_map, faceted_fields, &mut items)?; + let value = items.next().unwrap(); + match ftype { + FacetType::Integer => Ok(OperatorI64(fid, GreaterThan(value.as_str().parse()?))), + FacetType::Float => Ok(OperatorF64(fid, GreaterThan(value.as_str().parse()?))), + FacetType::String => { + Err(PestError::::new_from_span( + ErrorVariant::CustomError { + message: format!("invalid operator on a faceted string"), + }, + item_span, + ).into()) + }, + } + } + + fn greater_than_or_equal( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + item: Pair, + ) -> anyhow::Result + { + let item_span = item.as_span(); + let mut items = item.into_inner(); + let (fid, ftype) = get_field_id_facet_type(fields_ids_map, faceted_fields, &mut items)?; + let value = items.next().unwrap(); + match ftype { + FacetType::Integer => Ok(OperatorI64(fid, GreaterThanOrEqual(value.as_str().parse()?))), + FacetType::Float => Ok(OperatorF64(fid, GreaterThanOrEqual(value.as_str().parse()?))), + FacetType::String => { + Err(PestError::::new_from_span( + ErrorVariant::CustomError { + message: format!("invalid operator on a faceted string"), + }, + item_span, + ).into()) + }, + } + } + + fn lower_than( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + item: Pair, + ) -> anyhow::Result + { + let item_span = item.as_span(); + let mut items = item.into_inner(); + let (fid, ftype) = get_field_id_facet_type(fields_ids_map, faceted_fields, &mut items)?; + let value = items.next().unwrap(); + match ftype { + FacetType::Integer => Ok(OperatorI64(fid, LowerThan(value.as_str().parse()?))), + FacetType::Float => Ok(OperatorF64(fid, LowerThan(value.as_str().parse()?))), + FacetType::String => { + Err(PestError::::new_from_span( + ErrorVariant::CustomError { + message: format!("invalid operator on a faceted string"), + }, + item_span, + ).into()) + }, + } + } + + fn lower_than_or_equal( + fields_ids_map: &FieldsIdsMap, + faceted_fields: &HashMap, + item: Pair, + ) -> anyhow::Result + { + let item_span = item.as_span(); + let mut items = item.into_inner(); + let (fid, ftype) = get_field_id_facet_type(fields_ids_map, faceted_fields, &mut items)?; + let value = items.next().unwrap(); + match ftype { + FacetType::Integer => Ok(OperatorI64(fid, LowerThanOrEqual(value.as_str().parse()?))), + FacetType::Float => Ok(OperatorF64(fid, LowerThanOrEqual(value.as_str().parse()?))), + FacetType::String => { + Err(PestError::::new_from_span( + ErrorVariant::CustomError { + message: format!("invalid operator on a faceted string"), + }, + item_span, + ).into()) + }, + } + } +} + +impl FacetCondition { + /// Aggregates the documents ids that are part of the specified range automatically + /// going deeper through the levels. + fn explore_facet_levels<'t, T: 't, KC>( + rtxn: &'t heed::RoTxn, + db: heed::Database, + field_id: u8, + level: u8, + left: Bound, + right: Bound, + output: &mut RoaringBitmap, + ) -> anyhow::Result<()> + where + T: Copy + PartialEq + PartialOrd + Bounded + Debug, + KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, + KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, + { + match (left, right) { + // If the request is an exact value we must go directly to the deepest level. + (Included(l), Included(r)) if l == r && level > 0 => { + return Self::explore_facet_levels::(rtxn, db, field_id, 0, left, right, output); + }, + // lower TO upper when lower > upper must return no result + (Included(l), Included(r)) if l > r => return Ok(()), + (Included(l), Excluded(r)) if l >= r => return Ok(()), + (Excluded(l), Excluded(r)) if l >= r => return Ok(()), + (Excluded(l), Included(r)) if l >= r => return Ok(()), + (_, _) => (), + } + + let mut left_found = None; + let mut right_found = None; + + // We must create a custom iterator to be able to iterate over the + // requested range as the range iterator cannot express some conditions. + let left_bound = match left { + Included(left) => Included((field_id, level, left, T::min_value())), + Excluded(left) => Excluded((field_id, level, left, T::min_value())), + Unbounded => Unbounded, + }; + let right_bound = Included((field_id, level, T::max_value(), T::max_value())); + // We also make sure that we don't decode the data before we are sure we must return it. + let iter = db + .remap_key_type::() + .lazily_decode_data() + .range(rtxn, &(left_bound, right_bound))? + .take_while(|r| r.as_ref().map_or(true, |((.., r), _)| { + match right { + Included(right) => *r <= right, + Excluded(right) => *r < right, + Unbounded => true, + } + })) + .map(|r| r.and_then(|(key, lazy)| lazy.decode().map(|data| (key, data)))); + + debug!("Iterating between {:?} and {:?} (level {})", left, right, level); + + for (i, result) in iter.enumerate() { + let ((_fid, level, l, r), docids) = result?; + debug!("{:?} to {:?} (level {}) found {} documents", l, r, level, docids.len()); + output.union_with(&docids); + // We save the leftest and rightest bounds we actually found at this level. + if i == 0 { left_found = Some(l); } + right_found = Some(r); + } + + // Can we go deeper? + let deeper_level = match level.checked_sub(1) { + Some(level) => level, + None => return Ok(()), + }; + + // We must refine the left and right bounds of this range by retrieving the + // missing part in a deeper level. + match left_found.zip(right_found) { + Some((left_found, right_found)) => { + // If the bound is satisfied we avoid calling this function again. + if !matches!(left, Included(l) if l == left_found) { + let sub_right = Excluded(left_found); + debug!("calling left with {:?} to {:?} (level {})", left, sub_right, deeper_level); + Self::explore_facet_levels::(rtxn, db, field_id, deeper_level, left, sub_right, output)?; + } + if !matches!(right, Included(r) if r == right_found) { + let sub_left = Excluded(right_found); + debug!("calling right with {:?} to {:?} (level {})", sub_left, right, deeper_level); + Self::explore_facet_levels::(rtxn, db, field_id, deeper_level, sub_left, right, output)?; + } + }, + None => { + // If we found nothing at this level it means that we must find + // the same bounds but at a deeper, more precise level. + Self::explore_facet_levels::(rtxn, db, field_id, deeper_level, left, right, output)?; + }, + } + + Ok(()) + } + + fn evaluate_number_operator<'t, T: 't, KC>( + rtxn: &'t heed::RoTxn, + db: heed::Database, + field_id: u8, + operator: FacetNumberOperator, + ) -> anyhow::Result + where + T: Copy + PartialEq + PartialOrd + Bounded + Debug, + KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, + KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, + { + // Make sure we always bound the ranges with the field id and the level, + // as the facets values are all in the same database and prefixed by the + // field id and the level. + let (left, right) = match operator { + GreaterThan(val) => (Excluded(val), Included(T::max_value())), + GreaterThanOrEqual(val) => (Included(val), Included(T::max_value())), + LowerThan(val) => (Included(T::min_value()), Excluded(val)), + LowerThanOrEqual(val) => (Included(T::min_value()), Included(val)), + Equal(val) => (Included(val), Included(val)), + Between(left, right) => (Included(left), Included(right)), + }; + + // Ask for the biggest value that can exist for this specific field, if it exists + // that's fine if it don't, the value just before will be returned instead. + let biggest_level = db + .remap_types::() + .get_lower_than_or_equal_to(rtxn, &(field_id, u8::MAX, T::max_value(), T::max_value()))? + .and_then(|((id, level, _, _), _)| if id == field_id { Some(level) } else { None }); + + match biggest_level { + Some(level) => { + let mut output = RoaringBitmap::new(); + Self::explore_facet_levels::(rtxn, db, field_id, level, left, right, &mut output)?; + Ok(output) + }, + None => Ok(RoaringBitmap::new()), + } + } + + fn evaluate_string_operator( + rtxn: &heed::RoTxn, + db: heed::Database, + field_id: u8, + operator: &FacetStringOperator, + ) -> anyhow::Result + { + match operator { + FacetStringOperator::Equal(string) => { + match db.get(rtxn, &(field_id, string))? { + Some(docids) => Ok(docids), + None => Ok(RoaringBitmap::new()) + } + } + } + } + + pub fn evaluate( + &self, + rtxn: &heed::RoTxn, + index: &Index, + ) -> anyhow::Result + { + let db = index.facet_field_id_value_docids; + match self { + OperatorI64(fid, op) => { + Self::evaluate_number_operator::(rtxn, db, *fid, *op) + }, + OperatorF64(fid, op) => { + Self::evaluate_number_operator::(rtxn, db, *fid, *op) + }, + OperatorString(fid, op) => { + let db = db.remap_key_type::(); + Self::evaluate_string_operator(rtxn, db, *fid, op) + }, + Or(lhs, rhs) => { + let lhs = lhs.evaluate(rtxn, index)?; + let rhs = rhs.evaluate(rtxn, index)?; + Ok(lhs | rhs) + }, + And(lhs, rhs) => { + let lhs = lhs.evaluate(rtxn, index)?; + let rhs = rhs.evaluate(rtxn, index)?; + Ok(lhs & rhs) + }, + Not(op) => { + // TODO is this right or is this wrong? because all documents ids are not faceted + // so doing that can return documents that are not faceted at all. + let all_documents_ids = index.documents_ids(rtxn)?; + let documents_ids = op.evaluate(rtxn, index)?; + Ok(all_documents_ids - documents_ids) + }, + } + } +} diff --git a/src/search/facet/parser.rs b/src/search/facet/parser.rs new file mode 100644 index 000000000..0e8bd23ac --- /dev/null +++ b/src/search/facet/parser.rs @@ -0,0 +1,12 @@ +use once_cell::sync::Lazy; +use pest::prec_climber::{Operator, Assoc, PrecClimber}; + +pub static PREC_CLIMBER: Lazy> = Lazy::new(|| { + use Assoc::*; + use Rule::*; + pest::prec_climber::PrecClimber::new(vec![Operator::new(or, Left), Operator::new(and, Left)]) +}); + +#[derive(Parser)] +#[grammar = "search/facet/grammar.pest"] +pub struct FilterParser; diff --git a/src/search/mod.rs b/src/search/mod.rs index d236e396a..7020fa838 100644 --- a/src/search/mod.rs +++ b/src/search/mod.rs @@ -155,9 +155,8 @@ impl<'a> Search<'a> { }; // We create the original candidates with the facet conditions results. - let facet_db = self.index.facet_field_id_value_docids; let facet_candidates = match &self.facet_condition { - Some(condition) => Some(condition.evaluate(self.rtxn, facet_db)?), + Some(condition) => Some(condition.evaluate(self.rtxn, self.index)?), None => None, };