From 8ebf5eed0d80beb623a056d4e70a5d0535cd7181 Mon Sep 17 00:00:00 2001
From: Kerollmops <clement@meilisearch.com>
Date: Wed, 15 Jun 2022 17:58:52 +0200
Subject: [PATCH] Make the nested primary key work

---
 milli/src/error.rs                            |   2 +
 milli/src/update/index_documents/mod.rs       |   2 +-
 milli/src/update/index_documents/transform.rs |   5 +-
 milli/src/update/index_documents/validate.rs  | 228 ++++++++++++++----
 4 files changed, 191 insertions(+), 46 deletions(-)

diff --git a/milli/src/error.rs b/milli/src/error.rs
index a23472951..d05acbe1c 100644
--- a/milli/src/error.rs
+++ b/milli/src/error.rs
@@ -121,6 +121,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
     MaxDatabaseSizeReached,
     #[error("Document doesn't have a `{}` attribute: `{}`.", .primary_key, serde_json::to_string(.document).unwrap())]
     MissingDocumentId { primary_key: String, document: Object },
+    #[error("Document have too many matching `{}` attribute: `{}`.", .primary_key, serde_json::to_string(.document).unwrap())]
+    TooManyDocumentIds { primary_key: String, document: Object },
     #[error("The primary key inference process failed because the engine did not find any fields containing `id` substring in their name. If your document identifier does not contain any `id` substring, you can set the primary key of the index.")]
     MissingPrimaryKey,
     #[error("There is no more space left on the device. Consider increasing the size of the disk/partition.")]
diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs
index 5bce3b851..ba1064684 100644
--- a/milli/src/update/index_documents/mod.rs
+++ b/milli/src/update/index_documents/mod.rs
@@ -29,7 +29,7 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
 pub use self::transform::{Transform, TransformOutput};
 use self::validate::validate_documents_batch;
 pub use self::validate::{
-    extract_float_from_value, validate_document_id, validate_document_id_from_json,
+    extract_float_from_value, validate_document_id, validate_document_id_value,
     validate_geo_from_json,
 };
 use crate::documents::{obkv_to_object, DocumentsBatchReader};
diff --git a/milli/src/update/index_documents/transform.rs b/milli/src/update/index_documents/transform.rs
index 4ece58509..38f6dc8ff 100644
--- a/milli/src/update/index_documents/transform.rs
+++ b/milli/src/update/index_documents/transform.rs
@@ -17,7 +17,7 @@ use super::{IndexDocumentsMethod, IndexerConfig};
 use crate::documents::{DocumentsBatchIndex, DocumentsBatchReader};
 use crate::error::{Error, InternalError, UserError};
 use crate::index::db_name;
-use crate::update::index_documents::validate_document_id_from_json;
+use crate::update::index_documents::validate_document_id_value;
 use crate::update::{AvailableDocumentsIds, UpdateIndexingStep};
 use crate::{
     ExternalDocumentsIds, FieldDistribution, FieldId, FieldIdMapMissingEntry, FieldsIdsMap, Index,
@@ -806,7 +806,8 @@ fn update_primary_key<'a>(
 ) -> Result<Cow<'a, str>> {
     match field_buffer_cache.iter_mut().find(|(id, _)| *id == primary_key_id) {
         Some((_, bytes)) => {
-            let value = validate_document_id_from_json(bytes)??;
+            let document_id = serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)?;
+            let value = validate_document_id_value(document_id)??;
             serde_json::to_writer(external_id_buffer, &value).map_err(InternalError::SerdeJson)?;
             Ok(Cow::Owned(value))
         }
diff --git a/milli/src/update/index_documents/validate.rs b/milli/src/update/index_documents/validate.rs
index c69c754ac..32e8de03f 100644
--- a/milli/src/update/index_documents/validate.rs
+++ b/milli/src/update/index_documents/validate.rs
@@ -1,11 +1,16 @@
 use std::io::{Read, Seek};
+use std::iter;
 use std::result::Result as StdResult;
 
 use serde_json::Value;
 
+use crate::documents::{DocumentsBatchIndex, DocumentsBatchReader};
 use crate::error::{GeoError, InternalError, UserError};
