diff --git a/milli/src/documents/builder.rs b/milli/src/documents/builder.rs index 19cc1ce53..15a22090a 100644 --- a/milli/src/documents/builder.rs +++ b/milli/src/documents/builder.rs @@ -180,24 +180,10 @@ fn parse_csv_header(header: &str) -> (&str, AllowedType) { mod test { use std::io::Cursor; - use serde_json::{json, Map}; + use serde_json::json; use super::*; - use crate::documents::DocumentsBatchReader; - use crate::FieldId; - - fn obkv_to_value(obkv: &obkv::KvReader, index: &DocumentsBatchIndex) -> Value { - let mut map = Map::new(); - - for (fid, value) in obkv.iter() { - let field_name = index.name(fid).unwrap().clone(); - let value: Value = serde_json::from_slice(value).unwrap(); - - map.insert(field_name.to_string(), value); - } - - Value::Object(map) - } + use crate::documents::{obkv_to_object, DocumentsBatchReader}; #[test] fn add_single_documents_json() { @@ -272,7 +258,7 @@ mod test { DocumentsBatchReader::from_reader(Cursor::new(vector)).unwrap().into_cursor(); let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -301,7 +287,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -328,7 +314,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -355,7 +341,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -382,7 +368,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -409,7 +395,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -436,7 +422,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -463,7 +449,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, @@ -507,7 +493,7 @@ mod test { let index = cursor.documents_batch_index().clone(); let doc = cursor.next_document().unwrap().unwrap(); - let val = obkv_to_value(&doc, &index); + let val = obkv_to_object(&doc, &index).map(Value::from).unwrap(); assert_eq!( val, diff --git a/milli/src/documents/mod.rs b/milli/src/documents/mod.rs index 7a34ae13b..ee3593bf8 100644 --- a/milli/src/documents/mod.rs +++ b/milli/src/documents/mod.rs @@ -6,15 +6,30 @@ use std::io; use bimap::BiHashMap; pub use builder::DocumentsBatchBuilder; +use obkv::KvReader; pub use reader::{DocumentsBatchCursor, DocumentsBatchReader}; use serde::{Deserialize, Serialize}; -use crate::FieldId; +use crate::error::{FieldIdMapMissingEntry, InternalError}; +use crate::{FieldId, Object, Result}; /// The key that is used to store the `DocumentsBatchIndex` datastructure, /// it is the absolute last key of the list. const DOCUMENTS_BATCH_INDEX_KEY: [u8; 8] = u64::MAX.to_be_bytes(); +/// Helper function to convert an obkv reader into a JSON object. +pub fn obkv_to_object(obkv: &KvReader, index: &DocumentsBatchIndex) -> Result { + obkv.iter() + .map(|(field_id, value)| { + let field_name = index.name(field_id).ok_or_else(|| { + FieldIdMapMissingEntry::FieldId { field_id, process: "obkv_to_object" } + })?; + let value = serde_json::from_slice(value).map_err(InternalError::SerdeJson)?; + Ok((field_name.to_string(), value)) + }) + .collect() +} + /// A bidirectional map that links field ids to their name in a document batch. #[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct DocumentsBatchIndex(pub BiHashMap); @@ -48,11 +63,12 @@ impl DocumentsBatchIndex { self.0.get_by_left(&id).map(AsRef::as_ref) } - pub fn recreate_json( - &self, - document: &obkv::KvReaderU16, - ) -> Result, crate::Error> { - let mut map = serde_json::Map::new(); + pub fn id(&self, name: &str) -> Option { + self.0.get_by_right(name).cloned() + } + + pub fn recreate_json(&self, document: &obkv::KvReaderU16) -> Result { + let mut map = Object::new(); for (k, v) in document.iter() { // TODO: TAMO: update the error type diff --git a/milli/src/error.rs b/milli/src/error.rs index 57ae1c85a..d34130210 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -141,10 +141,16 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco #[derive(Error, Debug)] pub enum GeoError { + #[error("The `_geo` field in the document with the id: `{document_id}` is not an object. Was expecting an object with the `_geo.lat` and `_geo.lng` fields but instead got `{value}`.")] + NotAnObject { document_id: Value, value: Value }, + #[error("Could not find latitude nor longitude in the document with the id: `{document_id}`. Was expecting `_geo.lat` and `_geo.lng` fields.")] + MissingLatitudeAndLongitude { document_id: Value }, #[error("Could not find latitude in the document with the id: `{document_id}`. Was expecting a `_geo.lat` field.")] MissingLatitude { document_id: Value }, #[error("Could not find longitude in the document with the id: `{document_id}`. Was expecting a `_geo.lng` field.")] MissingLongitude { document_id: Value }, + #[error("Could not parse latitude nor longitude in the document with the id: `{document_id}`. Was expecting a number but instead got `{lat}` and `{lng}`.")] + BadLatitudeAndLongitude { document_id: Value, lat: Value, lng: Value }, #[error("Could not parse latitude in the document with the id: `{document_id}`. Was expecting a number but instead got `{value}`.")] BadLatitude { document_id: Value, value: Value }, #[error("Could not parse longitude in the document with the id: `{document_id}`. Was expecting a number but instead got `{value}`.")] diff --git a/milli/src/update/index_documents/extract/extract_geo_points.rs b/milli/src/update/index_documents/extract/extract_geo_points.rs index fffae5e77..0f804b93b 100644 --- a/milli/src/update/index_documents/extract/extract_geo_points.rs +++ b/milli/src/update/index_documents/extract/extract_geo_points.rs @@ -7,6 +7,7 @@ use serde_json::Value; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::error::GeoError; +use crate::update::index_documents::extract_float_from_value; use crate::{FieldId, InternalError, Result}; /// Extracts the geographical coordinates contained in each document under the `_geo` field. @@ -61,11 +62,3 @@ pub fn extract_geo_points( Ok(writer_into_reader(writer)?) } - -fn extract_float_from_value(value: Value) -> StdResult { - match value { - Value::Number(ref n) => n.as_f64().ok_or(value), - Value::String(ref s) => s.parse::().map_err(|_| value), - value => Err(value), - } -} diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 7f6e00b11..2fb7cbcd9 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2,11 +2,13 @@ mod extract; mod helpers; mod transform; mod typed_chunk; +mod validate; use std::collections::HashSet; use std::io::{Cursor, Read, Seek}; use std::iter::FromIterator; use std::num::{NonZeroU32, NonZeroUsize}; +use std::result::Result as StdResult; use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; @@ -25,13 +27,19 @@ pub use self::helpers::{ }; use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; -use crate::documents::DocumentsBatchReader; +use self::validate::validate_documents_batch; +pub use self::validate::{ + extract_float_from_value, validate_document_id, validate_document_id_from_json, + validate_geo_from_json, +}; +use crate::documents::{obkv_to_object, DocumentsBatchReader}; +use crate::error::UserError; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ self, Facets, IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixPairProximityDocids, WordPrefixPositionDocids, WordsPrefixesFst, }; -use crate::{Index, Result, RoaringBitmapCodec, UserError}; +use crate::{Index, Result, RoaringBitmapCodec}; static MERGED_DATABASE_COUNT: usize = 7; static PREFIX_DATABASE_COUNT: usize = 5; @@ -117,19 +125,27 @@ where /// Adds a batch of documents to the current builder. /// - /// Since the documents are progressively added to the writer, a failure will cause a stale - /// builder, and the builder must be discarded. + /// Since the documents are progressively added to the writer, a failure will cause only + /// return an error and not the `IndexDocuments` struct as it is invalid to use it afterward. /// /// Returns the number of documents added to the builder. - pub fn add_documents(&mut self, reader: DocumentsBatchReader) -> Result - where - R: Read + Seek, - { + pub fn add_documents( + mut self, + reader: DocumentsBatchReader, + ) -> Result<(Self, StdResult)> { // Early return when there is no document to add if reader.is_empty() { - return Ok(0); + return Ok((self, Ok(0))); } + // We check for user errors in this validator and if there is one, we can return + // the `IndexDocument` struct as it is valid to send more documents into it. + // However, if there is an internal error we throw it away! + let reader = match validate_documents_batch(self.wtxn, self.index, reader)? { + Ok(reader) => reader, + Err(user_error) => return Ok((self, Err(user_error))), + }; + let indexed_documents = self .transform .as_mut() @@ -139,7 +155,7 @@ where self.added_documents += indexed_documents; - Ok(indexed_documents) + Ok((self, Ok(indexed_documents))) } #[logging_timer::time("IndexDocuments::{}")] diff --git a/milli/src/update/index_documents/transform.rs b/milli/src/update/index_documents/transform.rs index 42187fc1e..bc7eefd33 100644 --- a/milli/src/update/index_documents/transform.rs +++ b/milli/src/update/index_documents/transform.rs @@ -17,6 +17,7 @@ use super::{validate_document_id, IndexDocumentsMethod, IndexerConfig}; use crate::documents::{DocumentsBatchIndex, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; use crate::index::db_name; +use crate::update::index_documents::validate_document_id_from_json; use crate::update::{AvailableDocumentsIds, UpdateIndexingStep}; use crate::{ ExternalDocumentsIds, FieldDistribution, FieldId, FieldIdMapMissingEntry, FieldsIdsMap, Index, @@ -782,14 +783,6 @@ fn compute_primary_key_pair( } } -fn validate_document_id(document_id: &str) -> Option<&str> { - let document_id = document_id.trim(); - Some(document_id).filter(|id| { - !id.is_empty() - && id.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_')) - }) -} - /// Drops all the value of type `U` in vec, and reuses the allocation to create a `Vec`. /// /// The size and alignment of T and U must match. @@ -813,22 +806,7 @@ fn update_primary_key<'a>( ) -> Result> { match field_buffer_cache.iter_mut().find(|(id, _)| *id == primary_key_id) { Some((_, bytes)) => { - let value = match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? { - Value::String(string) => match validate_document_id(&string) { - Some(s) if s.len() == string.len() => string, - Some(s) => s.to_string(), - None => { - return Err(UserError::InvalidDocumentId { - document_id: Value::String(string), - } - .into()) - } - }, - Value::Number(number) => number.to_string(), - content => { - return Err(UserError::InvalidDocumentId { document_id: content.clone() }.into()) - } - }; + let value = validate_document_id_from_json(bytes)??; serde_json::to_writer(external_id_buffer, &value).map_err(InternalError::SerdeJson)?; Ok(Cow::Owned(value)) } diff --git a/milli/src/update/index_documents/validate.rs b/milli/src/update/index_documents/validate.rs new file mode 100644 index 000000000..b4c0cb68f --- /dev/null +++ b/milli/src/update/index_documents/validate.rs @@ -0,0 +1,140 @@ +use std::io::{Read, Seek}; +use std::result::Result as StdResult; + +use serde_json::Value; + +use crate::error::{GeoError, InternalError, UserError}; +use crate::update::index_documents::{obkv_to_object, DocumentsBatchReader}; +use crate::{Index, Result}; + +/// This function validates a documents by checking that: +/// - we can infer a primary key, +/// - all the documents id exist and, +/// - the validity of them but also, +/// - the validity of the `_geo` field depending on the settings. +pub fn validate_documents_batch( + rtxn: &heed::RoTxn, + index: &Index, + reader: DocumentsBatchReader, +) -> Result, UserError>> { + let mut cursor = reader.into_cursor(); + let documents_batch_index = cursor.documents_batch_index().clone(); + + // The primary key *field id* that has already been set for this index or the one + // we will guess by searching for the first key that contains "id" as a substring. + let (primary_key, primary_key_id) = match index.primary_key(rtxn)? { + Some(primary_key) => match documents_batch_index.id(primary_key) { + Some(id) => (primary_key, id), + None => { + return match cursor.next_document()? { + Some(first_document) => Ok(Err(UserError::MissingDocumentId { + primary_key: primary_key.to_string(), + document: obkv_to_object(&first_document, &documents_batch_index)?, + })), + // If there is no document in this batch the best we can do is to return this error. + None => Ok(Err(UserError::MissingPrimaryKey)), + }; + } + }, + None => { + let guessed = documents_batch_index + .iter() + .filter(|(_, name)| name.contains("id")) + .min_by_key(|(fid, _)| *fid); + match guessed { + Some((id, name)) => (name.as_str(), *id), + None => return Ok(Err(UserError::MissingPrimaryKey)), + } + } + }; + + // If the settings specifies that a _geo field must be used therefore we must check the + // validity of it in all the documents of this batch and this is when we return `Some`. + let geo_field_id = match documents_batch_index.id("_geo") { + Some(geo_field_id) if index.sortable_fields(rtxn)?.contains("_geo") => Some(geo_field_id), + _otherwise => None, + }; + + while let Some(document) = cursor.next_document()? { + let document_id = match document.get(primary_key_id) { + Some(document_id_bytes) => match validate_document_id_from_json(document_id_bytes)? { + Ok(document_id) => document_id, + Err(user_error) => return Ok(Err(user_error)), + }, + None => { + return Ok(Err(UserError::MissingDocumentId { + primary_key: primary_key.to_string(), + document: obkv_to_object(&document, &documents_batch_index)?, + })) + } + }; + + if let Some(geo_value) = geo_field_id.and_then(|fid| document.get(fid)) { + if let Err(user_error) = validate_geo_from_json(Value::from(document_id), geo_value)? { + return Ok(Err(UserError::from(user_error))); + } + } + } + + Ok(Ok(cursor.into_reader())) +} + +/// Returns a trimmed version of the document id or `None` if it is invalid. +pub fn validate_document_id(document_id: &str) -> Option<&str> { + let id = document_id.trim(); + if !id.is_empty() + && id.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_')) + { + Some(id) + } else { + None + } +} + +/// Parses a Json encoded document id and validate it, returning a user error when it is one. +pub fn validate_document_id_from_json(bytes: &[u8]) -> Result> { + match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? { + Value::String(string) => match validate_document_id(&string) { + Some(s) if s.len() == string.len() => Ok(Ok(string)), + Some(s) => Ok(Ok(s.to_string())), + None => { + return Ok(Err(UserError::InvalidDocumentId { document_id: Value::String(string) })) + } + }, + Value::Number(number) => Ok(Ok(number.to_string())), + content => return Ok(Err(UserError::InvalidDocumentId { document_id: content.clone() })), + } +} + +/// Try to extract an `f64` from a JSON `Value` and return the `Value` +/// in the `Err` variant if it failed. +pub fn extract_float_from_value(value: Value) -> StdResult { + match value { + Value::Number(ref n) => n.as_f64().ok_or(value), + Value::String(ref s) => s.parse::().map_err(|_| value), + value => Err(value), + } +} + +pub fn validate_geo_from_json(document_id: Value, bytes: &[u8]) -> Result> { + let result = match serde_json::from_slice(bytes).map_err(InternalError::SerdeJson)? { + Value::Object(mut object) => match (object.remove("lat"), object.remove("lng")) { + (Some(lat), Some(lng)) => { + match (extract_float_from_value(lat), extract_float_from_value(lng)) { + (Ok(_), Ok(_)) => Ok(()), + (Err(value), Ok(_)) => Err(GeoError::BadLatitude { document_id, value }), + (Ok(_), Err(value)) => Err(GeoError::BadLongitude { document_id, value }), + (Err(lat), Err(lng)) => { + Err(GeoError::BadLatitudeAndLongitude { document_id, lat, lng }) + } + } + } + (None, Some(_)) => Err(GeoError::MissingLatitude { document_id }), + (Some(_), None) => Err(GeoError::MissingLongitude { document_id }), + (None, None) => Err(GeoError::MissingLatitudeAndLongitude { document_id }), + }, + value => Err(GeoError::NotAnObject { document_id, value }), + }; + + Ok(result) +}