diff --git a/src/update/index_documents/mod.rs b/src/update/index_documents/mod.rs index ce4e8c2df..8860c1a93 100644 --- a/src/update/index_documents/mod.rs +++ b/src/update/index_documents/mod.rs @@ -697,17 +697,17 @@ mod tests { assert_eq!(count, 3); let docs = index.documents(&rtxn, vec![0, 1, 2]).unwrap(); - let (kevin_id, _) = docs.iter().find(|(_, d)| d.get(1).unwrap() == br#""kevin""#).unwrap(); + let (kevin_id, _) = docs.iter().find(|(_, d)| { + d.get(0).unwrap() == br#""updated kevin""# + }).unwrap(); let (id, doc) = docs[*kevin_id as usize]; assert_eq!(id, *kevin_id); // Check that this document is equal to the last // one sent and that an UUID has been generated. - let mut doc_iter = doc.iter(); + assert_eq!(doc.get(0), Some(&br#""updated kevin""#[..])); // This is an UUID, it must be 36 bytes long plus the 2 surrounding string quotes ("). - doc_iter.next().filter(|(_, id)| id.len() == 36 + 2).unwrap(); - assert_eq!(doc_iter.next(), Some((1, &br#""kevin""#[..]))); - assert_eq!(doc_iter.next(), None); + assert!(doc.get(1).unwrap().len() == 36 + 2); drop(rtxn); } @@ -842,4 +842,36 @@ mod tests { assert_eq!(count, 3); drop(rtxn); } + + #[test] + fn invalid_documents_ids() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // First we send 1 document with an invalid id. + let mut wtxn = index.write_txn().unwrap(); + // There is a space in the document id. + let content = &b"id,name\nbrume bleue,kevin\n"[..]; + let mut builder = IndexDocuments::new(&mut wtxn, &index); + builder.update_format(UpdateFormat::Csv); + assert!(builder.execute(content, |_, _| ()).is_err()); + wtxn.commit().unwrap(); + + // First we send 1 document with a valid id. + let mut wtxn = index.write_txn().unwrap(); + // There is a space in the document id. + let content = &b"id,name\n32,kevin\n"[..]; + let mut builder = IndexDocuments::new(&mut wtxn, &index); + builder.update_format(UpdateFormat::Csv); + builder.execute(content, |_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Check that there is 1 document now. + let rtxn = index.read_txn().unwrap(); + let count = index.number_of_documents(&rtxn).unwrap(); + assert_eq!(count, 1); + drop(rtxn); + } } diff --git a/src/update/index_documents/transform.rs b/src/update/index_documents/transform.rs index df0dccc3f..ff71928cb 100644 --- a/src/update/index_documents/transform.rs +++ b/src/update/index_documents/transform.rs @@ -172,6 +172,12 @@ impl Transform<'_, '_> { writer.insert(field_id, &json_buffer)?; } else if field_id == primary_key { + // We validate the document id [a-zA-Z0-9\-_]. + let user_id = match validate_document_id(&user_id) { + Some(valid) => valid, + None => return Err(anyhow!("invalid document id: {:?}", user_id)), + }; + // We serialize the document id. serde_json::to_writer(&mut json_buffer, &user_id)?; writer.insert(field_id, &json_buffer)?; @@ -256,9 +262,15 @@ impl Transform<'_, '_> { let mut writer = obkv::KvWriter::new(&mut obkv_buffer); // We extract the user id if we know where it is or generate an UUID V4 otherwise. - // TODO we must validate the user id (i.e. [a-zA-Z0-9\-_]). let user_id = match user_id_pos { - Some(pos) => &record[pos], + Some(pos) => { + let user_id = &record[pos]; + // We validate the document id [a-zA-Z0-9\-_]. + match validate_document_id(&user_id) { + Some(valid) => valid, + None => return Err(anyhow!("invalid document id: {:?}", user_id)), + } + }, None => uuid::Uuid::new_v4().to_hyphenated().encode_lower(&mut uuid_buffer), }; @@ -411,3 +423,12 @@ fn merge_obkvs(_key: &[u8], obkvs: &[Cow<[u8]>]) -> anyhow::Result> { buffer })) } + +fn validate_document_id(document_id: &str) -> Option<&str> { + let document_id = document_id.trim(); + Some(document_id).filter(|id| { + !id.is_empty() && id.chars().all(|c| { + matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_') + }) + }) +}