147 lines
4.6 KiB
Rust
Raw Normal View History

2024-04-09 12:03:40 +02:00
use std::sync::Arc;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use crate::score_details::{self, ScoreDetails};
use crate::vector::Embedder;
use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult};
pub struct Similar<'a> {
id: DocumentId,
// this should be linked to the String in the query
filter: Option<Filter<'a>>,
offset: usize,
limit: usize,
rtxn: &'a heed::RoTxn<'a>,
index: &'a Index,
embedder_name: String,
embedder: Arc<Embedder>,
2024-05-30 10:34:09 +02:00
ranking_score_threshold: Option<f64>,
2024-04-09 12:03:40 +02:00
}
impl<'a> Similar<'a> {
pub fn new(
id: DocumentId,
offset: usize,
limit: usize,
index: &'a Index,
rtxn: &'a heed::RoTxn<'a>,
embedder_name: String,
embedder: Arc<Embedder>,
) -> Self {
2024-05-30 10:34:09 +02:00
Self {
id,
filter: None,
offset,
limit,
rtxn,
index,
embedder_name,
embedder,
ranking_score_threshold: None,
}
2024-04-09 12:03:40 +02:00
}
pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self {
self.filter = Some(filter);
self
}
2024-05-30 10:34:09 +02:00
pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Self {
self.ranking_score_threshold = Some(ranking_score_threshold);
self
}
2024-04-09 12:03:40 +02:00
pub fn execute(&self) -> Result<SearchResult> {
2024-05-30 10:34:09 +02:00
let mut universe = filtered_universe(self.index, self.rtxn, &self.filter)?;
// we never want to receive the docid
universe.remove(self.id);
let universe = universe;
2024-04-09 12:03:40 +02:00
let embedder_index =
self.index
.embedder_category_id
.get(self.rtxn, &self.embedder_name)?
.ok_or_else(|| crate::UserError::InvalidEmbedder(self.embedder_name.to_owned()))?;
let readers: std::result::Result<Vec<_>, _> =
self.index.arroy_readers(self.rtxn, embedder_index).collect();
let readers = readers?;
let mut results = Vec::new();
for reader in readers.iter() {
let nns_by_item = reader.nns_by_item(
self.rtxn,
self.id,
self.limit + self.offset + 1,
None,
Some(&universe),
)?;
if let Some(mut nns_by_item) = nns_by_item {
results.append(&mut nns_by_item);
} else {
break;
}
}
results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
let mut documents_ids = Vec::with_capacity(self.limit);
let mut document_scores = Vec::with_capacity(self.limit);
// list of documents we've already seen, so that we don't return the same document multiple times.
// initialized to the target document, that we never want to return.
let mut documents_seen = RoaringBitmap::new();
documents_seen.insert(self.id);
2024-05-30 10:34:09 +02:00
let mut candidates = universe;
2024-04-09 12:03:40 +02:00
for (docid, distance) in results
.into_iter()
// skip documents we've already seen & mark that we saw the current document
.filter(|(docid, _)| documents_seen.insert(*docid))
.skip(self.offset)
// take **after** filter and skip so that we get exactly limit elements if available
.take(self.limit)
{
let score = 1.0 - distance;
let score = self
.embedder
.distribution()
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
2024-05-30 10:34:09 +02:00
let score_details =
vec![ScoreDetails::Vector(score_details::Vector { similarity: Some(score) })];
let score = ScoreDetails::global_score(score_details.iter());
2024-04-09 12:03:40 +02:00
2024-05-30 10:34:09 +02:00
if let Some(ranking_score_threshold) = &self.ranking_score_threshold {
if score < *ranking_score_threshold {
// this document is no longer a candidate
candidates.remove(docid);
// any document after this one is no longer a candidate either, so restrict the set to documents already seen.
candidates &= documents_seen;
break;
}
}
documents_ids.push(docid);
document_scores.push(score_details);
2024-04-09 12:03:40 +02:00
}
Ok(SearchResult {
matching_words: Default::default(),
2024-05-30 10:34:09 +02:00
candidates,
2024-04-09 12:03:40 +02:00
documents_ids,
document_scores,
degraded: false,
used_negative_operator: false,
})
}
}