implement crop around

This commit is contained in:
Marin Postma 2021-05-11 18:30:55 +02:00 committed by Clémentine Urquizar
parent 56c9633c53
commit 7473cc6e27
No known key found for this signature in database
GPG Key ID: D8E7CC7422E77E1A

View File

@ -1,6 +1,6 @@
use std::borrow::Cow;
use std::collections::{BTreeMap, HashSet};
use std::collections::{BTreeMap, HashSet, VecDeque};
use std::time::Instant;
use std::{borrow::Cow, collections::HashMap};
use anyhow::bail;
use either::Either;
@ -157,7 +157,12 @@ impl Index {
let stop_words = fst::Set::default();
let highlighter =
Highlighter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
Formatter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
let to_crop = to_crop_ids
.into_iter()
.map(|id| (id, query.crop_length))
.collect::<HashMap<_, _>>();
for (_id, obkv) in self.documents(&rtxn, documents_ids)? {
let document = make_document(&all_attributes, &fields_ids_map, obkv)?;
@ -168,7 +173,7 @@ impl Index {
&matching_words,
all_formatted.as_ref().as_slice(),
&to_highlight_ids,
&to_crop_ids,
&to_crop,
)?;
let hit = SearchHit {
document,
@ -230,11 +235,11 @@ fn make_document(
fn compute_formatted<A: AsRef<[u8]>>(
field_ids_map: &FieldsIdsMap,
obkv: obkv::KvReader,
highlighter: &Highlighter<A>,
highlighter: &Formatter<A>,
matching_words: &impl Matcher,
all_formatted: &[FieldId],
to_highlight_fields: &HashSet<FieldId>,
to_crop_fields: &HashSet<FieldId>,
to_crop_fields: &HashMap<FieldId, Option<usize>>,
) -> anyhow::Result<Document> {
let mut document = Document::new();
@ -242,15 +247,12 @@ fn compute_formatted<A: AsRef<[u8]>>(
if let Some(value) = obkv.get(*field) {
let mut value: Value = serde_json::from_slice(value)?;
let need_to_crop = if to_crop_fields.contains(field) {
Some(200) // TO CHANGE
} else {
None
};
if to_highlight_fields.contains(field) {
value = highlighter.format_value(value, matching_words, need_to_crop, to_highlight_fields.contains(field));
}
value = highlighter.format_value(
value,
matching_words,
to_crop_fields.get(field).copied().flatten(),
to_highlight_fields.contains(field),
);
// This unwrap must be safe since we got the ids from the fields_ids_map just
// before.
@ -284,12 +286,12 @@ impl Matcher for MatchingWords {
}
}
struct Highlighter<'a, A> {
struct Formatter<'a, A> {
analyzer: Analyzer<'a, A>,
marks: (String, String),
}
impl<'a, A: AsRef<[u8]>> Highlighter<'a, A> {
impl<'a, A: AsRef<[u8]>> Formatter<'a, A> {
pub fn new(stop_words: &'a fst::Set<A>, marks: (String, String)) -> Self {
let mut config = AnalyzerConfig::default();
config.stop_words(stop_words);
@ -308,7 +310,8 @@ impl<'a, A: AsRef<[u8]>> Highlighter<'a, A> {
) -> Value {
match value {
Value::String(old_string) => {
let value = self.format_string(old_string, matcher, need_to_crop, need_to_highlight);
let value =
self.format_string(old_string, matcher, need_to_crop, need_to_highlight);
Value::String(value)
}
Value::Array(values) => Value::Array(
@ -326,27 +329,54 @@ impl<'a, A: AsRef<[u8]>> Highlighter<'a, A> {
value => value,
}
}
fn format_string(&self, s: String, matcher: &impl Matcher, need_to_crop: Option<usize>, need_to_highlight: bool) -> String {
fn format_string(
&self,
s: String,
matcher: &impl Matcher,
need_to_crop: Option<usize>,
need_to_highlight: bool,
) -> String {
let analyzed = self.analyzer.analyze(&s);
let tokens: Box<dyn Iterator<Item=(&str, Token)>> = match need_to_crop {
let tokens: Box<dyn Iterator<Item = (&str, Token)>> = match need_to_crop {
Some(crop_len) => {
let mut taken = 0;
let iter = analyzed
.reconstruct()
.skip_while(|(_, token)| !matcher.matches(token.text()))
let mut buffer = VecDeque::new();
let mut tokens = analyzed.reconstruct().peekable();
let mut taken_before = 0;
while let Some((word, token)) = tokens.next_if(|(_, token)| !matcher.matches(token.text())) {
buffer.push_back((word, token));
taken_before += word.chars().count();
while taken_before > crop_len {
if let Some((word, _)) = buffer.pop_front() {
taken_before -= word.chars().count();
}
}
}
if let Some(token) = tokens.next() {
buffer.push_back(token);
}
let mut taken_after = 0;
let after_iter = tokens
.take_while(move |(word, _)| {
let take = taken < crop_len;
taken += word.chars().count();
let take = taken_after <= crop_len;
taken_after += word.chars().count();
take
});
let iter = buffer
.into_iter()
.chain(after_iter);
Box::new(iter)
},
}
None => Box::new(analyzed.reconstruct()),
};
tokens.map(|(word, token)| {
if need_to_highlight && token.is_word() && matcher.matches(token.text()){
tokens
.map(|(word, token)| {
if need_to_highlight && token.is_word() && matcher.matches(token.text()) {
let mut new_word = String::new();
new_word.push_str(&self.marks.0);
new_word.push_str(&word);
@ -360,7 +390,6 @@ impl<'a, A: AsRef<[u8]>> Highlighter<'a, A> {
}
}
fn parse_facets(
facets: &Value,
index: &Index,
@ -412,7 +441,7 @@ mod test {
fn no_formatted() {
let stop_words = fst::Set::default();
let highlighter =
Highlighter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
Formatter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
let mut fields = FieldsIdsMap::new();
let id = fields.insert("test").unwrap();
@ -439,7 +468,8 @@ mod test {
&all_formatted,
&to_highlight_ids,
&to_crop_ids,
).unwrap();
)
.unwrap();
assert!(value.is_empty());
}
@ -448,7 +478,7 @@ mod test {
fn formatted_no_highlight() {
let stop_words = fst::Set::default();
let highlighter =
Highlighter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
Formatter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
let mut fields = FieldsIdsMap::new();
let id = fields.insert("test").unwrap();
@ -475,7 +505,8 @@ mod test {
&all_formatted,
&to_highlight_ids,
&to_crop_ids,
).unwrap();
)
.unwrap();
assert_eq!(value["test"], "hello");
}
@ -484,7 +515,7 @@ mod test {
fn formatted_with_highlight() {
let stop_words = fst::Set::default();
let highlighter =
Highlighter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
Formatter::new(&stop_words, (String::from("<em>"), String::from("</em>")));
let mut fields = FieldsIdsMap::new();
let id = fields.insert("test").unwrap();
@ -511,7 +542,8 @@ mod test {
&all_formatted,
&to_highlight_ids,
&to_crop_ids,
).unwrap();
)
.unwrap();
assert_eq!(value["test"], "<em>hello</em>");
}