From e7af499314f24e51f1bff27ff231ceb898aa27a1 Mon Sep 17 00:00:00 2001 From: "F. Levi" <55688616+flevi29@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:58:13 +0300 Subject: [PATCH] Improve changes to Matcher --- milli/src/search/new/matches/mod.rs | 136 +++++++++++++++++++++------- 1 file changed, 104 insertions(+), 32 deletions(-) diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 6ddb81c6a..26dd6f6e8 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -93,15 +93,28 @@ impl FormatOptions { } } +#[derive(Clone, Debug)] +pub enum MatchPosition { + Word { + // position of the word in the whole text. + word_position: usize, + // position of the token in the whole text. + token_position: usize, + }, + Phrase { + // position of the first and last word in the phrase in the whole text. + word_positions: (usize, usize), + // position of the first and last token in the phrase in the whole text. + token_positions: (usize, usize), + }, +} + #[derive(Clone, Debug)] pub struct Match { match_len: usize, // ids of the query words that matches. ids: Vec, - // position of the word in the whole text. - word_position: usize, - // position of the token in the whole text. - token_position: usize, + position: MatchPosition, } #[derive(Serialize, Debug, Clone, PartialEq, Eq)] @@ -130,13 +143,13 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { /// compute_partial_match peek into next words to validate if the match is complete. fn compute_partial_match<'a>( mut partial: PartialMatch<'a>, - token_position: usize, - word_position: usize, + first_token_position: usize, + first_word_position: usize, first_word_char_start: &usize, words_positions: &mut impl Iterator)>, matches: &mut Vec, ) -> bool { - for (_, _, word) in words_positions { + for (token_position, word_position, word) in words_positions { partial = match partial.match_token(word) { // token matches the partial match, but the match is not full, // we temporarily save the current token then we try to match the next one. @@ -145,10 +158,12 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { Some(MatchType::Full { ids, .. }) => { // save the token that closes the partial match as a match. matches.push(Match { - match_len: word.char_end - first_word_char_start, + match_len: word.char_end - *first_word_char_start, ids: ids.clone().collect(), - word_position, - token_position, + position: MatchPosition::Phrase { + word_positions: (first_word_position, word_position), + token_positions: (first_token_position, token_position), + }, }); // the match is complete, we return true. @@ -191,8 +206,7 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { matches.push(Match { match_len: char_len, ids, - word_position, - token_position, + position: MatchPosition::Word { word_position, token_position }, }); break; } @@ -228,13 +242,47 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { Some((tokens, matches)) => matches .iter() .map(|m| MatchBounds { - start: tokens[m.token_position].byte_start, + start: tokens[match m.position { + MatchPosition::Word { token_position, .. } => token_position, + MatchPosition::Phrase { + token_positions: (first_token_position, _), + .. + } => first_token_position, + }] + .byte_start, length: m.match_len, }) .collect(), } } + // @TODO: This should be improved, looks nasty + fn get_match_pos(&self, m: &Match, is_first: bool, is_word: bool) -> usize { + match m.position { + MatchPosition::Word { word_position, token_position } => { + if is_word { + word_position + } else { + token_position + } + } + MatchPosition::Phrase { word_positions: (wpf, wpl), token_positions: (tpf, tpl) } => { + if is_word { + if is_first { + return wpf; + } else { + return wpl; + } + } + if is_first { + tpf + } else { + tpl + } + } + } + } + /// Returns the bounds in byte index of the crop window. fn crop_bounds( &self, @@ -243,10 +291,14 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { crop_size: usize, ) -> (usize, usize) { // if there is no match, we start from the beginning of the string by default. - let first_match_word_position = matches.first().map(|m| m.word_position).unwrap_or(0); - let first_match_token_position = matches.first().map(|m| m.token_position).unwrap_or(0); - let last_match_word_position = matches.last().map(|m| m.word_position).unwrap_or(0); - let last_match_token_position = matches.last().map(|m| m.token_position).unwrap_or(0); + let first_match_word_position = + matches.first().map(|m| self.get_match_pos(m, true, true)).unwrap_or(0); + let first_match_token_position = + matches.first().map(|m| self.get_match_pos(m, true, false)).unwrap_or(0); + let last_match_word_position = + matches.last().map(|m| self.get_match_pos(m, false, true)).unwrap_or(0); + let last_match_token_position = + matches.last().map(|m| self.get_match_pos(m, false, false)).unwrap_or(0); // matches needs to be counted in the crop len. let mut remaining_words = crop_size + first_match_word_position - last_match_word_position; @@ -350,7 +402,9 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { } // compute distance between matches - distance_score -= (next_match.word_position - m.word_position).min(7) as i16; + distance_score -= (self.get_match_pos(next_match, true, true) + - self.get_match_pos(m, true, true)) + .min(7) as i16; } ids.extend(m.ids.iter()); @@ -378,7 +432,12 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { // if next match would make interval gross more than crop_size, // we compare the current interval with the best one, // then we increase `interval_first` until next match can be added. - if next_match.word_position - matches[interval_first].word_position >= crop_size { + let next_match_word_position = self.get_match_pos(next_match, true, true); + + if next_match_word_position + - self.get_match_pos(&matches[interval_first], false, true) + >= crop_size + { let interval_score = self.match_interval_score(&matches[interval_first..=interval_last]); @@ -389,10 +448,15 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { } // advance start of the interval while interval is longer than crop_size. - while next_match.word_position - matches[interval_first].word_position - >= crop_size - { + loop { interval_first += 1; + + if next_match_word_position + - self.get_match_pos(&matches[interval_first], false, true) + < crop_size + { + break; + } } } interval_last = index; @@ -441,33 +505,41 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { if format_options.highlight { // insert highlight markers around matches. for m in matches { - let token = &tokens[m.token_position]; + let (current_byte_start, current_byte_end) = match m.position { + MatchPosition::Word { token_position, .. } => { + let token = &tokens[token_position]; + (&token.byte_start, &token.byte_end) + } + MatchPosition::Phrase { token_positions: (ftp, ltp), .. } => { + (&tokens[ftp].byte_start, &tokens[ltp].byte_end) + } + }; // skip matches out of the crop window. - if token.byte_start < byte_start || token.byte_end > byte_end { + if *current_byte_start < byte_start || *current_byte_end > byte_end { continue; } - if byte_index < token.byte_start { - formatted.push(&self.text[byte_index..token.byte_start]); + if byte_index < *current_byte_start { + formatted.push(&self.text[byte_index..*current_byte_start]); } - let highlight_byte_index = self.text[token.byte_start..] + let highlight_byte_index = self.text[*current_byte_start..] .char_indices() .enumerate() .find(|(i, _)| *i == m.match_len) - .map_or(token.byte_end, |(_, (i, _))| i + token.byte_start); + .map_or(*current_byte_end, |(_, (i, _))| i + *current_byte_start); formatted.push(self.highlight_prefix); - formatted.push(&self.text[token.byte_start..highlight_byte_index]); + formatted.push(&self.text[*current_byte_start..highlight_byte_index]); formatted.push(self.highlight_suffix); // if it's a prefix highlight, we put the end of the word after the highlight marker. - if highlight_byte_index < token.byte_end { - formatted.push(&self.text[highlight_byte_index..token.byte_end]); + if highlight_byte_index < *current_byte_end { + formatted.push(&self.text[highlight_byte_index..*current_byte_end]); } - byte_index = token.byte_start + m.match_len; + byte_index = *current_byte_end; } }