diff --git a/milli/src/search/facet/filter_parser.rs b/milli/src/search/facet/filter_parser.rs index 4d8a54987..cfa3cdae0 100644 --- a/milli/src/search/facet/filter_parser.rs +++ b/milli/src/search/facet/filter_parser.rs @@ -3,17 +3,19 @@ use std::fmt::Debug; use std::result::Result as StdResult; use nom::branch::alt; -use nom::bytes::complete::{tag, take_while1}; +use nom::bytes::complete::{tag, take_till, take_till1, take_while1}; use nom::character::complete::{char, multispace0}; use nom::combinator::map; use nom::error::{ContextError, ErrorKind, VerboseError}; use nom::multi::{many0, separated_list1}; +use nom::number::complete::recognize_float; use nom::sequence::{delimited, preceded, tuple}; use nom::IResult; use self::Operator::*; use super::FilterCondition; use crate::{FieldId, FieldsIdsMap}; + #[derive(Debug, Clone, PartialEq)] pub enum Operator { GreaterThan(f64), @@ -111,28 +113,33 @@ impl<'a> ParseContext<'a> { where E: FilterParserError<'a>, { - let operator = alt((tag("<="), tag(">="), tag(">"), tag("="), tag("<"), tag("!="))); - let k = tuple((self.ws(|c| self.parse_key(c)), operator, self.ws(|c| self.parse_key(c))))( + let operator = alt((tag("<="), tag(">="), tag("!="), tag("<"), tag(">"), tag("="))); + let k = tuple((self.ws(|c| self.parse_key(c)), operator, self.ws(|c| self.parse_value(c))))( input, ); let (input, (key, op, value)) = match k { Ok(o) => o, - Err(e) => { - return Err(e); - } + Err(e) => return Err(e), }; let fid = self.parse_fid(input, key)?; let r: StdResult>> = self.parse_numeric(value); - let k = match op { - "=" => FilterCondition::Operator(fid, Equal(r.ok(), value.to_string().to_lowercase())), - "!=" => { - FilterCondition::Operator(fid, NotEqual(r.ok(), value.to_string().to_lowercase())) + match op { + "=" => { + let k = + FilterCondition::Operator(fid, Equal(r.ok(), value.to_string().to_lowercase())); + Ok((input, k)) } - ">" | "<" | "<=" | ">=" => return self.parse_numeric_unary_condition(op, fid, value), + "!=" => { + let k = FilterCondition::Operator( + fid, + NotEqual(r.ok(), value.to_string().to_lowercase()), + ); + Ok((input, k)) + } + ">" | "<" | "<=" | ">=" => self.parse_numeric_unary_condition(op, fid, value), _ => unreachable!(), - }; - Ok((input, k)) + } } fn parse_numeric(&'a self, input: &'a str) -> StdResult> @@ -142,12 +149,10 @@ impl<'a> ParseContext<'a> { { match input.parse::() { Ok(n) => Ok(n), - Err(_) => { - return match input.chars().nth(0) { - Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), - None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), - }; - } + Err(_) => match input.chars().nth(0) { + Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), + None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), + }, } } @@ -194,9 +199,9 @@ impl<'a> ParseContext<'a> { { let (input, (key, from, _, to)) = tuple(( self.ws(|c| self.parse_key(c)), - self.ws(|c| self.parse_key(c)), + self.ws(|c| self.parse_value(c)), tag("TO"), - self.ws(|c| self.parse_key(c)), + self.ws(|c| self.parse_value(c)), ))(input)?; let fid = self.parse_fid(input, key)?; @@ -211,22 +216,23 @@ impl<'a> ParseContext<'a> { where E: FilterParserError<'a>, { - let err_msg_args_incomplete= "_geoRadius. The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`"; + let err_msg_args_incomplete = "_geoRadius. The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`"; let err_msg_latitude_invalid = "_geoRadius. Latitude must be contained between -90 and 90 degrees."; let err_msg_longitude_invalid = "_geoRadius. Longitude must be contained between -180 and 180 degrees."; - let (input, args): (&str, Vec<&str>) = match preceded( + let parsed = preceded::<_, _, _, E, _, _>( tag("_geoRadius"), delimited( char('('), - separated_list1(tag(","), self.ws(|c| self.parse_value::(c))), + separated_list1(tag(","), self.ws(|c| recognize_float(c))), char(')'), ), - )(input) - { + )(input); + + let (input, args): (&str, Vec<&str>) = match parsed { Ok(e) => e, Err(_e) => { return Err(nom::Err::Failure(E::add_context( @@ -293,15 +299,30 @@ impl<'a> ParseContext<'a> { E: FilterParserError<'a>, { let key = |input| take_while1(Self::is_key_component)(input); - alt((key, delimited(char('"'), key, char('"'))))(input) + let simple_quoted_key = |input| take_till(|c: char| c == '\'')(input); + let quoted_key = |input| take_till(|c: char| c == '"')(input); + + alt(( + delimited(char('\''), simple_quoted_key, char('\'')), + delimited(char('"'), quoted_key, char('"')), + key, + ))(input) } fn parse_value(&'a self, input: &'a str) -> IResult<&'a str, &'a str, E> where E: FilterParserError<'a>, { - let key = |input| take_while1(Self::is_key_component)(input); - alt((key, delimited(char('"'), key, char('"'))))(input) + let key = + |input| take_till1(|c: char| c.is_ascii_whitespace() || c == '(' || c == ')')(input); + let simple_quoted_key = |input| take_till(|c: char| c == '\'')(input); + let quoted_key = |input| take_till(|c: char| c == '"')(input); + + alt(( + delimited(char('\''), simple_quoted_key, char('\'')), + delimited(char('"'), quoted_key, char('"')), + key, + ))(input) } fn is_key_component(c: char) -> bool { @@ -312,7 +333,7 @@ impl<'a> ParseContext<'a> { where E: FilterParserError<'a>, { - self.parse_or(input) + alt((|input| self.parse_or(input), |input| self.parse_and(input)))(input) } } @@ -481,6 +502,90 @@ mod tests { builder.execute(|_, _| ()).unwrap(); wtxn.commit().unwrap(); + // Simple array with Left + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, _, _, &str>( + &rtxn, + &index, + vec![Either::Left(["channel = mv"])], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "channel = mv").unwrap(); + assert_eq!(condition, expected); + + // Simple array with Right + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, Option<&str>, _, _>( + &rtxn, + &index, + vec![Either::Right("channel = mv")], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "channel = mv").unwrap(); + assert_eq!(condition, expected); + + // Array with Left and escaped quote + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, _, _, &str>( + &rtxn, + &index, + vec![Either::Left(["channel = \"Mister Mv\""])], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "channel = \"Mister Mv\"").unwrap(); + assert_eq!(condition, expected); + + // Array with Right and escaped quote + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, Option<&str>, _, _>( + &rtxn, + &index, + vec![Either::Right("channel = \"Mister Mv\"")], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "channel = \"Mister Mv\"").unwrap(); + assert_eq!(condition, expected); + + // Array with Left and escaped simple quote + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, _, _, &str>( + &rtxn, + &index, + vec![Either::Left(["channel = 'Mister Mv'"])], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "channel = 'Mister Mv'").unwrap(); + assert_eq!(condition, expected); + + // Array with Right and escaped simple quote + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, Option<&str>, _, _>( + &rtxn, + &index, + vec![Either::Right("channel = 'Mister Mv'")], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "channel = 'Mister Mv'").unwrap(); + assert_eq!(condition, expected); + + // Simple with parenthesis + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array::<_, _, _, &str>( + &rtxn, + &index, + vec![Either::Left(["(channel = mv)"])], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str(&rtxn, &index, "(channel = mv)").unwrap(); + assert_eq!(condition, expected); + // Test that the facet condition is correctly generated. let rtxn = index.read_txn().unwrap(); let condition = FilterCondition::from_array( @@ -501,6 +606,7 @@ mod tests { .unwrap(); assert_eq!(condition, expected); } + #[test] fn geo_radius() { let path = tempfile::tempdir().unwrap(); @@ -591,9 +697,11 @@ mod tests { let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-100, 150, 10)"); assert!(result.is_err()); let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Latitude must be contained between -90 and 90 degrees.")); + assert!( + error.to_string().contains("Latitude must be contained between -90 and 90 degrees."), + "{}", + error.to_string() + ); // georadius have a bad latitude let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-90.0000001, 150, 10)");