From db2fb86b8bbb69cb79781d74dda885460ea45560 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 9 Nov 2023 14:19:16 +0100 Subject: [PATCH] Extract PrimaryKey logic to a type --- milli/src/documents/mod.rs | 10 ++ milli/src/documents/primary_key.rs | 168 +++++++++++++++++++++++++++++ milli/src/fields_ids_map.rs | 6 ++ 3 files changed, 184 insertions(+) create mode 100644 milli/src/documents/primary_key.rs diff --git a/milli/src/documents/mod.rs b/milli/src/documents/mod.rs index 7c037b3bf..4429f083d 100644 --- a/milli/src/documents/mod.rs +++ b/milli/src/documents/mod.rs @@ -1,5 +1,6 @@ mod builder; mod enriched; +mod primary_key; mod reader; mod serde_impl; @@ -11,6 +12,9 @@ use bimap::BiHashMap; pub use builder::DocumentsBatchBuilder; pub use enriched::{EnrichedDocument, EnrichedDocumentsBatchCursor, EnrichedDocumentsBatchReader}; use obkv::KvReader; +pub use primary_key::{ + DocumentIdExtractionError, FieldDistribution, PrimaryKey, DEFAULT_PRIMARY_KEY, +}; pub use reader::{DocumentsBatchCursor, DocumentsBatchCursorError, DocumentsBatchReader}; use serde::{Deserialize, Serialize}; @@ -87,6 +91,12 @@ impl DocumentsBatchIndex { } } +impl FieldDistribution for DocumentsBatchIndex { + fn id(&self, name: &str) -> Option { + self.id(name) + } +} + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Error parsing number {value:?} at line {line}: {error}")] diff --git a/milli/src/documents/primary_key.rs b/milli/src/documents/primary_key.rs new file mode 100644 index 000000000..dd97f2608 --- /dev/null +++ b/milli/src/documents/primary_key.rs @@ -0,0 +1,168 @@ +use std::iter; +use std::result::Result as StdResult; + +use serde_json::Value; + +use crate::{FieldId, InternalError, Object, Result, UserError}; + +/// The symbol used to define levels in a nested primary key. +const PRIMARY_KEY_SPLIT_SYMBOL: char = '.'; + +/// The default primary that is used when not specified. +pub const DEFAULT_PRIMARY_KEY: &str = "id"; + +pub trait FieldDistribution { + fn id(&self, name: &str) -> Option; +} + +/// A type that represent the type of primary key that has been set +/// for this index, a classic flat one or a nested one. +#[derive(Debug, Clone, Copy)] +pub enum PrimaryKey<'a> { + Flat { name: &'a str, field_id: FieldId }, + Nested { name: &'a str }, +} + +pub enum DocumentIdExtractionError { + InvalidDocumentId(UserError), + MissingDocumentId, + TooManyDocumentIds(usize), +} + +impl<'a> PrimaryKey<'a> { + pub fn new(path: &'a str, fields: &impl FieldDistribution) -> Option { + Some(if path.contains(PRIMARY_KEY_SPLIT_SYMBOL) { + Self::Nested { name: path } + } else { + let field_id = fields.id(path)?; + Self::Flat { name: path, field_id } + }) + } + + pub fn name(&self) -> &str { + match self { + PrimaryKey::Flat { name, .. } => name, + PrimaryKey::Nested { name } => name, + } + } + + pub fn document_id( + &self, + document: &obkv::KvReader, + fields: &impl FieldDistribution, + ) -> Result> { + match self { + PrimaryKey::Flat { name: _, field_id } => match document.get(*field_id) { + Some(document_id_bytes) => { + let document_id = serde_json::from_slice(document_id_bytes) + .map_err(InternalError::SerdeJson)?; + match validate_document_id_value(document_id)? { + Ok(document_id) => Ok(Ok(document_id)), + Err(user_error) => { + Ok(Err(DocumentIdExtractionError::InvalidDocumentId(user_error))) + } + } + } + None => Ok(Err(DocumentIdExtractionError::MissingDocumentId)), + }, + nested @ PrimaryKey::Nested { .. } => { + let mut matching_documents_ids = Vec::new(); + for (first_level_name, right) in nested.possible_level_names() { + if let Some(field_id) = fields.id(first_level_name) { + if let Some(value_bytes) = document.get(field_id) { + let object = serde_json::from_slice(value_bytes) + .map_err(InternalError::SerdeJson)?; + fetch_matching_values(object, right, &mut matching_documents_ids); + + if matching_documents_ids.len() >= 2 { + return Ok(Err(DocumentIdExtractionError::TooManyDocumentIds( + matching_documents_ids.len(), + ))); + } + } + } + } + + match matching_documents_ids.pop() { + Some(document_id) => match validate_document_id_value(document_id)? { + Ok(document_id) => Ok(Ok(document_id)), + Err(user_error) => { + Ok(Err(DocumentIdExtractionError::InvalidDocumentId(user_error))) + } + }, + None => Ok(Err(DocumentIdExtractionError::MissingDocumentId)), + } + } + } + } + + /// Returns an `Iterator` that gives all the possible fields names the primary key + /// can have depending of the first level name and depth of the objects. + pub fn possible_level_names(&self) -> impl Iterator + '_ { + let name = self.name(); + name.match_indices(PRIMARY_KEY_SPLIT_SYMBOL) + .map(move |(i, _)| (&name[..i], &name[i + PRIMARY_KEY_SPLIT_SYMBOL.len_utf8()..])) + .chain(iter::once((name, ""))) + } +} + +fn fetch_matching_values(value: Value, selector: &str, output: &mut Vec) { + match value { + Value::Object(object) => fetch_matching_values_in_object(object, selector, "", output), + otherwise => output.push(otherwise), + } +} + +fn fetch_matching_values_in_object( + object: Object, + selector: &str, + base_key: &str, + output: &mut Vec, +) { + for (key, value) in object { + let base_key = if base_key.is_empty() { + key.to_string() + } else { + format!("{}{}{}", base_key, PRIMARY_KEY_SPLIT_SYMBOL, key) + }; + + if starts_with(selector, &base_key) { + match value { + Value::Object(object) => { + fetch_matching_values_in_object(object, selector, &base_key, output) + } + value => output.push(value), + } + } + } +} + +fn starts_with(selector: &str, key: &str) -> bool { + selector.strip_prefix(key).map_or(false, |tail| { + tail.chars().next().map(|c| c == PRIMARY_KEY_SPLIT_SYMBOL).unwrap_or(true) + }) +} + +// FIXME: move to a DocumentId struct + +fn validate_document_id(document_id: &str) -> Option<&str> { + if !document_id.is_empty() + && document_id.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_')) + { + Some(document_id) + } else { + None + } +} + +pub fn validate_document_id_value(document_id: Value) -> Result> { + match document_id { + Value::String(string) => match validate_document_id(&string) { + Some(s) if s.len() == string.len() => Ok(Ok(string)), + Some(s) => Ok(Ok(s.to_string())), + None => Ok(Err(UserError::InvalidDocumentId { document_id: Value::String(string) })), + }, + Value::Number(number) if number.is_i64() => Ok(Ok(number.to_string())), + content => Ok(Err(UserError::InvalidDocumentId { document_id: content })), + } +} diff --git a/milli/src/fields_ids_map.rs b/milli/src/fields_ids_map.rs index 810ff755b..85320c168 100644 --- a/milli/src/fields_ids_map.rs +++ b/milli/src/fields_ids_map.rs @@ -81,6 +81,12 @@ impl Default for FieldsIdsMap { } } +impl crate::documents::FieldDistribution for FieldsIdsMap { + fn id(&self, name: &str) -> Option { + self.id(name) + } +} + #[cfg(test)] mod tests { use super::*;