Add ranking score threshold to similar

This commit is contained in:
Louis Dureuil 2024-05-30 10:34:09 +02:00
parent c26db7878c
commit 4f03b0cf5b
No known key found for this signature in database

View File

@ -17,6 +17,7 @@ pub struct Similar<'a> {
index: &'a Index, index: &'a Index,
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
ranking_score_threshold: Option<f64>,
} }
impl<'a> Similar<'a> { impl<'a> Similar<'a> {
@ -29,7 +30,17 @@ impl<'a> Similar<'a> {
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
) -> Self { ) -> Self {
Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } Self {
id,
filter: None,
offset,
limit,
rtxn,
index,
embedder_name,
embedder,
ranking_score_threshold: None,
}
} }
pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self {
@ -37,8 +48,18 @@ impl<'a> Similar<'a> {
self self
} }
pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Self {
self.ranking_score_threshold = Some(ranking_score_threshold);
self
}
pub fn execute(&self) -> Result<SearchResult> { pub fn execute(&self) -> Result<SearchResult> {
let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; let mut universe = filtered_universe(self.index, self.rtxn, &self.filter)?;
// we never want to receive the docid
universe.remove(self.id);
let universe = universe;
let embedder_index = let embedder_index =
self.index self.index
@ -77,6 +98,8 @@ impl<'a> Similar<'a> {
let mut documents_seen = RoaringBitmap::new(); let mut documents_seen = RoaringBitmap::new();
documents_seen.insert(self.id); documents_seen.insert(self.id);
let mut candidates = universe;
for (docid, distance) in results for (docid, distance) in results
.into_iter() .into_iter()
// skip documents we've already seen & mark that we saw the current document // skip documents we've already seen & mark that we saw the current document
@ -85,8 +108,6 @@ impl<'a> Similar<'a> {
// take **after** filter and skip so that we get exactly limit elements if available // take **after** filter and skip so that we get exactly limit elements if available
.take(self.limit) .take(self.limit)
{ {
documents_ids.push(docid);
let score = 1.0 - distance; let score = 1.0 - distance;
let score = self let score = self
.embedder .embedder
@ -94,14 +115,28 @@ impl<'a> Similar<'a> {
.map(|distribution| distribution.shift(score)) .map(|distribution| distribution.shift(score))
.unwrap_or(score); .unwrap_or(score);
let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); let score_details =
vec![ScoreDetails::Vector(score_details::Vector { similarity: Some(score) })];
document_scores.push(vec![score]); let score = ScoreDetails::global_score(score_details.iter());
if let Some(ranking_score_threshold) = &self.ranking_score_threshold {
if score < *ranking_score_threshold {
// this document is no longer a candidate
candidates.remove(docid);
// any document after this one is no longer a candidate either, so restrict the set to documents already seen.
candidates &= documents_seen;
break;
}
}
documents_ids.push(docid);
document_scores.push(score_details);
} }
Ok(SearchResult { Ok(SearchResult {
matching_words: Default::default(), matching_words: Default::default(),
candidates: universe, candidates,
documents_ids, documents_ids,
document_scores, document_scores,
degraded: false, degraded: false,