diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 3e08498de..e4e93416d 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -236,6 +236,7 @@ InvalidSearchHighlightPreTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHitsPerPage , InvalidRequest , BAD_REQUEST ; InvalidSearchLimit , InvalidRequest , BAD_REQUEST ; InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ; +InvalidMultiSearchMergeStrategy , InvalidRequest , BAD_REQUEST ; InvalidSearchOffset , InvalidRequest , BAD_REQUEST ; InvalidSearchPage , InvalidRequest , BAD_REQUEST ; InvalidSearchQ , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index fd78df5e5..c2da30d60 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -2,9 +2,11 @@ use actix_http::StatusCode; use actix_web::web::{self, Data}; use actix_web::{HttpRequest, HttpResponse}; use deserr::actix_web::AwebJson; +use deserr::Deserr; use index_scheduler::IndexScheduler; use log::debug; use meilisearch_types::deserr::DeserrJsonError; +use meilisearch_types::error::deserr_codes::InvalidMultiSearchMergeStrategy; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; use serde::Serialize; @@ -14,7 +16,7 @@ use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, + add_search_rules, perform_search, SearchHit, SearchQueryWithIndex, SearchResultWithIndex, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -23,13 +25,34 @@ pub fn configure(cfg: &mut web::ServiceConfig) { #[derive(Serialize)] struct SearchResults { + #[serde(skip_serializing_if = "Option::is_none")] + aggregate_hits: Option>, results: Vec, } +#[derive(Serialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +struct SearchHitWithIndex { + pub index_uid: String, + #[serde(flatten)] + pub hit: SearchHit, +} + #[derive(Debug, deserr::Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct SearchQueries { queries: Vec, + #[deserr(default, error = DeserrJsonError, default)] + merge_strategy: MergeStrategy, +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserr, Default)] +#[deserr(rename_all = camelCase)] +pub enum MergeStrategy { + #[default] + None, + ByNormalizedScore, + ByScoreDetails, } pub async fn multi_search_with_post( @@ -38,7 +61,13 @@ pub async fn multi_search_with_post( req: HttpRequest, analytics: web::Data, ) -> Result { - let queries = params.into_inner().queries; + let SearchQueries { queries, merge_strategy } = params.into_inner(); + // FIXME: REMOVE UNWRAP + let max_hits = queries + .iter() + .map(|SearchQueryWithIndex { limit, hits_per_page, .. }| hits_per_page.unwrap_or(*limit)) + .max() + .unwrap(); let mut multi_aggregate = MultiSearchAggregator::from_queries(&queries, &req); @@ -104,7 +133,49 @@ pub async fn multi_search_with_post( debug!("returns: {:?}", search_results); - Ok(HttpResponse::Ok().json(SearchResults { results: search_results })) + let aggregate_hits = match merge_strategy { + MergeStrategy::None => None, + MergeStrategy::ByScoreDetails => todo!(), + MergeStrategy::ByNormalizedScore => { + Some(merge_by_normalized_score(&search_results, max_hits)) + } + }; + + Ok(HttpResponse::Ok().json(SearchResults { aggregate_hits, results: search_results })) +} + +fn merge_by_normalized_score( + search_results: &[SearchResultWithIndex], + max_hits: usize, +) -> Vec { + let mut iterators: Vec<_> = search_results + .iter() + .filter_map(|SearchResultWithIndex { index_uid, result }| { + let mut it = result.hits.iter(); + let next = it.next()?; + Some((index_uid, it, next)) + }) + .collect(); + + let mut hits = Vec::with_capacity(max_hits); + + for _ in 0..max_hits { + iterators.sort_by_key(|(_, _, peeked)| peeked.ranking_score.unwrap()); + + let Some((index_uid, it, next)) = iterators.last_mut() + else { + break; + }; + + let hit = SearchHitWithIndex { index_uid: index_uid.clone(), hit: next.clone() }; + if let Some(next_hit) = it.next() { + *next = next_hit; + } else { + iterators.pop(); + } + hits.push(hit); + } + hits } /// Local `Result` extension trait to avoid `map_err` boilerplate.