From c8aee7ed7a2dd566a29fb7147e80474087b76eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Sat, 1 Dec 2018 18:37:21 +0100 Subject: [PATCH] fix: Make the merge operator work --- src/blob/mod.rs | 1 + src/blob/negative/blob.rs | 12 +++++- src/blob/ops.rs | 18 ++++++--- src/blob/positive/blob.rs | 82 ++++++++++++++++++++++++++++++++++++--- src/blob/positive/ops.rs | 5 +-- src/data/doc_ids.rs | 2 +- src/data/doc_indexes.rs | 69 ++++++++++++++++++++++++-------- src/rank/ranked_stream.rs | 2 +- 8 files changed, 157 insertions(+), 34 deletions(-) diff --git a/src/blob/mod.rs b/src/blob/mod.rs index 05f9367c4..10357e7dc 100644 --- a/src/blob/mod.rs +++ b/src/blob/mod.rs @@ -11,6 +11,7 @@ use std::fmt; use serde::ser::{Serialize, Serializer, SerializeTuple}; use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; +#[derive(Debug)] pub enum Blob { Positive(PositiveBlob), Negative(NegativeBlob), diff --git a/src/blob/negative/blob.rs b/src/blob/negative/blob.rs index 425be4cf6..038c90cf2 100644 --- a/src/blob/negative/blob.rs +++ b/src/blob/negative/blob.rs @@ -1,11 +1,13 @@ -use std::path::Path; use std::error::Error; +use std::path::Path; +use std::fmt; use serde::de::{self, Deserialize, Deserializer}; use serde::ser::{Serialize, Serializer}; use crate::data::DocIds; use crate::DocumentId; +#[derive(Default)] pub struct NegativeBlob { doc_ids: DocIds, } @@ -42,6 +44,14 @@ impl AsRef<[DocumentId]> for NegativeBlob { } } +impl fmt::Debug for NegativeBlob { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NegativeBlob(")?; + f.debug_list().entries(self.as_ref()).finish()?; + write!(f, ")") + } +} + impl Serialize for NegativeBlob { fn serialize(&self, serializer: S) -> Result { self.doc_ids.serialize(serializer) diff --git a/src/blob/ops.rs b/src/blob/ops.rs index faceab7cf..b345b8eab 100644 --- a/src/blob/ops.rs +++ b/src/blob/ops.rs @@ -46,8 +46,8 @@ impl OpBuilder { pub fn merge(self) -> Result> { let groups = GroupBy::new(&self.blobs, blob_same_sign); - let mut positives = Vec::new(); - let mut negatives = Vec::new(); + let mut aggregated = Vec::new(); + for blobs in groups { match blobs[0].sign() { Sign::Positive => { @@ -66,7 +66,7 @@ impl OpBuilder { } let (map, doc_indexes) = builder.into_inner().unwrap(); let blob = PositiveBlob::from_bytes(map, doc_indexes).unwrap(); - positives.push(blob); + aggregated.push(Blob::Positive(blob)); }, Sign::Negative => { let mut op_builder = negative::OpBuilder::with_capacity(blobs.len()); @@ -74,14 +74,20 @@ impl OpBuilder { op_builder.push(unwrap_negative(blob)); } let blob = op_builder.union().into_negative_blob(); - negatives.push(blob); + aggregated.push(Blob::Negative(blob)); }, } } - let mut zipped = positives.into_iter().zip(negatives); let mut buffer = Vec::new(); - zipped.try_fold(PositiveBlob::default(), |base, (positive, negative)| { + aggregated.chunks(2).try_fold(PositiveBlob::default(), |base, slice| { + let negative = NegativeBlob::default(); + let (positive, negative) = match slice { + [a, b] => (unwrap_positive(a), unwrap_negative(b)), + [a] => (unwrap_positive(a), &negative), + _ => unreachable!(), + }; + let mut builder = PositiveBlobBuilder::memory(); let doc_ids = Set::new_unchecked(negative.as_ref()); diff --git a/src/blob/positive/blob.rs b/src/blob/positive/blob.rs index 851f4c686..b58bee7a2 100644 --- a/src/blob/positive/blob.rs +++ b/src/blob/positive/blob.rs @@ -37,7 +37,7 @@ impl PositiveBlob { } pub fn get>(&self, key: K) -> Option<&[DocIndex]> { - self.map.get(key).and_then(|index| self.indexes.get(index)) + self.map.get(key).map(|index| &self.indexes[index as usize]) } pub fn as_map(&self) -> &Map { @@ -53,6 +53,22 @@ impl PositiveBlob { } } +impl fmt::Debug for PositiveBlob { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "PositiveBlob([")?; + let mut stream = self.into_stream(); + let mut first = true; + while let Some((k, v)) = stream.next() { + if !first { + write!(f, ", ")?; + } + first = false; + write!(f, "({}, {:?})", String::from_utf8_lossy(k), v)?; + } + write!(f, "])") + } +} + impl<'m, 'a> IntoStreamer<'a> for &'m PositiveBlob { type Item = (&'a [u8], &'a [DocIndex]); /// The type of the stream to be constructed. @@ -78,8 +94,7 @@ impl<'m, 'a> Streamer<'a> for PositiveBlobStream<'m> { fn next(&'a mut self) -> Option { match self.map_stream.next() { Some((input, index)) => { - let doc_indexes = self.doc_indexes.get(index); - let doc_indexes = doc_indexes.expect("BUG: could not find document indexes"); + let doc_indexes = &self.doc_indexes[index as usize]; Some((input, doc_indexes)) }, None => None, @@ -91,7 +106,7 @@ impl Serialize for PositiveBlob { fn serialize(&self, serializer: S) -> Result { let mut tuple = serializer.serialize_tuple(2)?; tuple.serialize_element(&self.map.as_fst().to_vec())?; - tuple.serialize_element(&self.indexes)?; + tuple.serialize_element(&self.indexes.to_vec())?; tuple.end() } } @@ -162,7 +177,9 @@ impl PositiveBlobBuilder { /// then an error is returned. Similarly, if there was a problem writing /// to the underlying writer, an error is returned. // FIXME what if one write doesn't work but the other do ? - pub fn insert(&mut self, key: &[u8], doc_indexes: &[DocIndex]) -> Result<(), Box> { + pub fn insert(&mut self, key: K, doc_indexes: &[DocIndex]) -> Result<(), Box> + where K: AsRef<[u8]>, + { self.map.insert(key, self.value)?; self.indexes.insert(doc_indexes)?; self.value += 1; @@ -179,3 +196,58 @@ impl PositiveBlobBuilder { Ok((map, indexes)) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::error::Error; + + #[test] + fn serialize_deserialize() -> Result<(), Box> { + let a = DocIndex { document_id: 0, attribute: 3, attribute_index: 11 }; + let b = DocIndex { document_id: 1, attribute: 4, attribute_index: 21 }; + let c = DocIndex { document_id: 2, attribute: 8, attribute_index: 2 }; + + let mut builder = PositiveBlobBuilder::memory(); + + builder.insert("aaa", &[a])?; + builder.insert("aab", &[a, b, c])?; + builder.insert("aac", &[a, c])?; + + let (map_bytes, indexes_bytes) = builder.into_inner()?; + let positive_blob = PositiveBlob::from_bytes(map_bytes, indexes_bytes)?; + + assert_eq!(positive_blob.get("aaa"), Some(&[a][..])); + assert_eq!(positive_blob.get("aab"), Some(&[a, b, c][..])); + assert_eq!(positive_blob.get("aac"), Some(&[a, c][..])); + assert_eq!(positive_blob.get("aad"), None); + + Ok(()) + } + + #[test] + fn serde_serialize_deserialize() -> Result<(), Box> { + let a = DocIndex { document_id: 0, attribute: 3, attribute_index: 11 }; + let b = DocIndex { document_id: 1, attribute: 4, attribute_index: 21 }; + let c = DocIndex { document_id: 2, attribute: 8, attribute_index: 2 }; + + let mut builder = PositiveBlobBuilder::memory(); + + builder.insert("aaa", &[a])?; + builder.insert("aab", &[a, b, c])?; + builder.insert("aac", &[a, c])?; + + let (map_bytes, indexes_bytes) = builder.into_inner()?; + let positive_blob = PositiveBlob::from_bytes(map_bytes, indexes_bytes)?; + + let bytes = bincode::serialize(&positive_blob)?; + let positive_blob: PositiveBlob = bincode::deserialize(&bytes)?; + + assert_eq!(positive_blob.get("aaa"), Some(&[a][..])); + assert_eq!(positive_blob.get("aab"), Some(&[a, b, c][..])); + assert_eq!(positive_blob.get("aac"), Some(&[a, c][..])); + assert_eq!(positive_blob.get("aad"), None); + + Ok(()) + } +} diff --git a/src/blob/positive/ops.rs b/src/blob/positive/ops.rs index 2788d0c3c..aed81aa9a 100644 --- a/src/blob/positive/ops.rs +++ b/src/blob/positive/ops.rs @@ -106,9 +106,8 @@ impl<'m, 'a> fst::Streamer<'a> for $name<'m> { let mut builder = SdOpBuilder::with_capacity(ivalues.len()); for ivalue in ivalues { - let indexes = self.indexes[ivalue.index].get(ivalue.value); - let indexes = indexes.expect("BUG: could not find document indexes"); - let set = Set::new_unchecked(indexes); + let doc_indexes = &self.indexes[ivalue.index][ivalue.value as usize]; + let set = Set::new_unchecked(doc_indexes); builder.push(set); } diff --git a/src/data/doc_ids.rs b/src/data/doc_ids.rs index d5650cce6..7c8613744 100644 --- a/src/data/doc_ids.rs +++ b/src/data/doc_ids.rs @@ -10,7 +10,7 @@ use serde::ser::{Serialize, Serializer}; use crate::DocumentId; use crate::data::Data; -#[derive(Clone)] +#[derive(Default, Clone)] pub struct DocIds { data: Data, } diff --git a/src/data/doc_indexes.rs b/src/data/doc_indexes.rs index 78e8ebe73..c7a73a149 100644 --- a/src/data/doc_indexes.rs +++ b/src/data/doc_indexes.rs @@ -1,12 +1,12 @@ use std::slice::from_raw_parts; use std::io::{self, Write}; +use std::mem::size_of; +use std::ops::Index; use std::path::Path; use std::sync::Arc; -use std::mem; use fst::raw::MmapReadOnly; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use serde::ser::{Serialize, Serializer, SerializeTuple}; use crate::DocIndex; use crate::data::Data; @@ -41,7 +41,7 @@ impl DocIndexes { } fn from_data(data: Data) -> io::Result { - let ranges_len_offset = data.len() - mem::size_of::(); + let ranges_len_offset = data.len() - size_of::(); let ranges_len = (&data[ranges_len_offset..]).read_u64::()?; let ranges_len = ranges_len as usize; @@ -53,7 +53,18 @@ impl DocIndexes { Ok(DocIndexes { ranges, indexes }) } - pub fn get(&self, index: u64) -> Option<&[DocIndex]> { + pub fn to_vec(&self) -> Vec { + let capacity = self.indexes.len() + self.ranges.len() + size_of::(); + let mut bytes = Vec::with_capacity(capacity); + + bytes.extend_from_slice(&self.indexes); + bytes.extend_from_slice(&self.ranges); + bytes.write_u64::(self.ranges.len() as u64).unwrap(); + + bytes + } + + pub fn get(&self, index: usize) -> Option<&[DocIndex]> { self.ranges().get(index as usize).map(|Range { start, end }| { let start = *start as usize; let end = *end as usize; @@ -64,24 +75,26 @@ impl DocIndexes { fn ranges(&self) -> &[Range] { let slice = &self.ranges; let ptr = slice.as_ptr() as *const Range; - let len = slice.len() / mem::size_of::(); + let len = slice.len() / size_of::(); unsafe { from_raw_parts(ptr, len) } } fn indexes(&self) -> &[DocIndex] { let slice = &self.indexes; let ptr = slice.as_ptr() as *const DocIndex; - let len = slice.len() / mem::size_of::(); + let len = slice.len() / size_of::(); unsafe { from_raw_parts(ptr, len) } } } -impl Serialize for DocIndexes { - fn serialize(&self, serializer: S) -> Result { - let mut tuple = serializer.serialize_tuple(2)?; - tuple.serialize_element(self.ranges.as_ref())?; - tuple.serialize_element(self.indexes.as_ref())?; - tuple.end() +impl Index for DocIndexes { + type Output = [DocIndex]; + + fn index(&self, index: usize) -> &Self::Output { + match self.get(index) { + Some(indexes) => indexes, + None => panic!("index {} out of range for a maximum of {} ranges", index, self.ranges().len()), + } } } @@ -134,7 +147,7 @@ impl DocIndexesBuilder { unsafe fn into_u8_slice(slice: &[T]) -> &[u8] { let ptr = slice.as_ptr() as *const u8; - let len = slice.len() * mem::size_of::(); + let len = slice.len() * size_of::(); from_raw_parts(ptr, len) } @@ -144,7 +157,7 @@ mod tests { use std::error::Error; #[test] - fn serialize_deserialize() -> Result<(), Box> { + fn builder_serialize_deserialize() -> Result<(), Box> { let a = DocIndex { document_id: 0, attribute: 3, attribute_index: 11 }; let b = DocIndex { document_id: 1, attribute: 4, attribute_index: 21 }; let c = DocIndex { document_id: 2, attribute: 8, attribute_index: 2 }; @@ -158,9 +171,31 @@ mod tests { let bytes = builder.into_inner()?; let docs = DocIndexes::from_bytes(bytes)?; - assert_eq!(docs.get(0).unwrap(), &[a]); - assert_eq!(docs.get(1).unwrap(), &[a, b, c]); - assert_eq!(docs.get(2).unwrap(), &[a, c]); + assert_eq!(docs.get(0), Some(&[a][..])); + assert_eq!(docs.get(1), Some(&[a, b, c][..])); + assert_eq!(docs.get(2), Some(&[a, c][..])); + assert_eq!(docs.get(3), None); + + Ok(()) + } + + #[test] + fn serialize_deserialize() -> Result<(), Box> { + let a = DocIndex { document_id: 0, attribute: 3, attribute_index: 11 }; + let b = DocIndex { document_id: 1, attribute: 4, attribute_index: 21 }; + let c = DocIndex { document_id: 2, attribute: 8, attribute_index: 2 }; + + let mut builder = DocIndexesBuilder::memory(); + + builder.insert(&[a])?; + builder.insert(&[a, b, c])?; + builder.insert(&[a, c])?; + + let builder_bytes = builder.into_inner()?; + let docs = DocIndexes::from_bytes(builder_bytes.clone())?; + let bytes = docs.to_vec(); + + assert_eq!(builder_bytes, bytes); Ok(()) } diff --git a/src/rank/ranked_stream.rs b/src/rank/ranked_stream.rs index b1abc68d9..a2391b98a 100644 --- a/src/rank/ranked_stream.rs +++ b/src/rank/ranked_stream.rs @@ -86,7 +86,7 @@ where T: Deref, let is_exact = distance == 0 && input.len() == automaton.query_len(); let doc_indexes = self.blob.as_indexes(); - let doc_indexes = doc_indexes.get(iv.value).expect("BUG: could not find document indexes"); + let doc_indexes = &doc_indexes[iv.value as usize]; for doc_index in doc_indexes { let match_ = Match {