From da48506f151adf79d398d89bf8ac24e249cc0338 Mon Sep 17 00:00:00 2001
From: ManyTheFish <many@meilisearch.com>
Date: Tue, 7 Mar 2023 18:35:26 +0100
Subject: [PATCH] Rerun extraction when language detection might have failed

---
 .../extract/extract_docid_word_positions.rs   | 177 ++++++++++++++----
 1 file changed, 143 insertions(+), 34 deletions(-)

diff --git a/milli/src/update/index_documents/extract/extract_docid_word_positions.rs b/milli/src/update/index_documents/extract/extract_docid_word_positions.rs
index 2d51fcc1a..5a103f1e0 100644
--- a/milli/src/update/index_documents/extract/extract_docid_word_positions.rs
+++ b/milli/src/update/index_documents/extract/extract_docid_word_positions.rs
@@ -3,12 +3,14 @@ use std::convert::TryInto;
 use std::fs::File;
 use std::{io, mem, str};
 
-use charabia::{Language, Script, SeparatorKind, Token, TokenKind, TokenizerBuilder};
+use charabia::{Language, Script, SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder};
+use obkv::KvReader;
 use roaring::RoaringBitmap;
 use serde_json::Value;
 
 use super::helpers::{concat_u32s_array, create_sorter, sorter_into_reader, GrenadParameters};
 use crate::error::{InternalError, SerializationError};
+use crate::update::index_documents::MergeFn;
 use crate::{
     absolute_from_relative_position, FieldId, Result, MAX_POSITION_PER_ATTRIBUTE, MAX_WORD_LENGTH,
 };
@@ -33,7 +35,7 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
     let max_memory = indexer.max_memory_by_thread();
 
     let mut documents_ids = RoaringBitmap::new();
