From a745819ddf0a9dda0da557ca980f439cdab1dcb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 26 Feb 2019 12:16:10 +0100 Subject: [PATCH] feat: Simplify the Tokenizer to use the LinearStrGroupBy type --- meilidb-tokenizer/Cargo.toml | 2 +- meilidb-tokenizer/src/lib.rs | 259 +++++++++++++---------------------- 2 files changed, 96 insertions(+), 165 deletions(-) diff --git a/meilidb-tokenizer/Cargo.toml b/meilidb-tokenizer/Cargo.toml index c8b643d09..32c9429b7 100644 --- a/meilidb-tokenizer/Cargo.toml +++ b/meilidb-tokenizer/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "meilidb-tokenizer" version = "0.1.0" -authors = ["Clément Renault "] +authors = ["Kerollmops "] edition = "2018" [dependencies] diff --git a/meilidb-tokenizer/src/lib.rs b/meilidb-tokenizer/src/lib.rs index 8cdb32dc3..48bce151b 100644 --- a/meilidb-tokenizer/src/lib.rs +++ b/meilidb-tokenizer/src/lib.rs @@ -1,6 +1,5 @@ -use std::mem; -use slice_group_by::LinearStrGroupBy; -use self::Separator::*; +use slice_group_by::StrGroupBy; +use self::SeparatorCategory::*; pub fn is_cjk(c: char) -> bool { (c >= '\u{2e80}' && c <= '\u{2eff}') || @@ -14,208 +13,140 @@ pub fn is_cjk(c: char) -> bool { (c >= '\u{f900}' && c <= '\u{faff}') } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum SeparatorCategory { + Soft, + Hard, +} + +impl SeparatorCategory { + fn merge(self, other: SeparatorCategory) -> SeparatorCategory { + if let (Soft, Soft) = (self, other) { Soft } else { Hard } + } + + fn to_usize(self) -> usize { + match self { + Soft => 1, + Hard => 8, + } + } +} + +fn is_separator(c: char) -> bool { + classify_separator(c).is_some() +} + +fn classify_separator(c: char) -> Option { + match c { + ' ' | '\'' | '"' => Some(Soft), + '.' | ';' | ',' | '!' | '?' | '-' | '(' | ')' => Some(Hard), + _ => None, + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] enum CharCategory { - Space, + Separator(SeparatorCategory), Cjk, Other, } fn classify_char(c: char) -> CharCategory { - if c.is_whitespace() { CharCategory::Space } - else if is_cjk(c) { CharCategory::Cjk } - else { CharCategory::Other } -} - -fn is_word(s: &&str) -> bool { - !s.chars().any(char::is_whitespace) -} - -fn same_group_category(a: char, b: char) -> bool { - let ca = classify_char(a); - let cb = classify_char(b); - if ca == CharCategory::Cjk || cb == CharCategory::Cjk { false } else { ca == cb } -} - -pub fn split_query_string(query: &str) -> impl Iterator { - LinearStrGroupBy::new(query, same_group_category).filter(is_word) -} - -pub trait TokenizerBuilder { - fn build<'a>(&self, text: &'a str) -> Box> + 'a>; -} - -pub struct DefaultBuilder; - -impl DefaultBuilder { - pub fn new() -> DefaultBuilder { - DefaultBuilder + if let Some(category) = classify_separator(c) { + CharCategory::Separator(category) + } else if is_cjk(c) { + CharCategory::Cjk + } else { + CharCategory::Other } } -#[derive(Debug, PartialEq, Eq)] +fn is_str_word(s: &str) -> bool { + !s.chars().any(is_separator) +} + +fn same_group_category(a: char, b: char) -> bool { + match (classify_char(a), classify_char(b)) { + (CharCategory::Cjk, _) | (_, CharCategory::Cjk) => false, + (CharCategory::Separator(_), CharCategory::Separator(_)) => true, + (a, b) => a == b, + } +} + +// fold the number of chars along with the index position +fn chars_count_index((n, _): (usize, usize), (i, c): (usize, char)) -> (usize, usize) { + (n + 1, i + c.len_utf8()) +} + +pub fn split_query_string(query: &str) -> impl Iterator { + Tokenizer::new(query).map(|t| t.word) +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct Token<'a> { pub word: &'a str, pub word_index: usize, pub char_index: usize, } -impl TokenizerBuilder for DefaultBuilder { - fn build<'a>(&self, text: &'a str) -> Box> + 'a> { - Box::new(Tokenizer::new(text)) - } -} - pub struct Tokenizer<'a> { + inner: &'a str, word_index: usize, char_index: usize, - inner: &'a str, } impl<'a> Tokenizer<'a> { pub fn new(string: &str) -> Tokenizer { - let mut char_advance = 0; - let mut index_advance = 0; - for (n, (i, c)) in string.char_indices().enumerate() { - char_advance = n; - index_advance = i; - if detect_separator(c).is_none() { break } - } + // skip every separator and set `char_index` + // to the number of char trimmed + let (count, index) = string.char_indices() + .take_while(|(_, c)| is_separator(*c)) + .fold((0, 0), chars_count_index); Tokenizer { + inner: &string[index..], word_index: 0, - char_index: char_advance, - inner: &string[index_advance..], + char_index: count, } } } -#[derive(Debug, Clone, Copy)] -enum Separator { - Short, - Long, -} - -impl Separator { - fn add(self, add: Separator) -> Separator { - match (self, add) { - (_, Long) => Long, - (Short, Short) => Short, - (Long, Short) => Long, - } - } - - fn to_usize(self) -> usize { - match self { - Short => 1, - Long => 8, - } - } -} - -fn detect_separator(c: char) -> Option { - match c { - '.' | ';' | ',' | '!' | '?' | '-' | '(' | ')' => Some(Long), - ' ' | '\'' | '"' => Some(Short), - _ => None, - } -} - impl<'a> Iterator for Tokenizer<'a> { type Item = Token<'a>; fn next(&mut self) -> Option { - let mut start_word = None; - let mut distance = None; + let mut iter = self.inner.linear_group_by(same_group_category).peekable(); - for (i, c) in self.inner.char_indices() { - match detect_separator(c) { - Some(sep) => { - if let Some(start_word) = start_word { - let (prefix, tail) = self.inner.split_at(i); - let (spaces, word) = prefix.split_at(start_word); + while let (Some(string), next_string) = (iter.next(), iter.peek()) { + let (count, index) = string.char_indices().fold((0, 0), chars_count_index); - self.inner = tail; - self.char_index += spaces.chars().count(); - self.word_index += distance.map(Separator::to_usize).unwrap_or(0); - - let token = Token { - word: word, - word_index: self.word_index, - char_index: self.char_index, - }; - - self.char_index += word.chars().count(); - return Some(token) - } - - distance = Some(distance.map_or(sep, |s| s.add(sep))); - }, - None => { - // if this is a Chinese, a Japanese or a Korean character - // See - if is_cjk(c) { - match start_word { - Some(start_word) => { - let (prefix, tail) = self.inner.split_at(i); - let (spaces, word) = prefix.split_at(start_word); - - self.inner = tail; - self.char_index += spaces.chars().count(); - self.word_index += distance.map(Separator::to_usize).unwrap_or(0); - - let token = Token { - word: word, - word_index: self.word_index, - char_index: self.char_index, - }; - - self.word_index += 1; - self.char_index += word.chars().count(); - - return Some(token) - }, - None => { - let (prefix, tail) = self.inner.split_at(i + c.len_utf8()); - let (spaces, word) = prefix.split_at(i); - - self.inner = tail; - self.char_index += spaces.chars().count(); - self.word_index += distance.map(Separator::to_usize).unwrap_or(0); - - let token = Token { - word: word, - word_index: self.word_index, - char_index: self.char_index, - }; - - if tail.chars().next().and_then(detect_separator).is_none() { - self.word_index += 1; - } - self.char_index += 1; - - return Some(token) - } - } - } - - if start_word.is_none() { start_word = Some(i) } - }, + if !is_str_word(string) { + self.word_index += string.chars() + .filter_map(classify_separator) + .fold(Soft, |a, x| a.merge(x)) + .to_usize(); + self.char_index += count; + self.inner = &self.inner[index..]; + continue; } - } - - if let Some(start_word) = start_word { - let prefix = mem::replace(&mut self.inner, ""); - let (spaces, word) = prefix.split_at(start_word); let token = Token { - word: word, - word_index: self.word_index + distance.map(Separator::to_usize).unwrap_or(0), - char_index: self.char_index + spaces.chars().count(), + word: string, + word_index: self.word_index, + char_index: self.char_index, }; - return Some(token) + + if next_string.filter(|s| is_str_word(s)).is_some() { + self.word_index += 1; + } + + self.char_index += count; + self.inner = &self.inner[index..]; + + return Some(token); } + self.inner = ""; None } }