Lazily embed, don't fail hybrid search on embedding failure

This commit is contained in:
Louis Dureuil 2024-03-28 11:50:53 +01:00
parent fabc9cf14a
commit 6ebb6b55a6
No known key found for this signature in database
11 changed files with 237 additions and 203 deletions

View File

@ -12,6 +12,7 @@ use tracing::debug;
use crate::analytics::{Analytics, FacetSearchAggregator}; use crate::analytics::{Analytics, FacetSearchAggregator};
use crate::extractors::authentication::policies::*; use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::routes::indexes::search::search_kind;
use crate::search::{ use crate::search::{
add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
@ -73,9 +74,10 @@ pub async fn search(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
let search_kind = search_kind(&search_query, &index_scheduler, &index)?;
let _permit = search_queue.try_get_search_permit().await?; let _permit = search_queue.try_get_search_permit().await?;
let search_result = tokio::task::spawn_blocking(move || { let search_result = tokio::task::spawn_blocking(move || {
perform_facet_search(&index, search_query, facet_query, facet_name, features) perform_facet_search(&index, search_query, facet_query, facet_name, features, search_kind)
}) })
.await?; .await?;

View File

@ -8,19 +8,19 @@ use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli; use meilisearch_types::milli;
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value; use serde_json::Value;
use tracing::{debug, warn}; use tracing::debug;
use crate::analytics::{Analytics, SearchAggregator}; use crate::analytics::{Analytics, SearchAggregator};
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::*; use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio, add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchKind, SearchQuery,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
}; };
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
@ -204,11 +204,11 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; let search_kind = search_kind(&query, index_scheduler.get_ref(), &index)?;
let _permit = search_queue.try_get_search_permit().await?; let _permit = search_queue.try_get_search_permit().await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) tokio::task::spawn_blocking(move || perform_search(&index, query, features, search_kind))
.await?; .await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
@ -245,11 +245,11 @@ pub async fn search_with_post(
let features = index_scheduler.features(); let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; let search_kind = search_kind(&query, index_scheduler.get_ref(), &index)?;
let _permit = search_queue.try_get_search_permit().await?; let _permit = search_queue.try_get_search_permit().await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) tokio::task::spawn_blocking(move || perform_search(&index, query, features, search_kind))
.await?; .await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
@ -265,76 +265,49 @@ pub async fn search_with_post(
Ok(HttpResponse::Ok().json(search_result)) Ok(HttpResponse::Ok().json(search_result))
} }
pub fn embed( pub fn search_kind(
query: &mut SearchQuery, query: &SearchQuery,
index_scheduler: &IndexScheduler, index_scheduler: &IndexScheduler,
index: &milli::Index, index: &milli::Index,
) -> Result<Option<DistributionShift>, ResponseError> { ) -> Result<SearchKind, ResponseError> {
match (&query.hybrid, &query.vector, &query.q) { // regardless of anything, always do a semantic search when we don't have a vector and the query is whitespace or missing
(Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q)) if query.vector.is_none() {
if !q.trim().is_empty() => match &query.q {
{ Some(q) if q.trim().is_empty() => return Ok(SearchKind::KeywordOnly),
let embedder_configs = index.embedding_configs(&index.read_txn()?)?; None => return Ok(SearchKind::KeywordOnly),
let embedders = index_scheduler.embedders(embedder_configs)?; _ => {}
let embedder = if let Some(embedder_name) = embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
let distribution = embedder.distribution();
let embeddings = embedder
.embed(vec![q.to_owned()])
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
.expect("No vector returned from embedding");
if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query");
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
} else {
query.vector = Some(embeddings.into_inner());
} }
Ok(distribution)
} }
(Some(hybrid), vector, _) => {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = &hybrid.embedder { match &query.hybrid {
embedders.get(embedder_name) Some(HybridQuery { semantic_ratio, embedder }) if **semantic_ratio == 1.0 => {
} else { Ok(SearchKind::semantic(
embedders.get_default() index_scheduler,
}; index,
embedder.as_deref(),
let embedder = embedder query.vector.as_ref().map(Vec::len),
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) )?)
.map_err(milli::Error::from)? }
.0; Some(HybridQuery { semantic_ratio, embedder: _ }) if **semantic_ratio == 0.0 => {
Ok(SearchKind::KeywordOnly)
if let Some(vector) = vector { }
if vector.len() != embedder.dimensions() { Some(HybridQuery { semantic_ratio, embedder }) => Ok(SearchKind::hybrid(
return Err(meilisearch_types::milli::Error::UserError( index_scheduler,
meilisearch_types::milli::UserError::InvalidVectorDimensions { index,
expected: embedder.dimensions(), embedder.as_deref(),
found: vector.len(), **semantic_ratio,
query.vector.as_ref().map(Vec::len),
)?),
None => match (query.q.as_deref(), query.vector.as_deref()) {
(_query, None) => Ok(SearchKind::KeywordOnly),
(None, Some(_vector)) => Ok(SearchKind::semantic(
index_scheduler,
index,
None,
query.vector.as_ref().map(Vec::len),
)?),
(Some(_), Some(_)) => Err(MeilisearchHttpError::MissingSearchHybrid.into()),
}, },
)
.into());
}
}
Ok(embedder.distribution())
}
_ => Ok(None),
} }
} }

