use std::io::{self, Write}; use grenad::{CompressionType, WriterBuilder}; use serde_json::{to_writer, Map, Value}; use super::{DocumentsBatchIndex, Error, DOCUMENTS_BATCH_INDEX_KEY}; /// The `DocumentsBatchBuilder` provides a way to build a documents batch in the intermediary /// format used by milli. /// /// The writer used by the `DocumentsBatchBuilder` can be read using a `DocumentsBatchReader` /// to iterate over the documents. /// /// ## example: /// ``` /// use serde_json::json; /// use milli::documents::DocumentsBatchBuilder; /// /// let json = json!({ "id": 1, "name": "foo" }); /// /// let mut builder = DocumentsBatchBuilder::new(Vec::new()); /// builder.append_json_object(json.as_object().unwrap()).unwrap(); /// let _vector = builder.into_inner().unwrap(); /// ``` pub struct DocumentsBatchBuilder { /// The inner grenad writer, the last value must always be the `DocumentsBatchIndex`. writer: grenad::Writer, /// A map that creates the relation between field ids and field names. fields_index: DocumentsBatchIndex, /// The number of documents that were added to this builder, /// it doesn't take the primary key of the documents into account at this point. documents_count: u32, /// A buffer to store a temporary obkv buffer and avoid reallocating. obkv_buffer: Vec, /// A buffer to serialize the values and avoid reallocating, /// serialized values are stored in an obkv. value_buffer: Vec, } impl DocumentsBatchBuilder { pub fn new(writer: W) -> DocumentsBatchBuilder { DocumentsBatchBuilder { writer: WriterBuilder::new().compression_type(CompressionType::None).build(writer), fields_index: DocumentsBatchIndex::default(), documents_count: 0, obkv_buffer: Vec::new(), value_buffer: Vec::new(), } } /// Returns the number of documents inserted into this builder. pub fn documents_count(&self) -> u32 { self.documents_count } /// Appends a new JSON object into the batch and updates the `DocumentsBatchIndex` accordingly. pub fn append_json_object(&mut self, object: &Map) -> io::Result<()> { // Make sure that we insert the fields ids in order as the obkv writer has this requirement. let mut fields_ids: Vec<_> = object.keys().map(|k| self.fields_index.insert(&k)).collect(); fields_ids.sort_unstable(); self.obkv_buffer.clear(); let mut writer = obkv::KvWriter::new(&mut self.obkv_buffer); for field_id in fields_ids { let key = self.fields_index.name(field_id).unwrap(); self.value_buffer.clear(); to_writer(&mut self.value_buffer, &object[key])?; writer.insert(field_id, &self.value_buffer)?; } let internal_id = self.documents_count.to_be_bytes(); let document_bytes = writer.into_inner()?; self.writer.insert(internal_id, &document_bytes)?; self.documents_count += 1; Ok(()) } /// Appends a new CSV file into the batch and updates the `DocumentsBatchIndex` accordingly. pub fn append_csv(&mut self, mut reader: csv::Reader) -> Result<(), Error> { // Make sure that we insert the fields ids in order as the obkv writer has this requirement. let mut typed_fields_ids: Vec<_> = reader .headers()? .into_iter() .map(parse_csv_header) .map(|(k, t)| (self.fields_index.insert(k), t)) .enumerate() .collect(); typed_fields_ids.sort_unstable_by_key(|(_, (fid, _))| *fid); let mut record = csv::StringRecord::new(); let mut line = 0; while reader.read_record(&mut record)? { // We increment here and not at the end of the while loop to take // the header offset into account. line += 1; self.obkv_buffer.clear(); let mut writer = obkv::KvWriter::new(&mut self.obkv_buffer); for (i, (field_id, type_)) in typed_fields_ids.iter() { self.value_buffer.clear(); let value = &record[*i]; match type_ { AllowedType::Number => { if value.trim().is_empty() { to_writer(&mut self.value_buffer, &Value::Null)?; } else { match value.trim().parse::() { Ok(float) => { to_writer(&mut self.value_buffer, &float)?; } Err(error) => { return Err(Error::ParseFloat { error, line, value: value.to_string(), }); } } } } AllowedType::String => { if value.is_empty() { to_writer(&mut self.value_buffer, &Value::Null)?; } else { to_writer(&mut self.value_buffer, value)?; } } } // We insert into the obkv writer the value buffer that has been filled just above. writer.insert(*field_id, &self.value_buffer)?; } let internal_id = self.documents_count.to_be_bytes(); let document_bytes = writer.into_inner()?; self.writer.insert(internal_id, &document_bytes)?; self.documents_count += 1; } Ok(()) } /// Flushes the content on disk and stores the final version of the `DocumentsBatchIndex`. pub fn into_inner(mut self) -> io::Result { let DocumentsBatchBuilder { mut writer, fields_index, .. } = self; // We serialize and insert the `DocumentsBatchIndex` as the last key of the grenad writer. self.value_buffer.clear(); to_writer(&mut self.value_buffer, &fields_index)?; writer.insert(DOCUMENTS_BATCH_INDEX_KEY, &self.value_buffer)?; writer.into_inner() } } #[derive(Debug)] enum AllowedType { String, Number, } fn parse_csv_header(header: &str) -> (&str, AllowedType) { // if there are several separators we only split on the last one. match header.rsplit_once(':') { Some((field_name, field_type)) => match field_type { "string" => (field_name, AllowedType::String), "number" => (field_name, AllowedType::Number), // if the pattern isn't reconized, we keep the whole field. _otherwise => (header, AllowedType::String), }, None => (header, AllowedType::String), } } #[cfg(test)] mod test { use std::io::Cursor; use serde_json::{json, Map}; use super::*; use crate::documents::DocumentBatchReader; fn obkv_to_value(obkv: &obkv::KvReader, index: &DocumentsBatchIndex) -> Value { let mut map = Map::new(); for (fid, value) in obkv.iter() { let field_name = index.name(fid).unwrap().clone(); let value: Value = serde_json::from_slice(value).unwrap(); map.insert(field_name, value); } Value::Object(map) } #[test] fn add_single_documents_json() { let mut cursor = Cursor::new(Vec::new()); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); let json = serde_json::json!({ "id": 1, "field": "hello!", }); builder.extend_from_json(Cursor::new(serde_json::to_vec(&json).unwrap())).unwrap(); let json = serde_json::json!({ "blabla": false, "field": "hello!", "id": 1, }); builder.extend_from_json(Cursor::new(serde_json::to_vec(&json).unwrap())).unwrap(); assert_eq!(builder.len(), 2); builder.finish().unwrap(); cursor.set_position(0); let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); let (index, document) = reader.next_document_with_index().unwrap().unwrap(); assert_eq!(index.len(), 3); assert_eq!(document.iter().count(), 2); let (index, document) = reader.next_document_with_index().unwrap().unwrap(); assert_eq!(index.len(), 3); assert_eq!(document.iter().count(), 3); assert!(reader.next_document_with_index().unwrap().is_none()); } #[test] fn add_documents_seq_json() { let mut cursor = Cursor::new(Vec::new()); let mut builder = DocumentBatchBuilder::new(&mut cursor).unwrap(); let json = serde_json::json!([{ "id": 1, "field": "hello!", },{ "blabla": false, "field": "hello!", "id": 1, } ]); builder.extend_from_json(Cursor::new(serde_json::to_vec(&json).unwrap())).unwrap(); assert_eq!(builder.len(), 2); builder.finish().unwrap(); cursor.set_position(0); let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); let (index, document) = reader.next_document_with_index().unwrap().unwrap(); assert_eq!(index.len(), 3); assert_eq!(document.iter().count(), 2); let (index, document) = reader.next_document_with_index().unwrap().unwrap(); assert_eq!(index.len(), 3); assert_eq!(document.iter().count(), 3); assert!(reader.next_document_with_index().unwrap().is_none()); } #[test] fn add_documents_csv() { let mut cursor = Cursor::new(Vec::new()); let csv = "id:number,field:string\n1,hello!\n2,blabla"; let builder = DocumentBatchBuilder::from_csv(Cursor::new(csv.as_bytes()), &mut cursor).unwrap(); builder.finish().unwrap(); cursor.set_position(0); let mut reader = DocumentBatchReader::from_reader(cursor).unwrap(); let (index, document) = reader.next_document_with_index().unwrap().unwrap(); assert_eq!(index.len(), 2); assert_eq!(document.iter().count(), 2); let (_index, document) = reader.next_document_with_index().unwrap().unwrap(); assert_eq!(document.iter().count(), 2); assert!(reader.next_document_with_index().unwrap().is_none()); } #[test] fn simple_csv_document() { let documents = r#"city,country,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city": "Boston", "country": "United States", "pop": "4628910", }) ); assert!(reader.next_document_with_index().unwrap().is_none()); } #[test] fn coma_in_field() { let documents = r#"city,country,pop "Boston","United, States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city": "Boston", "country": "United, States", "pop": "4628910", }) ); } #[test] fn quote_in_field() { let documents = r#"city,country,pop "Boston","United"" States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city": "Boston", "country": "United\" States", "pop": "4628910", }) ); } #[test] fn integer_in_field() { let documents = r#"city,country,pop:number "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city": "Boston", "country": "United States", "pop": 4628910.0, }) ); } #[test] fn float_in_field() { let documents = r#"city,country,pop:number "Boston","United States","4628910.01""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city": "Boston", "country": "United States", "pop": 4628910.01, }) ); } #[test] fn several_colon_in_header() { let documents = r#"city:love:string,country:state,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city:love": "Boston", "country:state": "United States", "pop": "4628910", }) ); } #[test] fn ending_by_colon_in_header() { let documents = r#"city:,country,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city:": "Boston", "country": "United States", "pop": "4628910", }) ); } #[test] fn starting_by_colon_in_header() { let documents = r#":city,country,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ ":city": "Boston", "country": "United States", "pop": "4628910", }) ); } #[ignore] #[test] fn starting_by_colon_in_header2() { let documents = r#":string,country,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); assert!(reader.next_document_with_index().is_err()); } #[test] fn double_colon_in_header() { let documents = r#"city::string,country,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)) .unwrap() .finish() .unwrap(); let mut reader = DocumentBatchReader::from_reader(Cursor::new(buf)).unwrap(); let (index, doc) = reader.next_document_with_index().unwrap().unwrap(); let val = obkv_to_value(&doc, index); assert_eq!( val, json!({ "city:": "Boston", "country": "United States", "pop": "4628910", }) ); } #[test] fn bad_type_in_header() { let documents = r#"city,country:number,pop "Boston","United States","4628910""#; let mut buf = Vec::new(); assert!( DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)).is_err() ); } #[test] fn bad_column_count1() { let documents = r#"city,country,pop "Boston","United States","4628910", "too much""#; let mut buf = Vec::new(); assert!( DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)).is_err() ); } #[test] fn bad_column_count2() { let documents = r#"city,country,pop "Boston","United States""#; let mut buf = Vec::new(); assert!( DocumentBatchBuilder::from_csv(documents.as_bytes(), Cursor::new(&mut buf)).is_err() ); } }