diff --git a/src/heed_codec/mod.rs b/src/heed_codec/mod.rs new file mode 100644 index 000000000..3324f1006 --- /dev/null +++ b/src/heed_codec/mod.rs @@ -0,0 +1,3 @@ +mod roaring_bitmap; + +pub use self::roaring_bitmap::RoaringBitmapCodec; diff --git a/src/heed_codec/roaring_bitmap.rs b/src/heed_codec/roaring_bitmap.rs new file mode 100644 index 000000000..abc89e90d --- /dev/null +++ b/src/heed_codec/roaring_bitmap.rs @@ -0,0 +1,22 @@ +use std::borrow::Cow; +use roaring::RoaringBitmap; + +pub struct RoaringBitmapCodec; + +impl heed::BytesDecode<'_> for RoaringBitmapCodec { + type DItem = RoaringBitmap; + + fn bytes_decode(bytes: &[u8]) -> Option { + RoaringBitmap::deserialize_from(bytes).ok() + } +} + +impl heed::BytesEncode<'_> for RoaringBitmapCodec { + type EItem = RoaringBitmap; + + fn bytes_encode(item: &Self::EItem) -> Option> { + let mut bytes = Vec::new(); + item.serialize_into(&mut bytes).ok()?; + Some(Cow::Owned(bytes)) + } +} diff --git a/src/lib.rs b/src/lib.rs index f09158fb7..90fe6e541 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod best_proximity; +mod heed_codec; mod iter_shortest_paths; mod query_tokens; @@ -16,8 +17,9 @@ use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; use once_cell::sync::Lazy; use roaring::RoaringBitmap; -use self::query_tokens::{QueryTokens, QueryToken}; use self::best_proximity::BestProximity; +use self::heed_codec::RoaringBitmapCodec; +use self::query_tokens::{QueryTokens, QueryToken}; // Building these factories is not free. static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); @@ -35,10 +37,10 @@ pub type AttributeId = u32; #[derive(Clone)] pub struct Index { pub main: PolyDatabase, - pub postings_attrs: Database, - pub prefix_postings_attrs: Database, - pub postings_ids: Database, - pub prefix_postings_ids: Database, + pub postings_attrs: Database, + pub prefix_postings_attrs: Database, + pub postings_ids: Database, + pub prefix_postings_ids: Database, pub documents: Database, ByteSlice>, } @@ -105,8 +107,7 @@ impl Index { let mut stream = fst.search(&dfa).into_stream(); while let Some(word) = stream.next() { let word = std::str::from_utf8(word)?; - if let Some(attrs) = self.postings_attrs.get(rtxn, word)? { - let right = RoaringBitmap::deserialize_from_slice(attrs)?; + if let Some(right) = self.postings_attrs.get(rtxn, word)? { union_positions.union_with(&right); derived_words.push((word.as_bytes().to_vec(), right)); count += 1; @@ -130,8 +131,7 @@ impl Index { if attrs.contains(pos) { let mut key = word.clone(); key.extend_from_slice(&pos.to_be_bytes()); - if let Some(attrs) = self.postings_ids.get(rtxn, &key).unwrap() { - let right = RoaringBitmap::deserialize_from_slice(attrs).unwrap(); + if let Some(right) = self.postings_ids.get(rtxn, &key).unwrap() { union_docids.union_with(&right); } }