2023-11-15 22:46:37 +08:00
|
|
|
use std::cmp::Ordering;
|
|
|
|
|
|
|
|
use itertools::Itertools;
|
|
|
|
use roaring::RoaringBitmap;
|
|
|
|
|
|
|
|
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
|
2024-03-28 18:50:53 +08:00
|
|
|
use crate::search::SemanticSearch;
|
2023-12-14 19:42:37 +08:00
|
|
|
use crate::{MatchingWords, Result, Search, SearchResult};
|
2023-11-15 22:46:37 +08:00
|
|
|
|
2023-12-14 19:42:37 +08:00
|
|
|
struct ScoreWithRatioResult {
|
2023-11-15 22:46:37 +08:00
|
|
|
matching_words: MatchingWords,
|
|
|
|
candidates: RoaringBitmap,
|
2023-12-14 19:42:37 +08:00
|
|
|
document_scores: Vec<(u32, ScoreWithRatio)>,
|
2024-03-19 22:11:21 +08:00
|
|
|
degraded: bool,
|
2024-03-27 01:01:27 +08:00
|
|
|
used_negative_operator: bool,
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
|
2023-12-14 19:42:37 +08:00
|
|
|
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
|
2023-11-15 22:46:37 +08:00
|
|
|
|
2023-12-14 19:42:37 +08:00
|
|
|
fn compare_scores(
|
|
|
|
&(ref left_scores, left_ratio): &ScoreWithRatio,
|
|
|
|
&(ref right_scores, right_ratio): &ScoreWithRatio,
|
|
|
|
) -> Ordering {
|
|
|
|
let mut left_it = ScoreDetails::score_values(left_scores.iter());
|
|
|
|
let mut right_it = ScoreDetails::score_values(right_scores.iter());
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
loop {
|
2023-12-14 19:42:37 +08:00
|
|
|
let left = left_it.next();
|
|
|
|
let right = right_it.next();
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
match (left, right) {
|
|
|
|
(None, None) => return Ordering::Equal,
|
|
|
|
(None, Some(_)) => return Ordering::Less,
|
|
|
|
(Some(_), None) => return Ordering::Greater,
|
|
|
|
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
|
2023-12-14 19:42:37 +08:00
|
|
|
let left = left * left_ratio as f64;
|
|
|
|
let right = right * right_ratio as f64;
|
2023-11-15 22:46:37 +08:00
|
|
|
if (left - right).abs() <= f64::EPSILON {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
return left.partial_cmp(&right).unwrap();
|
|
|
|
}
|
|
|
|
(Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => {
|
|
|
|
match left.partial_cmp(right).unwrap() {
|
|
|
|
Ordering::Equal => continue,
|
|
|
|
order => return order,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
(Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => {
|
|
|
|
match left.partial_cmp(right).unwrap() {
|
|
|
|
Ordering::Equal => continue,
|
|
|
|
order => return order,
|
|
|
|
}
|
|
|
|
}
|
2024-03-20 00:32:32 +08:00
|
|
|
(Some(ScoreValue::Score(x)), Some(_)) => {
|
|
|
|
return if x == 0. { Ordering::Less } else { Ordering::Greater }
|
|
|
|
}
|
|
|
|
(Some(_), Some(ScoreValue::Score(x))) => {
|
|
|
|
return if x == 0. { Ordering::Greater } else { Ordering::Less }
|
|
|
|
}
|
2023-11-15 22:46:37 +08:00
|
|
|
// if we have this, we're bad
|
|
|
|
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
|
|
|
|
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
|
|
|
|
unreachable!("Unexpected geo and sort comparison")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-14 19:42:37 +08:00
|
|
|
impl ScoreWithRatioResult {
|
|
|
|
fn new(results: SearchResult, ratio: f32) -> Self {
|
|
|
|
let document_scores = results
|
2023-11-15 22:46:37 +08:00
|
|
|
.documents_ids
|
2023-12-14 19:42:37 +08:00
|
|
|
.into_iter()
|
|
|
|
.zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
|
|
|
|
.collect();
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
Self {
|
2023-12-14 19:42:37 +08:00
|
|
|
matching_words: results.matching_words,
|
|
|
|
candidates: results.candidates,
|
2023-11-15 22:46:37 +08:00
|
|
|
document_scores,
|
2024-03-19 22:11:21 +08:00
|
|
|
degraded: results.degraded,
|
2024-03-27 01:01:27 +08:00
|
|
|
used_negative_operator: results.used_negative_operator,
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-03 15:35:07 +08:00
|
|
|
fn merge(
|
|
|
|
vector_results: Self,
|
|
|
|
keyword_results: Self,
|
|
|
|
from: usize,
|
|
|
|
length: usize,
|
|
|
|
) -> (SearchResult, u32) {
|
|
|
|
#[derive(Clone, Copy)]
|
|
|
|
enum ResultSource {
|
|
|
|
Semantic,
|
|
|
|
Keyword,
|
|
|
|
}
|
|
|
|
let mut semantic_hit_count = 0;
|
|
|
|
|
|
|
|
let mut documents_ids = Vec::with_capacity(
|
|
|
|
vector_results.document_scores.len() + keyword_results.document_scores.len(),
|
|
|
|
);
|
|
|
|
let mut document_scores = Vec::with_capacity(
|
|
|
|
vector_results.document_scores.len() + keyword_results.document_scores.len(),
|
|
|
|
);
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
let mut documents_seen = RoaringBitmap::new();
|
2024-04-03 15:35:07 +08:00
|
|
|
for ((docid, (main_score, _sub_score)), source) in vector_results
|
2023-11-15 22:46:37 +08:00
|
|
|
.document_scores
|
|
|
|
.into_iter()
|
2024-04-03 15:35:07 +08:00
|
|
|
.zip(std::iter::repeat(ResultSource::Semantic))
|
|
|
|
.merge_by(
|
|
|
|
keyword_results
|
|
|
|
.document_scores
|
|
|
|
.into_iter()
|
|
|
|
.zip(std::iter::repeat(ResultSource::Keyword)),
|
|
|
|
|((_, left), _), ((_, right), _)| {
|
|
|
|
// the first value is the one with the greatest score
|
|
|
|
compare_scores(left, right).is_ge()
|
|
|
|
},
|
|
|
|
)
|
2023-11-15 22:46:37 +08:00
|
|
|
// remove documents we already saw
|
2024-04-03 15:35:07 +08:00
|
|
|
.filter(|((docid, _), _)| documents_seen.insert(*docid))
|
2023-11-15 22:46:37 +08:00
|
|
|
// start skipping **after** the filter
|
|
|
|
.skip(from)
|
|
|
|
// take **after** skipping
|
|
|
|
.take(length)
|
|
|
|
{
|
2024-04-03 15:35:07 +08:00
|
|
|
if let ResultSource::Semantic = source {
|
|
|
|
semantic_hit_count += 1;
|
|
|
|
}
|
2023-11-15 22:46:37 +08:00
|
|
|
documents_ids.push(docid);
|
|
|
|
// TODO: pass both scores to documents_score in some way?
|
|
|
|
document_scores.push(main_score);
|
|
|
|
}
|
|
|
|
|
2024-04-03 15:35:07 +08:00
|
|
|
(
|
|
|
|
SearchResult {
|
|
|
|
matching_words: keyword_results.matching_words,
|
|
|
|
candidates: vector_results.candidates | keyword_results.candidates,
|
|
|
|
documents_ids,
|
|
|
|
document_scores,
|
|
|
|
degraded: vector_results.degraded | keyword_results.degraded,
|
|
|
|
used_negative_operator: vector_results.used_negative_operator
|
|
|
|
| keyword_results.used_negative_operator,
|
|
|
|
},
|
|
|
|
semantic_hit_count,
|
|
|
|
)
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> Search<'a> {
|
2024-04-03 15:35:07 +08:00
|
|
|
pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<(SearchResult, Option<u32>)> {
|
2023-11-15 22:46:37 +08:00
|
|
|
// TODO: find classier way to achieve that than to reset vector and query params
|
|
|
|
// create separate keyword and semantic searches
|
|
|
|
let mut search = Search {
|
|
|
|
query: self.query.clone(),
|
|
|
|
filter: self.filter.clone(),
|
|
|
|
offset: 0,
|
|
|
|
limit: self.limit + self.offset,
|
|
|
|
sort_criteria: self.sort_criteria.clone(),
|
2024-06-11 23:39:35 +08:00
|
|
|
distinct: self.distinct.clone(),
|
2023-11-15 22:46:37 +08:00
|
|
|
searchable_attributes: self.searchable_attributes,
|
|
|
|
geo_strategy: self.geo_strategy,
|
|
|
|
terms_matching_strategy: self.terms_matching_strategy,
|
|
|
|
scoring_strategy: ScoringStrategy::Detailed,
|
|
|
|
words_limit: self.words_limit,
|
|
|
|
exhaustive_number_hits: self.exhaustive_number_hits,
|
|
|
|
rtxn: self.rtxn,
|
|
|
|
index: self.index,
|
2024-03-28 18:50:53 +08:00
|
|
|
semantic: self.semantic.clone(),
|
2024-03-15 00:34:46 +08:00
|
|
|
time_budget: self.time_budget.clone(),
|
2024-04-12 01:04:06 +08:00
|
|
|
ranking_score_threshold: self.ranking_score_threshold,
|
2023-11-15 22:46:37 +08:00
|
|
|
};
|
|
|
|
|
2024-03-28 18:50:53 +08:00
|
|
|
let semantic = search.semantic.take();
|
2023-11-15 22:46:37 +08:00
|
|
|
let keyword_results = search.execute()?;
|
|
|
|
|
|
|
|
// completely skip semantic search if the results of the keyword search are good enough
|
2023-12-14 19:42:37 +08:00
|
|
|
if self.results_good_enough(&keyword_results, semantic_ratio) {
|
2024-06-27 20:21:48 +08:00
|
|
|
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
|
2024-03-28 18:50:53 +08:00
|
|
|
// no vector search against placeholder search
|
2024-04-03 15:35:07 +08:00
|
|
|
let Some(query) = search.query.take() else {
|
2024-06-27 20:21:48 +08:00
|
|
|
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
|
2024-04-03 15:35:07 +08:00
|
|
|
};
|
2024-03-28 18:50:53 +08:00
|
|
|
// no embedder, no semantic search
|
|
|
|
let Some(SemanticSearch { vector, embedder_name, embedder }) = semantic else {
|
2024-06-27 20:21:48 +08:00
|
|
|
return Ok(return_keyword_results(self.limit, self.offset, keyword_results));
|
2024-03-28 18:50:53 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
let vector_query = match vector {
|
|
|
|
Some(vector_query) => vector_query,
|
|
|
|
None => {
|
|
|
|
// attempt to embed the vector
|
|
|
|
match embedder.embed_one(query) {
|
|
|
|
Ok(embedding) => embedding,
|
|
|
|
Err(error) => {
|
|
|
|
tracing::error!(error=%error, "Embedding failed");
|
2024-04-03 15:35:07 +08:00
|
|
|
return Ok((keyword_results, Some(0)));
|
2024-03-28 18:50:53 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
search.semantic =
|
|
|
|
Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder });
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
// TODO: would be better to have two distinct functions at this point
|
|
|
|
let vector_results = search.execute()?;
|
|
|
|
|
2023-12-14 19:42:37 +08:00
|
|
|
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
|
|
|
|
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
|
2023-11-15 22:46:37 +08:00
|
|
|
|
2024-04-03 15:35:07 +08:00
|
|
|
let (merge_results, semantic_hit_count) =
|
2023-12-14 19:42:37 +08:00
|
|
|
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
|
2023-11-15 22:46:37 +08:00
|
|
|
assert!(merge_results.documents_ids.len() <= self.limit);
|
2024-04-03 15:35:07 +08:00
|
|
|
Ok((merge_results, Some(semantic_hit_count)))
|
2023-11-15 22:46:37 +08:00
|
|
|
}
|
|
|
|
|
2023-12-14 19:42:37 +08:00
|
|
|
fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
|
|
|
|
// A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
|
|
|
|
const GOOD_ENOUGH_SCORE: f64 = 0.45;
|
2023-11-15 22:46:37 +08:00
|
|
|
|
|
|
|
// 1. we check that we got a sufficient number of results
|
|
|
|
if keyword_results.document_scores.len() < self.limit + self.offset {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 2. and that all results have a good enough score.
|
|
|
|
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
|
|
|
|
for score in &keyword_results.document_scores {
|
|
|
|
let score = ScoreDetails::global_score(score.iter());
|
2023-12-14 19:42:37 +08:00
|
|
|
if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
|
2023-11-15 22:46:37 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
true
|
|
|
|
}
|
|
|
|
}
|
2024-06-27 20:21:48 +08:00
|
|
|
|
|
|
|
fn return_keyword_results(
|
|
|
|
limit: usize,
|
|
|
|
offset: usize,
|
|
|
|
SearchResult {
|
|
|
|
matching_words,
|
|
|
|
candidates,
|
|
|
|
mut documents_ids,
|
|
|
|
mut document_scores,
|
|
|
|
degraded,
|
|
|
|
used_negative_operator,
|
|
|
|
}: SearchResult,
|
|
|
|
) -> (SearchResult, Option<u32>) {
|
|
|
|
let (documents_ids, document_scores) = if offset >= documents_ids.len() ||
|
|
|
|
// technically redudant because documents_ids.len() == document_scores.len(),
|
|
|
|
// defensive programming
|
|
|
|
offset >= document_scores.len()
|
|
|
|
{
|
|
|
|
(vec![], vec![])
|
|
|
|
} else {
|
|
|
|
// PANICS: offset < len
|
|
|
|
documents_ids.rotate_left(offset);
|
|
|
|
documents_ids.truncate(limit);
|
|
|
|
|
|
|
|
// PANICS: offset < len
|
|
|
|
document_scores.rotate_left(offset);
|
|
|
|
document_scores.truncate(limit);
|
|
|
|
(documents_ids, document_scores)
|
|
|
|
};
|
|
|
|
(
|
|
|
|
SearchResult {
|
|
|
|
matching_words,
|
|
|
|
candidates,
|
|
|
|
documents_ids,
|
|
|
|
document_scores,
|
|
|
|
degraded,
|
|
|
|
used_negative_operator,
|
|
|
|
},
|
|
|
|
Some(0),
|
|
|
|
)
|
|
|
|
}
|