diff --git a/milli/src/error.rs b/milli/src/error.rs index be3fbfdef..a798539cd 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -64,6 +64,7 @@ pub enum UserError { InvalidGeoField { document_id: Value, object: Value }, InvalidFilterAttributeNom, InvalidFilterValue, + InvalidFilterNom { input: String }, InvalidSortName { name: String }, InvalidSortableAttribute { field: String, valid_fields: HashSet }, SortRankingRuleMissing, @@ -85,11 +86,6 @@ impl From for Error { Error::IoError(error) } } -impl From for UserError { - fn from(_: std::num::ParseFloatError) -> UserError { - UserError::InvalidFilterValue - } -} impl From for Error { fn from(error: fst::Error) -> Error { @@ -217,8 +213,7 @@ impl fmt::Display for UserError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { //TODO - Self::InvalidFilterAttributeNom => f.write_str("parser error "), - Self::InvalidFilterValue => f.write_str("parser error "), + Self::InvalidFilterNom { input } => write!(f, "parser error {}", input), Self::AttributeLimitReached => f.write_str("maximum number of attributes reached"), Self::CriterionError(error) => write!(f, "{}", error), Self::DocumentLimitReached => f.write_str("maximum number of documents reached"), diff --git a/milli/src/search/facet/filter_condition.rs b/milli/src/search/facet/filter_condition.rs index b14c3648f..e6ed79230 100644 --- a/milli/src/search/facet/filter_condition.rs +++ b/milli/src/search/facet/filter_condition.rs @@ -4,7 +4,6 @@ use std::ops::Bound::{self, Excluded, Included}; use std::result::Result as StdResult; use std::str::FromStr; -use crate::error::UserError as IError; use either::Either; use heed::types::DecodeIgnore; use itertools::Itertools; @@ -19,7 +18,9 @@ use nom::{ bytes::complete::{tag, take_while1}, character::complete::{char, multispace0}, combinator::map, + error::ErrorKind, error::ParseError, + error::VerboseError, multi::many0, sequence::{delimited, preceded, tuple}, IResult, @@ -29,7 +30,7 @@ use self::FilterCondition::*; use self::Operator::*; use super::parser::{FilterParser, Rule, PREC_CLIMBER}; use super::FacetNumberRange; -use crate::error::UserError; +use crate::error::{Error, UserError}; use crate::heed_codec::facet::{ FacetLevelValueF64Codec, FacetStringLevelZeroCodec, FacetStringLevelZeroValueCodec, }; @@ -83,7 +84,10 @@ struct ParseContext<'a> { // impl From impl<'a> ParseContext<'a> { - fn parse_or_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition> { + fn parse_or_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { let (input, lhs) = self.parse_and_nom(input)?; let (input, ors) = many0(preceded(tag("OR"), |c| Self::parse_or_nom(self, c)))(input)?; let expr = ors @@ -91,7 +95,11 @@ impl<'a> ParseContext<'a> { .fold(lhs, |acc, branch| FilterCondition::Or(Box::new(acc), Box::new(branch))); Ok((input, expr)) } - fn parse_and_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition> { + + fn parse_and_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { let (input, lhs) = self.parse_not_nom(input)?; let (input, ors) = many0(preceded(tag("AND"), |c| Self::parse_and_nom(self, c)))(input)?; let expr = ors @@ -100,119 +108,146 @@ impl<'a> ParseContext<'a> { Ok((input, expr)) } - fn parse_not_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition> { - let r = alt(( + fn parse_not_nom(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + alt(( map( - preceded(alt((Self::ws(tag("!")), Self::ws(tag("NOT")))), |c| { + preceded(alt((self.ws(tag("!")), self.ws(tag("NOT")))), |c| { Self::parse_condition_expression(self, c) }), |e| e.negate(), ), |c| Self::parse_condition_expression(self, c), - ))(input); - return r; + ))(input) } - fn ws<'b, F: 'b, O, E: ParseError<&'b str>>( - inner: F, - ) -> impl FnMut(&'b str) -> IResult<&'b str, O, E> + fn ws(&'a self, inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O, E> where - F: Fn(&'b str) -> IResult<&'b str, O, E>, + F: Fn(&'a str) -> IResult<&'a str, O, E>, + E: ParseError<&'a str>, { delimited(multispace0, inner, multispace0) } - fn parse_simple_condition( - &self, - input: &'a str, - ) -> StdResult<(&'a str, FilterCondition), UserError> { + fn parse_simple_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { let operator = alt((tag(">"), tag(">="), tag("="), tag("<"), tag("!="), tag("<="))); let (input, (key, op, value)) = - match tuple((Self::ws(Self::parse_key), operator, Self::ws(Self::parse_key)))(input) { - Ok((input, (key, op, value))) => (input, (key, op, value)), - Err(_) => return Err(UserError::InvalidFilterAttributeNom), - }; - - let fid = match field_id_by_key(self.fields_ids_map, self.filterable_fields, key)? { - Some(fid) => fid, - None => return Err(UserError::InvalidFilterAttributeNom), - }; - let r = nom_parse::(value); + tuple((self.ws(|c| self.parse_key(c)), operator, self.ws(|c| self.parse_key(c))))( + input, + )?; + let fid = self.parse_fid(input, key)?; + let r: StdResult>> = self.parse_numeric(value); let k = match op { - ">" => Operator(fid, GreaterThan(value.parse::()?)), - "<" => Operator(fid, LowerThan(value.parse::()?)), - "<=" => Operator(fid, LowerThanOrEqual(value.parse::()?)), - ">=" => Operator(fid, GreaterThanOrEqual(value.parse::()?)), - "=" => Operator(fid, Equal(r.0.ok(), value.to_string().to_lowercase())), - "!=" => Operator(fid, NotEqual(r.0.ok(), value.to_string().to_lowercase())), + "=" => Operator(fid, Equal(r.ok(), value.to_string().to_lowercase())), + "!=" => Operator(fid, NotEqual(r.ok(), value.to_string().to_lowercase())), + ">" | "<" | "<=" | ">=" => { + return self.parse_numeric_unary_condition(input, fid, value) + } _ => unreachable!(), }; Ok((input, k)) } - fn parse_range_condition( - &'a self, - input: &'a str, - ) -> StdResult<(&str, FilterCondition), UserError> { - let (input, (key, from, _, to)) = match tuple(( - Self::ws(Self::parse_key), - Self::ws(Self::parse_key), - tag("TO"), - Self::ws(Self::parse_key), - ))(input) - { - Ok((input, (key, from, tag, to))) => (input, (key, from, tag, to)), - Err(_) => return Err(UserError::InvalidFilterAttributeNom), - }; - let fid = match field_id_by_key(self.fields_ids_map, self.filterable_fields, key)? { - Some(fid) => fid, - None => return Err(UserError::InvalidFilterAttributeNom), - }; - let res = Operator(fid, Between(from.parse::()?, to.parse::()?)); - Ok((input, res)) - } - - fn parse_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition> { - let l1 = |c| self.wrap(|c| self.parse_simple_condition(c), c); - let l2 = |c| self.wrap(|c| self.parse_range_condition(c), c); - let (input, condition) = match alt((l1, l2))(input) { - Ok((i, c)) => (i, c), - Err(_) => { - return Err(nom::Err::Error(nom::error::Error::from_error_kind( - "foo", - nom::error::ErrorKind::Fail, - ))) - } - }; - Ok((input, condition)) - } - fn wrap(&'a self, inner: F, input: &'a str) -> IResult<&'a str, FilterCondition> + fn parse_numeric(&'a self, input: &'a str) -> StdResult> where - F: Fn(&'a str) -> StdResult<(&'a str, FilterCondition), E>, + E: ParseError<&'a str>, + T: std::str::FromStr, { - match inner(input) { - Ok(e) => Ok(e), + match input.parse::() { + Ok(n) => Ok(n), Err(_) => { - return Err(nom::Err::Error(nom::error::Error::from_error_kind( - "foo", - nom::error::ErrorKind::Fail, - ))) + return match input.chars().nth(0) { + Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), + None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), + }; } } } - fn parse_condition_expression(&'a self, input: &'a str) -> IResult<&str, FilterCondition> { + fn parse_numeric_unary_condition( + &'a self, + input: &'a str, + fid: u16, + value: &'a str, + ) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let numeric: f64 = self.parse_numeric(value)?; + let k = match input { + ">" => Operator(fid, GreaterThan(numeric)), + "<" => Operator(fid, LowerThan(numeric)), + "<=" => Operator(fid, LowerThanOrEqual(numeric)), + ">=" => Operator(fid, GreaterThanOrEqual(numeric)), + _ => unreachable!(), + }; + Ok((input, k)) + } + + fn parse_fid(&'a self, input: &'a str, key: &'a str) -> StdResult> + where + E: ParseError<&'a str>, + { + let error = match input.chars().nth(0) { + Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), + None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), + }; + if !self.filterable_fields.contains(key) { + return error; + } + match self.fields_ids_map.id(key) { + Some(fid) => Ok(fid), + None => error, + } + } + + fn parse_range_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let (input, (key, from, _, to)) = tuple(( + self.ws(|c| self.parse_key(c)), + self.ws(|c| self.parse_key(c)), + tag("TO"), + self.ws(|c| self.parse_key(c)), + ))(input)?; + + let fid = self.parse_fid(input, key)?; + let numeric_from: f64 = self.parse_numeric(from)?; + let numeric_to: f64 = self.parse_numeric(to)?; + let res = Operator(fid, Between(numeric_from, numeric_to)); + Ok((input, res)) + } + + fn parse_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { + let l1 = |c| self.parse_simple_condition(c); + let l2 = |c| self.parse_range_condition(c); + let (input, condition) = alt((l1, l2))(input)?; + Ok((input, condition)) + } + + fn parse_condition_expression(&'a self, input: &'a str) -> IResult<&str, FilterCondition, E> + where + E: ParseError<&'a str>, + { return alt(( - delimited( - Self::ws(char('(')), - |c| Self::parse_expression(self, c), - Self::ws(char(')')), - ), + delimited(self.ws(char('(')), |c| Self::parse_expression(self, c), self.ws(char(')'))), |c| Self::parse_condition(self, c), ))(input); } - fn parse_key(input: &str) -> IResult<&str, &str> { + fn parse_key(&'a self, input: &'a str) -> IResult<&'a str, &'a str, E> + where + E: ParseError<&'a str>, + { let key = |input| take_while1(Self::is_key_component)(input); alt((key, delimited(char('"'), key, char('"'))))(input) } @@ -220,7 +255,10 @@ impl<'a> ParseContext<'a> { c.is_alphanumeric() || ['_', '-', '.'].contains(&c) } - pub fn parse_expression(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition> { + pub fn parse_expression(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> + where + E: ParseError<&'a str>, + { self.parse_or_nom(input) } } @@ -280,12 +318,9 @@ impl FilterCondition { let filterable_fields = index.filterable_fields(rtxn)?; let ctx = ParseContext { fields_ids_map: &fields_ids_map, filterable_fields: &filterable_fields }; - match ctx.parse_expression(expression) { + match ctx.parse_expression::>(expression) { Ok((_, fc)) => Ok(fc), - Err(e) => { - println!("{:?}", e); - unreachable!() - } + Err(e) => Err(Error::UserError(UserError::InvalidFilterNom { input: e.to_string() })), } } } @@ -812,19 +847,6 @@ impl FilterCondition { } } -fn field_id_by_key( - fields_ids_map: &FieldsIdsMap, - filterable_fields: &HashSet, - key: &str, -) -> StdResult, IError> { - // lexing ensures that we at least have a key - if !filterable_fields.contains(key) { - return StdResult::Err(UserError::InvalidFilterAttributeNom); - } - - Ok(fields_ids_map.id(key)) -} - /// Retrieve the field id base on the pest value. /// /// Returns an error if the given value is not filterable. @@ -879,19 +901,6 @@ fn field_id( Ok(fields_ids_map.id(key.as_str())) } -fn nom_parse(input: &str) -> (StdResult, String) -where - T: FromStr, - T::Err: ToString, -{ - let result = match input.parse::() { - Ok(value) => Ok(value), - Err(e) => Err(UserError::InvalidFilterValue), - }; - - (result, input.to_string()) -} - /// Tries to parse the pest pair into the type `T` specified, always returns /// the original string that we tried to parse. ///