diff --git a/milli/src/search/facet/filter_condition.rs b/milli/src/search/facet/filter_condition.rs index e39687117..c728e0acd 100644 --- a/milli/src/search/facet/filter_condition.rs +++ b/milli/src/search/facet/filter_condition.rs @@ -1,29 +1,16 @@ use std::collections::HashSet; use std::fmt::Debug; use std::ops::Bound::{self, Excluded, Included}; -use std::result::Result as StdResult; use either::Either; use heed::types::DecodeIgnore; use log::debug; +use nom::error::VerboseError; use roaring::RoaringBitmap; -use nom::{ - branch::alt, - bytes::complete::{tag, take_while1}, - character::complete::{char, multispace0}, - combinator::map, - error::ErrorKind, - error::ParseError, - error::VerboseError, - multi::many0, - sequence::{delimited, preceded, tuple}, - IResult, -}; - use self::FilterCondition::*; -use self::Operator::*; +use super::filter_parser::{Operator, ParseContext}; use super::FacetNumberRange; use crate::error::{Error, UserError}; use crate::heed_codec::facet::{ @@ -33,37 +20,6 @@ use crate::{ distance_between_two_points, CboRoaringBitmapCodec, FieldId, FieldsIdsMap, Index, Result, }; -#[derive(Debug, Clone, PartialEq)] -pub enum Operator { - GreaterThan(f64), - GreaterThanOrEqual(f64), - Equal(Option, String), - NotEqual(Option, String), - LowerThan(f64), - LowerThanOrEqual(f64), - Between(f64, f64), - GeoLowerThan([f64; 2], f64), - GeoGreaterThan([f64; 2], f64), -} - -impl Operator { - /// This method can return two operations in case it must express - /// an OR operation for the between case (i.e. `TO`). - fn negate(self) -> (Self, Option) { - match self { - GreaterThan(n) => (LowerThanOrEqual(n), None), - GreaterThanOrEqual(n) => (LowerThan(n), None), - Equal(n, s) => (NotEqual(n, s), None), - NotEqual(n, s) => (Equal(n, s), None), - LowerThan(n) => (GreaterThanOrEqual(n), None), - LowerThanOrEqual(n) => (GreaterThan(n), None), - Between(n, m) => (LowerThan(n), Some(GreaterThan(m))), - GeoLowerThan(point, distance) => (GeoGreaterThan(point, distance), None), - GeoGreaterThan(point, distance) => (GeoLowerThan(point, distance), None), - } - } -} - #[derive(Debug, Clone, PartialEq)] pub enum FilterCondition { Operator(FieldId, Operator), @@ -72,190 +28,8 @@ pub enum FilterCondition { Empty, } -struct ParseContext<'a> { - fields_ids_map: &'a FieldsIdsMap, - filterable_fields: &'a HashSet, -} // impl From -impl<'a> ParseContext<'a> { - fn parse_or_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - let (input, lhs) = self.parse_and_nom(input)?; - let (input, ors) = many0(preceded(tag("OR"), |c| Self::parse_or_nom(self, c)))(input)?; - let expr = ors - .into_iter() - .fold(lhs, |acc, branch| FilterCondition::Or(Box::new(acc), Box::new(branch))); - Ok((input, expr)) - } - - fn parse_and_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - let (input, lhs) = self.parse_not_nom(input)?; - let (input, ors) = many0(preceded(tag("AND"), |c| Self::parse_and_nom(self, c)))(input)?; - let expr = ors - .into_iter() - .fold(lhs, |acc, branch| FilterCondition::And(Box::new(acc), Box::new(branch))); - Ok((input, expr)) - } - - fn parse_not_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - alt(( - map( - preceded(alt((self.ws(tag("!")), self.ws(tag("NOT")))), |c| { - Self::parse_condition_expression(self, c) - }), - |e| e.negate(), - ), - |c| Self::parse_condition_expression(self, c), - ))(input) - } - - fn ws(&'a self, inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O, E> - where - F: Fn(&'a str) -> IResult<&'a str, O, E>, - E: ParseError<&'a str>, - { - delimited(multispace0, inner, multispace0) - } - - fn parse_simple_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - let operator = alt((tag(">"), tag(">="), tag("="), tag("<"), tag("!="), tag("<="))); - let (input, (key, op, value)) = - tuple((self.ws(|c| self.parse_key(c)), operator, self.ws(|c| self.parse_key(c))))( - input, - )?; - let fid = self.parse_fid(input, key)?; - let r: StdResult>> = self.parse_numeric(value); - let k = match op { - "=" => Operator(fid, Equal(r.ok(), value.to_string().to_lowercase())), - "!=" => Operator(fid, NotEqual(r.ok(), value.to_string().to_lowercase())), - ">" | "<" | "<=" | ">=" => return self.parse_numeric_unary_condition(op, fid, value), - _ => unreachable!(), - }; - Ok((input, k)) - } - - fn parse_numeric(&'a self, input: &'a str) -> StdResult> - where - E: ParseError<&'a str>, - T: std::str::FromStr, - { - match input.parse::() { - Ok(n) => Ok(n), - Err(_) => { - return match input.chars().nth(0) { - Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), - None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), - }; - } - } - } - - fn parse_numeric_unary_condition( - &'a self, - input: &'a str, - fid: u16, - value: &'a str, - ) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - let numeric: f64 = self.parse_numeric(value)?; - let k = match input { - ">" => Operator(fid, GreaterThan(numeric)), - "<" => Operator(fid, LowerThan(numeric)), - "<=" => Operator(fid, LowerThanOrEqual(numeric)), - ">=" => Operator(fid, GreaterThanOrEqual(numeric)), - _ => unreachable!(), - }; - Ok((input, k)) - } - - fn parse_fid(&'a self, input: &'a str, key: &'a str) -> StdResult> - where - E: ParseError<&'a str>, - { - let error = match input.chars().nth(0) { - Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), - None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), - }; - if !self.filterable_fields.contains(key) { - return error; - } - match self.fields_ids_map.id(key) { - Some(fid) => Ok(fid), - None => error, - } - } - - fn parse_range_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - let (input, (key, from, _, to)) = tuple(( - self.ws(|c| self.parse_key(c)), - self.ws(|c| self.parse_key(c)), - tag("TO"), - self.ws(|c| self.parse_key(c)), - ))(input)?; - - let fid = self.parse_fid(input, key)?; - let numeric_from: f64 = self.parse_numeric(from)?; - let numeric_to: f64 = self.parse_numeric(to)?; - let res = Operator(fid, Between(numeric_from, numeric_to)); - Ok((input, res)) - } - - fn parse_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - let l1 = |c| self.parse_simple_condition(c); - let l2 = |c| self.parse_range_condition(c); - let (input, condition) = alt((l1, l2))(input)?; - Ok((input, condition)) - } - - fn parse_condition_expression(&'a self, input: &'a str) -> IResult<&str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - return alt(( - delimited(self.ws(char('(')), |c| Self::parse_expression(self, c), self.ws(char(')'))), - |c| Self::parse_condition(self, c), - ))(input); - } - - fn parse_key(&'a self, input: &'a str) -> IResult<&'a str, &'a str, E> - where - E: ParseError<&'a str>, - { - let key = |input| take_while1(Self::is_key_component)(input); - alt((key, delimited(char('"'), key, char('"'))))(input) - } - fn is_key_component(c: char) -> bool { - c.is_alphanumeric() || ['_', '-', '.'].contains(&c) - } - - pub fn parse_expression(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: ParseError<&'a str>, - { - self.parse_or_nom(input) - } -} - //for nom impl FilterCondition { pub fn from_array( @@ -269,7 +43,7 @@ impl FilterCondition { A: AsRef, B: AsRef, { - let mut ands = None; + let mut ands: Option = None; for either in array { match either { @@ -316,10 +90,7 @@ impl FilterCondition { Err(e) => Err(Error::UserError(UserError::InvalidFilterNom { input: e.to_string() })), } } -} - -impl FilterCondition { - fn negate(self) -> FilterCondition { + pub fn negate(self) -> FilterCondition { match self { Operator(fid, op) => match op.negate() { (op, None) => Operator(fid, op), @@ -389,7 +160,7 @@ impl FilterCondition { lng.1.clone(), )))?; } - Ok(Operator(fid, GeoLowerThan([lat.0, lng.0], distance))) + Ok(Operator(fid, Operator::GeoLowerThan([lat.0, lng.0], distance))) } } @@ -514,9 +285,9 @@ impl FilterCondition { // as the facets values are all in the same database and prefixed by the // field id and the level. let (left, right) = match operator { - GreaterThan(val) => (Excluded(*val), Included(f64::MAX)), - GreaterThanOrEqual(val) => (Included(*val), Included(f64::MAX)), - Equal(number, string) => { + Operator::GreaterThan(val) => (Excluded(*val), Included(f64::MAX)), + Operator::GreaterThanOrEqual(val) => (Included(*val), Included(f64::MAX)), + Operator::Equal(number, string) => { let (_original_value, string_docids) = strings_db.get(rtxn, &(field_id, &string))?.unwrap_or_default(); let number_docids = match number { @@ -538,23 +309,23 @@ impl FilterCondition { }; return Ok(string_docids | number_docids); } - NotEqual(number, string) => { + Operator::NotEqual(number, string) => { let all_numbers_ids = if number.is_some() { index.number_faceted_documents_ids(rtxn, field_id)? } else { RoaringBitmap::new() }; let all_strings_ids = index.string_faceted_documents_ids(rtxn, field_id)?; - let operator = Equal(*number, string.clone()); + let operator = Operator::Equal(*number, string.clone()); let docids = Self::evaluate_operator( rtxn, index, numbers_db, strings_db, field_id, &operator, )?; return Ok((all_numbers_ids | all_strings_ids) - docids); } - LowerThan(val) => (Included(f64::MIN), Excluded(*val)), - LowerThanOrEqual(val) => (Included(f64::MIN), Included(*val)), - Between(left, right) => (Included(*left), Included(*right)), - GeoLowerThan(base_point, distance) => { + Operator::LowerThan(val) => (Included(f64::MIN), Excluded(*val)), + Operator::LowerThanOrEqual(val) => (Included(f64::MIN), Included(*val)), + Operator::Between(left, right) => (Included(*left), Included(*right)), + Operator::GeoLowerThan(base_point, distance) => { let rtree = match index.geo_rtree(rtxn)? { Some(rtree) => rtree, None => return Ok(RoaringBitmap::new()), @@ -570,7 +341,7 @@ impl FilterCondition { return Ok(result); } - GeoGreaterThan(point, distance) => { + Operator::GeoGreaterThan(point, distance) => { let result = Self::evaluate_operator( rtxn, index, @@ -631,361 +402,3 @@ impl FilterCondition { } } } - -/// Retrieve the field id base on the pest value. -/// -/// Returns an error if the given value is not filterable. -/// -/// Returns Ok(None) if the given value is filterable, but is not yet ascociated to a field_id. -/// -/// The pest pair is simply a string associated with a span, a location to highlight in -/// the error message. -#[cfg(test)] -mod tests { - use big_s::S; - use heed::EnvOpenOptions; - use maplit::hashset; - - use super::*; - use crate::update::Settings; - - #[test] - fn string() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut map = index.fields_ids_map(&wtxn).unwrap(); - map.insert("channel"); - index.put_fields_ids_map(&mut wtxn, &map).unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_filterable_fields(hashset! { S("channel") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "channel = Ponce").unwrap(); - let expected = Operator(0, Operator::Equal(None, S("ponce"))); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str(&rtxn, &index, "channel != ponce").unwrap(); - let expected = Operator(0, Operator::NotEqual(None, S("ponce"))); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str(&rtxn, &index, "NOT channel = ponce").unwrap(); - let expected = Operator(0, Operator::NotEqual(None, S("ponce"))); - assert_eq!(condition, expected); - } - - #[test] - fn number() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut map = index.fields_ids_map(&wtxn).unwrap(); - map.insert("timestamp"); - index.put_fields_ids_map(&mut wtxn, &map).unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_filterable_fields(hashset! { "timestamp".into() }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "timestamp 22 TO 44").unwrap(); - let expected = Operator(0, Between(22.0, 44.0)); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str(&rtxn, &index, "NOT timestamp 22 TO 44").unwrap(); - let expected = - Or(Box::new(Operator(0, LowerThan(22.0))), Box::new(Operator(0, GreaterThan(44.0)))); - assert_eq!(condition, expected); - } - - #[test] - fn compare() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "channel < 20").unwrap(); - let expected = Operator(0, LowerThan(20.0)); - - assert_eq!(condition, expected); - } - - #[test] - fn parentheses() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str( - &rtxn, - &index, - "channel = gotaga OR (timestamp 22 TO 44 AND channel != ponce)", - ) - .unwrap(); - let expected = Or( - Box::new(Operator(0, Operator::Equal(None, S("gotaga")))), - Box::new(And( - Box::new(Operator(1, Between(22.0, 44.0))), - Box::new(Operator(0, Operator::NotEqual(None, S("ponce")))), - )), - ); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str( - &rtxn, - &index, - "channel = gotaga OR NOT (timestamp 22 TO 44 AND channel != ponce)", - ) - .unwrap(); - let expected = Or( - Box::new(Operator(0, Operator::Equal(None, S("gotaga")))), - Box::new(Or( - Box::new(Or( - Box::new(Operator(1, LowerThan(22.0))), - Box::new(Operator(1, GreaterThan(44.0))), - )), - Box::new(Operator(0, Operator::Equal(None, S("ponce")))), - )), - ); - assert_eq!(condition, expected); - } - - #[test] - fn reserved_field_names() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - let rtxn = index.read_txn().unwrap(); - - let error = FilterCondition::from_str(&rtxn, &index, "_geo = 12").unwrap_err(); - assert!(error - .to_string() - .contains("`_geo` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates."), - "{}", - error.to_string() - ); - - let error = - FilterCondition::from_str(&rtxn, &index, r#"_geoDistance <= 1000"#).unwrap_err(); - assert!(error - .to_string() - .contains("`_geoDistance` is a reserved keyword and thus can't be used as a filter expression."), - "{}", - error.to_string() - ); - - let error = FilterCondition::from_str(&rtxn, &index, r#"_geoPoint > 5"#).unwrap_err(); - assert!(error - .to_string() - .contains("`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates."), - "{}", - error.to_string() - ); - - let error = - FilterCondition::from_str(&rtxn, &index, r#"_geoPoint(12, 16) > 5"#).unwrap_err(); - assert!(error - .to_string() - .contains("`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates."), - "{}", - error.to_string() - ); - } - - #[test] - fn geo_radius() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("_geo"), S("price")]); // to keep the fields order - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - // _geo is not filterable - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(12, 12, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("attribute `_geo` is not filterable, available filterable attributes are:"),); - - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_filterable_fields(hashset! { S("_geo"), S("price") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - // basic test - let condition = - FilterCondition::from_str(&rtxn, &index, "_geoRadius(12, 13.0005, 2000)").unwrap(); - let expected = Operator(0, GeoLowerThan([12., 13.0005], 2000.)); - assert_eq!(condition, expected); - - // basic test with latitude and longitude at the max angle - let condition = - FilterCondition::from_str(&rtxn, &index, "_geoRadius(90, 180, 2000)").unwrap(); - let expected = Operator(0, GeoLowerThan([90., 180.], 2000.)); - assert_eq!(condition, expected); - - // basic test with latitude and longitude at the min angle - let condition = - FilterCondition::from_str(&rtxn, &index, "_geoRadius(-90, -180, 2000)").unwrap(); - let expected = Operator(0, GeoLowerThan([-90., -180.], 2000.)); - assert_eq!(condition, expected); - - // test the negation of the GeoLowerThan - let condition = - FilterCondition::from_str(&rtxn, &index, "NOT _geoRadius(50, 18, 2000.500)").unwrap(); - let expected = Operator(0, GeoGreaterThan([50., 18.], 2000.500)); - assert_eq!(condition, expected); - - // composition of multiple operations - let condition = FilterCondition::from_str( - &rtxn, - &index, - "(NOT _geoRadius(1, 2, 300) AND _geoRadius(1.001, 2.002, 1000.300)) OR price <= 10", - ) - .unwrap(); - let expected = Or( - Box::new(And( - Box::new(Operator(0, GeoGreaterThan([1., 2.], 300.))), - Box::new(Operator(0, GeoLowerThan([1.001, 2.002], 1000.300))), - )), - Box::new(Operator(1, LowerThanOrEqual(10.))), - ); - assert_eq!(condition, expected); - - // georadius don't have any parameters - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius don't have any parameters - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius()"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius don't have enough parameters - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius have too many parameters - let result = - FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius have a bad latitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-100, 150, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Latitude must be contained between -90 and 90 degrees.")); - - // georadius have a bad latitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-90.0000001, 150, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Latitude must be contained between -90 and 90 degrees.")); - - // georadius have a bad longitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 250, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Longitude must be contained between -180 and 180 degrees.")); - - // georadius have a bad longitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 180.000001, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Longitude must be contained between -180 and 180 degrees.")); - } - - #[test] - fn from_array() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_array( - &rtxn, - &index, - vec![ - Either::Right("channel = gotaga"), - Either::Left(vec!["timestamp = 44", "channel != ponce"]), - ], - ) - .unwrap() - .unwrap(); - let expected = FilterCondition::from_str( - &rtxn, - &index, - "channel = gotaga AND (timestamp = 44 OR channel != ponce)", - ) - .unwrap(); - assert_eq!(condition, expected); - } -} diff --git a/milli/src/search/facet/filter_parser.rs b/milli/src/search/facet/filter_parser.rs new file mode 100644 index 000000000..53a51ca49 --- /dev/null +++ b/milli/src/search/facet/filter_parser.rs @@ -0,0 +1,500 @@ +use std::collections::HashSet; +use std::fmt::Debug; +use std::result::Result as StdResult; + +use super::FilterCondition; +use crate::{FieldId, FieldsIdsMap}; +use nom::{ + branch::alt, + bytes::complete::{tag, take_while1}, + character::complete::{char, multispace0}, + combinator::map, + error::ErrorKind, + error::ParseError, + error::VerboseError, + multi::many0, + sequence::{delimited, preceded, tuple}, + IResult, +}; + +use self::Operator::*; +#[derive(Debug, Clone, PartialEq)] +pub enum Operator { + GreaterThan(f64), + GreaterThanOrEqual(f64), + Equal(Option, String), + NotEqual(Option, String), + LowerThan(f64), + LowerThanOrEqual(f64), + Between(f64, f64), + GeoLowerThan([f64; 2], f64), + GeoGreaterThan([f64; 2], f64), +} + +impl Operator { + /// This method can return two operations in case it must express + /// an OR operation for the between case (i.e. `TO`). + pub fn negate(self) -> (Self, Option) { + match self { + GreaterThan(n) => (LowerThanOrEqual(n), None), + GreaterThanOrEqual(n) => (LowerThan(n), None), + Equal(n, s) => (NotEqual(n, s), None), + NotEqual(n, s) => (Equal(n, s), None), + LowerThan(n) => (GreaterThanOrEqual(n), None), + LowerThanOrEqual(n) => (GreaterThan(n), None), + Between(n, m) => (LowerThan(n), Some(GreaterThan(m))), + } + } +} + +pub struct ParseContext<'a> { + pub fields_ids_map: &'a FieldsIdsMap, + pub filterable_fields: &'a HashSet, +} + +impl<'a> ParseContext<'a> { + fn parse_or_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let (input, lhs) = self.parse_and_nom(input)?; + let (input, ors) = many0(preceded(tag("OR"), |c| Self::parse_or_nom(self, c)))(input)?; + let expr = ors + .into_iter() + .fold(lhs, |acc, branch| FilterCondition::Or(Box::new(acc), Box::new(branch))); + Ok((input, expr)) + } + + fn parse_and_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let (input, lhs) = self.parse_not_nom(input)?; + let (input, ors) = many0(preceded(tag("AND"), |c| Self::parse_and_nom(self, c)))(input)?; + let expr = ors + .into_iter() + .fold(lhs, |acc, branch| FilterCondition::And(Box::new(acc), Box::new(branch))); + Ok((input, expr)) + } + + fn parse_not_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + alt(( + map( + preceded(alt((self.ws(tag("!")), self.ws(tag("NOT")))), |c| { + Self::parse_condition_expression(self, c) + }), + |e| e.negate(), + ), + |c| Self::parse_condition_expression(self, c), + ))(input) + } + + fn ws(&'a self, inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O, E> + where + F: Fn(&'a str) -> IResult<&'a str, O, E>, + E: ParseError<&'a str>, + { + delimited(multispace0, inner, multispace0) + } + + fn parse_simple_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let operator = alt((tag(">"), tag(">="), tag("="), tag("<"), tag("!="), tag("<="))); + let (input, (key, op, value)) = + tuple((self.ws(|c| self.parse_key(c)), operator, self.ws(|c| self.parse_key(c))))( + input, + )?; + let fid = self.parse_fid(input, key)?; + let r: StdResult>> = self.parse_numeric(value); + let k = match op { + "=" => FilterCondition::Operator(fid, Equal(r.ok(), value.to_string().to_lowercase())), + "!=" => { + FilterCondition::Operator(fid, NotEqual(r.ok(), value.to_string().to_lowercase())) + } + ">" | "<" | "<=" | ">=" => return self.parse_numeric_unary_condition(op, fid, value), + _ => unreachable!(), + }; + Ok((input, k)) + } + + fn parse_numeric(&'a self, input: &'a str) -> StdResult> + where + E: ParseError<&'a str>, + T: std::str::FromStr, + { + match input.parse::() { + Ok(n) => Ok(n), + Err(_) => { + return match input.chars().nth(0) { + Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), + None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), + }; + } + } + } + + fn parse_numeric_unary_condition( + &'a self, + input: &'a str, + fid: u16, + value: &'a str, + ) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let numeric: f64 = self.parse_numeric(value)?; + let k = match input { + ">" => FilterCondition::Operator(fid, GreaterThan(numeric)), + "<" => FilterCondition::Operator(fid, LowerThan(numeric)), + "<=" => FilterCondition::Operator(fid, LowerThanOrEqual(numeric)), + ">=" => FilterCondition::Operator(fid, GreaterThanOrEqual(numeric)), + _ => unreachable!(), + }; + Ok((input, k)) + } + + fn parse_fid(&'a self, input: &'a str, key: &'a str) -> StdResult> + where + E: ParseError<&'a str>, + { + let error = match input.chars().nth(0) { + Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), + None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), + }; + if !self.filterable_fields.contains(key) { + return error; + } + match self.fields_ids_map.id(key) { + Some(fid) => Ok(fid), + None => error, + } + } + + fn parse_range_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let (input, (key, from, _, to)) = tuple(( + self.ws(|c| self.parse_key(c)), + self.ws(|c| self.parse_key(c)), + tag("TO"), + self.ws(|c| self.parse_key(c)), + ))(input)?; + + let fid = self.parse_fid(input, key)?; + let numeric_from: f64 = self.parse_numeric(from)?; + let numeric_to: f64 = self.parse_numeric(to)?; + let res = FilterCondition::Operator(fid, Between(numeric_from, numeric_to)); + Ok((input, res)) + } + + fn parse_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let l1 = |c| self.parse_simple_condition(c); + let l2 = |c| self.parse_range_condition(c); + let (input, condition) = alt((l1, l2))(input)?; + Ok((input, condition)) + } + + fn parse_condition_expression(&'a self, input: &'a str) -> IResult<&str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + return alt(( + delimited(self.ws(char('(')), |c| Self::parse_expression(self, c), self.ws(char(')'))), + |c| Self::parse_condition(self, c), + ))(input); + } + + fn parse_key(&'a self, input: &'a str) -> IResult<&'a str, &'a str, E> + where + E: ParseError<&'a str>, + { + let key = |input| take_while1(Self::is_key_component)(input); + alt((key, delimited(char('"'), key, char('"'))))(input) + } + fn is_key_component(c: char) -> bool { + c.is_alphanumeric() || ['_', '-', '.'].contains(&c) + } + + pub fn parse_expression(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + self.parse_or_nom(input) + } +} + +#[cfg(test)] +mod tests { + use big_s::S; + use either::Either; + use heed::EnvOpenOptions; + use maplit::hashset; + + use super::*; + use crate::{update::Settings, Index}; + + #[test] + fn string() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut map = index.fields_ids_map(&wtxn).unwrap(); + map.insert("channel"); + index.put_fields_ids_map(&mut wtxn, &map).unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_filterable_fields(hashset! { S("channel") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str(&rtxn, &index, "channel = Ponce").unwrap(); + let expected = FilterCondition::Operator(0, Operator::Equal(None, S("ponce"))); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str(&rtxn, &index, "channel != ponce").unwrap(); + let expected = FilterCondition::Operator(0, Operator::NotEqual(None, S("ponce"))); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str(&rtxn, &index, "NOT channel = ponce").unwrap(); + let expected = FilterCondition::Operator(0, Operator::NotEqual(None, S("ponce"))); + assert_eq!(condition, expected); + } + + #[test] + fn number() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut map = index.fields_ids_map(&wtxn).unwrap(); + map.insert("timestamp"); + index.put_fields_ids_map(&mut wtxn, &map).unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_filterable_fields(hashset! { "timestamp".into() }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str(&rtxn, &index, "timestamp 22 TO 44").unwrap(); + let expected = FilterCondition::Operator(0, Between(22.0, 44.0)); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str(&rtxn, &index, "NOT timestamp 22 TO 44").unwrap(); + let expected = FilterCondition::Or( + Box::new(FilterCondition::Operator(0, LowerThan(22.0))), + Box::new(FilterCondition::Operator(0, GreaterThan(44.0))), + ); + assert_eq!(condition, expected); + } + + #[test] + fn compare() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order + builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str(&rtxn, &index, "channel < 20").unwrap(); + let expected = FilterCondition::Operator(0, LowerThan(20.0)); + + assert_eq!(condition, expected); + } + + #[test] + fn parentheses() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order + builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str( + &rtxn, + &index, + "channel = gotaga OR (timestamp 22 TO 44 AND channel != ponce)", + ) + .unwrap(); + let expected = FilterCondition::Or( + Box::new(FilterCondition::Operator(0, Operator::Equal(None, S("gotaga")))), + Box::new(FilterCondition::And( + Box::new(FilterCondition::Operator(1, Between(22.0, 44.0))), + Box::new(FilterCondition::Operator(0, Operator::NotEqual(None, S("ponce")))), + )), + ); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str( + &rtxn, + &index, + "channel = gotaga OR NOT (timestamp 22 TO 44 AND channel != ponce)", + ) + .unwrap(); + let expected = FilterCondition::Or( + Box::new(FilterCondition::Operator(0, Operator::Equal(None, S("gotaga")))), + Box::new(FilterCondition::Or( + Box::new(FilterCondition::Or( + Box::new(FilterCondition::Operator(1, LowerThan(22.0))), + Box::new(FilterCondition::Operator(1, GreaterThan(44.0))), + )), + Box::new(FilterCondition::Operator(0, Operator::Equal(None, S("ponce")))), + )), + ); + assert_eq!(condition, expected); + } + + #[test] + fn from_array() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order + builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array( + &rtxn, + &index, + vec![ + Either::Right("channel = gotaga"), + Either::Left(vec!["timestamp = 44", "channel != ponce"]), + ], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str( + &rtxn, + &index, + "channel = gotaga AND (timestamp = 44 OR channel != ponce)", + ) + .unwrap(); + assert_eq!(condition, expected); + } + #[test] + fn geo_radius() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("_geo"), S("price")]); // to keep the fields order + builder.set_filterable_fields(hashset! { S("_geo"), S("price") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + let rtxn = index.read_txn().unwrap(); + // basic test + let condition = + FilterCondition::from_str(&rtxn, &index, "_geoRadius(12, 13.0005, 2000)").unwrap(); + let expected = Operator(0, GeoLowerThan([12., 13.0005], 2000.)); + assert_eq!(condition, expected); + + // test the negation of the GeoLowerThan + let condition = + FilterCondition::from_str(&rtxn, &index, "NOT _geoRadius(50, 18, 2000.500)").unwrap(); + let expected = Operator(0, GeoGreaterThan([50., 18.], 2000.500)); + assert_eq!(condition, expected); + + // composition of multiple operations + let condition = FilterCondition::from_str( + &rtxn, + &index, + "(NOT _geoRadius(1, 2, 300) AND _geoRadius(1.001, 2.002, 1000.300)) OR price <= 10", + ) + .unwrap(); + let expected = Or( + Box::new(And( + Box::new(Operator(0, GeoGreaterThan([1., 2.], 300.))), + Box::new(Operator(0, GeoLowerThan([1.001, 2.002], 1000.300))), + )), + Box::new(Operator(1, LowerThanOrEqual(10.))), + ); + assert_eq!(condition, expected); + + // georadius don't have any parameters + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius don't have any parameters + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius()"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius don't have enough parameters + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius have too many parameters + let result = + FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius have a bad latitude + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-200, 150, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("Latitude and longitude must be contained between -180 to 180 degrees.")); + + // georadius have a bad longitude + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 181, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("Latitude and longitude must be contained between -180 to 180 degrees.")); + } +} diff --git a/milli/src/search/facet/mod.rs b/milli/src/search/facet/mod.rs index a5c041dd5..3efa0262f 100644 --- a/milli/src/search/facet/mod.rs +++ b/milli/src/search/facet/mod.rs @@ -1,9 +1,10 @@ pub use self::facet_distribution::FacetDistribution; pub use self::facet_number::{FacetNumberIter, FacetNumberRange, FacetNumberRevRange}; pub use self::facet_string::FacetStringIter; -pub use self::filter_condition::{FilterCondition, Operator}; +pub use self::filter_condition::FilterCondition; mod facet_distribution; mod facet_number; mod facet_string; mod filter_condition; +mod filter_parser; diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 85d5dc8a7..9b76ca851 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -14,7 +14,7 @@ use meilisearch_tokenizer::{Analyzer, AnalyzerConfig}; use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; -pub use self::facet::{FacetDistribution, FacetNumberIter, FilterCondition, Operator}; +pub use self::facet::{FacetDistribution, FacetNumberIter, FilterCondition}; pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; use crate::error::UserError;