diff --git a/milli/src/documents/primary_key.rs b/milli/src/documents/primary_key.rs index accb270c9..904109033 100644 --- a/milli/src/documents/primary_key.rs +++ b/milli/src/documents/primary_key.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::iter; +use std::ops::ControlFlow; use std::result::Result as StdResult; use bumpalo::Bump; @@ -7,7 +8,7 @@ use serde_json::value::RawValue; use serde_json::{from_str, Value}; use crate::fields_ids_map::MutFieldIdMapper; -use crate::update::new::indexer::de::DeOrBumpStr; +use crate::update::new::indexer::de::{match_component, DeOrBumpStr}; use crate::update::new::{CowStr, KvReaderFieldId, TopLevelMap}; use crate::{FieldId, InternalError, Object, Result, UserError}; @@ -64,7 +65,7 @@ impl<'a> PrimaryKey<'a> { }) } - pub fn name(&self) -> &str { + pub fn name(&self) -> &'a str { match self { PrimaryKey::Flat { name, .. } => name, PrimaryKey::Nested { name } => name, @@ -154,7 +155,31 @@ impl<'a> PrimaryKey<'a> { Ok(external_document_id) } - PrimaryKey::Nested { name } => todo!(), + nested @ PrimaryKey::Nested { name: _ } => { + let mut docid = None; + for (first_level, right) in nested.possible_level_names() { + let Some(fid) = db_fields_ids_map.id(first_level) else { continue }; + + let Some(value) = document.get(fid) else { continue }; + let value: &RawValue = + serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + match match_component(first_level, right, value, indexer, &mut docid) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(Ok(_)) => { + return Err(InternalError::DocumentsError( + crate::documents::Error::InvalidDocumentFormat, + ) + .into()) + } + ControlFlow::Break(Err(err)) => { + return Err(InternalError::SerdeJson(err).into()) + } + } + } + Ok(docid.ok_or(InternalError::DocumentsError( + crate::documents::Error::InvalidDocumentFormat, + ))?) + } } } @@ -171,7 +196,7 @@ impl<'a> PrimaryKey<'a> { self, indexer, )) - .map_err(UserError::SerdeJson)?; + .map_err(UserError::SerdeJson)??; let external_document_id = match res { Ok(document_id) => Ok(document_id), @@ -234,7 +259,7 @@ impl<'a> PrimaryKey<'a> { /// Returns an `Iterator` that gives all the possible fields names the primary key /// can have depending of the first level name and depth of the objects. - pub fn possible_level_names(&self) -> impl Iterator + '_ { + pub fn possible_level_names(&self) -> impl Iterator + '_ { let name = self.name(); name.match_indices(PRIMARY_KEY_SPLIT_SYMBOL) .map(move |(i, _)| (&name[..i], &name[i + PRIMARY_KEY_SPLIT_SYMBOL.len_utf8()..])) diff --git a/milli/src/update/new/indexer/de.rs b/milli/src/update/new/indexer/de.rs index 9a664b5f8..fa6b5fa76 100644 --- a/milli/src/update/new/indexer/de.rs +++ b/milli/src/update/new/indexer/de.rs @@ -1,4 +1,7 @@ +use std::ops::ControlFlow; + use bumpalo::Bump; +use serde::de::{DeserializeSeed, Deserializer as _, Visitor}; use serde_json::value::RawValue; use crate::documents::{ @@ -14,22 +17,6 @@ pub struct FieldAndDocidExtractor<'p, 'indexer, Mapper: MutFieldIdMapper> { indexer: &'indexer Bump, } -pub struct DocidExtractor<'p, 'indexer, Mapper: FieldIdMapper> { - fields_ids_map: &'p Mapper, - primary_key: &'p PrimaryKey<'p>, - indexer: &'indexer Bump, -} - -impl<'p, 'indexer, Mapper: FieldIdMapper> DocidExtractor<'p, 'indexer, Mapper> { - pub fn new( - fields_ids_map: &'p Mapper, - primary_key: &'p PrimaryKey<'p>, - indexer: &'indexer Bump, - ) -> Self { - Self { fields_ids_map, primary_key, indexer } - } -} - impl<'p, 'indexer, Mapper: MutFieldIdMapper> FieldAndDocidExtractor<'p, 'indexer, Mapper> { pub fn new( fields_ids_map: &'p mut Mapper, @@ -40,63 +27,56 @@ impl<'p, 'indexer, Mapper: MutFieldIdMapper> FieldAndDocidExtractor<'p, 'indexer } } -impl<'de, 'p, 'indexer: 'de, Mapper: MutFieldIdMapper> serde::de::Visitor<'de> +impl<'de, 'p, 'indexer: 'de, Mapper: MutFieldIdMapper> Visitor<'de> for FieldAndDocidExtractor<'p, 'indexer, Mapper> { - type Value = std::result::Result, DocumentIdExtractionError>; + type Value = + Result, DocumentIdExtractionError>, crate::UserError>; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!(formatter, "a map") } - fn visit_map(mut self, mut map: A) -> std::result::Result + fn visit_map(mut self, mut map: A) -> Result where A: serde::de::MapAccess<'de>, { let mut docid = None; - while let Some((fid, fields_ids_map)) = - map.next_key_seed(MutFieldIdMapSeed(self.fields_ids_map))? - { - use serde::de::Deserializer as _; - self.fields_ids_map = fields_ids_map; - /// FIXME unwrap => too many fields - let fid = fid.unwrap(); - match self.primary_key { - PrimaryKey::Flat { name: _, field_id } => { - let value: &'de RawValue = map.next_value()?; - if fid == *field_id { - let value = match value - .deserialize_any(DocumentIdVisitor(self.indexer)) - .map_err(|_err| { - DocumentIdExtractionError::InvalidDocumentId( - UserError::InvalidDocumentId { - document_id: serde_json::to_value(value).unwrap(), - }, - ) - }) { - Ok(Ok(value)) => value, - Ok(Err(err)) | Err(err) => return Ok(Err(err)), - }; - if let Some(_previous_value) = docid.replace(value) { - return Ok(Err(DocumentIdExtractionError::TooManyDocumentIds(2))); - } - } - } - PrimaryKey::Nested { name } => todo!(), + while let Some(((level_name, right), (fid, fields_ids_map))) = + map.next_key_seed(ComponentsSeed { + name: self.primary_key.name(), + visitor: MutFieldIdMapVisitor(self.fields_ids_map), + })? + { + let Some(fid) = fid else { + return Ok(Err(crate::UserError::AttributeLimitReached)); + }; + self.fields_ids_map = fields_ids_map; + + let value: &'de RawValue = map.next_value()?; + + match match_component(level_name, right, value, self.indexer, &mut docid) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(Err(err)) => return Err(serde::de::Error::custom(err)), + ControlFlow::Break(Ok(err)) => return Ok(Ok(Err(err))), } } - Ok(match docid { + + Ok(Ok(match docid { Some(docid) => Ok(docid), None => Err(DocumentIdExtractionError::MissingDocumentId), - }) + })) } } -impl<'de, 'p, 'indexer: 'de, Mapper: FieldIdMapper> serde::de::Visitor<'de> - for DocidExtractor<'p, 'indexer, Mapper> -{ - type Value = std::result::Result, DocumentIdExtractionError>; +struct NestedPrimaryKeyVisitor<'a, 'bump> { + components: &'a str, + bump: &'bump Bump, +} + +impl<'de, 'a, 'bump: 'de> Visitor<'de> for NestedPrimaryKeyVisitor<'a, 'bump> { + type Value = std::result::Result>, DocumentIdExtractionError>; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!(formatter, "a map") @@ -107,142 +87,156 @@ impl<'de, 'p, 'indexer: 'de, Mapper: FieldIdMapper> serde::de::Visitor<'de> A: serde::de::MapAccess<'de>, { let mut docid = None; - while let Some(fid) = map.next_key_seed(FieldIdMapSeed(self.fields_ids_map))? { - use serde::de::Deserializer as _; + while let Some(((matched_component, right), _)) = map.next_key_seed(ComponentsSeed { + name: self.components, + visitor: serde::de::IgnoredAny, + })? { + let value: &'de RawValue = map.next_value()?; - let Some(fid) = fid else { - continue; - }; - - match self.primary_key { - PrimaryKey::Flat { name: _, field_id } => { - let value: &'de RawValue = map.next_value()?; - if fid == *field_id { - let value = match value - .deserialize_any(DocumentIdVisitor(self.indexer)) - .map_err(|_err| { - DocumentIdExtractionError::InvalidDocumentId( - UserError::InvalidDocumentId { - document_id: serde_json::to_value(value).unwrap(), - }, - ) - }) { - Ok(Ok(value)) => value, - Ok(Err(err)) | Err(err) => return Ok(Err(err)), - }; - if let Some(_previous_value) = docid.replace(value) { - return Ok(Err(DocumentIdExtractionError::TooManyDocumentIds(2))); - } - } - } - PrimaryKey::Nested { name } => todo!(), + match match_component(matched_component, right, value, self.bump, &mut docid) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(Err(err)) => return Err(serde::de::Error::custom(err)), + ControlFlow::Break(Ok(err)) => return Ok(Err(err)), } } - Ok(match docid { - Some(docid) => Ok(docid), - None => Err(DocumentIdExtractionError::MissingDocumentId), - }) + Ok(Ok(docid)) } } +/// Either a `&'de str` or a `&'bump str`. pub enum DeOrBumpStr<'de, 'bump: 'de> { + /// Lifetime of the deserializer De(&'de str), + /// Lifetime of the allocator Bump(&'bump str), } impl<'de, 'bump: 'de> DeOrBumpStr<'de, 'bump> { + /// Returns a `&'bump str`, possibly allocating to extend its lifetime. pub fn to_bump(&self, bump: &'bump Bump) -> &'bump str { match self { DeOrBumpStr::De(de) => bump.alloc_str(de), - DeOrBumpStr::Bump(bump) => *bump, + DeOrBumpStr::Bump(bump) => bump, } } + /// Returns a `&'de str`. + /// + /// This function never allocates because `'bump: 'de`. pub fn to_de(&self) -> &'de str { match self { - DeOrBumpStr::De(de) => *de, - DeOrBumpStr::Bump(bump) => *bump, + DeOrBumpStr::De(de) => de, + DeOrBumpStr::Bump(bump) => bump, } } } -struct MutFieldIdMapSeed<'a, Mapper: MutFieldIdMapper>(&'a mut Mapper); +struct ComponentsSeed<'a, V> { + name: &'a str, + visitor: V, +} -impl<'de, 'a, Mapper: MutFieldIdMapper> serde::de::DeserializeSeed<'de> - for MutFieldIdMapSeed<'a, Mapper> -{ +impl<'de, 'a, V: Visitor<'de>> DeserializeSeed<'de> for ComponentsSeed<'a, V> { + type Value = ((&'a str, &'a str), V::Value); + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct ComponentsSeedVisitor<'a, V> { + name: &'a str, + visitor: V, + } + + impl<'a, V> ComponentsSeedVisitor<'a, V> { + fn match_str(&self, v: &str) -> (&'a str, &'a str) { + let p = PrimaryKey::Nested { name: self.name }; + for (name, right) in p.possible_level_names() { + if name == v { + return (name, right); + } + } + ("", self.name) + } + } + + impl<'de, 'a, V: Visitor<'de>> Visitor<'de> for ComponentsSeedVisitor<'a, V> { + type Value = ((&'a str, &'a str), V::Value); + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting a string") + } + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + let matched = self.match_str(v); + let inner = self.visitor.visit_borrowed_str(v)?; + Ok((matched, inner)) + } + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + let matched = self.match_str(v); + let inner = self.visitor.visit_str(v)?; + + Ok((matched, inner)) + } + } + deserializer + .deserialize_str(ComponentsSeedVisitor { name: self.name, visitor: self.visitor }) + } +} + +struct MutFieldIdMapVisitor<'a, Mapper: MutFieldIdMapper>(&'a mut Mapper); + +impl<'de, 'a, Mapper: MutFieldIdMapper> Visitor<'de> for MutFieldIdMapVisitor<'a, Mapper> { type Value = (Option, &'a mut Mapper); - fn deserialize(self, deserializer: D) -> std::result::Result + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting a string") + } + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - struct MutFieldIdMapVisitor<'a, Mapper: MutFieldIdMapper>(&'a mut Mapper); - impl<'de, 'a, Mapper: MutFieldIdMapper> serde::de::Visitor<'de> - for MutFieldIdMapVisitor<'a, Mapper> - { - type Value = (Option, &'a mut Mapper); + Ok((self.0.insert(v), self.0)) + } - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "expecting a string") - } - fn visit_borrowed_str(self, v: &'de str) -> std::result::Result - where - E: serde::de::Error, - { - Ok((self.0.insert(v), self.0)) - } - - fn visit_str(self, v: &str) -> std::result::Result - where - E: serde::de::Error, - { - Ok((self.0.insert(v), self.0)) - } - } - deserializer.deserialize_str(MutFieldIdMapVisitor(self.0)) + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + Ok((self.0.insert(v), self.0)) } } -struct FieldIdMapSeed<'a, Mapper: FieldIdMapper>(&'a Mapper); +pub struct FieldIdMapVisitor<'a, Mapper: FieldIdMapper>(pub &'a Mapper); -impl<'de, 'a, Mapper: FieldIdMapper> serde::de::DeserializeSeed<'de> - for FieldIdMapSeed<'a, Mapper> -{ +impl<'de, 'a, Mapper: FieldIdMapper> Visitor<'de> for FieldIdMapVisitor<'a, Mapper> { type Value = Option; - fn deserialize(self, deserializer: D) -> std::result::Result + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting a string") + } + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result where - D: serde::Deserializer<'de>, + E: serde::de::Error, { - struct FieldIdMapVisitor<'a, Mapper: FieldIdMapper>(&'a Mapper); - impl<'de, 'a, Mapper: FieldIdMapper> serde::de::Visitor<'de> for FieldIdMapVisitor<'a, Mapper> { - type Value = Option; + Ok(self.0.id(v)) + } - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "expecting a string") - } - fn visit_borrowed_str(self, v: &'de str) -> std::result::Result - where - E: serde::de::Error, - { - Ok(self.0.id(v)) - } - - fn visit_str(self, v: &str) -> std::result::Result - where - E: serde::de::Error, - { - Ok(self.0.id(v)) - } - } - deserializer.deserialize_str(FieldIdMapVisitor(self.0)) + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + Ok(self.0.id(v)) } } - pub struct DocumentIdVisitor<'indexer>(pub &'indexer Bump); -impl<'de, 'indexer: 'de> serde::de::Visitor<'de> for DocumentIdVisitor<'indexer> { +impl<'de, 'indexer: 'de> Visitor<'de> for DocumentIdVisitor<'indexer> { type Value = std::result::Result, DocumentIdExtractionError>; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -262,13 +256,15 @@ impl<'de, 'indexer: 'de> serde::de::Visitor<'de> for DocumentIdVisitor<'indexer> .map(DeOrBumpStr::De)) } - fn visit_str(self, v: &str) -> std::result::Result + fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { let v = self.0.alloc_str(v); - self.visit_borrowed_str(v)?; - Ok(Ok(DeOrBumpStr::Bump(v))) + Ok(match self.visit_borrowed_str(v)? { + Ok(_) => Ok(DeOrBumpStr::Bump(v)), + Err(err) => Err(err), + }) } fn visit_u64(self, v: u64) -> std::result::Result @@ -288,8 +284,45 @@ impl<'de, 'indexer: 'de> serde::de::Visitor<'de> for DocumentIdVisitor<'indexer> { use std::fmt::Write as _; - let mut out = bumpalo::collections::String::new_in(&self.0); - write!(&mut out, "{v}"); + let mut out = bumpalo::collections::String::new_in(self.0); + write!(&mut out, "{v}").unwrap(); Ok(Ok(DeOrBumpStr::Bump(out.into_bump_str()))) } } + +pub fn match_component<'de, 'indexer: 'de>( + first_level_name: &str, + right: &str, + value: &'de RawValue, + bump: &'indexer Bump, + docid: &mut Option>, +) -> ControlFlow, ()> { + if first_level_name.is_empty() { + return ControlFlow::Continue(()); + } + + let value = if right.is_empty() { + match value.deserialize_any(DocumentIdVisitor(bump)).map_err(|_err| { + DocumentIdExtractionError::InvalidDocumentId(UserError::InvalidDocumentId { + document_id: serde_json::to_value(value).unwrap(), + }) + }) { + Ok(Ok(value)) => value, + Ok(Err(err)) | Err(err) => return ControlFlow::Break(Ok(err)), + } + } else { + // if right is not empty, recursively extract right components from value + let res = value.deserialize_map(NestedPrimaryKeyVisitor { components: right, bump }); + match res { + Ok(Ok(Some(value))) => value, + Ok(Ok(None)) => return ControlFlow::Continue(()), + Ok(Err(err)) => return ControlFlow::Break(Ok(err)), + Err(err) if err.is_data() => return ControlFlow::Continue(()), // we expected the field to be a map, but it was not and that's OK. + Err(err) => return ControlFlow::Break(Err(err)), + } + }; + if let Some(_previous_value) = docid.replace(value) { + return ControlFlow::Break(Ok(DocumentIdExtractionError::TooManyDocumentIds(2))); + } + ControlFlow::Continue(()) +}