diff --git a/examples/create-database.rs b/examples/create-database.rs index e7a8e72e4..89e96014b 100644 --- a/examples/create-database.rs +++ b/examples/create-database.rs @@ -1,6 +1,4 @@ -use std::collections::hash_map::DefaultHasher; use std::path::{Path, PathBuf}; -use std::hash::{Hash, Hasher}; use std::error::Error; use serde_derive::{Serialize, Deserialize}; @@ -10,7 +8,6 @@ use meilidb::database::schema::{Schema, SchemaBuilder, STORED, INDEXED}; use meilidb::database::update::PositiveUpdateBuilder; use meilidb::tokenizer::DefaultBuilder; use meilidb::database::Database; -use meilidb::DocumentId; #[derive(Debug, StructOpt)] pub struct Opt { @@ -31,14 +28,8 @@ struct Document<'a> { image: &'a str, } -fn calculate_hash(t: &T) -> u64 { - let mut s = DefaultHasher::new(); - t.hash(&mut s); - s.finish() -} - fn create_schema() -> Schema { - let mut schema = SchemaBuilder::new(); + let mut schema = SchemaBuilder::with_identifier("id"); schema.new_attribute("id", STORED); schema.new_attribute("title", STORED | INDEXED); schema.new_attribute("description", STORED | INDEXED); @@ -68,8 +59,7 @@ fn index(schema: Schema, database_path: &Path, csv_data_path: &Path) -> Result Result<(), Box> { let number_of_documents = documents.len(); for doc in documents { - match view.retrieve_document::(doc.id) { + match view.document_by_id::(doc.id) { Ok(document) => { print!("title: "); diff --git a/src/database/database_view.rs b/src/database/database_view.rs index c8eed37c6..6e8d32a78 100644 --- a/src/database/database_view.rs +++ b/src/database/database_view.rs @@ -75,15 +75,14 @@ where D: Deref QueryBuilder::new(self) } - // TODO create an enum error type - pub fn retrieve_document(&self, id: DocumentId) -> Result> + pub fn document_by_id(&self, id: DocumentId) -> Result> where T: DeserializeOwned { let mut deserializer = Deserializer::new(&self.snapshot, &self.schema, id); Ok(T::deserialize(&mut deserializer)?) } - pub fn retrieve_documents(&self, ids: I) -> DocumentIter + pub fn documents_by_id(&self, ids: I) -> DocumentIter where T: DeserializeOwned, I: IntoIterator, { @@ -149,7 +148,7 @@ where D: Deref, fn next(&mut self) -> Option { match self.document_ids.next() { - Some(id) => Some(self.database_view.retrieve_document(id)), + Some(id) => Some(self.database_view.document_by_id(id)), None => None } } @@ -168,7 +167,7 @@ where D: Deref, { fn next_back(&mut self) -> Option { match self.document_ids.next_back() { - Some(id) => Some(self.database_view.retrieve_document(id)), + Some(id) => Some(self.database_view.document_by_id(id)), None => None } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 2351c658c..245ec3db6 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,4 +1,6 @@ use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::error::Error; use std::path::Path; use std::ops::Deref; @@ -14,6 +16,16 @@ use self::update::Update; use self::schema::Schema; use self::blob::Blob; +macro_rules! forward_to_unserializable_type { + ($($ty:ident => $se_method:ident,)*) => { + $( + fn $se_method(self, _v: $ty) -> Result { + Err(SerializerError::UnserializableType { name: "$ty" }) + } + )* + } +} + pub mod blob; pub mod schema; pub mod update; @@ -24,6 +36,12 @@ mod deserializer; const DATA_INDEX: &[u8] = b"data-index"; const DATA_SCHEMA: &[u8] = b"data-schema"; +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + pub fn retrieve_data_schema(snapshot: &Snapshot) -> Result> where D: Deref { @@ -194,7 +212,6 @@ mod tests { use serde_derive::{Serialize, Deserialize}; use tempfile::tempdir; - use crate::DocumentId; use crate::tokenizer::DefaultBuilder; use crate::database::update::PositiveUpdateBuilder; use crate::database::schema::{SchemaBuilder, STORED, INDEXED}; @@ -207,13 +224,15 @@ mod tests { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] struct SimpleDoc { + id: u64, title: String, description: String, timestamp: u64, } let schema = { - let mut builder = SchemaBuilder::new(); + let mut builder = SchemaBuilder::with_identifier("id"); + builder.new_attribute("id", STORED); builder.new_attribute("title", STORED | INDEXED); builder.new_attribute("description", STORED | INDEXED); builder.new_attribute("timestamp", STORED); @@ -226,21 +245,25 @@ mod tests { let update_path = dir.path().join("update.sst"); let doc0 = SimpleDoc { + id: 0, title: String::from("I am a title"), description: String::from("I am a description"), timestamp: 1234567, }; let doc1 = SimpleDoc { + id: 1, title: String::from("I am the second title"), description: String::from("I am the second description"), timestamp: 7654321, }; + let docid0; + let docid1; let mut update = { let mut builder = PositiveUpdateBuilder::new(update_path, schema, tokenizer_builder); - builder.update(DocumentId(0), &doc0).unwrap(); - builder.update(DocumentId(1), &doc1).unwrap(); + docid0 = builder.update(&doc0).unwrap(); + docid1 = builder.update(&doc1).unwrap(); builder.build()? }; @@ -249,8 +272,8 @@ mod tests { database.ingest_update_file(update)?; let view = database.view(); - let de_doc0: SimpleDoc = view.retrieve_document(DocumentId(0))?; - let de_doc1: SimpleDoc = view.retrieve_document(DocumentId(1))?; + let de_doc0: SimpleDoc = view.document_by_id(docid0)?; + let de_doc1: SimpleDoc = view.document_by_id(docid1)?; assert_eq!(doc0, de_doc0); assert_eq!(doc1, de_doc1); diff --git a/src/database/schema.rs b/src/database/schema.rs index edb19ad79..0d02f342b 100644 --- a/src/database/schema.rs +++ b/src/database/schema.rs @@ -1,4 +1,6 @@ +use crate::database::update::SerializerError; use std::collections::{HashMap, BTreeMap}; +use crate::database::calculate_hash; use std::io::{Read, Write}; use std::{fmt, u16}; use std::path::Path; @@ -7,8 +9,11 @@ use std::sync::Arc; use std::fs::File; use serde_derive::{Serialize, Deserialize}; +use serde::ser::{self, Serialize}; use linked_hash_map::LinkedHashMap; +use crate::DocumentId; + pub const STORED: SchemaProps = SchemaProps { stored: true, indexed: false }; pub const INDEXED: SchemaProps = SchemaProps { stored: false, indexed: true }; @@ -40,12 +45,16 @@ impl BitOr for SchemaProps { } pub struct SchemaBuilder { + identifier: String, attrs: LinkedHashMap, } impl SchemaBuilder { - pub fn new() -> SchemaBuilder { - SchemaBuilder { attrs: LinkedHashMap::new() } + pub fn with_identifier>(name: S) -> SchemaBuilder { + SchemaBuilder { + identifier: name.into(), + attrs: LinkedHashMap::new(), + } } pub fn new_attribute>(&mut self, name: S, props: SchemaProps) -> SchemaAttr { @@ -65,7 +74,8 @@ impl SchemaBuilder { props.push((name, prop)); } - Schema { inner: Arc::new(InnerSchema { attrs, props }) } + let identifier = self.identifier; + Schema { inner: Arc::new(InnerSchema { identifier, attrs, props }) } } } @@ -76,6 +86,7 @@ pub struct Schema { #[derive(Debug, Clone, PartialEq, Eq)] struct InnerSchema { + identifier: String, attrs: HashMap, props: Vec<(String, SchemaProps)>, } @@ -87,8 +98,8 @@ impl Schema { } pub fn read_from(reader: R) -> bincode::Result { - let attrs = bincode::deserialize_from(reader)?; - let builder = SchemaBuilder { attrs }; + let (identifier, attrs) = bincode::deserialize_from(reader)?; + let builder = SchemaBuilder { identifier, attrs }; Ok(builder.build()) } @@ -99,12 +110,22 @@ impl Schema { ordered.insert(attr.0, (name, props)); } + let identifier = &self.inner.identifier; let mut attrs = LinkedHashMap::with_capacity(ordered.len()); for (_, (name, props)) in ordered { attrs.insert(name, props); } - bincode::serialize_into(writer, &attrs) + bincode::serialize_into(writer, &(identifier, attrs)) + } + + pub fn document_id(&self, document: &T) -> Result + where T: Serialize, + { + let find_document_id = FindDocumentIdSerializer { + id_attribute_name: self.identifier_name(), + }; + document.serialize(find_document_id) } pub fn props(&self, attr: SchemaAttr) -> SchemaProps { @@ -112,6 +133,10 @@ impl Schema { props } + pub fn identifier_name(&self) -> &str { + &self.inner.identifier + } + pub fn attribute>(&self, name: S) -> Option { self.inner.attrs.get(name.as_ref()).cloned() } @@ -141,13 +166,199 @@ impl fmt::Display for SchemaAttr { } } +struct FindDocumentIdSerializer<'a> { + id_attribute_name: &'a str, +} + +impl<'a> ser::Serializer for FindDocumentIdSerializer<'a> { + type Ok = DocumentId; + type Error = SerializerError; + type SerializeSeq = ser::Impossible; + type SerializeTuple = ser::Impossible; + type SerializeTupleStruct = ser::Impossible; + type SerializeTupleVariant = ser::Impossible; + type SerializeMap = ser::Impossible; + type SerializeStruct = FindDocumentIdStructSerializer<'a>; + type SerializeStructVariant = ser::Impossible; + + forward_to_unserializable_type! { + bool => serialize_bool, + char => serialize_char, + + i8 => serialize_i8, + i16 => serialize_i16, + i32 => serialize_i32, + i64 => serialize_i64, + + u8 => serialize_u8, + u16 => serialize_u16, + u32 => serialize_u32, + u64 => serialize_u64, + + f32 => serialize_f32, + f64 => serialize_f64, + } + + fn serialize_str(self, _v: &str) -> Result { + Err(SerializerError::UnserializableType { name: "str" }) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(SerializerError::UnserializableType { name: "&[u8]" }) + } + + fn serialize_none(self) -> Result { + Err(SerializerError::UnserializableType { name: "Option" }) + } + + fn serialize_some(self, _value: &T) -> Result + where T: Serialize, + { + Err(SerializerError::UnserializableType { name: "Option" }) + } + + fn serialize_unit(self) -> Result { + Err(SerializerError::UnserializableType { name: "()" }) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(SerializerError::UnserializableType { name: "unit struct" }) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str + ) -> Result + { + Err(SerializerError::UnserializableType { name: "unit variant" }) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T + ) -> Result + where T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T + ) -> Result + where T: Serialize, + { + Err(SerializerError::UnserializableType { name: "newtype variant" }) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(SerializerError::UnserializableType { name: "sequence" }) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(SerializerError::UnserializableType { name: "tuple" }) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "tuple struct" }) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "tuple variant" }) + } + + fn serialize_map(self, _len: Option) -> Result { + // Ok(MapSerializer { + // schema: self.schema, + // document_id: self.document_id, + // new_states: self.new_states, + // }) + Err(SerializerError::UnserializableType { name: "map" }) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize + ) -> Result + { + Ok(FindDocumentIdStructSerializer { + id_attribute_name: self.id_attribute_name, + document_id: None, + }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize + ) -> Result + { + Err(SerializerError::UnserializableType { name: "struct variant" }) + } +} + +struct FindDocumentIdStructSerializer<'a> { + id_attribute_name: &'a str, + document_id: Option, +} + +impl<'a> ser::SerializeStruct for FindDocumentIdStructSerializer<'a> { + type Ok = DocumentId; + type Error = SerializerError; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T + ) -> Result<(), Self::Error> + where T: Serialize, + { + if self.id_attribute_name == key { + // TODO can it be possible to have multiple ids? + let id = bincode::serialize(value).unwrap(); + let hash = calculate_hash(&id); + self.document_id = Some(DocumentId(hash)); + } + + Ok(()) + } + + fn end(self) -> Result { + match self.document_id { + Some(document_id) => Ok(document_id), + None => Err(SerializerError::DocumentIdNotFound) + } + } +} + #[cfg(test)] mod tests { use super::*; #[test] fn serialize_deserialize() -> bincode::Result<()> { - let mut builder = SchemaBuilder::new(); + let mut builder = SchemaBuilder::with_identifier("id"); builder.new_attribute("alphabet", STORED); builder.new_attribute("beta", STORED | INDEXED); builder.new_attribute("gamma", INDEXED); diff --git a/src/database/update/mod.rs b/src/database/update/mod.rs index d298a656f..433624022 100644 --- a/src/database/update/mod.rs +++ b/src/database/update/mod.rs @@ -4,7 +4,7 @@ use std::error::Error; mod negative; mod positive; -pub use self::positive::{PositiveUpdateBuilder, NewState}; +pub use self::positive::{PositiveUpdateBuilder, NewState, SerializerError}; pub use self::negative::NegativeUpdateBuilder; pub struct Update { diff --git a/src/database/update/positive/mod.rs b/src/database/update/positive/mod.rs index e05bd9dff..414f88722 100644 --- a/src/database/update/positive/mod.rs +++ b/src/database/update/positive/mod.rs @@ -1,4 +1,4 @@ mod update; mod unordered_builder; -pub use self::update::{PositiveUpdateBuilder, NewState}; +pub use self::update::{PositiveUpdateBuilder, NewState, SerializerError}; diff --git a/src/database/update/positive/update.rs b/src/database/update/positive/update.rs index 595307cd2..de064e5a1 100644 --- a/src/database/update/positive/update.rs +++ b/src/database/update/positive/update.rs @@ -40,18 +40,21 @@ impl PositiveUpdateBuilder { } } - pub fn update(&mut self, id: DocumentId, document: &T) -> Result<(), Box> + pub fn update(&mut self, document: &T) -> Result where B: TokenizerBuilder { + let document_id = self.schema.document_id(document)?; + let serializer = Serializer { schema: &self.schema, - document_id: id, tokenizer_builder: &self.tokenizer_builder, + document_id: document_id, builder: &mut self.builder, new_states: &mut self.new_states }; + document.serialize(serializer)?; - Ok(ser::Serialize::serialize(document, serializer)?) + Ok(document_id) } // TODO value must be a field that can be indexed @@ -67,7 +70,7 @@ impl PositiveUpdateBuilder { #[derive(Debug)] pub enum SerializerError { - SchemaDontMatch { attribute: String }, + DocumentIdNotFound, UnserializableType { name: &'static str }, Custom(String), } @@ -81,10 +84,9 @@ impl ser::Error for SerializerError { impl fmt::Display for SerializerError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - SerializerError::SchemaDontMatch { attribute } => { - write!(f, "serialized document try to specify the \ - {:?} attribute that is not known by the schema", attribute) - }, + SerializerError::DocumentIdNotFound => { + write!(f, "serialized document does not have an id according to the schema") + } SerializerError::UnserializableType { name } => { write!(f, "Only struct and map types are considered valid documents and can be serialized, not {} types directly.", name) @@ -104,16 +106,6 @@ struct Serializer<'a, B> { new_states: &'a mut BTreeMap, } -macro_rules! forward_to_unserializable_type { - ($($ty:ident => $se_method:ident,)*) => { - $( - fn $se_method(self, _v: $ty) -> Result { - Err(SerializerError::UnserializableType { name: "$ty" }) - } - )* - } -} - impl<'a, B> ser::Serializer for Serializer<'a, B> where B: TokenizerBuilder { @@ -288,27 +280,25 @@ where B: TokenizerBuilder ) -> Result<(), Self::Error> where T: Serialize, { - match self.schema.attribute(key) { - Some(attr) => { - let props = self.schema.props(attr); - if props.is_stored() { - let value = bincode::serialize(value).unwrap(); - let key = DocumentKeyAttr::new(self.document_id, attr); - self.new_states.insert(key, NewState::Updated { value }); - } - if props.is_indexed() { - let serializer = IndexerSerializer { - builder: self.builder, - tokenizer_builder: self.tokenizer_builder, - document_id: self.document_id, - attribute: attr, - }; - value.serialize(serializer)?; - } - Ok(()) - }, - None => Err(SerializerError::SchemaDontMatch { attribute: key.to_owned() }), + if let Some(attr) = self.schema.attribute(key) { + let props = self.schema.props(attr); + if props.is_stored() { + let value = bincode::serialize(value).unwrap(); + let key = DocumentKeyAttr::new(self.document_id, attr); + self.new_states.insert(key, NewState::Updated { value }); + } + if props.is_indexed() { + let serializer = IndexerSerializer { + builder: self.builder, + tokenizer_builder: self.tokenizer_builder, + document_id: self.document_id, + attribute: attr, + }; + value.serialize(serializer)?; + } } + + Ok(()) } fn end(self) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 10daf8d4f..2bb82a4b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ pub use self::common_words::CommonWords; /// It is used to inform the database the document you want to deserialize. /// Helpful for custom ranking. #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] -pub struct DocumentId(pub u64); +pub struct DocumentId(u64); /// Represent an attribute number along with the word index /// according to the tokenizer used. diff --git a/src/rank/criterion/sort_by.rs b/src/rank/criterion/sort_by.rs index 7f60962aa..bce8d0d90 100644 --- a/src/rank/criterion/sort_by.rs +++ b/src/rank/criterion/sort_by.rs @@ -62,12 +62,12 @@ where D: Deref, T: DeserializeOwned + Ord, { fn evaluate(&self, lhs: &Document, rhs: &Document, view: &DatabaseView) -> Ordering { - let lhs = match view.retrieve_document::(lhs.id) { + let lhs = match view.document_by_id::(lhs.id) { Ok(doc) => Some(doc), Err(e) => { eprintln!("{}", e); None }, }; - let rhs = match view.retrieve_document::(rhs.id) { + let rhs = match view.document_by_id::(rhs.id) { Ok(doc) => Some(doc), Err(e) => { eprintln!("{}", e); None }, };