View File

@ -13,7 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator};
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::authentication::{AuthenticationError, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::routes::indexes::search::embed; use crate::routes::indexes::search::search_kind;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
}; };
@ -81,11 +81,11 @@ pub async fn multi_search_with_post(
}) })
.with_index(query_index)?; .with_index(query_index)?;
let distribution = let search_kind =
embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?; search_kind(&query, index_scheduler.get_ref(), &index).with_index(query_index)?;
let search_result = tokio::task::spawn_blocking(move || { let search_result = tokio::task::spawn_blocking(move || {
perform_search(&index, query, features, distribution) perform_search(&index, query, features, search_kind)
}) })
.await .await
.with_index(query_index)?; .with_index(query_index)?;

View File

@ -1,6 +1,7 @@
use std::cmp::min; use std::cmp::min;
use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use deserr::Deserr; use deserr::Deserr;
@ -10,10 +11,11 @@ use indexmap::IndexMap;
use meilisearch_auth::IndexSearchRules; use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError;
use meilisearch_types::heed::RoTxn; use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::vector::DistributionShift; use meilisearch_types::milli::vector::Embedder;
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget}; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document}; use meilisearch_types::{milli, Document};
@ -90,13 +92,75 @@ pub struct SearchQuery {
#[derive(Debug, Clone, Default, PartialEq, Deserr)] #[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery { pub struct HybridQuery {
/// TODO validate that sementic ratio is between 0.0 and 1,0
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
pub semantic_ratio: SemanticRatio, pub semantic_ratio: SemanticRatio,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)] #[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>, pub embedder: Option<String>,
} }
pub enum SearchKind {
KeywordOnly,
SemanticOnly { embedder_name: String, embedder: Arc<Embedder> },
Hybrid { embedder_name: String, embedder: Arc<Embedder>, semantic_ratio: f32 },
}
impl SearchKind {
pub(crate) fn semantic(
index_scheduler: &index_scheduler::IndexScheduler,
index: &Index,
embedder_name: Option<&str>,
vector_len: Option<usize>,
) -> Result<Self, ResponseError> {
let (embedder_name, embedder) =
Self::embedder(index_scheduler, index, embedder_name, vector_len)?;
Ok(Self::SemanticOnly { embedder_name, embedder })
}
pub(crate) fn hybrid(
index_scheduler: &index_scheduler::IndexScheduler,
index: &Index,
embedder_name: Option<&str>,
semantic_ratio: f32,
vector_len: Option<usize>,
) -> Result<Self, ResponseError> {
let (embedder_name, embedder) =
Self::embedder(index_scheduler, index, embedder_name, vector_len)?;
Ok(Self::Hybrid { embedder_name, embedder, semantic_ratio })
}
fn embedder(
index_scheduler: &index_scheduler::IndexScheduler,
index: &Index,
embedder_name: Option<&str>,
vector_len: Option<usize>,
) -> Result<(String, Arc<Embedder>), ResponseError> {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder_name = embedder_name.unwrap_or_else(|| embedders.get_default_embedder_name());
let embedder = embedders.get(embedder_name);
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned()))
.map_err(milli::Error::from)?
.0;
if let Some(vector_len) = vector_len {
if vector_len != embedder.dimensions() {
return Err(meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorDimensions {
expected: embedder.dimensions(),
found: vector_len,
},
)
.into());
}
}
Ok((embedder_name.to_owned(), embedder))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)] #[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32); pub struct SemanticRatio(f32);
@ -385,7 +449,7 @@ fn prepare_search<'t>(
rtxn: &'t RoTxn, rtxn: &'t RoTxn,
query: &'t SearchQuery, query: &'t SearchQuery,
features: RoFeatures, features: RoFeatures,
distribution: Option<DistributionShift>, search_kind: &SearchKind,
time_budget: TimeBudget, time_budget: TimeBudget,
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn); let mut search = index.search(rtxn);
@ -399,32 +463,30 @@ fn prepare_search<'t>(
features.check_vector("Passing `hybrid` as a query parameter")?; features.check_vector("Passing `hybrid` as a query parameter")?;
} }
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() { match search_kind {
return Err(MeilisearchHttpError::MissingSearchHybrid); SearchKind::KeywordOnly => {
} if let Some(q) = &query.q {
search.distribution_shift(distribution);
if let Some(ref vector) = query.vector {
match &query.hybrid {
// If semantic ratio is 0.0, only the query search will impact the search results,
// skip the vector
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
_otherwise => {
search.vector(vector.clone());
}
}
}
if let Some(ref q) = query.q {
match &query.hybrid {
// If semantic ratio is 1.0, only the vector search will impact the search results,
// skip the query
Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (),
_otherwise => {
search.query(q); search.query(q);
} }
} }
SearchKind::SemanticOnly { embedder_name, embedder } => {
let vector = match query.vector.clone() {
Some(vector) => vector,
None => embedder
.embed_one(query.q.clone().unwrap())
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?,
};
search.semantic(embedder_name.clone(), embedder.clone(), Some(vector));
}
SearchKind::Hybrid { embedder_name, embedder, semantic_ratio: _ } => {
if let Some(q) = &query.q {
search.query(q);
}
// will be embedded in hybrid search if necessary
search.semantic(embedder_name.clone(), embedder.clone(), query.vector.clone());
}
} }
if let Some(ref searchable) = query.attributes_to_search_on { if let Some(ref searchable) = query.attributes_to_search_on {
@ -447,10 +509,6 @@ fn prepare_search<'t>(
ScoringStrategy::Skip ScoringStrategy::Skip
}); });
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
search.embedder_name(embedder);
}
// compute the offset on the limit depending on the pagination mode. // compute the offset on the limit depending on the pagination mode.
let (offset, limit) = if is_finite_pagination { let (offset, limit) = if is_finite_pagination {
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
@ -494,7 +552,7 @@ pub fn perform_search(
index: &Index, index: &Index,
query: SearchQuery, query: SearchQuery,
features: RoFeatures, features: RoFeatures,
distribution: Option<DistributionShift>, search_kind: SearchKind,
) -> Result<SearchResult, MeilisearchHttpError> { ) -> Result<SearchResult, MeilisearchHttpError> {
let before_search = Instant::now(); let before_search = Instant::now();
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
@ -504,7 +562,7 @@ pub fn perform_search(
}; };
let (search, is_finite_pagination, max_total_hits, offset) = let (search, is_finite_pagination, max_total_hits, offset) =
prepare_search(index, &rtxn, &query, features, distribution, time_budget)?; prepare_search(index, &rtxn, &query, features, &search_kind, time_budget)?;
let milli::SearchResult { let milli::SearchResult {
documents_ids, documents_ids,
@ -514,12 +572,9 @@ pub fn perform_search(
degraded, degraded,
used_negative_operator, used_negative_operator,
.. ..
} = match &query.hybrid { } = match &search_kind {
Some(hybrid) => match *hybrid.semantic_ratio { SearchKind::KeywordOnly | SearchKind::SemanticOnly { .. } => search.execute()?,
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?, SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?,
ratio => search.execute_hybrid(ratio)?,
},
None => search.execute()?,
}; };
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
@ -726,6 +781,7 @@ pub fn perform_facet_search(
facet_query: Option<String>, facet_query: Option<String>,
facet_name: String, facet_name: String,
features: RoFeatures, features: RoFeatures,
search_kind: SearchKind,
) -> Result<FacetSearchResult, MeilisearchHttpError> { ) -> Result<FacetSearchResult, MeilisearchHttpError> {
let before_search = Instant::now(); let before_search = Instant::now();
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
@ -735,9 +791,12 @@ pub fn perform_facet_search(
}; };
let (search, _, _, _) = let (search, _, _, _) =
prepare_search(index, &rtxn, &search_query, features, None, time_budget)?; prepare_search(index, &rtxn, &search_query, features, &search_kind, time_budget)?;
let mut facet_search = let mut facet_search = SearchForFacetValues::new(
SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some()); facet_name,
search,
matches!(search_kind, SearchKind::Hybrid { .. }),
);
if let Some(facet_query) = &facet_query { if let Some(facet_query) = &facet_query {
facet_search.query(facet_query); facet_search.query(facet_query);
} }

View File

@ -1499,14 +1499,6 @@ impl Index {
.unwrap_or_default()) .unwrap_or_default())
} }
pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result<String> {
let configs = self.embedding_configs(rtxn)?;
Ok(match configs.as_slice() {
[(ref first_name, _)] => first_name.clone(),
_ => "default".to_owned(),
})
}
pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> { pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> {
self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff) self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff)
} }

