mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-23 10:37:41 +08:00
hybrid search uses semantic ratio, error handling
This commit is contained in:
parent
1b7c164a55
commit
217105b7da
@ -299,6 +299,7 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
|||||||
MissingIndexUid , InvalidRequest , BAD_REQUEST ;
|
MissingIndexUid , InvalidRequest , BAD_REQUEST ;
|
||||||
MissingMasterKey , Auth , UNAUTHORIZED ;
|
MissingMasterKey , Auth , UNAUTHORIZED ;
|
||||||
MissingPayload , InvalidRequest , BAD_REQUEST ;
|
MissingPayload , InvalidRequest , BAD_REQUEST ;
|
||||||
|
MissingSearchHybrid , InvalidRequest , BAD_REQUEST ;
|
||||||
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
|
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
|
||||||
MissingTaskFilters , InvalidRequest , BAD_REQUEST ;
|
MissingTaskFilters , InvalidRequest , BAD_REQUEST ;
|
||||||
NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY;
|
NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY;
|
||||||
|
@ -692,7 +692,7 @@ impl SearchAggregator {
|
|||||||
ret.max_terms_number = q.split_whitespace().count();
|
ret.max_terms_number = q.split_whitespace().count();
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector {
|
if let Some(ref vector) = vector {
|
||||||
ret.max_vector_size = vector.len();
|
ret.max_vector_size = vector.len();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +51,8 @@ pub enum MeilisearchHttpError {
|
|||||||
DocumentFormat(#[from] DocumentFormatError),
|
DocumentFormat(#[from] DocumentFormatError),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Join(#[from] JoinError),
|
Join(#[from] JoinError),
|
||||||
|
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
|
||||||
|
MissingSearchHybrid,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ErrorCode for MeilisearchHttpError {
|
impl ErrorCode for MeilisearchHttpError {
|
||||||
@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError {
|
|||||||
MeilisearchHttpError::FileStore(_) => Code::Internal,
|
MeilisearchHttpError::FileStore(_) => Code::Internal,
|
||||||
MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
|
MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
|
||||||
MeilisearchHttpError::Join(_) => Code::Internal,
|
MeilisearchHttpError::Join(_) => Code::Internal,
|
||||||
|
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,6 @@ use meilisearch_types::deserr::DeserrJsonError;
|
|||||||
use meilisearch_types::error::deserr_codes::*;
|
use meilisearch_types::error::deserr_codes::*;
|
||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use meilisearch_types::index_uid::IndexUid;
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
use meilisearch_types::milli::VectorQuery;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::analytics::{Analytics, FacetSearchAggregator};
|
use crate::analytics::{Analytics, FacetSearchAggregator};
|
||||||
@ -121,7 +120,7 @@ impl From<FacetSearchQuery> for SearchQuery {
|
|||||||
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
|
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
|
||||||
crop_marker: DEFAULT_CROP_MARKER(),
|
crop_marker: DEFAULT_CROP_MARKER(),
|
||||||
matching_strategy,
|
matching_strategy,
|
||||||
vector: vector.map(VectorQuery::Vector),
|
vector,
|
||||||
attributes_to_search_on,
|
attributes_to_search_on,
|
||||||
hybrid,
|
hybrid,
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
|
|||||||
use meilisearch_types::error::deserr_codes::*;
|
use meilisearch_types::error::deserr_codes::*;
|
||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::ResponseError;
|
||||||
use meilisearch_types::index_uid::IndexUid;
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
use meilisearch_types::milli::{self, VectorQuery};
|
use meilisearch_types::milli;
|
||||||
use meilisearch_types::serde_cs::vec::CS;
|
use meilisearch_types::serde_cs::vec::CS;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ impl From<SearchQueryGet> for SearchQuery {
|
|||||||
|
|
||||||
Self {
|
Self {
|
||||||
q: other.q,
|
q: other.q,
|
||||||
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
|
vector: other.vector.map(CS::into_inner),
|
||||||
offset: other.offset.0,
|
offset: other.offset.0,
|
||||||
limit: other.limit.0,
|
limit: other.limit.0,
|
||||||
page: other.page.as_deref().copied(),
|
page: other.page.as_deref().copied(),
|
||||||
@ -258,21 +258,13 @@ pub async fn embed(
|
|||||||
index_scheduler: &IndexScheduler,
|
index_scheduler: &IndexScheduler,
|
||||||
index: &milli::Index,
|
index: &milli::Index,
|
||||||
) -> Result<(), ResponseError> {
|
) -> Result<(), ResponseError> {
|
||||||
match query.vector.take() {
|
if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) =
|
||||||
Some(VectorQuery::String(prompt)) => {
|
(&query.vector, &query.q, &query.hybrid)
|
||||||
|
{
|
||||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||||
let embedders = index_scheduler.embedders(embedder_configs)?;
|
let embedders = index_scheduler.embedders(embedder_configs)?;
|
||||||
|
|
||||||
let embedder_name =
|
let embedder = if let Some(embedder_name) = embedder {
|
||||||
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
|
|
||||||
&query.hybrid
|
|
||||||
{
|
|
||||||
Some(embedder)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let embedder = if let Some(embedder_name) = embedder_name {
|
|
||||||
embedders.get(embedder_name)
|
embedders.get(embedder_name)
|
||||||
} else {
|
} else {
|
||||||
embedders.get_default()
|
embedders.get_default()
|
||||||
@ -283,7 +275,7 @@ pub async fn embed(
|
|||||||
.map_err(milli::Error::from)?
|
.map_err(milli::Error::from)?
|
||||||
.0;
|
.0;
|
||||||
let embeddings = embedder
|
let embeddings = embedder
|
||||||
.embed(vec![prompt])
|
.embed(vec![q.to_owned()])
|
||||||
.await
|
.await
|
||||||
.map_err(milli::vector::Error::from)
|
.map_err(milli::vector::Error::from)
|
||||||
.map_err(milli::Error::from)?
|
.map_err(milli::Error::from)?
|
||||||
@ -292,15 +284,11 @@ pub async fn embed(
|
|||||||
|
|
||||||
if embeddings.iter().nth(1).is_some() {
|
if embeddings.iter().nth(1).is_some() {
|
||||||
warn!("Ignoring embeddings past the first one in long search query");
|
warn!("Ignoring embeddings past the first one in long search query");
|
||||||
query.vector =
|
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
|
||||||
Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
|
|
||||||
} else {
|
} else {
|
||||||
query.vector = Some(VectorQuery::Vector(embeddings.into_inner()));
|
query.vector = Some(embeddings.into_inner());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(vector) => query.vector = Some(vector),
|
|
||||||
None => {}
|
|
||||||
};
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,14 +7,13 @@ use deserr::Deserr;
|
|||||||
use either::Either;
|
use either::Either;
|
||||||
use index_scheduler::RoFeatures;
|
use index_scheduler::RoFeatures;
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use log::warn;
|
|
||||||
use meilisearch_auth::IndexSearchRules;
|
use meilisearch_auth::IndexSearchRules;
|
||||||
use meilisearch_types::deserr::DeserrJsonError;
|
use meilisearch_types::deserr::DeserrJsonError;
|
||||||
use meilisearch_types::error::deserr_codes::*;
|
use meilisearch_types::error::deserr_codes::*;
|
||||||
use meilisearch_types::heed::RoTxn;
|
use meilisearch_types::heed::RoTxn;
|
||||||
use meilisearch_types::index_uid::IndexUid;
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
|
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
|
||||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
|
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
|
||||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||||
use meilisearch_types::{milli, Document};
|
use meilisearch_types::{milli, Document};
|
||||||
use milli::tokenizer::TokenizerBuilder;
|
use milli::tokenizer::TokenizerBuilder;
|
||||||
@ -44,7 +43,7 @@ pub struct SearchQuery {
|
|||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
||||||
pub vector: Option<milli::VectorQuery>,
|
pub vector: Option<Vec<f32>>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||||
pub hybrid: Option<HybridQuery>,
|
pub hybrid: Option<HybridQuery>,
|
||||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||||
@ -105,6 +104,8 @@ impl std::convert::TryFrom<f32> for SemanticRatio {
|
|||||||
type Error = InvalidSearchSemanticRatio;
|
type Error = InvalidSearchSemanticRatio;
|
||||||
|
|
||||||
fn try_from(f: f32) -> Result<Self, Self::Error> {
|
fn try_from(f: f32) -> Result<Self, Self::Error> {
|
||||||
|
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
|
||||||
|
#[allow(clippy::manual_range_contains)]
|
||||||
if f > 1.0 || f < 0.0 {
|
if f > 1.0 || f < 0.0 {
|
||||||
Err(InvalidSearchSemanticRatio)
|
Err(InvalidSearchSemanticRatio)
|
||||||
} else {
|
} else {
|
||||||
@ -139,7 +140,7 @@ pub struct SearchQueryWithIndex {
|
|||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
pub vector: Option<VectorQuery>,
|
pub vector: Option<Vec<f32>>,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||||
pub hybrid: Option<HybridQuery>,
|
pub hybrid: Option<HybridQuery>,
|
||||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||||
@ -376,8 +377,16 @@ fn prepare_search<'t>(
|
|||||||
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
|
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
|
||||||
let mut search = index.search(rtxn);
|
let mut search = index.search(rtxn);
|
||||||
|
|
||||||
if query.vector.is_some() && query.q.is_some() {
|
if query.vector.is_some() {
|
||||||
warn!("Attempting hybrid search");
|
features.check_vector("Passing `vector` as a query parameter")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.hybrid.is_some() {
|
||||||
|
features.check_vector("Passing `hybrid` as a query parameter")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
|
||||||
|
return Err(MeilisearchHttpError::MissingSearchHybrid);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref vector) = query.vector {
|
if let Some(ref vector) = query.vector {
|
||||||
@ -385,14 +394,9 @@ fn prepare_search<'t>(
|
|||||||
// If semantic ratio is 0.0, only the query search will impact the search results,
|
// If semantic ratio is 0.0, only the query search will impact the search results,
|
||||||
// skip the vector
|
// skip the vector
|
||||||
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
|
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
|
||||||
_otherwise => match vector {
|
_otherwise => {
|
||||||
VectorQuery::Vector(vector) => {
|
|
||||||
search.vector(vector.clone());
|
search.vector(vector.clone());
|
||||||
}
|
}
|
||||||
VectorQuery::String(_) => {
|
|
||||||
panic!("Failed while preparing search; caller did not generate embedding for query")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -431,10 +435,6 @@ fn prepare_search<'t>(
|
|||||||
features.check_score_details()?;
|
features.check_score_details()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if query.vector.is_some() {
|
|
||||||
features.check_vector("Passing `vector` as a query parameter")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
|
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
|
||||||
search.embedder_name(embedder);
|
search.embedder_name(embedder);
|
||||||
}
|
}
|
||||||
@ -492,7 +492,7 @@ pub fn perform_search(
|
|||||||
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
|
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
|
||||||
match &query.hybrid {
|
match &query.hybrid {
|
||||||
Some(hybrid) => match *hybrid.semantic_ratio {
|
Some(hybrid) => match *hybrid.semantic_ratio {
|
||||||
0.0 | 1.0 => search.execute()?,
|
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
|
||||||
ratio => search.execute_hybrid(ratio)?,
|
ratio => search.execute_hybrid(ratio)?,
|
||||||
},
|
},
|
||||||
None => search.execute()?,
|
None => search.execute()?,
|
||||||
@ -700,10 +700,7 @@ pub fn perform_search(
|
|||||||
hits: documents,
|
hits: documents,
|
||||||
hits_info,
|
hits_info,
|
||||||
query: query.q.unwrap_or_default(),
|
query: query.q.unwrap_or_default(),
|
||||||
vector: match query.vector {
|
vector: query.vector,
|
||||||
Some(VectorQuery::Vector(vector)) => Some(vector),
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
processing_time_ms: before_search.elapsed().as_millis(),
|
processing_time_ms: before_search.elapsed().as_millis(),
|
||||||
facet_distribution,
|
facet_distribution,
|
||||||
facet_stats,
|
facet_stats,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use meili_snap::{json_string, snapshot};
|
use meili_snap::snapshot;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
|
||||||
use crate::common::index::Index;
|
use crate::common::index::Index;
|
||||||
|
@ -59,7 +59,7 @@ pub use self::index::Index;
|
|||||||
pub use self::search::{
|
pub use self::search::{
|
||||||
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
|
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
|
||||||
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
|
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
|
||||||
VectorQuery, DEFAULT_VALUES_PER_FACET,
|
DEFAULT_VALUES_PER_FACET,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, error::Error>;
|
pub type Result<T> = std::result::Result<T, error::Error>;
|
||||||
|
@ -1,49 +1,37 @@
|
|||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
|
|
||||||
use super::new::{execute_vector_search, PartialSearchResult};
|
|
||||||
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
|
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
|
||||||
use crate::{
|
use crate::{MatchingWords, Result, Search, SearchResult};
|
||||||
execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CombinedSearchResult {
|
struct ScoreWithRatioResult {
|
||||||
matching_words: MatchingWords,
|
matching_words: MatchingWords,
|
||||||
candidates: RoaringBitmap,
|
candidates: RoaringBitmap,
|
||||||
document_scores: Vec<(u32, CombinedScore)>,
|
document_scores: Vec<(u32, ScoreWithRatio)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>);
|
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
|
||||||
|
|
||||||
fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
|
fn compare_scores(
|
||||||
let mut left_main_it = ScoreDetails::score_values(left.0.iter());
|
&(ref left_scores, left_ratio): &ScoreWithRatio,
|
||||||
let mut left_sub_it =
|
&(ref right_scores, right_ratio): &ScoreWithRatio,
|
||||||
ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten());
|
) -> Ordering {
|
||||||
|
let mut left_it = ScoreDetails::score_values(left_scores.iter());
|
||||||
let mut right_main_it = ScoreDetails::score_values(right.0.iter());
|
let mut right_it = ScoreDetails::score_values(right_scores.iter());
|
||||||
let mut right_sub_it =
|
|
||||||
ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten());
|
|
||||||
|
|
||||||
let mut left_main = left_main_it.next();
|
|
||||||
let mut left_sub = left_sub_it.next();
|
|
||||||
let mut right_main = right_main_it.next();
|
|
||||||
let mut right_sub = right_sub_it.next();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let left =
|
let left = left_it.next();
|
||||||
take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it);
|
let right = right_it.next();
|
||||||
|
|
||||||
let right =
|
|
||||||
take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it);
|
|
||||||
|
|
||||||
match (left, right) {
|
match (left, right) {
|
||||||
(None, None) => return Ordering::Equal,
|
(None, None) => return Ordering::Equal,
|
||||||
(None, Some(_)) => return Ordering::Less,
|
(None, Some(_)) => return Ordering::Less,
|
||||||
(Some(_), None) => return Ordering::Greater,
|
(Some(_), None) => return Ordering::Greater,
|
||||||
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
|
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
|
||||||
|
let left = left * left_ratio as f64;
|
||||||
|
let right = right * right_ratio as f64;
|
||||||
if (left - right).abs() <= f64::EPSILON {
|
if (left - right).abs() <= f64::EPSILON {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -72,94 +60,17 @@ fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn take_best_score<'a>(
|
impl ScoreWithRatioResult {
|
||||||
main_score: &mut Option<ScoreValue<'a>>,
|
fn new(results: SearchResult, ratio: f32) -> Self {
|
||||||
sub_score: &mut Option<ScoreValue<'a>>,
|
let document_scores = results
|
||||||
main_it: &mut impl Iterator<Item = ScoreValue<'a>>,
|
|
||||||
sub_it: &mut impl Iterator<Item = ScoreValue<'a>>,
|
|
||||||
) -> Option<ScoreValue<'a>> {
|
|
||||||
match (*main_score, *sub_score) {
|
|
||||||
(Some(main), None) => {
|
|
||||||
*main_score = main_it.next();
|
|
||||||
Some(main)
|
|
||||||
}
|
|
||||||
(None, Some(sub)) => {
|
|
||||||
*sub_score = sub_it.next();
|
|
||||||
Some(sub)
|
|
||||||
}
|
|
||||||
(main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => {
|
|
||||||
// take max, both advance
|
|
||||||
*main_score = main_it.next();
|
|
||||||
*sub_score = sub_it.next();
|
|
||||||
if main_f >= sub_v {
|
|
||||||
main
|
|
||||||
} else {
|
|
||||||
sub
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(main @ Some(ScoreValue::Score(_)), _) => {
|
|
||||||
*main_score = main_it.next();
|
|
||||||
main
|
|
||||||
}
|
|
||||||
(_, sub @ Some(ScoreValue::Score(_))) => {
|
|
||||||
*sub_score = sub_it.next();
|
|
||||||
sub
|
|
||||||
}
|
|
||||||
(main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => {
|
|
||||||
// take best advance both
|
|
||||||
*main_score = main_it.next();
|
|
||||||
*sub_score = sub_it.next();
|
|
||||||
if main_geo >= sub_geo {
|
|
||||||
main
|
|
||||||
} else {
|
|
||||||
sub
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => {
|
|
||||||
// take best advance both
|
|
||||||
*main_score = main_it.next();
|
|
||||||
*sub_score = sub_it.next();
|
|
||||||
if main_sort >= sub_sort {
|
|
||||||
main
|
|
||||||
} else {
|
|
||||||
sub
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(
|
|
||||||
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
|
|
||||||
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
|
|
||||||
) => None,
|
|
||||||
|
|
||||||
(None, None) => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CombinedSearchResult {
|
|
||||||
fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self {
|
|
||||||
let mut docid_scores = HashMap::new();
|
|
||||||
for (docid, score) in
|
|
||||||
main_results.documents_ids.iter().zip(main_results.document_scores.into_iter())
|
|
||||||
{
|
|
||||||
docid_scores.insert(*docid, (score, None));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (docid, score) in ancillary_results
|
|
||||||
.documents_ids
|
.documents_ids
|
||||||
.iter()
|
.into_iter()
|
||||||
.zip(ancillary_results.document_scores.into_iter())
|
.zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
|
||||||
{
|
.collect();
|
||||||
docid_scores
|
|
||||||
.entry(*docid)
|
|
||||||
.and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut document_scores: Vec<_> = docid_scores.into_iter().collect();
|
|
||||||
|
|
||||||
document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse());
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
matching_words: main_results.matching_words,
|
matching_words: results.matching_words,
|
||||||
candidates: main_results.candidates,
|
candidates: results.candidates,
|
||||||
document_scores,
|
document_scores,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -200,7 +111,7 @@ impl CombinedSearchResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Search<'a> {
|
impl<'a> Search<'a> {
|
||||||
pub fn execute_hybrid(&self) -> Result<SearchResult> {
|
pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> {
|
||||||
// TODO: find classier way to achieve that than to reset vector and query params
|
// TODO: find classier way to achieve that than to reset vector and query params
|
||||||
// create separate keyword and semantic searches
|
// create separate keyword and semantic searches
|
||||||
let mut search = Search {
|
let mut search = Search {
|
||||||
@ -223,8 +134,6 @@ impl<'a> Search<'a> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let vector_query = search.vector.take();
|
let vector_query = search.vector.take();
|
||||||
let keyword_query = self.query.as_deref();
|
|
||||||
|
|
||||||
let keyword_results = search.execute()?;
|
let keyword_results = search.execute()?;
|
||||||
|
|
||||||
// skip semantic search if we don't have a vector query (placeholder search)
|
// skip semantic search if we don't have a vector query (placeholder search)
|
||||||
@ -233,7 +142,7 @@ impl<'a> Search<'a> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// completely skip semantic search if the results of the keyword search are good enough
|
// completely skip semantic search if the results of the keyword search are good enough
|
||||||
if self.results_good_enough(&keyword_results) {
|
if self.results_good_enough(&keyword_results, semantic_ratio) {
|
||||||
return Ok(keyword_results);
|
return Ok(keyword_results);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,94 +152,18 @@ impl<'a> Search<'a> {
|
|||||||
// TODO: would be better to have two distinct functions at this point
|
// TODO: would be better to have two distinct functions at this point
|
||||||
let vector_results = search.execute()?;
|
let vector_results = search.execute()?;
|
||||||
|
|
||||||
// Compute keyword scores for vector_results
|
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
|
||||||
let keyword_results_for_vector =
|
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
|
||||||
self.keyword_results_for_vector(keyword_query, &vector_results)?;
|
|
||||||
|
|
||||||
// compute vector scores for keyword_results
|
|
||||||
let vector_results_for_keyword =
|
|
||||||
// can unwrap because we returned already if there was no vector query
|
|
||||||
self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?;
|
|
||||||
|
|
||||||
/// TODO apply sementic ratio
|
|
||||||
let keyword_results =
|
|
||||||
CombinedSearchResult::new(keyword_results, vector_results_for_keyword);
|
|
||||||
let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector);
|
|
||||||
|
|
||||||
let merge_results =
|
let merge_results =
|
||||||
CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit);
|
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
|
||||||
assert!(merge_results.documents_ids.len() <= self.limit);
|
assert!(merge_results.documents_ids.len() <= self.limit);
|
||||||
Ok(merge_results)
|
Ok(merge_results)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vector_results_for_keyword(
|
fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
|
||||||
&self,
|
// A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
|
||||||
vector: &[f32],
|
const GOOD_ENOUGH_SCORE: f64 = 0.45;
|
||||||
keyword_results: &SearchResult,
|
|
||||||
) -> Result<PartialSearchResult> {
|
|
||||||
let embedder_name;
|
|
||||||
let embedder_name = match &self.embedder_name {
|
|
||||||
Some(embedder_name) => embedder_name,
|
|
||||||
None => {
|
|
||||||
embedder_name = self.index.default_embedding_name(self.rtxn)?;
|
|
||||||
&embedder_name
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
|
||||||
|
|
||||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
|
||||||
ctx.searchable_attributes(searchable_attributes)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let universe = keyword_results.documents_ids.iter().collect();
|
|
||||||
|
|
||||||
execute_vector_search(
|
|
||||||
&mut ctx,
|
|
||||||
vector,
|
|
||||||
ScoringStrategy::Detailed,
|
|
||||||
universe,
|
|
||||||
&self.sort_criteria,
|
|
||||||
self.geo_strategy,
|
|
||||||
0,
|
|
||||||
self.limit + self.offset,
|
|
||||||
self.distribution_shift,
|
|
||||||
embedder_name,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn keyword_results_for_vector(
|
|
||||||
&self,
|
|
||||||
query: Option<&str>,
|
|
||||||
vector_results: &SearchResult,
|
|
||||||
) -> Result<PartialSearchResult> {
|
|
||||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
|
||||||
|
|
||||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
|
||||||
ctx.searchable_attributes(searchable_attributes)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let universe = vector_results.documents_ids.iter().collect();
|
|
||||||
|
|
||||||
execute_search(
|
|
||||||
&mut ctx,
|
|
||||||
query,
|
|
||||||
self.terms_matching_strategy,
|
|
||||||
ScoringStrategy::Detailed,
|
|
||||||
self.exhaustive_number_hits,
|
|
||||||
universe,
|
|
||||||
&self.sort_criteria,
|
|
||||||
self.geo_strategy,
|
|
||||||
0,
|
|
||||||
self.limit + self.offset,
|
|
||||||
Some(self.words_limit),
|
|
||||||
&mut DefaultSearchLogger,
|
|
||||||
&mut DefaultSearchLogger,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn results_good_enough(&self, keyword_results: &SearchResult) -> bool {
|
|
||||||
const GOOD_ENOUGH_SCORE: f64 = 0.9;
|
|
||||||
|
|
||||||
// 1. we check that we got a sufficient number of results
|
// 1. we check that we got a sufficient number of results
|
||||||
if keyword_results.document_scores.len() < self.limit + self.offset {
|
if keyword_results.document_scores.len() < self.limit + self.offset {
|
||||||
@ -341,7 +174,7 @@ impl<'a> Search<'a> {
|
|||||||
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
|
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
|
||||||
for score in &keyword_results.document_scores {
|
for score in &keyword_results.document_scores {
|
||||||
let score = ScoreDetails::global_score(score.iter());
|
let score = ScoreDetails::global_score(score.iter());
|
||||||
if score < GOOD_ENOUGH_SCORE {
|
if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ use std::ops::ControlFlow;
|
|||||||
|
|
||||||
use charabia::normalizer::NormalizerOption;
|
use charabia::normalizer::NormalizerOption;
|
||||||
use charabia::Normalize;
|
use charabia::Normalize;
|
||||||
use deserr::{DeserializeError, Deserr, Sequence};
|
|
||||||
use fst::automaton::{Automaton, Str};
|
use fst::automaton::{Automaton, Str};
|
||||||
use fst::{IntoStreamer, Streamer};
|
use fst::{IntoStreamer, Streamer};
|
||||||
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
|
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
|
||||||
@ -57,53 +56,6 @@ pub struct Search<'a> {
|
|||||||
embedder_name: Option<String>,
|
embedder_name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub enum VectorQuery {
|
|
||||||
Vector(Vec<f32>),
|
|
||||||
String(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E> Deserr<E> for VectorQuery
|
|
||||||
where
|
|
||||||
E: DeserializeError,
|
|
||||||
{
|
|
||||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
|
||||||
value: deserr::Value<V>,
|
|
||||||
location: deserr::ValuePointerRef,
|
|
||||||
) -> std::result::Result<Self, E> {
|
|
||||||
match value {
|
|
||||||
deserr::Value::String(s) => Ok(VectorQuery::String(s)),
|
|
||||||
deserr::Value::Sequence(seq) => {
|
|
||||||
let v: std::result::Result<Vec<f32>, _> = seq
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(index, v)| match v.into_value() {
|
|
||||||
deserr::Value::Float(f) => Ok(f as f32),
|
|
||||||
deserr::Value::Integer(i) => Ok(i as f32),
|
|
||||||
v => Err(deserr::take_cf_content(E::error::<V>(
|
|
||||||
None,
|
|
||||||
deserr::ErrorKind::IncorrectValueKind {
|
|
||||||
actual: v,
|
|
||||||
accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer],
|
|
||||||
},
|
|
||||||
location.push_index(index),
|
|
||||||
))),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Ok(VectorQuery::Vector(v?))
|
|
||||||
}
|
|
||||||
_ => Err(deserr::take_cf_content(E::error::<V>(
|
|
||||||
None,
|
|
||||||
deserr::ErrorKind::IncorrectValueKind {
|
|
||||||
actual: value,
|
|
||||||
accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence],
|
|
||||||
},
|
|
||||||
location,
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Search<'a> {
|
impl<'a> Search<'a> {
|
||||||
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
||||||
Search {
|
Search {
|
||||||
|
Loading…
Reference in New Issue
Block a user