From 04fa44e7eb6568cf76fe52f56c8c0c3270bf32e9 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Tue, 23 Jul 2024 14:51:36 +0200 Subject: [PATCH] Implement localized attributes settings --- dump/src/lib.rs | 1 + dump/src/reader/compat/v5_to_v6.rs | 1 + meilisearch-types/src/error.rs | 1 + meilisearch-types/src/locales.rs | 26 ++++ meilisearch-types/src/settings.rs | 26 +++- .../src/routes/indexes/facet_search.rs | 7 ++ meilisearch/src/routes/indexes/settings.rs | 23 ++++ meilisearch/src/search/mod.rs | 22 +++- milli/src/heed_codec/mod.rs | 2 - milli/src/heed_codec/script_language_codec.rs | 39 ------ milli/src/index.rs | 92 ++++---------- milli/src/lib.rs | 4 + milli/src/localized_attributes_rules.rs | 114 ++++++++++++++++++ milli/src/update/clear_documents.rs | 2 - .../extract/extract_docid_word_positions.rs | 40 +++--- .../extract/extract_facet_string_docids.rs | 110 ++++++++++++----- milli/src/update/index_documents/mod.rs | 38 ------ milli/src/update/settings.rs | 66 +++++++++- 18 files changed, 405 insertions(+), 209 deletions(-) delete mode 100644 milli/src/heed_codec/script_language_codec.rs create mode 100644 milli/src/localized_attributes_rules.rs diff --git a/dump/src/lib.rs b/dump/src/lib.rs index 722633ec6..a17fcf941 100644 --- a/dump/src/lib.rs +++ b/dump/src/lib.rs @@ -286,6 +286,7 @@ pub(crate) mod test { pagination: Setting::NotSet, embedders: Setting::NotSet, search_cutoff_ms: Setting::NotSet, + localized_attributes: Setting::NotSet, _kind: std::marker::PhantomData, }; settings.check() diff --git a/dump/src/reader/compat/v5_to_v6.rs b/dump/src/reader/compat/v5_to_v6.rs index e6e030186..40a055465 100644 --- a/dump/src/reader/compat/v5_to_v6.rs +++ b/dump/src/reader/compat/v5_to_v6.rs @@ -379,6 +379,7 @@ impl From> for v6::Settings { v5::Setting::NotSet => v6::Setting::NotSet, }, embedders: v6::Setting::NotSet, + localized_attributes: v6::Setting::NotSet, search_cutoff_ms: v6::Setting::NotSet, _kind: std::marker::PhantomData, } diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index e56949b57..4d80fe9c9 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -298,6 +298,7 @@ InvalidSettingsSeparatorTokens , InvalidRequest , BAD_REQUEST ; InvalidSettingsDictionary , InvalidRequest , BAD_REQUEST ; InvalidSettingsSynonyms , InvalidRequest , BAD_REQUEST ; InvalidSettingsTypoTolerance , InvalidRequest , BAD_REQUEST ; +InvalidSettingsLocalizedAttributes , InvalidRequest , BAD_REQUEST ; InvalidState , Internal , INTERNAL_SERVER_ERROR ; InvalidStoreFile , Internal , INTERNAL_SERVER_ERROR ; InvalidSwapDuplicateIndexFound , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch-types/src/locales.rs b/meilisearch-types/src/locales.rs index 14972fc33..6f7fb3a40 100644 --- a/meilisearch-types/src/locales.rs +++ b/meilisearch-types/src/locales.rs @@ -130,3 +130,29 @@ make_locale! { Tgl, Hye } + +#[derive(Debug, Clone, PartialEq, Eq, Deserr, Serialize, Deserialize)] +#[deserr(rename_all = camelCase)] +#[serde(rename_all = "camelCase")] +pub struct LocalizedAttributesRuleView { + pub attribute_patterns: Vec, + pub locales: Vec, +} + +impl From for LocalizedAttributesRuleView { + fn from(rule: LocalizedAttributesRule) -> Self { + Self { + attribute_patterns: rule.attribute_patterns, + locales: rule.locales.into_iter().map(|l| l.into()).collect(), + } + } +} + +impl From for LocalizedAttributesRule { + fn from(view: LocalizedAttributesRuleView) -> Self { + Self { + attribute_patterns: view.attribute_patterns, + locales: view.locales.into_iter().map(|l| l.into()).collect(), + } + } +} diff --git a/meilisearch-types/src/settings.rs b/meilisearch-types/src/settings.rs index 8a9708d29..9e7a2bc15 100644 --- a/meilisearch-types/src/settings.rs +++ b/meilisearch-types/src/settings.rs @@ -17,6 +17,7 @@ use serde::{Deserialize, Serialize, Serializer}; use crate::deserr::DeserrJsonError; use crate::error::deserr_codes::*; use crate::facet_values_sort::FacetValuesSort; +use crate::locales::LocalizedAttributesRuleView; /// The maximum number of results that the engine /// will be able to return in one search call. @@ -198,6 +199,9 @@ pub struct Settings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default, error = DeserrJsonError)] pub search_cutoff_ms: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default, error = DeserrJsonError)] + pub localized_attributes: Setting>, #[serde(skip)] #[deserr(skip)] @@ -261,6 +265,7 @@ impl Settings { pagination: Setting::Reset, embedders: Setting::Reset, search_cutoff_ms: Setting::Reset, + localized_attributes: Setting::Reset, _kind: PhantomData, } } @@ -284,7 +289,8 @@ impl Settings { pagination, embedders, search_cutoff_ms, - .. + localized_attributes: localized_attributes_rules, + _kind, } = self; Settings { @@ -305,6 +311,7 @@ impl Settings { pagination, embedders, search_cutoff_ms, + localized_attributes: localized_attributes_rules, _kind: PhantomData, } } @@ -352,6 +359,7 @@ impl Settings { pagination: self.pagination, embedders: self.embedders, search_cutoff_ms: self.search_cutoff_ms, + localized_attributes: self.localized_attributes, _kind: PhantomData, } } @@ -402,6 +410,7 @@ pub fn apply_settings_to_builder( pagination, embedders, search_cutoff_ms, + localized_attributes: localized_attributes_rules, _kind, } = settings; @@ -485,6 +494,13 @@ pub fn apply_settings_to_builder( Setting::NotSet => (), } + match localized_attributes_rules { + Setting::Set(ref rules) => builder + .set_localized_attributes_rules(rules.iter().cloned().map(|r| r.into()).collect()), + Setting::Reset => builder.reset_localized_attributes_rules(), + Setting::NotSet => (), + } + match typo_tolerance { Setting::Set(ref value) => { match value.enabled { @@ -679,6 +695,8 @@ pub fn settings( let search_cutoff_ms = index.search_cutoff(rtxn)?; + let localized_attributes_rules = index.localized_attributes_rules(rtxn)?; + let mut settings = Settings { displayed_attributes: match displayed_attributes { Some(attrs) => Setting::Set(attrs), @@ -711,6 +729,10 @@ pub fn settings( Some(cutoff) => Setting::Set(cutoff), None => Setting::Reset, }, + localized_attributes: match localized_attributes_rules { + Some(rules) => Setting::Set(rules.into_iter().map(|r| r.into()).collect()), + None => Setting::Reset, + }, _kind: PhantomData, }; @@ -902,6 +924,7 @@ pub(crate) mod test { faceting: Setting::NotSet, pagination: Setting::NotSet, embedders: Setting::NotSet, + localized_attributes: Setting::NotSet, search_cutoff_ms: Setting::NotSet, _kind: PhantomData::, }; @@ -930,6 +953,7 @@ pub(crate) mod test { faceting: Setting::NotSet, pagination: Setting::NotSet, embedders: Setting::NotSet, + localized_attributes: Setting::NotSet, search_cutoff_ms: Setting::NotSet, _kind: PhantomData::, }; diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index ecb7757af..da575fdc4 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -6,6 +6,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::locales::Locale; use serde_json::Value; use tracing::debug; @@ -48,6 +49,8 @@ pub struct FacetSearchQuery { pub attributes_to_search_on: Option>, #[deserr(default, error = DeserrJsonError, default)] pub ranking_score_threshold: Option, + #[deserr(default, error = DeserrJsonError, default)] + pub locales: Option>, } pub async fn search( @@ -67,6 +70,7 @@ pub async fn search( let facet_query = query.facet_query.clone(); let facet_name = query.facet_name.clone(); + let locales = query.locales.clone().map(|l| l.into_iter().map(Into::into).collect()); let mut search_query = SearchQuery::from(query); // Tenant token search_rules. @@ -86,6 +90,7 @@ pub async fn search( facet_name, search_kind, index_scheduler.features(), + locales ) }) .await?; @@ -113,6 +118,7 @@ impl From for SearchQuery { attributes_to_search_on, hybrid, ranking_score_threshold, + locales, } = value; SearchQuery { @@ -141,6 +147,7 @@ impl From for SearchQuery { attributes_to_search_on, hybrid, ranking_score_threshold, + locales, } } } diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index e35ebc930..b62690295 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -474,6 +474,28 @@ make_setting_route!( } ); +make_setting_route!( + "/localized-attributes", + put, + Vec, + meilisearch_types::deserr::DeserrJsonError< + meilisearch_types::error::deserr_codes::InvalidSettingsLocalizedAttributes, + >, + localized_attributes, + "localizedAttributes", + analytics, + |rules: &Option>, req: &HttpRequest| { + use serde_json::json; + analytics.publish( + "LocalizedAttributesRules Updated".to_string(), + json!({ + "locales": rules.as_ref().map(|rules| rules.iter().map(|rule| rule.locales.iter().cloned()).flatten().collect::>()) + }), + Some(req), + ); + } +); + make_setting_route!( "/ranking-rules", put, @@ -786,6 +808,7 @@ pub async fn update_all( }, "embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set()), "search_cutoff_ms": new_settings.search_cutoff_ms.as_ref().set(), + "locales": new_settings.localized_attributes.as_ref().set().map(|rules| rules.into_iter().map(|rule| rule.locales.iter().cloned()).flatten().collect::>()), }), Some(&req), ); diff --git a/meilisearch/src/search/mod.rs b/meilisearch/src/search/mod.rs index d28d888aa..11bf4f84e 100644 --- a/meilisearch/src/search/mod.rs +++ b/meilisearch/src/search/mod.rs @@ -1290,6 +1290,9 @@ impl<'a> HitMaker<'a> { document.insert("_vectors".into(), vectors.into()); } + let localized_attributes = + self.index.localized_attributes_rules(self.rtxn)?.unwrap_or_default(); + let (matches_position, formatted) = format_fields( &displayed_document, &self.fields_ids_map, @@ -1298,6 +1301,7 @@ impl<'a> HitMaker<'a> { self.show_matches_position, &self.displayed_ids, self.locales.as_deref(), + &localized_attributes, )?; if let Some(sort) = self.sort.as_ref() { @@ -1365,6 +1369,14 @@ pub fn perform_facet_search( None => TimeBudget::default(), }; + let localized_attributes = index.localized_attributes_rules(&rtxn)?.unwrap_or_default(); + let locales = locales.or_else(|| { + localized_attributes + .into_iter() + .find(|attr| attr.match_str(&facet_name)) + .map(|attr| attr.locales) + }); + let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, &search_kind, time_budget, features)?; let mut facet_search = SearchForFacetValues::new( @@ -1653,6 +1665,7 @@ fn format_fields( compute_matches: bool, displayable_ids: &BTreeSet, locales: Option<&[Language]>, + localized_attributes: &[LocalizedAttributesRule], ) -> Result<(Option, Document), MeilisearchHttpError> { let mut matches_position = compute_matches.then(BTreeMap::new); let mut document = document.clone(); @@ -1685,7 +1698,14 @@ fn format_fields( .reduce(|acc, option| acc.merge(option)); let mut infos = Vec::new(); - *value = format_value(std::mem::take(value), builder, format, &mut infos, compute_matches); + // if no locales has been provided, we try to find the locales in the localized_attributes. + let locales = locales.or_else(|| { + localized_attributes + .iter() + .find(|rule| rule.match_str(key)) + .map(LocalizedAttributesRule::locales) + }); + *value = format_value( std::mem::take(value), builder, diff --git a/milli/src/heed_codec/mod.rs b/milli/src/heed_codec/mod.rs index 449d1955c..575b886bd 100644 --- a/milli/src/heed_codec/mod.rs +++ b/milli/src/heed_codec/mod.rs @@ -7,7 +7,6 @@ mod fst_set_codec; mod obkv_codec; mod roaring_bitmap; mod roaring_bitmap_length; -mod script_language_codec; mod str_beu32_codec; mod str_ref; mod str_str_u8_codec; @@ -26,7 +25,6 @@ pub use self::roaring_bitmap::{BoRoaringBitmapCodec, CboRoaringBitmapCodec, Roar pub use self::roaring_bitmap_length::{ BoRoaringBitmapLenCodec, CboRoaringBitmapLenCodec, RoaringBitmapLenCodec, }; -pub use self::script_language_codec::ScriptLanguageCodec; pub use self::str_beu32_codec::{StrBEU16Codec, StrBEU32Codec}; pub use self::str_str_u8_codec::{U8StrStrCodec, UncheckedU8StrStrCodec}; diff --git a/milli/src/heed_codec/script_language_codec.rs b/milli/src/heed_codec/script_language_codec.rs deleted file mode 100644 index 35f7af3c7..000000000 --- a/milli/src/heed_codec/script_language_codec.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::borrow::Cow; -use std::ffi::CStr; -use std::str; - -use charabia::{Language, Script}; -use heed::BoxedError; - -pub struct ScriptLanguageCodec; - -impl<'a> heed::BytesDecode<'a> for ScriptLanguageCodec { - type DItem = (Script, Language); - - fn bytes_decode(bytes: &'a [u8]) -> Result { - let cstr = CStr::from_bytes_until_nul(bytes)?; - let script = cstr.to_str()?; - let script_name = Script::from_name(script); - // skip '\0' byte between the two strings. - let lan = str::from_utf8(&bytes[script.len() + 1..])?; - let lan_name = Language::from_name(lan); - - Ok((script_name, lan_name)) - } -} - -impl<'a> heed::BytesEncode<'a> for ScriptLanguageCodec { - type EItem = (Script, Language); - - fn bytes_encode((script, lan): &Self::EItem) -> Result, BoxedError> { - let script_name = script.name().as_bytes(); - let lan_name = lan.name().as_bytes(); - - let mut bytes = Vec::with_capacity(script_name.len() + lan_name.len() + 1); - bytes.extend_from_slice(script_name); - bytes.push(0); - bytes.extend_from_slice(lan_name); - - Ok(Cow::Owned(bytes)) - } -} diff --git a/milli/src/index.rs b/milli/src/index.rs index 194f18faa..f5342f2c0 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -4,7 +4,6 @@ use std::convert::TryInto; use std::fs::File; use std::path::Path; -use charabia::{Language, Script}; use heed::types::*; use heed::{CompactionOption, Database, RoTxn, RwTxn, Unspecified}; use roaring::RoaringBitmap; @@ -19,9 +18,7 @@ use crate::heed_codec::facet::{ FacetGroupKeyCodec, FacetGroupValueCodec, FieldDocIdFacetF64Codec, FieldDocIdFacetStringCodec, FieldIdCodec, OrderedF64Codec, }; -use crate::heed_codec::{ - BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec, -}; +use crate::heed_codec::{BEU16StrCodec, FstSetCodec, StrBEU16Codec, StrRefCodec}; use crate::order_by_map::OrderByMap; use crate::proximity::ProximityPrecision; use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME; @@ -29,8 +26,8 @@ use crate::vector::{Embedding, EmbeddingConfig}; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdMapMissingEntry, FieldIdWordCountCodec, - FieldidsWeightsMap, GeoPoint, ObkvCodec, Result, RoaringBitmapCodec, RoaringBitmapLenCodec, - Search, U8StrStrCodec, Weight, BEU16, BEU32, BEU64, + FieldidsWeightsMap, GeoPoint, LocalizedAttributesRule, ObkvCodec, Result, RoaringBitmapCodec, + RoaringBitmapLenCodec, Search, U8StrStrCodec, Weight, BEU16, BEU32, BEU64, }; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; @@ -73,6 +70,7 @@ pub mod main_key { pub const PROXIMITY_PRECISION: &str = "proximity-precision"; pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; pub const SEARCH_CUTOFF: &str = "search_cutoff"; + pub const LOCALIZED_ATTRIBUTES_RULES: &str = "localized_attributes_rules"; } pub mod db_name { @@ -101,7 +99,6 @@ pub mod db_name { pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id"; pub const VECTOR_ARROY: &str = "vector-arroy"; pub const DOCUMENTS: &str = "documents"; - pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } #[derive(Clone)] @@ -142,9 +139,6 @@ pub struct Index { /// Maps the word prefix and a field id with all the docids where the prefix appears inside the field pub word_prefix_fid_docids: Database, - /// Maps the script and language with all the docids that corresponds to it. - pub script_language_docids: Database, - /// Maps the facet field id and the docids for which this field exists pub facet_id_exists_docids: Database, /// Maps the facet field id and the docids for which this field is set as null @@ -198,8 +192,6 @@ impl Index { env.create_database(&mut wtxn, Some(EXACT_WORD_PREFIX_DOCIDS))?; let word_pair_proximity_docids = env.create_database(&mut wtxn, Some(WORD_PAIR_PROXIMITY_DOCIDS))?; - let script_language_docids = - env.create_database(&mut wtxn, Some(SCRIPT_LANGUAGE_DOCIDS))?; let word_position_docids = env.create_database(&mut wtxn, Some(WORD_POSITION_DOCIDS))?; let word_fid_docids = env.create_database(&mut wtxn, Some(WORD_FIELD_ID_DOCIDS))?; let field_id_word_count_docids = @@ -243,7 +235,6 @@ impl Index { word_prefix_docids, exact_word_prefix_docids, word_pair_proximity_docids, - script_language_docids, word_position_docids, word_fid_docids, word_prefix_position_docids, @@ -1562,69 +1553,32 @@ impl Index { self.main.remap_key_type::().delete(txn, main_key::PROXIMITY_PRECISION) } - /* script language docids */ - /// Retrieve all the documents ids that correspond with (Script, Language) key, `None` if it is any. - pub fn script_language_documents_ids( + pub fn localized_attributes_rules( &self, rtxn: &RoTxn<'_>, - key: &(Script, Language), - ) -> heed::Result> { - self.script_language_docids.get(rtxn, key) + ) -> heed::Result>> { + self.main + .remap_types::>>() + .get(rtxn, main_key::LOCALIZED_ATTRIBUTES_RULES) } - pub fn script_language( + pub(crate) fn put_localized_attributes_rules( &self, - rtxn: &RoTxn<'_>, - ) -> heed::Result>> { - let mut script_language: HashMap> = HashMap::new(); - let mut script_language_doc_count: Vec<(Script, Language, u64)> = Vec::new(); - let mut total = 0; - for sl in self.script_language_docids.iter(rtxn)? { - let ((script, language), docids) = sl?; - - // keep only Languages that contains at least 1 document. - let remaining_documents_count = docids.len(); - total += remaining_documents_count; - if remaining_documents_count > 0 { - script_language_doc_count.push((script, language, remaining_documents_count)); - } - } - - let threshold = total / 20; // 5% (arbitrary) - for (script, language, count) in script_language_doc_count { - if count > threshold { - if let Some(languages) = script_language.get_mut(&script) { - (*languages).push(language); - } else { - script_language.insert(script, vec![language]); - } - } - } - - Ok(script_language) + txn: &mut RwTxn<'_>, + val: Vec, + ) -> heed::Result<()> { + self.main.remap_types::>>().put( + txn, + main_key::LOCALIZED_ATTRIBUTES_RULES, + &val, + ) } - pub fn languages(&self, rtxn: &RoTxn<'_>) -> heed::Result> { - let mut script_language_doc_count: Vec<(Language, u64)> = Vec::new(); - let mut total = 0; - for sl in self.script_language_docids.iter(rtxn)? { - let ((_script, language), docids) = sl?; - - // keep only Languages that contains at least 1 document. - let remaining_documents_count = docids.len(); - total += remaining_documents_count; - if remaining_documents_count > 0 { - script_language_doc_count.push((language, remaining_documents_count)); - } - } - - let threshold = total / 20; // 5% (arbitrary) - - Ok(script_language_doc_count - .into_iter() - .filter(|(_, count)| *count > threshold) - .map(|(language, _)| language) - .collect()) + pub(crate) fn delete_localized_attributes_rules( + &self, + txn: &mut RwTxn<'_>, + ) -> heed::Result { + self.main.remap_key_type::().delete(txn, main_key::LOCALIZED_ATTRIBUTES_RULES) } /// Put the embedding configs: diff --git a/milli/src/lib.rs b/milli/src/lib.rs index fcb0da19c..461971ddf 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -16,6 +16,7 @@ pub mod facet; mod fields_ids_map; pub mod heed_codec; pub mod index; +mod localized_attributes_rules; pub mod order_by_map; pub mod prompt; pub mod proximity; @@ -69,6 +70,9 @@ pub use self::search::{ Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; +pub use self::localized_attributes_rules::LocalizedAttributesRule; +use self::localized_attributes_rules::LocalizedFieldIds; + pub type Result = std::result::Result; pub type Attribute = u32; diff --git a/milli/src/localized_attributes_rules.rs b/milli/src/localized_attributes_rules.rs new file mode 100644 index 000000000..a3b3e820b --- /dev/null +++ b/milli/src/localized_attributes_rules.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use charabia::Language; +use serde::{Deserialize, Serialize}; + +use crate::fields_ids_map::FieldsIdsMap; +use crate::FieldId; + +/// A rule that defines which locales are supported for a given attribute. +/// +/// The rule is a list of attribute patterns and a list of locales. +/// The attribute patterns are matched against the attribute name. +/// The pattern `*` matches any attribute name. +/// The pattern `attribute_name*` matches any attribute name that starts with `attribute_name`. +/// The pattern `*attribute_name` matches any attribute name that ends with `attribute_name`. +/// The pattern `*attribute_name*` matches any attribute name that contains `attribute_name`. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct LocalizedAttributesRule { + pub attribute_patterns: Vec, + pub locales: Vec, +} + +impl LocalizedAttributesRule { + pub fn new(attribute_patterns: Vec, locales: Vec) -> Self { + Self { attribute_patterns, locales } + } + + pub fn match_str(&self, str: &str) -> bool { + self.attribute_patterns.iter().any(|pattern| match_pattern(pattern.as_str(), str)) + } + + pub fn locales(&self) -> &[Language] { + &self.locales + } +} + +fn match_pattern(pattern: &str, str: &str) -> bool { + let res = if pattern == "*" { + true + } else if pattern.starts_with('*') && pattern.ends_with('*') { + str.contains(&pattern[1..pattern.len() - 1]) + } else if pattern.ends_with('*') { + str.starts_with(&pattern[..pattern.len() - 1]) + } else if pattern.starts_with('*') { + str.ends_with(&pattern[1..]) + } else { + pattern == str + }; + + res +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocalizedFieldIds { + field_id_to_locales: HashMap>, +} + +impl LocalizedFieldIds { + pub fn new>( + rules: &Option>, + fields_ids_map: &FieldsIdsMap, + fields_ids: I, + ) -> Self { + let mut field_id_to_locales = HashMap::new(); + + if let Some(rules) = rules { + let fields = fields_ids.filter_map(|field_id| { + fields_ids_map.name(field_id).map(|field_name| (field_id, field_name)) + }); + + for (field_id, field_name) in fields { + let mut locales = Vec::new(); + for rule in rules { + if rule.match_str(field_name) { + locales.extend(rule.locales.iter()); + } + } + + if !locales.is_empty() { + locales.sort(); + locales.dedup(); + field_id_to_locales.insert(field_id, locales); + } + } + } + + Self { field_id_to_locales } + } + + pub fn locales<'a>(&'a self, fields_id: FieldId) -> Option<&'a [Language]> { + self.field_id_to_locales.get(&fields_id).map(Vec::as_slice) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_match_pattern() { + assert_eq!(match_pattern("*", "test"), true); + assert_eq!(match_pattern("test*", "test"), true); + assert_eq!(match_pattern("test*", "testa"), true); + assert_eq!(match_pattern("*test", "test"), true); + assert_eq!(match_pattern("*test", "atest"), true); + assert_eq!(match_pattern("*test*", "test"), true); + assert_eq!(match_pattern("*test*", "atesta"), true); + assert_eq!(match_pattern("*test*", "atest"), true); + assert_eq!(match_pattern("*test*", "testa"), true); + assert_eq!(match_pattern("test*test", "test"), false); + assert_eq!(match_pattern("*test", "testa"), false); + assert_eq!(match_pattern("test*", "atest"), false); + } +} diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 9eca378a5..6c4efb859 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -36,7 +36,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { field_id_word_count_docids, word_prefix_position_docids, word_prefix_fid_docids, - script_language_docids, facet_id_f64_docids, facet_id_string_docids, facet_id_normalized_string_strings, @@ -83,7 +82,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { field_id_word_count_docids.clear(self.wtxn)?; word_prefix_position_docids.clear(self.wtxn)?; word_prefix_fid_docids.clear(self.wtxn)?; - script_language_docids.clear(self.wtxn)?; facet_id_f64_docids.clear(self.wtxn)?; facet_id_normalized_string_strings.clear(self.wtxn)?; facet_id_string_fst.clear(self.wtxn)?; diff --git a/milli/src/update/index_documents/extract/extract_docid_word_positions.rs b/milli/src/update/index_documents/extract/extract_docid_word_positions.rs index 748a3886a..ba11ceeb3 100644 --- a/milli/src/update/index_documents/extract/extract_docid_word_positions.rs +++ b/milli/src/update/index_documents/extract/extract_docid_word_positions.rs @@ -3,7 +3,7 @@ use std::fs::File; use std::io::BufReader; use std::{io, mem, str}; -use charabia::{Language, SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder}; +use charabia::{SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder}; use obkv::{KvReader, KvWriterU16}; use roaring::RoaringBitmap; use serde_json::Value; @@ -11,7 +11,7 @@ use serde_json::Value; use super::helpers::{create_sorter, keep_latest_obkv, sorter_into_reader, GrenadParameters}; use crate::error::{InternalError, SerializationError}; use crate::update::del_add::{del_add_from_two_obkvs, DelAdd, KvReaderDelAdd}; -use crate::update::settings::InnerIndexSettingsDiff; +use crate::update::settings::{InnerIndexSettings, InnerIndexSettingsDiff}; use crate::{FieldId, Result, MAX_POSITION_PER_ATTRIBUTE, MAX_WORD_LENGTH}; /// Extracts the word and positions where this word appear and @@ -57,13 +57,9 @@ pub fn extract_docid_word_positions( .map(|s| s.iter().map(String::as_str).collect()); let old_dictionary: Option> = settings_diff.old.dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect()); - let mut del_builder = tokenizer_builder( - old_stop_words, - old_separators.as_deref(), - old_dictionary.as_deref(), - None, - ); - let del_tokenizer = del_builder.build(); + let del_builder = + tokenizer_builder(old_stop_words, old_separators.as_deref(), old_dictionary.as_deref()); + let del_tokenizer = del_builder.into_tokenizer(); let new_stop_words = settings_diff.new.stop_words.as_ref(); let new_separators: Option> = settings_diff @@ -73,13 +69,9 @@ pub fn extract_docid_word_positions( .map(|s| s.iter().map(String::as_str).collect()); let new_dictionary: Option> = settings_diff.new.dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect()); - let mut add_builder = tokenizer_builder( - new_stop_words, - new_separators.as_deref(), - new_dictionary.as_deref(), - None, - ); - let add_tokenizer = add_builder.build(); + let add_builder = + tokenizer_builder(new_stop_words, new_separators.as_deref(), new_dictionary.as_deref()); + let add_tokenizer = add_builder.into_tokenizer(); // iterate over documents. let mut cursor = obkv_documents.into_cursor()?; @@ -107,7 +99,7 @@ pub fn extract_docid_word_positions( // deletions tokens_from_document( &obkv, - &settings_diff.old.searchable_fields_ids, + &settings_diff.old, &del_tokenizer, max_positions_per_attributes, DelAdd::Deletion, @@ -118,7 +110,7 @@ pub fn extract_docid_word_positions( // additions tokens_from_document( &obkv, - &settings_diff.new.searchable_fields_ids, + &settings_diff.new, &add_tokenizer, max_positions_per_attributes, DelAdd::Addition, @@ -180,7 +172,6 @@ fn tokenizer_builder<'a>( stop_words: Option<&'a fst::Set>>, allowed_separators: Option<&'a [&str]>, dictionary: Option<&'a [&str]>, - languages: Option<&'a Vec>, ) -> TokenizerBuilder<'a, Vec> { let mut tokenizer_builder = TokenizerBuilder::new(); if let Some(stop_words) = stop_words { @@ -193,17 +184,13 @@ fn tokenizer_builder<'a>( tokenizer_builder.separators(separators); } - if let Some(languages) = languages { - tokenizer_builder.allow_list(languages); - } - tokenizer_builder } /// Extract words mapped with their positions of a document. fn tokens_from_document<'a>( obkv: &KvReader<'a, FieldId>, - searchable_fields: &[FieldId], + settings: &InnerIndexSettings, tokenizer: &Tokenizer<'_>, max_positions_per_attributes: u32, del_add: DelAdd, @@ -213,7 +200,7 @@ fn tokens_from_document<'a>( let mut document_writer = KvWriterU16::new(&mut buffers.obkv_buffer); for (field_id, field_bytes) in obkv.iter() { // if field is searchable. - if searchable_fields.as_ref().contains(&field_id) { + if settings.searchable_fields_ids.contains(&field_id) { // extract deletion or addition only. if let Some(field_bytes) = KvReaderDelAdd::new(field_bytes).get(del_add) { // parse json. @@ -228,7 +215,8 @@ fn tokens_from_document<'a>( buffers.field_buffer.clear(); if let Some(field) = json_to_string(&value, &mut buffers.field_buffer) { // create an iterator of token with their positions. - let tokens = process_tokens(tokenizer.tokenize(field)) + let locales = settings.localized_searchable_fields_ids.locales(field_id); + let tokens = process_tokens(tokenizer.tokenize_with_allow_list(field, locales)) .take_while(|(p, _)| (*p as u32) < max_positions_per_attributes); for (index, token) in tokens { diff --git a/milli/src/update/index_documents/extract/extract_facet_string_docids.rs b/milli/src/update/index_documents/extract/extract_facet_string_docids.rs index 3deace127..6452a67a1 100644 --- a/milli/src/update/index_documents/extract/extract_facet_string_docids.rs +++ b/milli/src/update/index_documents/extract/extract_facet_string_docids.rs @@ -5,6 +5,7 @@ use std::iter::FromIterator; use std::{io, str}; use charabia::normalizer::{Normalize, NormalizerOption}; +use charabia::{Language, StrDetection, Token}; use heed::types::SerdeJson; use heed::BytesEncode; @@ -26,10 +27,9 @@ use crate::{FieldId, Result, MAX_FACET_VALUE_LENGTH}; pub fn extract_facet_string_docids( docid_fid_facet_string: grenad::Reader, indexer: GrenadParameters, - _settings_diff: &InnerIndexSettingsDiff, + settings_diff: &InnerIndexSettingsDiff, ) -> Result<(grenad::Reader>, grenad::Reader>)> { let max_memory = indexer.max_memory_by_thread(); - let options = NormalizerOption { lossy: true, ..Default::default() }; let mut facet_string_docids_sorter = create_sorter( grenad::SortAlgorithm::Stable, @@ -54,12 +54,8 @@ pub fn extract_facet_string_docids( while let Some((key, deladd_original_value_bytes)) = cursor.move_on_next()? { let deladd_reader = KvReaderDelAdd::new(deladd_original_value_bytes); - // nothing to do if we delete and re-add the value. - if deladd_reader.get(DelAdd::Deletion).is_some() - && deladd_reader.get(DelAdd::Addition).is_some() - { - continue; - } + let is_same_value = deladd_reader.get(DelAdd::Deletion).is_some() + && deladd_reader.get(DelAdd::Addition).is_some(); let (field_id_bytes, bytes) = try_split_array_at(key).unwrap(); let field_id = FieldId::from_be_bytes(field_id_bytes); @@ -72,29 +68,66 @@ pub fn extract_facet_string_docids( // Facet search normalization { - let mut hyper_normalized_value = normalized_value.normalize(&options); - let normalized_truncated_facet: String; - if hyper_normalized_value.len() > MAX_FACET_VALUE_LENGTH { - normalized_truncated_facet = hyper_normalized_value - .char_indices() - .take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH) - .map(|(_, c)| c) - .collect(); - hyper_normalized_value = normalized_truncated_facet.into(); - } + let locales = settings_diff.old.localized_faceted_fields_ids.locales(field_id); + let old_hyper_normalized_value = normalize_facet_string(normalized_value, locales); + let locales = settings_diff.new.localized_faceted_fields_ids.locales(field_id); + let new_hyper_normalized_value = normalize_facet_string(normalized_value, locales); + let set = BTreeSet::from_iter(std::iter::once(normalized_value)); - buffer.clear(); - let mut obkv = KvWriterDelAdd::new(&mut buffer); - for (deladd_key, _) in deladd_reader.iter() { - let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?; - obkv.insert(deladd_key, val)?; - } - obkv.finish()?; + // if the facet string is the same, we can put the deletion and addition in the same obkv. + if old_hyper_normalized_value == new_hyper_normalized_value { + // nothing to do if we delete and re-add the value. + if is_same_value { + continue; + } - let key = (field_id, hyper_normalized_value.as_ref()); - let key_bytes = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; - normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?; + buffer.clear(); + let mut obkv = KvWriterDelAdd::new(&mut buffer); + for (deladd_key, _) in deladd_reader.iter() { + let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?; + obkv.insert(deladd_key, val)?; + } + obkv.finish()?; + + let key: (u16, &str) = (field_id, new_hyper_normalized_value.as_ref()); + let key_bytes = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; + normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?; + } else { + // if the facet string is different, we need to insert the deletion and addition in different obkv because the related key is different. + // deletion + if deladd_reader.get(DelAdd::Deletion).is_some() { + // insert old value + let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?; + buffer.clear(); + let mut obkv = KvWriterDelAdd::new(&mut buffer); + obkv.insert(DelAdd::Deletion, val)?; + obkv.finish()?; + let key: (u16, &str) = (field_id, old_hyper_normalized_value.as_ref()); + let key_bytes = + BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; + normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?; + } + + // addition + if deladd_reader.get(DelAdd::Addition).is_some() { + // insert new value + let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?; + buffer.clear(); + let mut obkv = KvWriterDelAdd::new(&mut buffer); + obkv.insert(DelAdd::Addition, val)?; + obkv.finish()?; + let key: (u16, &str) = (field_id, new_hyper_normalized_value.as_ref()); + let key_bytes = + BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?; + normalized_facet_string_docids_sorter.insert(key_bytes, &buffer)?; + } + } + } + + // nothing to do if we delete and re-add the value. + if is_same_value { + continue; } let key = FacetGroupKey { field_id, level: 0, left_bound: normalized_value }; @@ -112,3 +145,24 @@ pub fn extract_facet_string_docids( let normalized = sorter_into_reader(normalized_facet_string_docids_sorter, indexer)?; sorter_into_reader(facet_string_docids_sorter, indexer).map(|s| (s, normalized)) } + +/// Normalizes the facet string and truncates it to the max length. +fn normalize_facet_string(facet_string: &str, locales: Option<&[Language]>) -> String { + let options = NormalizerOption { lossy: true, ..Default::default() }; + let mut detection = StrDetection::new(facet_string, locales); + let token = Token { + lemma: std::borrow::Cow::Borrowed(facet_string), + script: detection.script(), + language: detection.language(), + ..Default::default() + }; + + // truncate the facet string to the max length + token + .normalize(&options) + .lemma + .char_indices() + .take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH) + .map(|(_, c)| c) + .collect() +} diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 2521b778f..1df31fff2 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -3388,44 +3388,6 @@ mod tests { wtxn.commit().unwrap(); } - #[test] - #[cfg(feature = "all-tokenizations")] - fn stored_detected_script_and_language_should_not_return_deleted_documents() { - use charabia::{Language, Script}; - let index = TempIndex::new(); - let mut wtxn = index.write_txn().unwrap(); - index - .add_documents_using_wtxn( - &mut wtxn, - documents!([ - { "id": "0", "title": "The quick (\"brown\") fox can't jump 32.3 feet, right? Brr, it's 29.3°F!" }, - { "id": "1", "title": "人人生而自由﹐在尊嚴和權利上一律平等。他們賦有理性和良心﹐並應以兄弟關係的精神互相對待。" }, - { "id": "2", "title": "הַשּׁוּעָל הַמָּהִיר (״הַחוּם״) לֹא יָכוֹל לִקְפֹּץ 9.94 מֶטְרִים, נָכוֹן? ברר, 1.5°C- בַּחוּץ!" }, - { "id": "3", "title": "関西国際空港限定トートバッグ すもももももももものうち" }, - { "id": "4", "title": "ภาษาไทยง่ายนิดเดียว" }, - { "id": "5", "title": "The quick 在尊嚴和權利上一律平等。" }, - ])) - .unwrap(); - - let key_cmn = (Script::Cj, Language::Cmn); - let cj_cmn_docs = - index.script_language_documents_ids(&wtxn, &key_cmn).unwrap().unwrap_or_default(); - let mut expected_cj_cmn_docids = RoaringBitmap::new(); - expected_cj_cmn_docids.push(1); - expected_cj_cmn_docids.push(5); - assert_eq!(cj_cmn_docs, expected_cj_cmn_docids); - - delete_documents(&mut wtxn, &index, &["1"]); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - let cj_cmn_docs = - index.script_language_documents_ids(&rtxn, &key_cmn).unwrap().unwrap_or_default(); - let mut expected_cj_cmn_docids = RoaringBitmap::new(); - expected_cj_cmn_docids.push(5); - assert_eq!(cj_cmn_docs, expected_cj_cmn_docids); - } - #[test] fn delete_words_exact_attributes() { let index = TempIndex::new(); diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 448c74fd8..2cac2777d 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -28,7 +28,7 @@ use crate::vector::settings::{ WriteBackToDocuments, }; use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; -use crate::{FieldId, FieldsIdsMap, Index, Result}; +use crate::{FieldId, FieldsIdsMap, Index, LocalizedAttributesRule, LocalizedFieldIds, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum Setting { @@ -159,6 +159,7 @@ pub struct Settings<'a, 't, 'i> { proximity_precision: Setting, embedder_settings: Setting>>, search_cutoff: Setting, + localized_attributes_rules: Setting>, } impl<'a, 't, 'i> Settings<'a, 't, 'i> { @@ -193,6 +194,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { proximity_precision: Setting::NotSet, embedder_settings: Setting::NotSet, search_cutoff: Setting::NotSet, + localized_attributes_rules: Setting::NotSet, indexer_config, } } @@ -391,6 +393,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { self.search_cutoff = Setting::Reset; } + pub fn set_localized_attributes_rules(&mut self, value: Vec) { + self.localized_attributes_rules = Setting::Set(value); + } + + pub fn reset_localized_attributes_rules(&mut self) { + self.localized_attributes_rules = Setting::Reset; + } + #[tracing::instrument( level = "trace" skip(self, progress_callback, should_abort, settings_diff), @@ -1118,6 +1128,24 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Ok(changed) } + fn update_localized_attributes_rules(&mut self) -> Result { + let changed = match &self.localized_attributes_rules { + Setting::Set(new) => { + let old = self.index.localized_attributes_rules(self.wtxn)?; + if old.as_ref() == Some(new) { + false + } else { + self.index.put_localized_attributes_rules(self.wtxn, new.clone())?; + true + } + } + Setting::Reset => self.index.delete_localized_attributes_rules(self.wtxn)?, + Setting::NotSet => false, + }; + + Ok(changed) + } + pub fn execute(mut self, progress_callback: FP, should_abort: FA) -> Result<()> where FP: Fn(UpdateIndexingStep) + Sync, @@ -1151,6 +1179,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { self.update_searchable()?; self.update_exact_attributes()?; self.update_proximity_precision()?; + self.update_localized_attributes_rules()?; let embedding_config_updates = self.update_embedding_configs()?; @@ -1229,6 +1258,8 @@ impl InnerIndexSettingsDiff { || old_settings.allowed_separators != new_settings.allowed_separators || old_settings.dictionary != new_settings.dictionary || old_settings.proximity_precision != new_settings.proximity_precision + || old_settings.localized_searchable_fields_ids + != new_settings.localized_searchable_fields_ids }; let cache_exact_attributes = old_settings.exact_attributes != new_settings.exact_attributes; @@ -1304,6 +1335,7 @@ impl InnerIndexSettingsDiff { } (existing_fields - old_faceted_fields) != (existing_fields - new_faceted_fields) + || self.old.localized_faceted_fields_ids != self.new.localized_faceted_fields_ids } pub fn reindex_vectors(&self) -> bool { @@ -1341,6 +1373,8 @@ pub(crate) struct InnerIndexSettings { pub geo_fields_ids: Option<(FieldId, FieldId)>, pub non_searchable_fields_ids: Vec, pub non_faceted_fields_ids: Vec, + pub localized_searchable_fields_ids: LocalizedFieldIds, + pub localized_faceted_fields_ids: LocalizedFieldIds, } impl InnerIndexSettings { @@ -1382,6 +1416,17 @@ impl InnerIndexSettings { } None => None, }; + let localized_attributes_rules = index.localized_attributes_rules(rtxn)?; + let localized_searchable_fields_ids = LocalizedFieldIds::new( + &localized_attributes_rules, + &fields_ids_map, + searchable_fields_ids.iter().cloned(), + ); + let localized_faceted_fields_ids = LocalizedFieldIds::new( + &localized_attributes_rules, + &fields_ids_map, + faceted_fields_ids.iter().cloned(), + ); let vectors_fids = fields_ids_map.nested_ids(RESERVED_VECTORS_FIELD_NAME); searchable_fields_ids.retain(|id| !vectors_fids.contains(id)); @@ -1403,6 +1448,8 @@ impl InnerIndexSettings { geo_fields_ids, non_searchable_fields_ids: vectors_fids.clone(), non_faceted_fields_ids: vectors_fids.clone(), + localized_searchable_fields_ids, + localized_faceted_fields_ids, }) } @@ -1418,6 +1465,12 @@ impl InnerIndexSettings { index.put_faceted_fields(wtxn, &new_facets)?; self.faceted_fields_ids = index.faceted_fields_ids(wtxn)?; + let localized_attributes_rules = index.localized_attributes_rules(wtxn)?; + self.localized_faceted_fields_ids = LocalizedFieldIds::new( + &localized_attributes_rules, + &self.fields_ids_map, + self.faceted_fields_ids.iter().cloned(), + ); Ok(()) } @@ -1441,8 +1494,13 @@ impl InnerIndexSettings { &self.fields_ids_map, )?; } - let searchable_fields_ids = index.searchable_fields_ids(wtxn)?; - self.searchable_fields_ids = searchable_fields_ids; + self.searchable_fields_ids = index.searchable_fields_ids(wtxn)?; + let localized_attributes_rules = index.localized_attributes_rules(wtxn)?; + self.localized_searchable_fields_ids = LocalizedFieldIds::new( + &localized_attributes_rules, + &self.fields_ids_map, + self.searchable_fields_ids.iter().cloned(), + ); Ok(()) } @@ -2573,6 +2631,7 @@ mod tests { proximity_precision, embedder_settings, search_cutoff, + localized_attributes_rules, } = settings; assert!(matches!(searchable_fields, Setting::NotSet)); assert!(matches!(displayed_fields, Setting::NotSet)); @@ -2597,6 +2656,7 @@ mod tests { assert!(matches!(proximity_precision, Setting::NotSet)); assert!(matches!(embedder_settings, Setting::NotSet)); assert!(matches!(search_cutoff, Setting::NotSet)); + assert!(matches!(localized_attributes_rules, Setting::NotSet)); }) .unwrap(); }