diff --git a/src/search/facet.rs b/src/search/facet.rs index 22352ab48..08daf5fbc 100644 --- a/src/search/facet.rs +++ b/src/search/facet.rs @@ -5,19 +5,21 @@ use std::str::FromStr; use anyhow::{bail, ensure, Context}; use heed::types::{ByteSlice, DecodeIgnore}; +use itertools::Itertools; use log::debug; use num_traits::Bounded; use roaring::RoaringBitmap; use crate::facet::FacetType; +use crate::heed_codec::facet::FacetValueStringCodec; use crate::heed_codec::facet::{FacetLevelValueI64Codec, FacetLevelValueF64Codec}; use crate::{Index, CboRoaringBitmapCodec}; use self::FacetCondition::*; -use self::FacetOperator::*; +use self::FacetNumberOperator::*; #[derive(Debug, Copy, Clone, PartialEq)] -pub enum FacetOperator { +pub enum FacetNumberOperator { GreaterThan(T), GreaterThanOrEqual(T), LowerThan(T), @@ -26,11 +28,17 @@ pub enum FacetOperator { Between(T, T), } +#[derive(Debug, Clone, PartialEq)] +pub enum FacetStringOperator { + Equal(String), +} + // TODO also support ANDs, ORs, NOTs. -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum FacetCondition { - OperatorI64(u8, FacetOperator), - OperatorF64(u8, FacetOperator), + OperatorI64(u8, FacetNumberOperator), + OperatorF64(u8, FacetNumberOperator), + OperatorString(u8, FacetStringOperator), } impl FacetCondition { @@ -55,15 +63,34 @@ impl FacetCondition { let field_type = faceted_fields.get(&field_id).with_context(|| format!("field {} is not faceted", field_name))?; match field_type { - FacetType::Integer => Self::parse_condition(iter).map(|op| Some(OperatorI64(field_id, op))), - FacetType::Float => Self::parse_condition(iter).map(|op| Some(OperatorF64(field_id, op))), - FacetType::String => bail!("invalid facet type"), + FacetType::Integer => Self::parse_number_condition(iter).map(|op| Some(OperatorI64(field_id, op))), + FacetType::Float => Self::parse_number_condition(iter).map(|op| Some(OperatorF64(field_id, op))), + FacetType::String => Self::parse_string_condition(iter).map(|op| Some(OperatorString(field_id, op))), } } - fn parse_condition<'a, T: FromStr>( + fn parse_string_condition<'a>( mut iter: impl Iterator, - ) -> anyhow::Result> + ) -> anyhow::Result + { + match iter.next() { + Some("=") | Some(":") => { + match iter.next() { + Some(q @ "\"") | Some(q @ "\'") => { + let string: String = iter.take_while(|&c| c != q).intersperse(" ").collect(); + Ok(FacetStringOperator::Equal(string.to_lowercase())) + }, + Some(param) => Ok(FacetStringOperator::Equal(param.to_lowercase())), + None => bail!("missing parameter"), + } + }, + _ => bail!("invalid facet string operator"), + } + } + + fn parse_number_condition<'a, T: FromStr>( + mut iter: impl Iterator, + ) -> anyhow::Result> where T::Err: Send + Sync + StdError + 'static, { match iter.next() { @@ -201,11 +228,11 @@ impl FacetCondition { Ok(()) } - fn evaluate_operator<'t, T: 't, KC>( + fn evaluate_number_operator<'t, T: 't, KC>( rtxn: &'t heed::RoTxn, db: heed::Database, field_id: u8, - operator: FacetOperator, + operator: FacetNumberOperator, ) -> anyhow::Result where T: Copy + PartialEq + PartialOrd + Bounded + Debug, @@ -241,19 +268,40 @@ impl FacetCondition { } } + fn evaluate_string_operator( + rtxn: &heed::RoTxn, + db: heed::Database, + field_id: u8, + operator: &FacetStringOperator, + ) -> anyhow::Result + { + match operator { + FacetStringOperator::Equal(string) => { + match db.get(rtxn, &(field_id, string))? { + Some(docids) => Ok(docids), + None => Ok(RoaringBitmap::new()) + } + } + } + } + pub fn evaluate( &self, rtxn: &heed::RoTxn, db: heed::Database, ) -> anyhow::Result { - match *self { - FacetCondition::OperatorI64(fid, operator) => { - Self::evaluate_operator::(rtxn, db, fid, operator) + match self { + OperatorI64(fid, op) => { + Self::evaluate_number_operator::(rtxn, db, *fid, *op) + }, + OperatorF64(fid, op) => { + Self::evaluate_number_operator::(rtxn, db, *fid, *op) + }, + OperatorString(fid, op) => { + let db = db.remap_key_type::(); + Self::evaluate_string_operator(rtxn, db, *fid, op) }, - FacetCondition::OperatorF64(fid, operator) => { - Self::evaluate_operator::(rtxn, db, fid, operator) - } } } } diff --git a/src/search/mod.rs b/src/search/mod.rs index 8ee8461a8..d236e396a 100644 --- a/src/search/mod.rs +++ b/src/search/mod.rs @@ -156,7 +156,7 @@ impl<'a> Search<'a> { // We create the original candidates with the facet conditions results. let facet_db = self.index.facet_field_id_value_docids; - let facet_candidates = match self.facet_condition { + let facet_candidates = match &self.facet_condition { Some(condition) => Some(condition.evaluate(self.rtxn, facet_db)?), None => None, }; diff --git a/src/update/index_documents/store.rs b/src/update/index_documents/store.rs index 289704b1a..6fb07b345 100644 --- a/src/update/index_documents/store.rs +++ b/src/update/index_documents/store.rs @@ -586,7 +586,7 @@ fn parse_facet_value(ftype: FacetType, value: &Value) -> anyhow::Result { - let string = string.trim(); + let string = string.trim().to_lowercase(); if string.is_empty() { return Ok(()) } match ftype { FacetType::String => {