From b32c96cdc90db62d5d9e13dc01598ff464da698e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Sun, 23 Dec 2018 16:46:49 +0100 Subject: [PATCH] feat: Introduce a WordArea struct Useful to highlight matching areas in the original text. --- Cargo.toml | 2 + examples/query-database.rs | 61 +++++- src/data/doc_indexes.rs | 13 +- src/database/blob/positive/blob.rs | 13 +- src/database/document_key.rs | 6 +- src/database/schema.rs | 29 +-- src/database/update/positive/update.rs | 10 +- src/lib.rs | 200 ++++++++++++++++--- src/rank/criterion/sum_of_typos.rs | 22 +- src/rank/criterion/sum_of_words_attribute.rs | 4 +- src/rank/criterion/sum_of_words_position.rs | 2 +- src/rank/criterion/words_proximity.rs | 28 +-- src/rank/query_builder.rs | 2 +- src/tokenizer/mod.rs | 117 ++++++++--- 14 files changed, 373 insertions(+), 136 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 970365b18..e04a2f521 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,5 +35,7 @@ nightly = [] [dev-dependencies] csv = "1.0" elapsed = "0.1" +quickcheck = "0.7" structopt = "0.2" tempfile = "3.0" +termcolor = "1.0" diff --git a/examples/query-database.rs b/examples/query-database.rs index 97bf49bf7..9b6a067cd 100644 --- a/examples/query-database.rs +++ b/examples/query-database.rs @@ -2,10 +2,12 @@ use std::io::{self, Write}; use std::path::PathBuf; use std::error::Error; +use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor}; use serde_derive::{Serialize, Deserialize}; use structopt::StructOpt; use meilidb::database::Database; +use meilidb::Match; #[derive(Debug, StructOpt)] pub struct Opt { @@ -26,6 +28,40 @@ struct Document { image: String, } +fn display_highlights(text: &str, ranges: &[usize]) -> io::Result<()> { + let mut stdout = StandardStream::stdout(ColorChoice::Always); + let mut highlighted = false; + + for range in ranges.windows(2) { + let [start, end] = match range { [start, end] => [*start, *end], _ => unreachable!() }; + if highlighted { + stdout.set_color(ColorSpec::new().set_fg(Some(Color::Yellow)))?; + } + write!(&mut stdout, "{}", &text[start..end])?; + stdout.reset()?; + highlighted = !highlighted; + } + + Ok(()) +} + +fn create_highlight_areas(text: &str, matches: &[Match], attribute: u16) -> Vec { + let mut title_areas = Vec::new(); + + title_areas.push(0); + for match_ in matches { + if match_.attribute.attribute() == attribute { + let word_area = match_.word_area; + let byte_index = word_area.byte_index() as usize; + let length = word_area.length() as usize; + title_areas.push(byte_index); + title_areas.push(byte_index + length); + } + } + title_areas.push(text.len()); + title_areas +} + fn main() -> Result<(), Box> { let opt = Opt::from_args(); @@ -41,26 +77,35 @@ fn main() -> Result<(), Box> { io::stdout().flush()?; if input.read_line(&mut buffer)? == 0 { break } + let query = buffer.trim_end_matches('\n'); let view = database.view(); let (elapsed, documents) = elapsed::measure_time(|| { let builder = view.query_builder().unwrap(); - builder.query(&buffer, 0..opt.number_results) + builder.query(query, 0..opt.number_results) }); - let mut full_documents = Vec::with_capacity(documents.len()); + let number_of_documents = documents.len(); + for doc in documents { + match view.retrieve_document::(doc.id) { + Ok(document) => { - for document in documents { - match view.retrieve_document::(document.id) { - Ok(document) => full_documents.push(document), + print!("title: "); + let title_areas = create_highlight_areas(&document.title, &doc.matches, 1); + display_highlights(&document.title, &title_areas)?; + println!(); + + print!("description: "); + let description_areas = create_highlight_areas(&document.description, &doc.matches, 2); + display_highlights(&document.description, &description_areas)?; + println!(); + }, Err(e) => eprintln!("{}", e), } } - println!("{:#?}", full_documents); - println!("Found {} results in {}", full_documents.len(), elapsed); - + println!("Found {} results in {}", number_of_documents, elapsed); buffer.clear(); } diff --git a/src/data/doc_indexes.rs b/src/data/doc_indexes.rs index 5d451f83c..ee4ec9d0a 100644 --- a/src/data/doc_indexes.rs +++ b/src/data/doc_indexes.rs @@ -158,14 +158,15 @@ mod tests { use super::*; use std::error::Error; + use crate::{Attribute, WordArea}; use crate::DocumentId; #[test] fn builder_serialize_deserialize() -> Result<(), Box> { - let a = DocIndex { document_id: DocumentId(0), attribute: 3, attribute_index: 11 }; - let b = DocIndex { document_id: DocumentId(1), attribute: 4, attribute_index: 21 }; - let c = DocIndex { document_id: DocumentId(2), attribute: 8, attribute_index: 2 }; + let a = DocIndex { document_id: DocumentId(0), attribute: Attribute::new(3, 11), word_area: WordArea::new(30, 4) }; + let b = DocIndex { document_id: DocumentId(1), attribute: Attribute::new(4, 21), word_area: WordArea::new(35, 6) }; + let c = DocIndex { document_id: DocumentId(2), attribute: Attribute::new(8, 2), word_area: WordArea::new(89, 6) }; let mut builder = DocIndexesBuilder::memory(); @@ -186,9 +187,9 @@ mod tests { #[test] fn serialize_deserialize() -> Result<(), Box> { - let a = DocIndex { document_id: DocumentId(0), attribute: 3, attribute_index: 11 }; - let b = DocIndex { document_id: DocumentId(1), attribute: 4, attribute_index: 21 }; - let c = DocIndex { document_id: DocumentId(2), attribute: 8, attribute_index: 2 }; + let a = DocIndex { document_id: DocumentId(0), attribute: Attribute::new(3, 11), word_area: WordArea::new(30, 4) }; + let b = DocIndex { document_id: DocumentId(1), attribute: Attribute::new(4, 21), word_area: WordArea::new(35, 6) }; + let c = DocIndex { document_id: DocumentId(2), attribute: Attribute::new(8, 2), word_area: WordArea::new(89, 6) }; let mut builder = DocIndexesBuilder::memory(); diff --git a/src/database/blob/positive/blob.rs b/src/database/blob/positive/blob.rs index 3687bc1bb..bd1f32d6f 100644 --- a/src/database/blob/positive/blob.rs +++ b/src/database/blob/positive/blob.rs @@ -203,14 +203,15 @@ mod tests { use super::*; use std::error::Error; + use crate::{Attribute, WordArea}; use crate::DocumentId; #[test] fn serialize_deserialize() -> Result<(), Box> { - let a = DocIndex { document_id: DocumentId(0), attribute: 3, attribute_index: 11 }; - let b = DocIndex { document_id: DocumentId(1), attribute: 4, attribute_index: 21 }; - let c = DocIndex { document_id: DocumentId(2), attribute: 8, attribute_index: 2 }; + let a = DocIndex { document_id: DocumentId(0), attribute: Attribute::new(3, 11), word_area: WordArea::new(30, 4) }; + let b = DocIndex { document_id: DocumentId(1), attribute: Attribute::new(4, 21), word_area: WordArea::new(35, 6) }; + let c = DocIndex { document_id: DocumentId(2), attribute: Attribute::new(8, 2), word_area: WordArea::new(89, 6) }; let mut builder = PositiveBlobBuilder::memory(); @@ -231,9 +232,9 @@ mod tests { #[test] fn serde_serialize_deserialize() -> Result<(), Box> { - let a = DocIndex { document_id: DocumentId(0), attribute: 3, attribute_index: 11 }; - let b = DocIndex { document_id: DocumentId(1), attribute: 4, attribute_index: 21 }; - let c = DocIndex { document_id: DocumentId(2), attribute: 8, attribute_index: 2 }; + let a = DocIndex { document_id: DocumentId(0), attribute: Attribute::new(3, 11), word_area: WordArea::new(30, 4) }; + let b = DocIndex { document_id: DocumentId(1), attribute: Attribute::new(4, 21), word_area: WordArea::new(35, 6) }; + let c = DocIndex { document_id: DocumentId(2), attribute: Attribute::new(8, 2), word_area: WordArea::new(89, 6) }; let mut builder = PositiveBlobBuilder::memory(); diff --git a/src/database/document_key.rs b/src/database/document_key.rs index 9104df5f6..b0a952a97 100644 --- a/src/database/document_key.rs +++ b/src/database/document_key.rs @@ -73,7 +73,7 @@ impl DocumentKeyAttr { let mut wtr = Cursor::new(&mut buffer[..]); wtr.write_all(&raw_key).unwrap(); wtr.write_all(b"-").unwrap(); - wtr.write_u32::(attr.as_u32()).unwrap(); + wtr.write_u16::(attr.0).unwrap(); DocumentKeyAttr(buffer) } @@ -95,7 +95,7 @@ impl DocumentKeyAttr { pub fn attribute(&self) -> SchemaAttr { let offset = 4 + size_of::() + 1; - let value = (&self.0[offset..]).read_u32::().unwrap(); + let value = (&self.0[offset..]).read_u16::().unwrap(); SchemaAttr::new(value) } @@ -114,7 +114,7 @@ impl fmt::Debug for DocumentKeyAttr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("DocumentKeyAttr") .field("document_id", &self.document_id()) - .field("attribute", &self.attribute().as_u32()) + .field("attribute", &self.attribute().0) .finish() } } diff --git a/src/database/schema.rs b/src/database/schema.rs index 255be9fa5..edb19ad79 100644 --- a/src/database/schema.rs +++ b/src/database/schema.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, BTreeMap}; use std::io::{Read, Write}; -use std::{fmt, u32}; +use std::{fmt, u16}; use std::path::Path; use std::ops::BitOr; use std::sync::Arc; @@ -53,7 +53,7 @@ impl SchemaBuilder { if self.attrs.insert(name.into(), props).is_some() { panic!("Field already inserted.") } - SchemaAttr(len as u32) + SchemaAttr(len as u16) } pub fn build(self) -> Schema { @@ -61,7 +61,7 @@ impl SchemaBuilder { let mut props = Vec::new(); for (i, (name, prop)) in self.attrs.into_iter().enumerate() { - attrs.insert(name.clone(), SchemaAttr(i as u32)); + attrs.insert(name.clone(), SchemaAttr(i as u16)); props.push((name, prop)); } @@ -94,10 +94,9 @@ impl Schema { pub fn write_to(&self, writer: W) -> bincode::Result<()> { let mut ordered = BTreeMap::new(); - for (name, field) in &self.inner.attrs { - let index = field.as_u32(); - let (_, props) = self.inner.props[index as usize]; - ordered.insert(index, (name, props)); + for (name, attr) in &self.inner.attrs { + let (_, props) = self.inner.props[attr.0 as usize]; + ordered.insert(attr.0, (name, props)); } let mut attrs = LinkedHashMap::with_capacity(ordered.len()); @@ -109,8 +108,7 @@ impl Schema { } pub fn props(&self, attr: SchemaAttr) -> SchemaProps { - let index = attr.as_u32(); - let (_, props) = self.inner.props[index as usize]; + let (_, props) = self.inner.props[attr.0 as usize]; props } @@ -119,26 +117,21 @@ impl Schema { } pub fn attribute_name(&self, attr: SchemaAttr) -> &str { - let index = attr.as_u32(); - let (name, _) = &self.inner.props[index as usize]; + let (name, _) = &self.inner.props[attr.0 as usize]; name } } #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct SchemaAttr(u32); +pub struct SchemaAttr(pub(crate) u16); impl SchemaAttr { - pub fn new(value: u32) -> SchemaAttr { + pub fn new(value: u16) -> SchemaAttr { SchemaAttr(value) } pub fn max() -> SchemaAttr { - SchemaAttr(u32::MAX) - } - - pub fn as_u32(&self) -> u32 { - self.0 + SchemaAttr(u16::MAX) } } diff --git a/src/database/update/positive/update.rs b/src/database/update/positive/update.rs index 13daebe6c..595307cd2 100644 --- a/src/database/update/positive/update.rs +++ b/src/database/update/positive/update.rs @@ -9,12 +9,12 @@ use serde::ser::{self, Serialize}; use crate::database::update::positive::unordered_builder::UnorderedPositiveBlobBuilder; use crate::database::blob::positive::PositiveBlob; use crate::database::schema::{Schema, SchemaAttr}; -use crate::tokenizer::TokenizerBuilder; +use crate::tokenizer::{TokenizerBuilder, Token}; use crate::database::DocumentKeyAttr; use crate::database::update::Update; -use crate::{DocumentId, DocIndex}; use crate::database::DATA_INDEX; use crate::database::blob::Blob; +use crate::{DocumentId, DocIndex, Attribute, WordArea}; pub enum NewState { Updated { value: Vec }, @@ -355,11 +355,11 @@ where B: TokenizerBuilder } fn serialize_str(self, v: &str) -> Result { - for (index, word) in self.tokenizer_builder.build(v) { + for Token { word, word_index, char_index } in self.tokenizer_builder.build(v) { let doc_index = DocIndex { document_id: self.document_id, - attribute: self.attribute.as_u32() as u8, - attribute_index: index as u32, + attribute: Attribute::new(self.attribute.0, word_index as u32), + word_area: WordArea::new(char_index as u32, word.len() as u16), }; // insert the exact representation diff --git a/src/lib.rs b/src/lib.rs index d95dcc2ae..10daf8d4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,8 @@ pub mod tokenizer; pub mod vec_read_only; mod common_words; +use std::fmt; + pub use rocksdb; pub use self::tokenizer::Tokenizer; @@ -18,28 +20,110 @@ pub use self::common_words::CommonWords; #[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct DocumentId(pub u64); +/// Represent an attribute number along with the word index +/// according to the tokenizer used. +/// +/// It can accept up to 1024 attributes and word positions +/// can be maximum 2^22. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Attribute(u32); + +impl Attribute { + /// Construct an `Attribute` from an attribute number and + /// the word position of a match according to the tokenizer used. + /// + /// # Panics + /// + /// The attribute must not be greater than 1024 + /// and the word index not greater than 2^22. + fn new(attribute: u16, index: u32) -> Attribute { + assert!(attribute & 0b1111_1100_0000_0000 == 0); + assert!(index & 0b1111_1111_1100_0000_0000_0000_0000 == 0); + + let attribute = (attribute as u32) << 22; + Attribute(attribute | index) + } + + pub fn attribute(&self) -> u16 { + (self.0 >> 22) as u16 + } + + pub fn word_index(&self) -> u32 { + self.0 & 0b0000_0000_0011_1111_1111_1111_1111 + } +} + +impl fmt::Debug for Attribute { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Attribute") + .field("attribute", &self.attribute()) + .field("word_index", &self.word_index()) + .finish() + } +} + +/// Represent a word position in bytes along with the length of it. +/// +/// It can represent words byte index to maximum 2^22 and +/// up to words of length 1024. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct WordArea(u32); + +impl WordArea { + /// Construct a `WordArea` from a word position in bytes + /// and the length of it. + /// + /// # Panics + /// + /// The byte index must not be greater than 2^22 + /// and the length not greater than 1024. + fn new(byte_index: u32, length: u16) -> WordArea { + assert!(byte_index & 0b1111_1111_1100_0000_0000_0000_0000 == 0); + assert!(length & 0b1111_1100_0000_0000 == 0); + + let byte_index = byte_index << 10; + WordArea(byte_index | (length as u32)) + } + + pub fn byte_index(&self) -> u32 { + self.0 >> 10 + } + + pub fn length(&self) -> u16 { + (self.0 & 0b0000_0000_0000_0000_0011_1111_1111) as u16 + } +} + +impl fmt::Debug for WordArea { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("WordArea") + .field("byte_index", &self.byte_index()) + .field("length", &self.length()) + .finish() + } +} + /// This structure represent the position of a word /// in a document and its attributes. /// /// This is stored in the map, generated at index time, /// extracted and interpreted at search time. -#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(C)] pub struct DocIndex { /// The document identifier where the word was found. pub document_id: DocumentId, - /// The attribute identifier in the document - /// where the word was found. - /// - /// This is an `u8` therefore a document - /// can not have more than `2^8` attributes. - pub attribute: u8, + /// The attribute in the document where the word was found + /// along with the index in it. + pub attribute: Attribute, - /// The index where the word was found in the attribute. + /// The position in bytes where the word was found + /// along with the length of it. /// - /// Only the first 1000 words are indexed. - pub attribute_index: u32, + /// It informs on the original word area in the text indexed + /// without needing to run the tokenizer again. + pub word_area: WordArea, } /// This structure represent a matching word with informations @@ -50,7 +134,7 @@ pub struct DocIndex { /// /// The word in itself is not important. // TODO do data oriented programming ? very arrays ? -#[derive(Debug, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Match { /// The word index in the query sentence. /// Same as the `attribute_index` but for the query words. @@ -62,23 +146,19 @@ pub struct Match { /// (i.e. the Levenshtein distance). pub distance: u8, - /// The attribute in which the word is located - /// (i.e. Title is 0, Description is 1). - /// - /// This is an `u8` therefore a document - /// can not have more than `2^8` attributes. - pub attribute: u8, - - /// Where does this word is located in the attribute string - /// (i.e. at the start or the end of the attribute). - /// - /// The index in the attribute is limited to a maximum of `2^32` - /// this is because we index only the first 1000 words - /// in an attribute. - pub attribute_index: u32, + /// The attribute in the document where the word was found + /// along with the index in it. + pub attribute: Attribute, /// Whether the word that match is an exact match or a prefix. pub is_exact: bool, + + /// The position in bytes where the word was found + /// along with the length of it. + /// + /// It informs on the original word area in the text indexed + /// without needing to run the tokenizer again. + pub word_area: WordArea, } impl Match { @@ -86,9 +166,9 @@ impl Match { Match { query_index: 0, distance: 0, - attribute: 0, - attribute_index: 0, + attribute: Attribute::new(0, 0), is_exact: false, + word_area: WordArea::new(0, 0), } } @@ -96,9 +176,71 @@ impl Match { Match { query_index: u32::max_value(), distance: u8::max_value(), - attribute: u8::max_value(), - attribute_index: u32::max_value(), + attribute: Attribute(u32::max_value()), is_exact: true, + word_area: WordArea(u32::max_value()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{quickcheck, TestResult}; + use std::mem; + + #[test] + fn docindex_mem_size() { + assert_eq!(mem::size_of::(), 16); + } + + quickcheck! { + fn qc_attribute(gen_attr: u16, gen_index: u32) -> TestResult { + if gen_attr > 2_u16.pow(10) || gen_index > 2_u32.pow(22) { + return TestResult::discard() + } + + let attribute = Attribute::new(gen_attr, gen_index); + + let valid_attribute = attribute.attribute() == gen_attr; + let valid_index = attribute.word_index() == gen_index; + + TestResult::from_bool(valid_attribute && valid_index) + } + + fn qc_attribute_ord(gen_attr: u16, gen_index: u32) -> TestResult { + if gen_attr >= 2_u16.pow(10) || gen_index >= 2_u32.pow(22) { + return TestResult::discard() + } + + let a = Attribute::new(gen_attr, gen_index); + let b = Attribute::new(gen_attr + 1, gen_index + 1); + + TestResult::from_bool(a < b) + } + + fn qc_word_area(gen_byte_index: u32, gen_length: u16) -> TestResult { + if gen_byte_index > 2_u32.pow(22) || gen_length > 2_u16.pow(10) { + return TestResult::discard() + } + + let word_area = WordArea::new(gen_byte_index, gen_length); + + let valid_char_index = word_area.byte_index() == gen_byte_index; + let valid_length = word_area.length() == gen_length; + + TestResult::from_bool(valid_char_index && valid_length) + } + + fn qc_word_area_ord(gen_byte_index: u32, gen_length: u16) -> TestResult { + if gen_byte_index >= 2_u32.pow(22) || gen_length >= 2_u16.pow(10) { + return TestResult::discard() + } + + let a = WordArea::new(gen_byte_index, gen_length); + let b = WordArea::new(gen_byte_index + 1, gen_length + 1); + + TestResult::from_bool(a < b) } } } diff --git a/src/rank/criterion/sum_of_typos.rs b/src/rank/criterion/sum_of_typos.rs index 3af339233..3015a6b4b 100644 --- a/src/rank/criterion/sum_of_typos.rs +++ b/src/rank/criterion/sum_of_typos.rs @@ -44,7 +44,7 @@ where D: Deref mod tests { use super::*; - use crate::DocumentId; + use crate::{DocumentId, Attribute, WordArea}; // typing: "Geox CEO" // @@ -54,8 +54,8 @@ mod tests { fn one_typo_reference() { let doc0 = { let matches = vec![ - Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, - Match { query_index: 1, distance: 0, attribute: 0, attribute_index: 2, is_exact: false }, + Match { query_index: 0, distance: 0, attribute: Attribute::new(0, 0), is_exact: false, word_area: WordArea::new(0, 6) }, + Match { query_index: 1, distance: 0, attribute: Attribute::new(0, 2), is_exact: false, word_area: WordArea::new(0, 6) }, ]; Document { id: DocumentId(0), @@ -65,8 +65,8 @@ mod tests { let doc1 = { let matches = vec![ - Match { query_index: 0, distance: 1, attribute: 0, attribute_index: 0, is_exact: false }, - Match { query_index: 1, distance: 0, attribute: 0, attribute_index: 2, is_exact: false }, + Match { query_index: 0, distance: 1, attribute: Attribute::new(0, 0), is_exact: false, word_area: WordArea::new(0, 6) }, + Match { query_index: 1, distance: 0, attribute: Attribute::new(0, 2), is_exact: false, word_area: WordArea::new(0, 6) }, ]; Document { id: DocumentId(1), @@ -87,8 +87,8 @@ mod tests { fn no_typo() { let doc0 = { let matches = vec![ - Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, - Match { query_index: 1, distance: 0, attribute: 0, attribute_index: 1, is_exact: false }, + Match { query_index: 0, distance: 0, attribute: Attribute::new(0, 0), is_exact: false, word_area: WordArea::new(0, 6) }, + Match { query_index: 1, distance: 0, attribute: Attribute::new(0, 1), is_exact: false, word_area: WordArea::new(0, 6) }, ]; Document { id: DocumentId(0), @@ -98,7 +98,7 @@ mod tests { let doc1 = { let matches = vec![ - Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, + Match { query_index: 0, distance: 0, attribute: Attribute::new(0, 0), is_exact: false, word_area: WordArea::new(0, 6) }, ]; Document { id: DocumentId(1), @@ -119,8 +119,8 @@ mod tests { fn one_typo() { let doc0 = { let matches = vec![ - Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, - Match { query_index: 1, distance: 1, attribute: 0, attribute_index: 1, is_exact: false }, + Match { query_index: 0, distance: 0, attribute: Attribute::new(0, 0), is_exact: false, word_area: WordArea::new(0, 6) }, + Match { query_index: 1, distance: 1, attribute: Attribute::new(0, 1), is_exact: false, word_area: WordArea::new(0, 6) }, ]; Document { id: DocumentId(0), @@ -130,7 +130,7 @@ mod tests { let doc1 = { let matches = vec![ - Match { query_index: 0, distance: 0, attribute: 0, attribute_index: 0, is_exact: false }, + Match { query_index: 0, distance: 0, attribute: Attribute::new(0, 0), is_exact: false, word_area: WordArea::new(0, 6) }, ]; Document { id: DocumentId(1), diff --git a/src/rank/criterion/sum_of_words_attribute.rs b/src/rank/criterion/sum_of_words_attribute.rs index 800fe7c7f..718ae7447 100644 --- a/src/rank/criterion/sum_of_words_attribute.rs +++ b/src/rank/criterion/sum_of_words_attribute.rs @@ -10,11 +10,11 @@ use crate::rank::criterion::Criterion; use crate::Match; #[inline] -fn sum_matches_attributes(matches: &[Match]) -> u8 { +fn sum_matches_attributes(matches: &[Match]) -> u16 { // note that GroupBy will never return an empty group // so we can do this assumption safely GroupBy::new(matches, match_query_index).map(|group| unsafe { - group.get_unchecked(0).attribute + group.get_unchecked(0).attribute.attribute() }).sum() } diff --git a/src/rank/criterion/sum_of_words_position.rs b/src/rank/criterion/sum_of_words_position.rs index 2a54b1098..d0ebaa74f 100644 --- a/src/rank/criterion/sum_of_words_position.rs +++ b/src/rank/criterion/sum_of_words_position.rs @@ -14,7 +14,7 @@ fn sum_matches_attribute_index(matches: &[Match]) -> u32 { // note that GroupBy will never return an empty group // so we can do this assumption safely GroupBy::new(matches, match_query_index).map(|group| unsafe { - group.get_unchecked(0).attribute_index + group.get_unchecked(0).attribute.word_index() }).sum() } diff --git a/src/rank/criterion/words_proximity.rs b/src/rank/criterion/words_proximity.rs index 14eb1ad0e..5d7e96122 100644 --- a/src/rank/criterion/words_proximity.rs +++ b/src/rank/criterion/words_proximity.rs @@ -20,8 +20,8 @@ fn index_proximity(lhs: u32, rhs: u32) -> u32 { } fn attribute_proximity(lhs: &Match, rhs: &Match) -> u32 { - if lhs.attribute != rhs.attribute { return MAX_DISTANCE } - index_proximity(lhs.attribute_index, rhs.attribute_index) + if lhs.attribute.attribute() != rhs.attribute.attribute() { return MAX_DISTANCE } + index_proximity(lhs.attribute.word_index(), rhs.attribute.word_index()) } fn min_proximity(lhs: &[Match], rhs: &[Match]) -> u32 { @@ -67,6 +67,8 @@ where D: Deref mod tests { use super::*; + use crate::Attribute; + #[test] fn three_different_attributes() { @@ -79,11 +81,11 @@ mod tests { // { id: 3, attr: 3, attr_index: 1 } let matches = &[ - Match { query_index: 0, attribute: 0, attribute_index: 0, ..Match::zero() }, - Match { query_index: 1, attribute: 1, attribute_index: 0, ..Match::zero() }, - Match { query_index: 2, attribute: 1, attribute_index: 1, ..Match::zero() }, - Match { query_index: 2, attribute: 2, attribute_index: 0, ..Match::zero() }, - Match { query_index: 3, attribute: 3, attribute_index: 1, ..Match::zero() }, + Match { query_index: 0, attribute: Attribute::new(0, 0), ..Match::zero() }, + Match { query_index: 1, attribute: Attribute::new(1, 0), ..Match::zero() }, + Match { query_index: 2, attribute: Attribute::new(1, 1), ..Match::zero() }, + Match { query_index: 2, attribute: Attribute::new(2, 0), ..Match::zero() }, + Match { query_index: 3, attribute: Attribute::new(3, 1), ..Match::zero() }, ]; // soup -> of = 8 @@ -105,12 +107,12 @@ mod tests { // { id: 3, attr: 1, attr_index: 3 } let matches = &[ - Match { query_index: 0, attribute: 0, attribute_index: 0, ..Match::zero() }, - Match { query_index: 0, attribute: 1, attribute_index: 0, ..Match::zero() }, - Match { query_index: 1, attribute: 1, attribute_index: 1, ..Match::zero() }, - Match { query_index: 2, attribute: 1, attribute_index: 2, ..Match::zero() }, - Match { query_index: 3, attribute: 0, attribute_index: 1, ..Match::zero() }, - Match { query_index: 3, attribute: 1, attribute_index: 3, ..Match::zero() }, + Match { query_index: 0, attribute: Attribute::new(0, 0), ..Match::zero() }, + Match { query_index: 0, attribute: Attribute::new(1, 0), ..Match::zero() }, + Match { query_index: 1, attribute: Attribute::new(1, 1), ..Match::zero() }, + Match { query_index: 2, attribute: Attribute::new(1, 2), ..Match::zero() }, + Match { query_index: 3, attribute: Attribute::new(0, 1), ..Match::zero() }, + Match { query_index: 3, attribute: Attribute::new(1, 3), ..Match::zero() }, ]; // soup -> of = 1 diff --git a/src/rank/query_builder.rs b/src/rank/query_builder.rs index 1dff4f2f0..fe0904160 100644 --- a/src/rank/query_builder.rs +++ b/src/rank/query_builder.rs @@ -97,8 +97,8 @@ where D: Deref query_index: iv.index as u32, distance: distance, attribute: doc_index.attribute, - attribute_index: doc_index.attribute_index, is_exact: is_exact, + word_area: doc_index.word_area, }; matches.entry(doc_index.document_id).or_insert_with(Vec::new).push(match_); } diff --git a/src/tokenizer/mod.rs b/src/tokenizer/mod.rs index 9b075786b..79794f6d8 100644 --- a/src/tokenizer/mod.rs +++ b/src/tokenizer/mod.rs @@ -2,7 +2,7 @@ use std::mem; use self::Separator::*; pub trait TokenizerBuilder { - fn build<'a>(&self, text: &'a str) -> Box + 'a>; + fn build<'a>(&self, text: &'a str) -> Box> + 'a>; } pub struct DefaultBuilder; @@ -13,22 +13,39 @@ impl DefaultBuilder { } } +#[derive(Debug, PartialEq, Eq)] +pub struct Token<'a> { + pub word: &'a str, + pub word_index: usize, + pub char_index: usize, +} + impl TokenizerBuilder for DefaultBuilder { - fn build<'a>(&self, text: &'a str) -> Box + 'a> { + fn build<'a>(&self, text: &'a str) -> Box> + 'a> { Box::new(Tokenizer::new(text)) } } pub struct Tokenizer<'a> { - index: usize, + word_index: usize, + char_index: usize, inner: &'a str, } impl<'a> Tokenizer<'a> { pub fn new(string: &str) -> Tokenizer { + let mut char_advance = 0; + let mut index_advance = 0; + for (n, (i, c)) in string.char_indices().enumerate() { + char_advance = n; + index_advance = i; + if detect_separator(c).is_none() { break } + } + Tokenizer { - index: 0, - inner: string.trim_matches(&[' ', '.', ';', ',', '!', '?', '-', '\'', '"'][..]), + word_index: 0, + char_index: char_advance, + inner: &string[index_advance..], } } } @@ -56,43 +73,58 @@ impl Separator { } } +fn detect_separator(c: char) -> Option { + match c { + '.' | ';' | ',' | '!' | '?' | '-' => Some(Long), + ' ' | '\'' | '"' => Some(Short), + _ => None, + } +} + impl<'a> Iterator for Tokenizer<'a> { - type Item = (usize, &'a str); + type Item = Token<'a>; fn next(&mut self) -> Option { let mut start_word = None; let mut distance = None; for (i, c) in self.inner.char_indices() { - let separator = match c { - '.' | ';' | ',' | '!' | '?' | '-' => Some(Long), - ' ' | '\'' | '"' => Some(Short), - _ => None, - }; - - match separator { - Some(dist) => { + match detect_separator(c) { + Some(sep) => { if let Some(start_word) = start_word { - let (word, tail) = self.inner.split_at(i); + let (prefix, tail) = self.inner.split_at(i); + let (spaces, word) = prefix.split_at(start_word); self.inner = tail; - self.index += distance.map(Separator::to_usize).unwrap_or(0); + self.char_index += spaces.len(); + self.word_index += distance.map(Separator::to_usize).unwrap_or(0); - let word = &word[start_word..]; - return Some((self.index, word)) + let token = Token { + word: word, + word_index: self.word_index, + char_index: self.char_index, + }; + + self.char_index += word.len(); + return Some(token) } - distance = Some(distance.map(|s| s.add(dist)).unwrap_or(dist)); + + distance.replace(distance.map_or(sep, |s| s.add(sep))); }, None => { start_word.get_or_insert(i); }, } } if let Some(start_word) = start_word { - let word = mem::replace(&mut self.inner, ""); - self.index += distance.map(Separator::to_usize).unwrap_or(0); + let prefix = mem::replace(&mut self.inner, ""); + let (spaces, word) = prefix.split_at(start_word); - let word = &word[start_word..]; - return Some((self.index, word)) + let token = Token { + word: word, + word_index: self.word_index + distance.map(Separator::to_usize).unwrap_or(0), + char_index: self.char_index + spaces.len(), + }; + return Some(token) } None @@ -107,12 +139,12 @@ mod tests { fn easy() { let mut tokenizer = Tokenizer::new("salut"); - assert_eq!(tokenizer.next(), Some((0, "salut"))); + assert_eq!(tokenizer.next(), Some(Token { word: "salut", word_index: 0, char_index: 0 })); assert_eq!(tokenizer.next(), None); let mut tokenizer = Tokenizer::new("yo "); - assert_eq!(tokenizer.next(), Some((0, "yo"))); + assert_eq!(tokenizer.next(), Some(Token { word: "yo", word_index: 0, char_index: 0 })); assert_eq!(tokenizer.next(), None); } @@ -120,18 +152,37 @@ mod tests { fn hard() { let mut tokenizer = Tokenizer::new(" .? yo lolo. aïe"); - assert_eq!(tokenizer.next(), Some((0, "yo"))); - assert_eq!(tokenizer.next(), Some((1, "lolo"))); - assert_eq!(tokenizer.next(), Some((9, "aïe"))); + assert_eq!(tokenizer.next(), Some(Token { word: "yo", word_index: 0, char_index: 4 })); + assert_eq!(tokenizer.next(), Some(Token { word: "lolo", word_index: 1, char_index: 7 })); + assert_eq!(tokenizer.next(), Some(Token { word: "aïe", word_index: 9, char_index: 13 })); assert_eq!(tokenizer.next(), None); let mut tokenizer = Tokenizer::new("yo ! lolo ? wtf - lol . aïe ,"); - assert_eq!(tokenizer.next(), Some((0, "yo"))); - assert_eq!(tokenizer.next(), Some((8, "lolo"))); - assert_eq!(tokenizer.next(), Some((16, "wtf"))); - assert_eq!(tokenizer.next(), Some((24, "lol"))); - assert_eq!(tokenizer.next(), Some((32, "aïe"))); + assert_eq!(tokenizer.next(), Some(Token { word: "yo", word_index: 0, char_index: 0 })); + assert_eq!(tokenizer.next(), Some(Token { word: "lolo", word_index: 8, char_index: 5 })); + assert_eq!(tokenizer.next(), Some(Token { word: "wtf", word_index: 16, char_index: 12 })); + assert_eq!(tokenizer.next(), Some(Token { word: "lol", word_index: 24, char_index: 18 })); + assert_eq!(tokenizer.next(), Some(Token { word: "aïe", word_index: 32, char_index: 24 })); + assert_eq!(tokenizer.next(), None); + } + + #[test] + fn hard_long_chars() { + let mut tokenizer = Tokenizer::new(" .? yo 😂. aïe"); + + assert_eq!(tokenizer.next(), Some(Token { word: "yo", word_index: 0, char_index: 4 })); + assert_eq!(tokenizer.next(), Some(Token { word: "😂", word_index: 1, char_index: 7 })); + assert_eq!(tokenizer.next(), Some(Token { word: "aïe", word_index: 9, char_index: 13 })); + assert_eq!(tokenizer.next(), None); + + let mut tokenizer = Tokenizer::new("yo ! lolo ? 😱 - lol . 😣 ,"); + + assert_eq!(tokenizer.next(), Some(Token { word: "yo", word_index: 0, char_index: 0 })); + assert_eq!(tokenizer.next(), Some(Token { word: "lolo", word_index: 8, char_index: 5 })); + assert_eq!(tokenizer.next(), Some(Token { word: "😱", word_index: 16, char_index: 12 })); + assert_eq!(tokenizer.next(), Some(Token { word: "lol", word_index: 24, char_index: 19 })); + assert_eq!(tokenizer.next(), Some(Token { word: "😣", word_index: 32, char_index: 25 })); assert_eq!(tokenizer.next(), None); } }