View File

@ -61,7 +61,7 @@ pub use self::index::Index;
pub use self::search::facet::{FacetValueHit, SearchForFacetValues}; pub use self::search::facet::{FacetValueHit, SearchForFacetValues};
pub use self::search::{ pub use self::search::{
FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy,
Search, SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
}; };
pub type Result<T> = std::result::Result<T, error::Error>; pub type Result<T> = std::result::Result<T, error::Error>;

View File

@ -92,9 +92,15 @@ impl<'a> SearchForFacetValues<'a> {
None => return Ok(Vec::new()), None => return Ok(Vec::new()),
}; };
let search_candidates = self let search_candidates = self.search_query.execute_for_candidates(
self.is_hybrid
|| self
.search_query .search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?; .semantic
.as_ref()
.and_then(|semantic| semantic.vector.as_ref())
.is_some(),
)?;
let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) { let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) {
OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values), OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values),

View File

@ -4,6 +4,7 @@ use itertools::Itertools;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::search::SemanticSearch;
use crate::{MatchingWords, Result, Search, SearchResult}; use crate::{MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult { struct ScoreWithRatioResult {
@ -126,7 +127,6 @@ impl<'a> Search<'a> {
// create separate keyword and semantic searches // create separate keyword and semantic searches
let mut search = Search { let mut search = Search {
query: self.query.clone(), query: self.query.clone(),
vector: self.vector.clone(),
filter: self.filter.clone(), filter: self.filter.clone(),
offset: 0, offset: 0,
limit: self.limit + self.offset, limit: self.limit + self.offset,
@ -139,26 +139,41 @@ impl<'a> Search<'a> {
exhaustive_number_hits: self.exhaustive_number_hits, exhaustive_number_hits: self.exhaustive_number_hits,
rtxn: self.rtxn, rtxn: self.rtxn,
index: self.index, index: self.index,
distribution_shift: self.distribution_shift, semantic: self.semantic.clone(),
embedder_name: self.embedder_name.clone(),
time_budget: self.time_budget.clone(), time_budget: self.time_budget.clone(),
}; };
let vector_query = search.vector.take(); let semantic = search.semantic.take();
let keyword_results = search.execute()?; let keyword_results = search.execute()?;
// skip semantic search if we don't have a vector query (placeholder search)
let Some(vector_query) = vector_query else {
return Ok(keyword_results);
};
// completely skip semantic search if the results of the keyword search are good enough // completely skip semantic search if the results of the keyword search are good enough
if self.results_good_enough(&keyword_results, semantic_ratio) { if self.results_good_enough(&keyword_results, semantic_ratio) {
return Ok(keyword_results); return Ok(keyword_results);
} }
search.vector = Some(vector_query); // no vector search against placeholder search
search.query = None; let Some(query) = search.query.take() else { return Ok(keyword_results) };
// no embedder, no semantic search
let Some(SemanticSearch { vector, embedder_name, embedder }) = semantic else {
return Ok(keyword_results);
};
let vector_query = match vector {
Some(vector_query) => vector_query,
None => {
// attempt to embed the vector
match embedder.embed_one(query) {
Ok(embedding) => embedding,
Err(error) => {
tracing::error!(error=%error, "Embedding failed");
return Ok(keyword_results);
}
}
}
};
search.semantic =
Some(SemanticSearch { vector: Some(vector_query), embedder_name, embedder });
// TODO: would be better to have two distinct functions at this point // TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?; let vector_results = search.execute()?;

View File

@ -1,4 +1,5 @@
use std::fmt; use std::fmt;
use std::sync::Arc;
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
@ -8,7 +9,7 @@ pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FAC
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::{execute_vector_search, PartialSearchResult}; use self::new::{execute_vector_search, PartialSearchResult};
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift; use crate::vector::Embedder;
use crate::{ use crate::{
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result, execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result,
SearchContext, TimeBudget, SearchContext, TimeBudget,
@ -24,9 +25,15 @@ mod fst_utils;
pub mod hybrid; pub mod hybrid;
pub mod new; pub mod new;
#[derive(Debug, Clone)]
pub struct SemanticSearch {
vector: Option<Vec<f32>>,
embedder_name: String,
embedder: Arc<Embedder>,
}
pub struct Search<'a> { pub struct Search<'a> {
query: Option<String>, query: Option<String>,
vector: Option<Vec<f32>>,
// this should be linked to the String in the query // this should be linked to the String in the query
filter: Option<Filter<'a>>, filter: Option<Filter<'a>>,
offset: usize, offset: usize,
@ -38,12 +45,9 @@ pub struct Search<'a> {
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
words_limit: usize, words_limit: usize,
exhaustive_number_hits: bool, exhaustive_number_hits: bool,
/// TODO: Add semantic ratio or pass it directly to execute_hybrid()
rtxn: &'a heed::RoTxn<'a>, rtxn: &'a heed::RoTxn<'a>,
index: &'a Index, index: &'a Index,
distribution_shift: Option<DistributionShift>, semantic: Option<SemanticSearch>,
embedder_name: Option<String>,
time_budget: TimeBudget, time_budget: TimeBudget,
} }
@ -51,7 +55,6 @@ impl<'a> Search<'a> {
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
Search { Search {
query: None, query: None,
vector: None,
filter: None, filter: None,
offset: 0, offset: 0,
limit: 20, limit: 20,
@ -64,8 +67,7 @@ impl<'a> Search<'a> {
words_limit: 10, words_limit: 10,
rtxn, rtxn,
index, index,
distribution_shift: None, semantic: None,
embedder_name: None,
time_budget: TimeBudget::max(), time_budget: TimeBudget::max(),
} }
} }
@ -75,8 +77,13 @@ impl<'a> Search<'a> {
self self
} }
pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> { pub fn semantic(
self.vector = Some(vector); &mut self,
embedder_name: String,
embedder: Arc<Embedder>,
vector: Option<Vec<f32>>,
) -> &mut Search<'a> {
self.semantic = Some(SemanticSearch { embedder_name, embedder, vector });
self self
} }
@ -133,19 +140,6 @@ impl<'a> Search<'a> {
self self
} }
pub fn distribution_shift(
&mut self,
distribution_shift: Option<DistributionShift>,
) -> &mut Search<'a> {
self.distribution_shift = distribution_shift;
self
}
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
self.embedder_name = Some(embedder_name.into());
self
}
pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> { pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> {
self.time_budget = time_budget; self.time_budget = time_budget;
self self
@ -161,15 +155,6 @@ impl<'a> Search<'a> {
} }
pub fn execute(&self) -> Result<SearchResult> { pub fn execute(&self) -> Result<SearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn); let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes { if let Some(searchable_attributes) = self.searchable_attributes {
@ -184,8 +169,9 @@ impl<'a> Search<'a> {
document_scores, document_scores,
degraded, degraded,
used_negative_operator, used_negative_operator,
} = match self.vector.as_ref() { } = match self.semantic.as_ref() {
Some(vector) => execute_vector_search( Some(SemanticSearch { vector: Some(vector), embedder_name, embedder }) => {
execute_vector_search(
&mut ctx, &mut ctx,
vector, vector,
self.scoring_strategy, self.scoring_strategy,
@ -194,11 +180,12 @@ impl<'a> Search<'a> {
self.geo_strategy, self.geo_strategy,
self.offset, self.offset,
self.limit, self.limit,
self.distribution_shift,
embedder_name, embedder_name,
embedder,
self.time_budget.clone(), self.time_budget.clone(),
)?, )?
None => execute_search( }
_ => execute_search(
&mut ctx, &mut ctx,
self.query.as_deref(), self.query.as_deref(),
self.terms_matching_strategy, self.terms_matching_strategy,
@ -237,7 +224,6 @@ impl fmt::Debug for Search<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Search { let Search {
query, query,
vector: _,
filter, filter,
offset, offset,
limit, limit,
@ -250,8 +236,7 @@ impl fmt::Debug for Search<'_> {
exhaustive_number_hits, exhaustive_number_hits,
rtxn: _, rtxn: _,
index: _, index: _,
distribution_shift, semantic,
embedder_name,
time_budget, time_budget,
} = self; } = self;
f.debug_struct("Search") f.debug_struct("Search")
@ -266,8 +251,10 @@ impl fmt::Debug for Search<'_> {
.field("scoring_strategy", scoring_strategy) .field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits) .field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit) .field("words_limit", words_limit)
.field("distribution_shift", distribution_shift) .field(
.field("embedder_name", embedder_name) "semantic.embedder_name",
&semantic.as_ref().map(|semantic| &semantic.embedder_name),
)
.field("time_budget", time_budget) .field("time_budget", time_budget)
.finish() .finish()
} }

View File

@ -52,7 +52,7 @@ use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry; use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule; use crate::search::new::distinct::apply_distinct_rule;
use crate::vector::DistributionShift; use crate::vector::Embedder;
use crate::{ use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, TimeBudget, AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, TimeBudget,
UserError, UserError,
@ -298,8 +298,8 @@ fn get_ranking_rules_for_vector<'ctx>(
geo_strategy: geo_sort::Strategy, geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize, limit_plus_offset: usize,
target: &[f32], target: &[f32],
distribution_shift: Option<DistributionShift>,
embedder_name: &str, embedder_name: &str,
embedder: &Embedder,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search // query graph search
@ -325,8 +325,8 @@ fn get_ranking_rules_for_vector<'ctx>(
target.to_vec(), target.to_vec(),
vector_candidates, vector_candidates,
limit_plus_offset, limit_plus_offset,
distribution_shift,
embedder_name, embedder_name,
embedder,
)?; )?;
ranking_rules.push(Box::new(vector_sort)); ranking_rules.push(Box::new(vector_sort));
vector = true; vector = true;
@ -548,8 +548,8 @@ pub fn execute_vector_search(
geo_strategy: geo_sort::Strategy, geo_strategy: geo_sort::Strategy,
from: usize, from: usize,
length: usize, length: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str, embedder_name: &str,
embedder: &Embedder,
time_budget: TimeBudget, time_budget: TimeBudget,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -562,8 +562,8 @@ pub fn execute_vector_search(
geo_strategy, geo_strategy,
from + length, from + length,
vector, vector,
distribution_shift,
embedder_name, embedder_name,
embedder,
)?; )?;
let mut placeholder_search_logger = logger::DefaultSearchLogger; let mut placeholder_search_logger = logger::DefaultSearchLogger;

View File

@ -5,7 +5,7 @@ use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails}; use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift; use crate::vector::{DistributionShift, Embedder};
use crate::{DocumentId, Result, SearchContext, SearchLogger}; use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> { pub struct VectorSort<Q: RankingRuleQueryTrait> {
@ -24,8 +24,8 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
target: Vec<f32>, target: Vec<f32>,
vector_candidates: RoaringBitmap, vector_candidates: RoaringBitmap,
limit: usize, limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str, embedder_name: &str,
embedder: &Embedder,
) -> Result<Self> { ) -> Result<Self> {
let embedder_index = ctx let embedder_index = ctx
.index .index
@ -39,7 +39,7 @@ impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
vector_candidates, vector_candidates,
cached_sorted_docids: Default::default(), cached_sorted_docids: Default::default(),
limit, limit,
distribution_shift, distribution_shift: embedder.distribution(),
embedder_index, embedder_index,
}) })
} }