-use crate::update::index_documents::{obkv_to_object, DocumentsBatchReader};
-use crate::{Index, Result};
+use crate::update::index_documents::obkv_to_object;
+use crate::{FieldId, Index, Object, Result};
+
+/// The symbol used to define levels in a nested primary key.
+const PRIMARY_KEY_SPLIT_SYMBOL: char = '.';
 
 /// This function validates a documents by checking that:
 ///  - we can infer a primary key,
@@ -23,10 +28,15 @@ pub fn validate_documents_batch<R: Read + Seek>(
 
     // The primary key *field id* that has already been set for this index or the one
     // we will guess by searching for the first key that contains "id" as a substring.
-    let (primary_key, primary_key_id) = match index.primary_key(rtxn)? {
+    let primary_key = match index.primary_key(rtxn)? {
+        Some(primary_key) if primary_key.contains(PRIMARY_KEY_SPLIT_SYMBOL) => {
+            PrimaryKey::nested(primary_key)
+        }
         Some(primary_key) => match documents_batch_index.id(primary_key) {
-            Some(id) => (primary_key, id),
-            None if autogenerate_docids => (primary_key, documents_batch_index.insert(primary_key)),
+            Some(id) => PrimaryKey::flat(primary_key, id),
+            None if autogenerate_docids => {
+                PrimaryKey::flat(primary_key, documents_batch_index.insert(primary_key))
+            }
             None => {
                 return match cursor.next_document()? {
                     Some(first_document) => Ok(Err(UserError::MissingDocumentId {
@@ -43,8 +53,10 @@ pub fn validate_documents_batch<R: Read + Seek>(
                 .filter(|(_, name)| name.to_lowercase().contains("id"))
                 .min_by_key(|(fid, _)| *fid);
             match guessed {
-                Some((id, name)) => (name.as_str(), *id),
-                None if autogenerate_docids => ("id", documents_batch_index.insert("id")),
+                Some((id, name)) => PrimaryKey::flat(name.as_str(), *id),
+                None if autogenerate_docids => {
+                    PrimaryKey::flat("id", documents_batch_index.insert("id"))
+                }
                 None => return Ok(Err(UserError::MissingPrimaryKey)),
             }
         }
@@ -59,20 +71,15 @@ pub fn validate_documents_batch<R: Read + Seek>(
 
     let mut count = 0;
     while let Some(document) = cursor.next_document()? {
-        let document_id = match document.get(primary_key_id) {
-            Some(document_id_bytes) => match validate_document_id_from_json(document_id_bytes)? {
-                Ok(document_id) => document_id,
-                Err(user_error) => return Ok(Err(user_error)),
-            },
-            None if autogenerate_docids => {
-                format!("{{auto-generated id of the {}nth document}}", count)
-            }
-            None => {
-                return Ok(Err(UserError::MissingDocumentId {
-                    primary_key: primary_key.to_string(),
-                    document: obkv_to_object(&document, &documents_batch_index)?,
-                }))
-            }
+        let document_id = match fetch_document_id(
+            &document,
+            &documents_batch_index,
+            primary_key,
+            autogenerate_docids,
+            count,
+        )? {
+            Ok(document_id) => document_id,
+            Err(user_error) => return Ok(Err(user_error)),
         };
 
         if let Some(geo_value) = geo_field_id.and_then(|fid| document.get(fid)) {
@@ -86,30 +93,167 @@ pub fn validate_documents_batch<R: Read + Seek>(
     Ok(Ok(cursor.into_reader()))
 }
 
+/// Retrieve the document id after validating it, returning a `UserError`
+/// if the id is invalid or can't be guessed.
+fn fetch_document_id(
+    document: &obkv::KvReader<FieldId>,
+    documents_batch_index: &DocumentsBatchIndex,
+    primary_key: PrimaryKey,
+    autogenerate_docids: bool,
+    count: usize,
+) -> Result<StdResult<String, UserError>> {
+    match primary_key {
+        PrimaryKey::Flat { name: primary_key, field_id: primary_key_id } => {
+            match document.get(primary_key_id) {
+                Some(document_id_bytes) => {
+                    let document_id = serde_json::from_slice(document_id_bytes)
+                        .map_err(InternalError::SerdeJson)?;
+                    match validate_document_id_value(document_id)? {
+                        Ok(document_id) => Ok(Ok(document_id)),
+                        Err(user_error) => Ok(Err(user_error)),
+                    }
+                }
+                None if autogenerate_docids => {
+                    Ok(Ok(format!("{{auto-generated id of the {}nth document}}", count)))
+                }
+                None => Ok(Err(UserError::MissingDocumentId {
+                    primary_key: primary_key.to_string(),
+                    document: obkv_to_object(&document, &documents_batch_index)?,
+                })),
+            }
+        }
+        nested @ PrimaryKey::Nested { .. } => {
+            let mut matching_documents_ids = Vec::new();
+            for (first_level_name, right) in nested.possible_level_names() {
+                if let Some(field_id) = documents_batch_index.id(first_level_name) {
+                    if let Some(value_bytes) = document.get(field_id) {
+                        let object = serde_json::from_slice(value_bytes)
+                            .map_err(InternalError::SerdeJson)?;
+                        fetch_matching_values(object, right, &mut matching_documents_ids);
+
+                        if matching_documents_ids.len() >= 2 {
+                            return Ok(Err(UserError::TooManyDocumentIds {
+                                primary_key: nested.primary_key().to_string(),
+                                document: obkv_to_object(&document, &documents_batch_index)?,
+                            }));
+                        }
+                    }
+                }
+            }
+
+            match matching_documents_ids.pop() {
+                Some(document_id) => match validate_document_id_value(document_id)? {
+                    Ok(document_id) => Ok(Ok(document_id)),
+                    Err(user_error) => Ok(Err(user_error)),
+                },
+                None => Ok(Err(UserError::MissingDocumentId {
+                    primary_key: nested.primary_key().to_string(),
+                    document: obkv_to_object(&document, &documents_batch_index)?,
+                })),
+            }
+        }
+    }
+}
+
+/// A type that represent the type of primary key that has been set
+/// for this index, a classic flat one or a nested one.
+#[derive(Debug, Clone, Copy)]
+enum PrimaryKey<'a> {
+    Flat { name: &'a str, field_id: FieldId },
+    Nested { name: &'a str },
+}
+
+impl PrimaryKey<'_> {
+    fn flat(name: &str, field_id: FieldId) -> PrimaryKey {
+        PrimaryKey::Flat { name, field_id }
+    }
+
+    fn nested(name: &str) -> PrimaryKey {
+        PrimaryKey::Nested { name }
+    }
+
+    fn primary_key(&self) -> &str {
+        match self {
+            PrimaryKey::Flat { name, .. } => name,
+            PrimaryKey::Nested { name } => name,
+        }
+    }
+
+    /// Returns an `Iterator` that gives all the possible fields names the primary key
+    /// can have depending of the first level name and deepnes of the objects.
+    fn possible_level_names(&self) -> impl Iterator<Item = (&str, &str)> + '_ {
+        let name = self.primary_key();
+        iter::successors(Some((name, "")), |(curr, _)| curr.rsplit_once(PRIMARY_KEY_SPLIT_SYMBOL))
+    }
+}
+
+fn contained_in(selector: &str, key: &str) -> bool {
+    selector.starts_with(key)
+        && selector[key.len()..]
+            .chars()
+            .next()
+            .map(|c| c == PRIMARY_KEY_SPLIT_SYMBOL)
+            .unwrap_or(true)
+}
+
+pub fn fetch_matching_values(value: Value, selector: &str, output: &mut Vec<Value>) {
+    match value {
+        Value::Object(object) => fetch_matching_values_in_object(object, selector, "", output),
+        otherwise => output.push(otherwise),
+    }
+}
+
+pub fn fetch_matching_values_in_object(
+    object: Object,
+    selector: &str,
+    base_key: &str,
+    output: &mut Vec<Value>,
+) {
+    for (key, value) in object {
+        let base_key = if base_key.is_empty() {
+            key.to_string()
+        } else {
+            format!("{}{}{}", base_key, PRIMARY_KEY_SPLIT_SYMBOL, key)
+        };
+
+        // here if the user only specified `doggo` we need to iterate in all the fields of `doggo`
+        // so we check the contained_in on both side.
+        let should_continue =
+            contained_in(selector, &base_key) || contained_in(&base_key, selector);
+
+        if should_continue {
+            match value {
+                Value::Object(object) => {
+                    fetch_matching_values_in_object(object, selector, &base_key, output)
+                }
+                value => output.push(value),
+            }
+        }
+    }
+}
+
 /// Returns a trimmed version of the document id or `None` if it is invalid.
 pub fn validate_document_id(document_id: &str) -> Option<&str> {
-    let id = document_id.trim();
-    if !id.is_empty()
-        && id.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_'))
+    let document_id = document_id.trim();
+    if !document_id.is_empty()
+        && document_id.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_'))
     {
-        Some(id)
+        Some(document_id)
     } else {
         None
     }
 }
 
 /// Parses a Json encoded document id and validate it, returning a user error when it is one.
-pub fn validate_document_id_from_json(bytes: &[u8]) -> Result<StdResult<String, UserError>> {
-    match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? {
+pub fn validate_document_id_value(document_id: Value) -> Result<StdResult<String, UserError>> {
+    match document_id {
         Value::String(string) => match validate_document_id(&string) {
             Some(s) if s.len() == string.len() => Ok(Ok(string)),
             Some(s) => Ok(Ok(s.to_string())),
-            None => {
-                return Ok(Err(UserError::InvalidDocumentId { document_id: Value::String(string) }))
-            }
+            None => Ok(Err(UserError::InvalidDocumentId { document_id: Value::String(string) })),
         },
         Value::Number(number) if number.is_i64() => Ok(Ok(number.to_string())),
-        content => return Ok(Err(UserError::InvalidDocumentId { document_id: content.clone() })),
+        content => Ok(Err(UserError::InvalidDocumentId { document_id: content.clone() })),
     }
 }
 
@@ -124,24 +268,22 @@ pub fn extract_float_from_value(value: Value) -> StdResult<f64, Value> {
 }
 
 pub fn validate_geo_from_json(document_id: Value, bytes: &[u8]) -> Result<StdResult<(), GeoError>> {
-    let result = match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? {
+    match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? {
         Value::Object(mut object) => match (object.remove("lat"), object.remove("lng")) {
             (Some(lat), Some(lng)) => {
                 match (extract_float_from_value(lat), extract_float_from_value(lng)) {
-                    (Ok(_), Ok(_)) => Ok(()),
-                    (Err(value), Ok(_)) => Err(GeoError::BadLatitude { document_id, value }),
-                    (Ok(_), Err(value)) => Err(GeoError::BadLongitude { document_id, value }),
+                    (Ok(_), Ok(_)) => Ok(Ok(())),
+                    (Err(value), Ok(_)) => Ok(Err(GeoError::BadLatitude { document_id, value })),
+                    (Ok(_), Err(value)) => Ok(Err(GeoError::BadLongitude { document_id, value })),
                     (Err(lat), Err(lng)) => {
-                        Err(GeoError::BadLatitudeAndLongitude { document_id, lat, lng })
+                        Ok(Err(GeoError::BadLatitudeAndLongitude { document_id, lat, lng }))
                     }
                 }
             }
-            (None, Some(_)) => Err(GeoError::MissingLatitude { document_id }),
-            (Some(_), None) => Err(GeoError::MissingLongitude { document_id }),
-            (None, None) => Err(GeoError::MissingLatitudeAndLongitude { document_id }),
+            (None, Some(_)) => Ok(Err(GeoError::MissingLatitude { document_id })),
+            (Some(_), None) => Ok(Err(GeoError::MissingLongitude { document_id })),
+            (None, None) => Ok(Err(GeoError::MissingLatitudeAndLongitude { document_id })),
         },
-        value => Err(GeoError::NotAnObject { document_id, value }),
-    };
-
-    Ok(result)
+        value => Ok(Err(GeoError::NotAnObject { document_id, value })),
+    }
 }