feat: Simplify the Tokenizer to use the LinearStrGroupBy type

This commit is contained in:
Clément Renault 2019-02-26 12:16:10 +01:00
parent 5d5bcf7011
commit a745819ddf
No known key found for this signature in database
GPG Key ID: 0151CDAB43460DAE
2 changed files with 96 additions and 165 deletions

View File

@ -1,7 +1,7 @@
[package] [package]
name = "meilidb-tokenizer" name = "meilidb-tokenizer"
version = "0.1.0" version = "0.1.0"
authors = ["Clément Renault <renault.cle@gmail.com>"] authors = ["Kerollmops <renault.cle@gmail.com>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]

View File

@ -1,6 +1,5 @@
use std::mem; use slice_group_by::StrGroupBy;
use slice_group_by::LinearStrGroupBy; use self::SeparatorCategory::*;
use self::Separator::*;
pub fn is_cjk(c: char) -> bool { pub fn is_cjk(c: char) -> bool {
(c >= '\u{2e80}' && c <= '\u{2eff}') || (c >= '\u{2e80}' && c <= '\u{2eff}') ||
@ -14,208 +13,140 @@ pub fn is_cjk(c: char) -> bool {
(c >= '\u{f900}' && c <= '\u{faff}') (c >= '\u{f900}' && c <= '\u{faff}')
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum SeparatorCategory {
Soft,
Hard,
}
impl SeparatorCategory {
fn merge(self, other: SeparatorCategory) -> SeparatorCategory {
if let (Soft, Soft) = (self, other) { Soft } else { Hard }
}
fn to_usize(self) -> usize {
match self {
Soft => 1,
Hard => 8,
}
}
}
fn is_separator(c: char) -> bool {
classify_separator(c).is_some()
}
fn classify_separator(c: char) -> Option<SeparatorCategory> {
match c {
' ' | '\'' | '"' => Some(Soft),
'.' | ';' | ',' | '!' | '?' | '-' | '(' | ')' => Some(Hard),
_ => None,
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum CharCategory { enum CharCategory {
Space, Separator(SeparatorCategory),
Cjk, Cjk,
Other, Other,
} }
fn classify_char(c: char) -> CharCategory { fn classify_char(c: char) -> CharCategory {
if c.is_whitespace() { CharCategory::Space } if let Some(category) = classify_separator(c) {
else if is_cjk(c) { CharCategory::Cjk } CharCategory::Separator(category)
else { CharCategory::Other } } else if is_cjk(c) {
CharCategory::Cjk
} else {
CharCategory::Other
}
} }
fn is_word(s: &&str) -> bool { fn is_str_word(s: &str) -> bool {
!s.chars().any(char::is_whitespace) !s.chars().any(is_separator)
} }
fn same_group_category(a: char, b: char) -> bool { fn same_group_category(a: char, b: char) -> bool {
let ca = classify_char(a); match (classify_char(a), classify_char(b)) {
let cb = classify_char(b); (CharCategory::Cjk, _) | (_, CharCategory::Cjk) => false,
if ca == CharCategory::Cjk || cb == CharCategory::Cjk { false } else { ca == cb } (CharCategory::Separator(_), CharCategory::Separator(_)) => true,
(a, b) => a == b,
}
}
// fold the number of chars along with the index position
fn chars_count_index((n, _): (usize, usize), (i, c): (usize, char)) -> (usize, usize) {
(n + 1, i + c.len_utf8())
} }
pub fn split_query_string(query: &str) -> impl Iterator<Item=&str> { pub fn split_query_string(query: &str) -> impl Iterator<Item=&str> {
LinearStrGroupBy::new(query, same_group_category).filter(is_word) Tokenizer::new(query).map(|t| t.word)
} }
pub trait TokenizerBuilder { #[derive(Debug, Copy, Clone, PartialEq, Eq)]
fn build<'a>(&self, text: &'a str) -> Box<Iterator<Item=Token<'a>> + 'a>;
}
pub struct DefaultBuilder;
impl DefaultBuilder {
pub fn new() -> DefaultBuilder {
DefaultBuilder
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Token<'a> { pub struct Token<'a> {
pub word: &'a str, pub word: &'a str,
pub word_index: usize, pub word_index: usize,
pub char_index: usize, pub char_index: usize,
} }
impl TokenizerBuilder for DefaultBuilder {
fn build<'a>(&self, text: &'a str) -> Box<Iterator<Item=Token<'a>> + 'a> {
Box::new(Tokenizer::new(text))
}
}
pub struct Tokenizer<'a> { pub struct Tokenizer<'a> {
inner: &'a str,
word_index: usize, word_index: usize,
char_index: usize, char_index: usize,
inner: &'a str,
} }
impl<'a> Tokenizer<'a> { impl<'a> Tokenizer<'a> {
pub fn new(string: &str) -> Tokenizer { pub fn new(string: &str) -> Tokenizer {
let mut char_advance = 0; // skip every separator and set `char_index`
let mut index_advance = 0; // to the number of char trimmed
for (n, (i, c)) in string.char_indices().enumerate() { let (count, index) = string.char_indices()
char_advance = n; .take_while(|(_, c)| is_separator(*c))
index_advance = i; .fold((0, 0), chars_count_index);
if detect_separator(c).is_none() { break }
}
Tokenizer { Tokenizer {
inner: &string[index..],
word_index: 0, word_index: 0,
char_index: char_advance, char_index: count,
inner: &string[index_advance..],
} }
} }
} }
#[derive(Debug, Clone, Copy)]
enum Separator {
Short,
Long,
}
impl Separator {
fn add(self, add: Separator) -> Separator {
match (self, add) {
(_, Long) => Long,
(Short, Short) => Short,
(Long, Short) => Long,
}
}
fn to_usize(self) -> usize {
match self {
Short => 1,
Long => 8,
}
}
}
fn detect_separator(c: char) -> Option<Separator> {
match c {
'.' | ';' | ',' | '!' | '?' | '-' | '(' | ')' => Some(Long),
' ' | '\'' | '"' => Some(Short),
_ => None,
}
}
impl<'a> Iterator for Tokenizer<'a> { impl<'a> Iterator for Tokenizer<'a> {
type Item = Token<'a>; type Item = Token<'a>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let mut start_word = None; let mut iter = self.inner.linear_group_by(same_group_category).peekable();
let mut distance = None;
for (i, c) in self.inner.char_indices() { while let (Some(string), next_string) = (iter.next(), iter.peek()) {
match detect_separator(c) { let (count, index) = string.char_indices().fold((0, 0), chars_count_index);
Some(sep) => {
if let Some(start_word) = start_word {
let (prefix, tail) = self.inner.split_at(i);
let (spaces, word) = prefix.split_at(start_word);
self.inner = tail; if !is_str_word(string) {
self.char_index += spaces.chars().count(); self.word_index += string.chars()
self.word_index += distance.map(Separator::to_usize).unwrap_or(0); .filter_map(classify_separator)
.fold(Soft, |a, x| a.merge(x))
let token = Token { .to_usize();
word: word, self.char_index += count;
word_index: self.word_index, self.inner = &self.inner[index..];
char_index: self.char_index, continue;
};
self.char_index += word.chars().count();
return Some(token)
} }
distance = Some(distance.map_or(sep, |s| s.add(sep)));
},
None => {
// if this is a Chinese, a Japanese or a Korean character
// See <http://unicode-table.com>
if is_cjk(c) {
match start_word {
Some(start_word) => {
let (prefix, tail) = self.inner.split_at(i);
let (spaces, word) = prefix.split_at(start_word);
self.inner = tail;
self.char_index += spaces.chars().count();
self.word_index += distance.map(Separator::to_usize).unwrap_or(0);
let token = Token { let token = Token {
word: word, word: string,
word_index: self.word_index, word_index: self.word_index,
char_index: self.char_index, char_index: self.char_index,
}; };
self.word_index += 1; if next_string.filter(|s| is_str_word(s)).is_some() {
self.char_index += word.chars().count();
return Some(token)
},
None => {
let (prefix, tail) = self.inner.split_at(i + c.len_utf8());
let (spaces, word) = prefix.split_at(i);
self.inner = tail;
self.char_index += spaces.chars().count();
self.word_index += distance.map(Separator::to_usize).unwrap_or(0);
let token = Token {
word: word,
word_index: self.word_index,
char_index: self.char_index,
};
if tail.chars().next().and_then(detect_separator).is_none() {
self.word_index += 1; self.word_index += 1;
} }
self.char_index += 1;
return Some(token) self.char_index += count;
} self.inner = &self.inner[index..];
}
} return Some(token);
if start_word.is_none() { start_word = Some(i) }
},
}
}
if let Some(start_word) = start_word {
let prefix = mem::replace(&mut self.inner, "");
let (spaces, word) = prefix.split_at(start_word);
let token = Token {
word: word,
word_index: self.word_index + distance.map(Separator::to_usize).unwrap_or(0),
char_index: self.char_index + spaces.chars().count(),
};
return Some(token)
} }
self.inner = "";
None None
} }
} }