WIP normalized score merge strategy

This commit is contained in:
Louis Dureuil 2023-06-08 22:22:21 +02:00
parent fc0eb3901d
commit 5e99f16859
No known key found for this signature in database
2 changed files with 75 additions and 3 deletions

View File

@ -236,6 +236,7 @@ InvalidSearchHighlightPreTag , InvalidRequest , BAD_REQUEST ;
InvalidSearchHitsPerPage , InvalidRequest , BAD_REQUEST ;
InvalidSearchLimit , InvalidRequest , BAD_REQUEST ;
InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ;
InvalidMultiSearchMergeStrategy , InvalidRequest , BAD_REQUEST ;
InvalidSearchOffset , InvalidRequest , BAD_REQUEST ;
InvalidSearchPage , InvalidRequest , BAD_REQUEST ;
InvalidSearchQ , InvalidRequest , BAD_REQUEST ;

View File

@ -2,9 +2,11 @@ use actix_http::StatusCode;
use actix_web::web::{self, Data};
use actix_web::{HttpRequest, HttpResponse};
use deserr::actix_web::AwebJson;
use deserr::Deserr;
use index_scheduler::IndexScheduler;
use log::debug;
use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::InvalidMultiSearchMergeStrategy;
use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions;
use serde::Serialize;
@ -14,7 +16,7 @@ use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{AuthenticationError, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::search::{
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
add_search_rules, perform_search, SearchHit, SearchQueryWithIndex, SearchResultWithIndex,
};
pub fn configure(cfg: &mut web::ServiceConfig) {
@ -23,13 +25,34 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
#[derive(Serialize)]
struct SearchResults {
#[serde(skip_serializing_if = "Option::is_none")]
aggregate_hits: Option<Vec<SearchHitWithIndex>>,
results: Vec<SearchResultWithIndex>,
}
#[derive(Serialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
struct SearchHitWithIndex {
pub index_uid: String,
#[serde(flatten)]
pub hit: SearchHit,
}
#[derive(Debug, deserr::Deserr)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
pub struct SearchQueries {
queries: Vec<SearchQueryWithIndex>,
#[deserr(default, error = DeserrJsonError<InvalidMultiSearchMergeStrategy>, default)]
merge_strategy: MergeStrategy,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserr, Default)]
#[deserr(rename_all = camelCase)]
pub enum MergeStrategy {
#[default]
None,
ByNormalizedScore,
ByScoreDetails,
}
pub async fn multi_search_with_post(
@ -38,7 +61,13 @@ pub async fn multi_search_with_post(
req: HttpRequest,
analytics: web::Data<dyn Analytics>,
) -> Result<HttpResponse, ResponseError> {
let queries = params.into_inner().queries;
let SearchQueries { queries, merge_strategy } = params.into_inner();
// FIXME: REMOVE UNWRAP
let max_hits = queries
.iter()
.map(|SearchQueryWithIndex { limit, hits_per_page, .. }| hits_per_page.unwrap_or(*limit))
.max()
.unwrap();
let mut multi_aggregate = MultiSearchAggregator::from_queries(&queries, &req);
@ -104,7 +133,49 @@ pub async fn multi_search_with_post(
debug!("returns: {:?}", search_results);
Ok(HttpResponse::Ok().json(SearchResults { results: search_results }))
let aggregate_hits = match merge_strategy {
MergeStrategy::None => None,
MergeStrategy::ByScoreDetails => todo!(),
MergeStrategy::ByNormalizedScore => {
Some(merge_by_normalized_score(&search_results, max_hits))
}
};
Ok(HttpResponse::Ok().json(SearchResults { aggregate_hits, results: search_results }))
}
fn merge_by_normalized_score(
search_results: &[SearchResultWithIndex],
max_hits: usize,
) -> Vec<SearchHitWithIndex> {
let mut iterators: Vec<_> = search_results
.iter()
.filter_map(|SearchResultWithIndex { index_uid, result }| {
let mut it = result.hits.iter();
let next = it.next()?;
Some((index_uid, it, next))
})
.collect();
let mut hits = Vec::with_capacity(max_hits);
for _ in 0..max_hits {
iterators.sort_by_key(|(_, _, peeked)| peeked.ranking_score.unwrap());
let Some((index_uid, it, next)) = iterators.last_mut()
else {
break;
};
let hit = SearchHitWithIndex { index_uid: index_uid.clone(), hit: next.clone() };
if let Some(next_hit) = it.next() {
*next = next_hit;
} else {
iterators.pop();
}
hits.push(hit);
}
hits
}
/// Local `Result` extension trait to avoid `map_err` boilerplate.