-    let mut script_language_pair = HashMap::new();
+    let mut script_language_docids = HashMap::new();
     let mut docid_word_positions_sorter = create_sorter(
         grenad::SortAlgorithm::Stable,
         concat_u32s_array,
@@ -45,11 +47,11 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
 
     let mut key_buffer = Vec::new();
     let mut field_buffer = String::new();
-    let mut builder = TokenizerBuilder::new();
+    let mut tokenizer_builder = TokenizerBuilder::new();
     if let Some(stop_words) = stop_words {
-        builder.stop_words(stop_words);
+        tokenizer_builder.stop_words(stop_words);
     }
-    let tokenizer = builder.build();
+    let tokenizer = tokenizer_builder.build();
 
     let mut cursor = obkv_documents.into_cursor()?;
     while let Some((key, value)) = cursor.move_on_next()? {
@@ -57,49 +59,120 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
             .try_into()
             .map(u32::from_be_bytes)
             .map_err(|_| SerializationError::InvalidNumberSerialization)?;
-        let obkv = obkv::KvReader::<FieldId>::new(value);
+        let obkv = KvReader::<FieldId>::new(value);
 
         documents_ids.push(document_id);
         key_buffer.clear();
         key_buffer.extend_from_slice(&document_id.to_be_bytes());
 
-        for (field_id, field_bytes) in obkv.iter() {
-            if searchable_fields.as_ref().map_or(true, |sf| sf.contains(&field_id)) {
-                let value =
-                    serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?;
-                field_buffer.clear();
-                if let Some(field) = json_to_string(&value, &mut field_buffer) {
-                    let tokens = process_tokens(tokenizer.tokenize(field))
-                        .take_while(|(p, _)| (*p as u32) < max_positions_per_attributes);
+        let mut script_language_word_count = HashMap::new();
 
-                    for (index, token) in tokens {
-                        if let Some(language) = token.language {
-                            let script = token.script;
-                            let entry = script_language_pair
-                                .entry((script, language))
-                                .or_insert_with(RoaringBitmap::new);
-                            entry.push(document_id);
-                        }
-                        let token = token.lemma().trim();
-                        if !token.is_empty() && token.len() <= MAX_WORD_LENGTH {
-                            key_buffer.truncate(mem::size_of::<u32>());
-                            key_buffer.extend_from_slice(token.as_bytes());
+        extract_tokens_from_document(
+            &obkv,
+            searchable_fields,
+            &tokenizer,
+            max_positions_per_attributes,
+            &mut key_buffer,
+            &mut field_buffer,
+            &mut script_language_word_count,
+            &mut docid_word_positions_sorter,
+        )?;
 
-                            let position: u16 = index
-                                .try_into()
-                                .map_err(|_| SerializationError::InvalidNumberSerialization)?;
-                            let position = absolute_from_relative_position(field_id, position);
-                            docid_word_positions_sorter
-                                .insert(&key_buffer, position.to_ne_bytes())?;
+        // if we detect a potetial mistake in the language detection,
+        // we rerun the extraction forcing the tokenizer to detect the most frequently detected Languages.
+        // context: https://github.com/meilisearch/meilisearch/issues/3565
+        if script_language_word_count.values().any(potential_language_detection_error) {
+            // build an allow list with the most frequent detected languages in the document.
+            let script_language: HashMap<_, _> =
+                script_language_word_count.iter().filter_map(most_frequent_languages).collect();
+
+            // if the allow list is empty, meaning that no Language is considered frequent,
+            // then we don't rerun the extraction.
+            if !script_language.is_empty() {
+                // build a new temporar tokenizer including the allow list.
+                let mut tokenizer_builder = TokenizerBuilder::new();
+                if let Some(stop_words) = stop_words {
+                    tokenizer_builder.stop_words(stop_words);
+                }
+                tokenizer_builder.allow_list(&script_language);
+                let tokenizer = tokenizer_builder.build();
+
+                script_language_word_count.clear();
+
+                // rerun the extraction.
+                extract_tokens_from_document(
+                    &obkv,
+                    searchable_fields,
+                    &tokenizer,
+                    max_positions_per_attributes,
+                    &mut key_buffer,
+                    &mut field_buffer,
+                    &mut script_language_word_count,
+                    &mut docid_word_positions_sorter,
+                )?;
+            }
+        }
+
+        for (script, languages_frequency) in script_language_word_count {
+            for (language, _) in languages_frequency {
+                let entry = script_language_docids
+                    .entry((script, language))
+                    .or_insert_with(RoaringBitmap::new);
+                entry.push(document_id);
+            }
+        }
+    }
+
+    sorter_into_reader(docid_word_positions_sorter, indexer)
+        .map(|reader| (documents_ids, reader, script_language_docids))
+}
+
+fn extract_tokens_from_document<T: AsRef<[u8]>>(
+    obkv: &KvReader<FieldId>,
+    searchable_fields: &Option<HashSet<FieldId>>,
+    tokenizer: &Tokenizer<T>,
+    max_positions_per_attributes: u32,
+    key_buffer: &mut Vec<u8>,
+    field_buffer: &mut String,
+    script_language_word_count: &mut HashMap<Script, Vec<(Language, usize)>>,
+    docid_word_positions_sorter: &mut grenad::Sorter<MergeFn>,
+) -> Result<()> {
+    for (field_id, field_bytes) in obkv.iter() {
+        if searchable_fields.as_ref().map_or(true, |sf| sf.contains(&field_id)) {
+            let value = serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?;
+            field_buffer.clear();
+            if let Some(field) = json_to_string(&value, field_buffer) {
+                let tokens = process_tokens(tokenizer.tokenize(field))
+                    .take_while(|(p, _)| (*p as u32) < max_positions_per_attributes);
+
+                for (index, token) in tokens {
+                    // if a language has been detected for the token, we update the counter.
+                    if let Some(language) = token.language {
+                        let script = token.script;
+                        let entry =
+                            script_language_word_count.entry(script).or_insert_with(Vec::new);
+                        match entry.iter_mut().find(|(l, _)| *l == language) {
+                            Some((_, n)) => *n += 1,
+                            None => entry.push((language, 1)),
                         }
                     }
+                    let token = token.lemma().trim();
+                    if !token.is_empty() && token.len() <= MAX_WORD_LENGTH {
+                        key_buffer.truncate(mem::size_of::<u32>());
+                        key_buffer.extend_from_slice(token.as_bytes());
+
+                        let position: u16 = index
+                            .try_into()
+                            .map_err(|_| SerializationError::InvalidNumberSerialization)?;
+                        let position = absolute_from_relative_position(field_id, position);
+                        docid_word_positions_sorter.insert(&key_buffer, position.to_ne_bytes())?;
+                    }
                 }
             }
         }
     }
 
-    sorter_into_reader(docid_word_positions_sorter, indexer)
-        .map(|reader| (documents_ids, reader, script_language_pair))
+    Ok(())
 }
 
 /// Transform a JSON value into a string that can be indexed.
@@ -183,3 +256,39 @@ fn process_tokens<'a>(
         })
         .filter(|(_, t)| t.is_word())
 }
+
+fn potential_language_detection_error(languages_frequency: &Vec<(Language, usize)>) -> bool {
+    if languages_frequency.len() > 1 {
+        let threshold = compute_laguage_frequency_threshold(languages_frequency);
+        languages_frequency.iter().any(|(_, c)| *c <= threshold)
+    } else {
+        false
+    }
+}
+
+fn most_frequent_languages(
+    (script, languages_frequency): (&Script, &Vec<(Language, usize)>),
+) -> Option<(Script, Vec<Language>)> {
+    if languages_frequency.len() > 1 {
+        let threshold = compute_laguage_frequency_threshold(languages_frequency);
+
+        let languages: Vec<_> = languages_frequency
+            .iter()
+            .filter(|(_, c)| *c > threshold)
+            .map(|(l, _)| l.clone())
+            .collect();
+
+        if languages.is_empty() {
+            None
+        } else {
+            Some((script.clone(), languages))
+        }
+    } else {
+        None
+    }
+}
+
+fn compute_laguage_frequency_threshold(languages_frequency: &Vec<(Language, usize)>) -> usize {
+    let total: usize = languages_frequency.iter().map(|(_, c)| c).sum();
+    total / 20 // 5% is a completely arbitrar value.
+}