From ca6cc4654b4182e2415bd35254906e89d32a953a Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 9 Apr 2024 12:03:40 +0200 Subject: [PATCH] Add similar route --- meilisearch-types/src/deserr/mod.rs | 1 + meilisearch-types/src/error.rs | 13 +- meilisearch/src/error.rs | 3 - meilisearch/src/routes/indexes/mod.rs | 2 + meilisearch/src/routes/indexes/similar.rs | 171 +++++++++ meilisearch/src/search.rs | 445 ++++++++++++++++------ milli/src/lib.rs | 1 + milli/src/search/mod.rs | 1 + milli/src/search/similar.rs | 111 ++++++ 9 files changed, 620 insertions(+), 128 deletions(-) create mode 100644 meilisearch/src/routes/indexes/similar.rs create mode 100644 milli/src/search/similar.rs diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index bf1aa1da5..c593c50fb 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -189,3 +189,4 @@ merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); +merge_with_error_impl_take_error_message!(InvalidSimilarId); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 8ae64e5a8..d2218807f 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -246,7 +246,7 @@ InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; -InvalidRecommendId , InvalidRequest , BAD_REQUEST ; +InvalidSimilarId , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSimilarFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -494,6 +494,17 @@ impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { } } +impl fmt::Display for deserr_codes::InvalidSimilarId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `id` is invalid. \ + A document identifier can be of type integer or string, \ + only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and underscores (_)." + ) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch/src/error.rs b/meilisearch/src/error.rs index 13e460c24..5a0b04020 100644 --- a/meilisearch/src/error.rs +++ b/meilisearch/src/error.rs @@ -23,8 +23,6 @@ pub enum MeilisearchHttpError { InvalidContentType(String, Vec), #[error("Document `{0}` not found.")] DocumentNotFound(String), - #[error("Document `{0}` not found.")] - InvalidDocumentId(String), #[error("Sending an empty filter is forbidden.")] EmptyFilter, #[error("Invalid syntax for the filter parameter: `expected {}, found: {1}`.", .0.join(", "))] @@ -72,7 +70,6 @@ impl ErrorCode for MeilisearchHttpError { MeilisearchHttpError::MissingPayload(_) => Code::MissingPayload, MeilisearchHttpError::InvalidContentType(_, _) => Code::InvalidContentType, MeilisearchHttpError::DocumentNotFound(_) => Code::DocumentNotFound, - MeilisearchHttpError::InvalidDocumentId(_) => Code::InvalidDocumentId, MeilisearchHttpError::EmptyFilter => Code::InvalidDocumentFilter, MeilisearchHttpError::InvalidExpression(_, _) => Code::InvalidSearchFilter, MeilisearchHttpError::PayloadTooLarge(_) => Code::PayloadTooLarge, diff --git a/meilisearch/src/routes/indexes/mod.rs b/meilisearch/src/routes/indexes/mod.rs index 651977723..35b747ccf 100644 --- a/meilisearch/src/routes/indexes/mod.rs +++ b/meilisearch/src/routes/indexes/mod.rs @@ -29,6 +29,7 @@ pub mod documents; pub mod facet_search; pub mod search; pub mod settings; +pub mod similar; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service( @@ -48,6 +49,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) { .service(web::scope("/documents").configure(documents::configure)) .service(web::scope("/search").configure(search::configure)) .service(web::scope("/facet-search").configure(facet_search::configure)) + .service(web::scope("/similar").configure(similar::configure)) .service(web::scope("/settings").configure(settings::configure)), ); } diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs new file mode 100644 index 000000000..da73dd63b --- /dev/null +++ b/meilisearch/src/routes/indexes/similar.rs @@ -0,0 +1,171 @@ +use actix_web::web::{self, Data}; +use actix_web::{HttpRequest, HttpResponse}; +use deserr::actix_web::{AwebJson, AwebQueryParameter}; +use index_scheduler::IndexScheduler; +use meilisearch_types::deserr::query_params::Param; +use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; +use meilisearch_types::error::deserr_codes::{ + InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId, + InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarShowRankingScore, + InvalidSimilarShowRankingScoreDetails, +}; +use meilisearch_types::error::{ErrorCode as _, ResponseError}; +use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::keys::actions; +use meilisearch_types::serde_cs::vec::CS; +use serde_json::Value; +use tracing::debug; + +use super::ActionPolicy; +use crate::analytics::{Analytics, SimilarAggregator}; +use crate::extractors::authentication::GuardedData; +use crate::extractors::sequential_extractor::SeqHandler; +use crate::search::{ + add_search_rules, perform_similar, SearchKind, SimilarQuery, SimilarResult, + DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, +}; + +pub fn configure(cfg: &mut web::ServiceConfig) { + cfg.service( + web::resource("") + .route(web::get().to(SeqHandler(similar_get))) + .route(web::post().to(SeqHandler(similar_post))), + ); +} + +pub async fn similar_get( + index_scheduler: GuardedData, Data>, + index_uid: web::Path, + params: AwebQueryParameter, + req: HttpRequest, + analytics: web::Data, +) -> Result { + let index_uid = IndexUid::try_from(index_uid.into_inner())?; + + let query = params.0.try_into().map_err(|code: InvalidSimilarId| { + ResponseError::from_msg(code.to_string(), code.error_code()) + })?; + + let mut aggregate = SimilarAggregator::from_query(&query, &req); + + debug!(parameters = ?query, "Similar get"); + + let similar = similar(index_scheduler, index_uid, query).await; + + if let Ok(similar) = &similar { + aggregate.succeed(similar); + } + analytics.get_similar(aggregate); + + let similar = similar?; + + debug!(returns = ?similar, "Similar get"); + Ok(HttpResponse::Ok().json(similar)) +} + +pub async fn similar_post( + index_scheduler: GuardedData, Data>, + index_uid: web::Path, + params: AwebJson, + req: HttpRequest, + analytics: web::Data, +) -> Result { + let index_uid = IndexUid::try_from(index_uid.into_inner())?; + + let query = params.into_inner(); + debug!(parameters = ?query, "Similar post"); + + let mut aggregate = SimilarAggregator::from_query(&query, &req); + + let similar = similar(index_scheduler, index_uid, query).await; + + if let Ok(similar) = &similar { + aggregate.succeed(similar); + } + analytics.post_similar(aggregate); + + let similar = similar?; + + debug!(returns = ?similar, "Similar post"); + Ok(HttpResponse::Ok().json(similar)) +} + +async fn similar( + index_scheduler: GuardedData, Data>, + index_uid: IndexUid, + mut query: SimilarQuery, +) -> Result { + let features = index_scheduler.features(); + + features.check_vector("Using the similar API")?; + + // Tenant token search_rules. + if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { + add_search_rules(&mut query.filter, search_rules); + } + + let index = index_scheduler.index(&index_uid)?; + + let (embedder_name, embedder) = + SearchKind::embedder(&index_scheduler, &index, query.embedder.as_deref(), None)?; + + tokio::task::spawn_blocking(move || perform_similar(&index, query, embedder_name, embedder)) + .await? +} + +#[derive(Debug, deserr::Deserr)] +#[deserr(error = DeserrQueryParamError, rename_all = camelCase, deny_unknown_fields)] +pub struct SimilarQueryGet { + #[deserr(error = DeserrQueryParamError)] + id: Param, + #[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError)] + offset: Param, + #[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError)] + limit: Param, + #[deserr(default, error = DeserrQueryParamError)] + attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrQueryParamError)] + filter: Option, + #[deserr(default, error = DeserrQueryParamError)] + show_ranking_score: Param, + #[deserr(default, error = DeserrQueryParamError)] + show_ranking_score_details: Param, + #[deserr(default, error = DeserrQueryParamError)] + pub embedder: Option, +} + +impl TryFrom for SimilarQuery { + type Error = InvalidSimilarId; + + fn try_from( + SimilarQueryGet { + id, + offset, + limit, + attributes_to_retrieve, + filter, + show_ranking_score, + show_ranking_score_details, + embedder, + }: SimilarQueryGet, + ) -> Result { + let filter = match filter { + Some(f) => match serde_json::from_str(&f) { + Ok(v) => Some(v), + _ => Some(Value::String(f)), + }, + None => None, + }; + + Ok(SimilarQuery { + id: id.0.try_into()?, + offset: offset.0, + limit: limit.0, + filter, + embedder, + attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), + show_ranking_score: show_ranking_score.0, + show_ranking_score_details: show_ranking_score_details.0, + }) + } +} diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 7c3813b55..c6c4e88ca 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -11,7 +11,7 @@ use indexmap::IndexMap; use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; -use meilisearch_types::error::ResponseError; +use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; @@ -417,6 +417,59 @@ impl SearchQueryWithIndex { } } +#[derive(Debug, Clone, PartialEq, Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct SimilarQuery { + #[deserr(error = DeserrJsonError)] + pub id: ExternalDocumentId, + #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] + pub offset: usize, + #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] + pub limit: usize, + #[deserr(default, error = DeserrJsonError)] + pub filter: Option, + #[deserr(default, error = DeserrJsonError, default)] + pub embedder: Option, + #[deserr(default, error = DeserrJsonError)] + pub attributes_to_retrieve: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub show_ranking_score: bool, + #[deserr(default, error = DeserrJsonError, default)] + pub show_ranking_score_details: bool, +} + +#[derive(Debug, Clone, PartialEq, Deserr)] +#[deserr(try_from(Value) = TryFrom::try_from -> InvalidSimilarId)] +pub struct ExternalDocumentId(String); + +impl AsRef for ExternalDocumentId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl ExternalDocumentId { + pub fn into_inner(self) -> String { + self.0 + } +} + +impl TryFrom for ExternalDocumentId { + type Error = InvalidSimilarId; + + fn try_from(value: String) -> Result { + serde_json::Value::String(value).try_into() + } +} + +impl TryFrom for ExternalDocumentId { + type Error = InvalidSimilarId; + + fn try_from(value: Value) -> Result { + Ok(Self(milli::documents::validate_document_id_value(value).map_err(|_| InvalidSimilarId)?)) + } +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)] #[deserr(rename_all = camelCase)] pub enum MatchingStrategy { @@ -538,6 +591,16 @@ impl fmt::Debug for SearchResult { } } +#[derive(Serialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SimilarResult { + pub hits: Vec, + pub id: String, + pub processing_time_ms: u128, + #[serde(flatten)] + pub hits_info: HitsInfo, +} + #[derive(Serialize, Debug, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub struct SearchResultWithIndex { @@ -719,131 +782,52 @@ pub fn perform_search( SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?, }; - let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); + let SearchQuery { + q, + vector: _, + hybrid: _, + // already computed from prepare_search + offset: _, + limit, + page, + hits_per_page, + attributes_to_retrieve, + attributes_to_crop, + crop_length, + attributes_to_highlight, + show_matches_position, + show_ranking_score, + show_ranking_score_details, + filter: _, + sort, + facets, + highlight_pre_tag, + highlight_post_tag, + crop_marker, + matching_strategy: _, + attributes_to_search_on: _, + } = query; - let displayed_ids = index - .displayed_fields_ids(&rtxn)? - .map(|fields| fields.into_iter().collect::>()) - .unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); - - let fids = |attrs: &BTreeSet| { - let mut ids = BTreeSet::new(); - for attr in attrs { - if attr == "*" { - ids.clone_from(&displayed_ids); - break; - } - - if let Some(id) = fields_ids_map.id(attr) { - ids.insert(id); - } - } - ids + let format = AttributesFormat { + attributes_to_retrieve, + attributes_to_highlight, + attributes_to_crop, + crop_length, + crop_marker, + highlight_pre_tag, + highlight_post_tag, + show_matches_position, + sort, + show_ranking_score, + show_ranking_score_details, }; - // The attributes to retrieve are the ones explicitly marked as to retrieve (all by default), - // but these attributes must be also be present - // - in the fields_ids_map - // - in the displayed attributes - let to_retrieve_ids: BTreeSet<_> = query - .attributes_to_retrieve - .as_ref() - .map(fids) - .unwrap_or_else(|| displayed_ids.clone()) - .intersection(&displayed_ids) - .cloned() - .collect(); - - let attr_to_highlight = query.attributes_to_highlight.unwrap_or_default(); - - let attr_to_crop = query.attributes_to_crop.unwrap_or_default(); - - // Attributes in `formatted_options` correspond to the attributes that will be in `_formatted` - // These attributes are: - // - the attributes asked to be highlighted or cropped (with `attributesToCrop` or `attributesToHighlight`) - // - the attributes asked to be retrieved: these attributes will not be highlighted/cropped - // But these attributes must be also present in displayed attributes - let formatted_options = compute_formatted_options( - &attr_to_highlight, - &attr_to_crop, - query.crop_length, - &to_retrieve_ids, - &fields_ids_map, - &displayed_ids, - ); - - let mut tokenizer_builder = TokenizerBuilder::default(); - tokenizer_builder.create_char_map(true); - - let script_lang_map = index.script_language(&rtxn)?; - if !script_lang_map.is_empty() { - tokenizer_builder.allow_list(&script_lang_map); - } - - let separators = index.allowed_separators(&rtxn)?; - let separators: Option> = - separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref separators) = separators { - tokenizer_builder.separators(separators); - } - - let dictionary = index.dictionary(&rtxn)?; - let dictionary: Option> = - dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref dictionary) = dictionary { - tokenizer_builder.words_dict(dictionary); - } - - let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build()); - formatter_builder.crop_marker(query.crop_marker); - formatter_builder.highlight_prefix(query.highlight_pre_tag); - formatter_builder.highlight_suffix(query.highlight_post_tag); - - let mut documents = Vec::new(); - let documents_iter = index.documents(&rtxn, documents_ids)?; - - for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { - // First generate a document with all the displayed fields - let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; - - // select the attributes to retrieve - let attributes_to_retrieve = to_retrieve_ids - .iter() - .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); - let mut document = - permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); - - let (matches_position, formatted) = format_fields( - &displayed_document, - &fields_ids_map, - &formatter_builder, - &formatted_options, - query.show_matches_position, - &displayed_ids, - )?; - - if let Some(sort) = query.sort.as_ref() { - insert_geo_distance(sort, &mut document); - } - - let ranking_score = - query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); - let ranking_score_details = - query.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); - - let hit = SearchHit { - document, - formatted, - matches_position, - ranking_score_details, - ranking_score, - }; - documents.push(hit); - } + let documents = + make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?; let number_of_hits = min(candidates.len() as usize, max_total_hits); let hits_info = if is_finite_pagination { - let hits_per_page = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); + let hits_per_page = hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); // If hit_per_page is 0, then pages can't be computed and so we respond 0. let total_pages = (number_of_hits + hits_per_page.saturating_sub(1)) .checked_div(hits_per_page) @@ -851,15 +835,15 @@ pub fn perform_search( HitsInfo::Pagination { hits_per_page, - page: query.page.unwrap_or(1), + page: page.unwrap_or(1), total_pages, total_hits: number_of_hits, } } else { - HitsInfo::OffsetLimit { limit: query.limit, offset, estimated_total_hits: number_of_hits } + HitsInfo::OffsetLimit { limit, offset, estimated_total_hits: number_of_hits } }; - let (facet_distribution, facet_stats) = match query.facets { + let (facet_distribution, facet_stats) = match facets { Some(ref fields) => { let mut facet_distribution = index.facets_distribution(&rtxn); @@ -896,7 +880,7 @@ pub fn perform_search( let result = SearchResult { hits: documents, hits_info, - query: query.q.unwrap_or_default(), + query: q.unwrap_or_default(), processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, @@ -907,6 +891,130 @@ pub fn perform_search( Ok(result) } +struct AttributesFormat { + attributes_to_retrieve: Option>, + attributes_to_highlight: Option>, + attributes_to_crop: Option>, + crop_length: usize, + crop_marker: String, + highlight_pre_tag: String, + highlight_post_tag: String, + show_matches_position: bool, + sort: Option>, + show_ranking_score: bool, + show_ranking_score_details: bool, +} + +fn make_hits( + index: &Index, + rtxn: &RoTxn<'_>, + format: AttributesFormat, + matching_words: milli::MatchingWords, + documents_ids: Vec, + document_scores: Vec>, +) -> Result, MeilisearchHttpError> { + let fields_ids_map = index.fields_ids_map(rtxn).unwrap(); + let displayed_ids = index + .displayed_fields_ids(rtxn)? + .map(|fields| fields.into_iter().collect::>()) + .unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); + let fids = |attrs: &BTreeSet| { + let mut ids = BTreeSet::new(); + for attr in attrs { + if attr == "*" { + ids.clone_from(&displayed_ids); + break; + } + + if let Some(id) = fields_ids_map.id(attr) { + ids.insert(id); + } + } + ids + }; + let to_retrieve_ids: BTreeSet<_> = format + .attributes_to_retrieve + .as_ref() + .map(fids) + .unwrap_or_else(|| displayed_ids.clone()) + .intersection(&displayed_ids) + .cloned() + .collect(); + let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); + let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); + let formatted_options = compute_formatted_options( + &attr_to_highlight, + &attr_to_crop, + format.crop_length, + &to_retrieve_ids, + &fields_ids_map, + &displayed_ids, + ); + let mut tokenizer_builder = TokenizerBuilder::default(); + tokenizer_builder.create_char_map(true); + let script_lang_map = index.script_language(rtxn)?; + if !script_lang_map.is_empty() { + tokenizer_builder.allow_list(&script_lang_map); + } + let separators = index.allowed_separators(rtxn)?; + let separators: Option> = + separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); + if let Some(ref separators) = separators { + tokenizer_builder.separators(separators); + } + let dictionary = index.dictionary(rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); + if let Some(ref dictionary) = dictionary { + tokenizer_builder.words_dict(dictionary); + } + let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build()); + formatter_builder.crop_marker(format.crop_marker); + formatter_builder.highlight_prefix(format.highlight_pre_tag); + formatter_builder.highlight_suffix(format.highlight_post_tag); + let mut documents = Vec::new(); + let documents_iter = index.documents(rtxn, documents_ids)?; + for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { + // First generate a document with all the displayed fields + let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; + + // select the attributes to retrieve + let attributes_to_retrieve = to_retrieve_ids + .iter() + .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); + let mut document = + permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); + + let (matches_position, formatted) = format_fields( + &displayed_document, + &fields_ids_map, + &formatter_builder, + &formatted_options, + format.show_matches_position, + &displayed_ids, + )?; + + if let Some(sort) = format.sort.as_ref() { + insert_geo_distance(sort, &mut document); + } + + let ranking_score = + format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); + let ranking_score_details = + format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); + + let hit = SearchHit { + document, + formatted, + matches_position, + ranking_score_details, + ranking_score, + }; + documents.push(hit); + } + Ok(documents) +} + pub fn perform_facet_search( index: &Index, search_query: SearchQuery, @@ -941,6 +1049,95 @@ pub fn perform_facet_search( }) } +pub fn perform_similar( + index: &Index, + query: SimilarQuery, + embedder_name: String, + embedder: Arc, +) -> Result { + let before_search = Instant::now(); + let rtxn = index.read_txn()?; + + let SimilarQuery { + id, + offset, + limit, + filter: _, + embedder: _, + attributes_to_retrieve, + show_ranking_score, + show_ranking_score_details, + } = query; + + // using let-else rather than `?` so that the borrow checker identifies we're always returning here, + // preventing a use-after-move + let Some(internal_id) = index.external_documents_ids().get(&rtxn, &id)? else { + return Err(ResponseError::from_msg( + MeilisearchHttpError::DocumentNotFound(id.into_inner()).to_string(), + Code::NotFoundSimilarId, + )); + }; + + let mut similar = + milli::Similar::new(internal_id, offset, limit, index, &rtxn, embedder_name, embedder); + + if let Some(ref filter) = query.filter { + if let Some(facets) = parse_filter(filter) + // inject InvalidSimilarFilter code + .map_err(|e| ResponseError::from_msg(e.to_string(), Code::InvalidSimilarFilter))? + { + similar.filter(facets); + } + } + + let milli::SearchResult { + documents_ids, + matching_words: _, + candidates, + document_scores, + degraded: _, + used_negative_operator: _, + } = similar.execute().map_err(|err| match err { + milli::Error::UserError(milli::UserError::InvalidFilter(_)) => { + ResponseError::from_msg(err.to_string(), Code::InvalidSimilarFilter) + } + err => err.into(), + })?; + + let format = AttributesFormat { + attributes_to_retrieve, + attributes_to_highlight: None, + attributes_to_crop: None, + crop_length: DEFAULT_CROP_LENGTH(), + crop_marker: DEFAULT_CROP_MARKER(), + highlight_pre_tag: DEFAULT_HIGHLIGHT_PRE_TAG(), + highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), + show_matches_position: false, + sort: None, + show_ranking_score, + show_ranking_score_details, + }; + + let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?; + + let max_total_hits = index + .pagination_max_total_hits(&rtxn) + .map_err(milli::Error::from)? + .map(|x| x as usize) + .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); + + let number_of_hits = min(candidates.len() as usize, max_total_hits); + let hits_info = HitsInfo::OffsetLimit { limit, offset, estimated_total_hits: number_of_hits }; + + let result = SimilarResult { + hits, + hits_info, + id: id.into_inner(), + processing_time_ms: before_search.elapsed().as_millis(), + }; + Ok(result) +} + fn insert_geo_distance(sorts: &[String], document: &mut Document) { lazy_static::lazy_static! { static ref GEO_REGEX: Regex = diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 4d4cdaf9b..095fe1b94 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -63,6 +63,7 @@ pub use self::heed_codec::{ }; pub use self::index::Index; pub use self::search::facet::{FacetValueHit, SearchForFacetValues}; +pub use self::search::similar::Similar; pub use self::search::{ FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 7d1254aa7..76068b1f2 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -24,6 +24,7 @@ pub mod facet; mod fst_utils; pub mod hybrid; pub mod new; +pub mod similar; #[derive(Debug, Clone)] pub struct SemanticSearch { diff --git a/milli/src/search/similar.rs b/milli/src/search/similar.rs new file mode 100644 index 000000000..49b7c876f --- /dev/null +++ b/milli/src/search/similar.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; + +use ordered_float::OrderedFloat; +use roaring::RoaringBitmap; + +use crate::score_details::{self, ScoreDetails}; +use crate::vector::Embedder; +use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult}; + +pub struct Similar<'a> { + id: DocumentId, + // this should be linked to the String in the query + filter: Option>, + offset: usize, + limit: usize, + rtxn: &'a heed::RoTxn<'a>, + index: &'a Index, + embedder_name: String, + embedder: Arc, +} + +impl<'a> Similar<'a> { + pub fn new( + id: DocumentId, + offset: usize, + limit: usize, + index: &'a Index, + rtxn: &'a heed::RoTxn<'a>, + embedder_name: String, + embedder: Arc, + ) -> Self { + Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } + } + + pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { + self.filter = Some(filter); + self + } + + pub fn execute(&self) -> Result { + let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + + let embedder_index = + self.index + .embedder_category_id + .get(self.rtxn, &self.embedder_name)? + .ok_or_else(|| crate::UserError::InvalidEmbedder(self.embedder_name.to_owned()))?; + + let readers: std::result::Result, _> = + self.index.arroy_readers(self.rtxn, embedder_index).collect(); + + let readers = readers?; + + let mut results = Vec::new(); + + for reader in readers.iter() { + let nns_by_item = reader.nns_by_item( + self.rtxn, + self.id, + self.limit + self.offset + 1, + None, + Some(&universe), + )?; + if let Some(mut nns_by_item) = nns_by_item { + results.append(&mut nns_by_item); + } else { + break; + } + } + + results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance)); + + let mut documents_ids = Vec::with_capacity(self.limit); + let mut document_scores = Vec::with_capacity(self.limit); + // list of documents we've already seen, so that we don't return the same document multiple times. + // initialized to the target document, that we never want to return. + let mut documents_seen = RoaringBitmap::new(); + documents_seen.insert(self.id); + + for (docid, distance) in results + .into_iter() + // skip documents we've already seen & mark that we saw the current document + .filter(|(docid, _)| documents_seen.insert(*docid)) + .skip(self.offset) + // take **after** filter and skip so that we get exactly limit elements if available + .take(self.limit) + { + documents_ids.push(docid); + + let score = 1.0 - distance; + let score = self + .embedder + .distribution() + .map(|distribution| distribution.shift(score)) + .unwrap_or(score); + + let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); + + document_scores.push(vec![score]); + } + + Ok(SearchResult { + matching_words: Default::default(), + candidates: universe, + documents_ids, + document_scores, + degraded: false, + used_negative_operator: false, + }) + } +}