mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-29 16:45:30 +08:00
Add similar route
This commit is contained in:
parent
3bd9d2478c
commit
ca6cc4654b
@ -189,3 +189,4 @@ merge_with_error_impl_take_error_message!(ParseTaskKindError);
|
|||||||
merge_with_error_impl_take_error_message!(ParseTaskStatusError);
|
merge_with_error_impl_take_error_message!(ParseTaskStatusError);
|
||||||
merge_with_error_impl_take_error_message!(IndexUidFormatError);
|
merge_with_error_impl_take_error_message!(IndexUidFormatError);
|
||||||
merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio);
|
merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio);
|
||||||
|
merge_with_error_impl_take_error_message!(InvalidSimilarId);
|
||||||
|
@ -246,7 +246,7 @@ InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
|
|||||||
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidRecommendId , InvalidRequest , BAD_REQUEST ;
|
InvalidSimilarId , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSimilarFilter , InvalidRequest , BAD_REQUEST ;
|
InvalidSimilarFilter , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
|
||||||
@ -494,6 +494,17 @@ impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for deserr_codes::InvalidSimilarId {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"the value of `id` is invalid. \
|
||||||
|
A document identifier can be of type integer or string, \
|
||||||
|
only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and underscores (_)."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! internal_error {
|
macro_rules! internal_error {
|
||||||
($target:ty : $($other:path), *) => {
|
($target:ty : $($other:path), *) => {
|
||||||
|
@ -23,8 +23,6 @@ pub enum MeilisearchHttpError {
|
|||||||
InvalidContentType(String, Vec<String>),
|
InvalidContentType(String, Vec<String>),
|
||||||
#[error("Document `{0}` not found.")]
|
#[error("Document `{0}` not found.")]
|
||||||
DocumentNotFound(String),
|
DocumentNotFound(String),
|
||||||
#[error("Document `{0}` not found.")]
|
|
||||||
InvalidDocumentId(String),
|
|
||||||
#[error("Sending an empty filter is forbidden.")]
|
#[error("Sending an empty filter is forbidden.")]
|
||||||
EmptyFilter,
|
EmptyFilter,
|
||||||
#[error("Invalid syntax for the filter parameter: `expected {}, found: {1}`.", .0.join(", "))]
|
#[error("Invalid syntax for the filter parameter: `expected {}, found: {1}`.", .0.join(", "))]
|
||||||
@ -72,7 +70,6 @@ impl ErrorCode for MeilisearchHttpError {
|
|||||||
MeilisearchHttpError::MissingPayload(_) => Code::MissingPayload,
|
MeilisearchHttpError::MissingPayload(_) => Code::MissingPayload,
|
||||||
MeilisearchHttpError::InvalidContentType(_, _) => Code::InvalidContentType,
|
MeilisearchHttpError::InvalidContentType(_, _) => Code::InvalidContentType,
|
||||||
MeilisearchHttpError::DocumentNotFound(_) => Code::DocumentNotFound,
|
MeilisearchHttpError::DocumentNotFound(_) => Code::DocumentNotFound,
|
||||||
MeilisearchHttpError::InvalidDocumentId(_) => Code::InvalidDocumentId,
|
|
||||||
MeilisearchHttpError::EmptyFilter => Code::InvalidDocumentFilter,
|
MeilisearchHttpError::EmptyFilter => Code::InvalidDocumentFilter,
|
||||||
MeilisearchHttpError::InvalidExpression(_, _) => Code::InvalidSearchFilter,
|
MeilisearchHttpError::InvalidExpression(_, _) => Code::InvalidSearchFilter,
|
||||||
MeilisearchHttpError::PayloadTooLarge(_) => Code::PayloadTooLarge,
|
MeilisearchHttpError::PayloadTooLarge(_) => Code::PayloadTooLarge,
|
||||||
|
@ -29,6 +29,7 @@ pub mod documents;
|
|||||||
pub mod facet_search;
|
pub mod facet_search;
|
||||||
pub mod search;
|
pub mod search;
|
||||||
pub mod settings;
|
pub mod settings;
|
||||||
|
pub mod similar;
|
||||||
|
|
||||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
cfg.service(
|
cfg.service(
|
||||||
@ -48,6 +49,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||||||
.service(web::scope("/documents").configure(documents::configure))
|
.service(web::scope("/documents").configure(documents::configure))
|
||||||
.service(web::scope("/search").configure(search::configure))
|
.service(web::scope("/search").configure(search::configure))
|
||||||
.service(web::scope("/facet-search").configure(facet_search::configure))
|
.service(web::scope("/facet-search").configure(facet_search::configure))
|
||||||
|
.service(web::scope("/similar").configure(similar::configure))
|
||||||
.service(web::scope("/settings").configure(settings::configure)),
|
.service(web::scope("/settings").configure(settings::configure)),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
171
meilisearch/src/routes/indexes/similar.rs
Normal file
171
meilisearch/src/routes/indexes/similar.rs
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
use actix_web::web::{self, Data};
|
||||||
|
use actix_web::{HttpRequest, HttpResponse};
|
||||||
|
use deserr::actix_web::{AwebJson, AwebQueryParameter};
|
||||||
|
use index_scheduler::IndexScheduler;
|
||||||
|
use meilisearch_types::deserr::query_params::Param;
|
||||||
|
use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
|
||||||
|
use meilisearch_types::error::deserr_codes::{
|
||||||
|
InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId,
|
||||||
|
InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarShowRankingScore,
|
||||||
|
InvalidSimilarShowRankingScoreDetails,
|
||||||
|
};
|
||||||
|
use meilisearch_types::error::{ErrorCode as _, ResponseError};
|
||||||
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
|
use meilisearch_types::keys::actions;
|
||||||
|
use meilisearch_types::serde_cs::vec::CS;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use super::ActionPolicy;
|
||||||
|
use crate::analytics::{Analytics, SimilarAggregator};
|
||||||
|
use crate::extractors::authentication::GuardedData;
|
||||||
|
use crate::extractors::sequential_extractor::SeqHandler;
|
||||||
|
use crate::search::{
|
||||||
|
add_search_rules, perform_similar, SearchKind, SimilarQuery, SimilarResult,
|
||||||
|
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
|
cfg.service(
|
||||||
|
web::resource("")
|
||||||
|
.route(web::get().to(SeqHandler(similar_get)))
|
||||||
|
.route(web::post().to(SeqHandler(similar_post))),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn similar_get(
|
||||||
|
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||||
|
index_uid: web::Path<String>,
|
||||||
|
params: AwebQueryParameter<SimilarQueryGet, DeserrQueryParamError>,
|
||||||
|
req: HttpRequest,
|
||||||
|
analytics: web::Data<dyn Analytics>,
|
||||||
|
) -> Result<HttpResponse, ResponseError> {
|
||||||
|
let index_uid = IndexUid::try_from(index_uid.into_inner())?;
|
||||||
|
|
||||||
|
let query = params.0.try_into().map_err(|code: InvalidSimilarId| {
|
||||||
|
ResponseError::from_msg(code.to_string(), code.error_code())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut aggregate = SimilarAggregator::from_query(&query, &req);
|
||||||
|
|
||||||
|
debug!(parameters = ?query, "Similar get");
|
||||||
|
|
||||||
|
let similar = similar(index_scheduler, index_uid, query).await;
|
||||||
|
|
||||||
|
if let Ok(similar) = &similar {
|
||||||
|
aggregate.succeed(similar);
|
||||||
|
}
|
||||||
|
analytics.get_similar(aggregate);
|
||||||
|
|
||||||
|
let similar = similar?;
|
||||||
|
|
||||||
|
debug!(returns = ?similar, "Similar get");
|
||||||
|
Ok(HttpResponse::Ok().json(similar))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn similar_post(
|
||||||
|
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||||
|
index_uid: web::Path<String>,
|
||||||
|
params: AwebJson<SimilarQuery, DeserrJsonError>,
|
||||||
|
req: HttpRequest,
|
||||||
|
analytics: web::Data<dyn Analytics>,
|
||||||
|
) -> Result<HttpResponse, ResponseError> {
|
||||||
|
let index_uid = IndexUid::try_from(index_uid.into_inner())?;
|
||||||
|
|
||||||
|
let query = params.into_inner();
|
||||||
|
debug!(parameters = ?query, "Similar post");
|
||||||
|
|
||||||
|
let mut aggregate = SimilarAggregator::from_query(&query, &req);
|
||||||
|
|
||||||
|
let similar = similar(index_scheduler, index_uid, query).await;
|
||||||
|
|
||||||
|
if let Ok(similar) = &similar {
|
||||||
|
aggregate.succeed(similar);
|
||||||
|
}
|
||||||
|
analytics.post_similar(aggregate);
|
||||||
|
|
||||||
|
let similar = similar?;
|
||||||
|
|
||||||
|
debug!(returns = ?similar, "Similar post");
|
||||||
|
Ok(HttpResponse::Ok().json(similar))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn similar(
|
||||||
|
index_scheduler: GuardedData<ActionPolicy<{ actions::SEARCH }>, Data<IndexScheduler>>,
|
||||||
|
index_uid: IndexUid,
|
||||||
|
mut query: SimilarQuery,
|
||||||
|
) -> Result<SimilarResult, ResponseError> {
|
||||||
|
let features = index_scheduler.features();
|
||||||
|
|
||||||
|
features.check_vector("Using the similar API")?;
|
||||||
|
|
||||||
|
// Tenant token search_rules.
|
||||||
|
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
|
||||||
|
add_search_rules(&mut query.filter, search_rules);
|
||||||
|
}
|
||||||
|
|
||||||
|
let index = index_scheduler.index(&index_uid)?;
|
||||||
|
|
||||||
|
let (embedder_name, embedder) =
|
||||||
|
SearchKind::embedder(&index_scheduler, &index, query.embedder.as_deref(), None)?;
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || perform_similar(&index, query, embedder_name, embedder))
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, deserr::Deserr)]
|
||||||
|
#[deserr(error = DeserrQueryParamError, rename_all = camelCase, deny_unknown_fields)]
|
||||||
|
pub struct SimilarQueryGet {
|
||||||
|
#[deserr(error = DeserrQueryParamError<InvalidSimilarId>)]
|
||||||
|
id: Param<String>,
|
||||||
|
#[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError<InvalidSimilarOffset>)]
|
||||||
|
offset: Param<usize>,
|
||||||
|
#[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError<InvalidSimilarLimit>)]
|
||||||
|
limit: Param<usize>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarAttributesToRetrieve>)]
|
||||||
|
attributes_to_retrieve: Option<CS<String>>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarFilter>)]
|
||||||
|
filter: Option<String>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarShowRankingScore>)]
|
||||||
|
show_ranking_score: Param<bool>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarShowRankingScoreDetails>)]
|
||||||
|
show_ranking_score_details: Param<bool>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)]
|
||||||
|
pub embedder: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<SimilarQueryGet> for SimilarQuery {
|
||||||
|
type Error = InvalidSimilarId;
|
||||||
|
|
||||||
|
fn try_from(
|
||||||
|
SimilarQueryGet {
|
||||||
|
id,
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
attributes_to_retrieve,
|
||||||
|
filter,
|
||||||
|
show_ranking_score,
|
||||||
|
show_ranking_score_details,
|
||||||
|
embedder,
|
||||||
|
}: SimilarQueryGet,
|
||||||
|
) -> Result<Self, Self::Error> {
|
||||||
|
let filter = match filter {
|
||||||
|
Some(f) => match serde_json::from_str(&f) {
|
||||||
|
Ok(v) => Some(v),
|
||||||
|
_ => Some(Value::String(f)),
|
||||||
|
},
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(SimilarQuery {
|
||||||
|
id: id.0.try_into()?,
|
||||||
|
offset: offset.0,
|
||||||
|
limit: limit.0,
|
||||||
|
filter,
|
||||||
|
embedder,
|
||||||
|
attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()),
|
||||||
|
show_ranking_score: show_ranking_score.0,
|
||||||
|
show_ranking_score_details: show_ranking_score_details.0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -11,7 +11,7 @@ use indexmap::IndexMap;
|
|||||||
use meilisearch_auth::IndexSearchRules;
|
use meilisearch_auth::IndexSearchRules;
|
||||||
use meilisearch_types::deserr::DeserrJsonError;
|
use meilisearch_types::deserr::DeserrJsonError;
|
||||||
use meilisearch_types::error::deserr_codes::*;
|
use meilisearch_types::error::deserr_codes::*;
|
||||||
use meilisearch_types::error::ResponseError;
|
use meilisearch_types::error::{Code, ResponseError};
|
||||||
use meilisearch_types::heed::RoTxn;
|
use meilisearch_types::heed::RoTxn;
|
||||||
use meilisearch_types::index_uid::IndexUid;
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
||||||
@ -417,6 +417,59 @@ impl SearchQueryWithIndex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserr)]
|
||||||
|
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||||
|
pub struct SimilarQuery {
|
||||||
|
#[deserr(error = DeserrJsonError<InvalidSimilarId>)]
|
||||||
|
pub id: ExternalDocumentId,
|
||||||
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSimilarOffset>)]
|
||||||
|
pub offset: usize,
|
||||||
|
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSimilarLimit>)]
|
||||||
|
pub limit: usize,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSimilarFilter>)]
|
||||||
|
pub filter: Option<Value>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
|
||||||
|
pub embedder: Option<String>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSimilarAttributesToRetrieve>)]
|
||||||
|
pub attributes_to_retrieve: Option<BTreeSet<String>>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScore>, default)]
|
||||||
|
pub show_ranking_score: bool,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScoreDetails>, default)]
|
||||||
|
pub show_ranking_score_details: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserr)]
|
||||||
|
#[deserr(try_from(Value) = TryFrom::try_from -> InvalidSimilarId)]
|
||||||
|
pub struct ExternalDocumentId(String);
|
||||||
|
|
||||||
|
impl AsRef<str> for ExternalDocumentId {
|
||||||
|
fn as_ref(&self) -> &str {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExternalDocumentId {
|
||||||
|
pub fn into_inner(self) -> String {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<String> for ExternalDocumentId {
|
||||||
|
type Error = InvalidSimilarId;
|
||||||
|
|
||||||
|
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||||
|
serde_json::Value::String(value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<Value> for ExternalDocumentId {
|
||||||
|
type Error = InvalidSimilarId;
|
||||||
|
|
||||||
|
fn try_from(value: Value) -> Result<Self, Self::Error> {
|
||||||
|
Ok(Self(milli::documents::validate_document_id_value(value).map_err(|_| InvalidSimilarId)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)]
|
||||||
#[deserr(rename_all = camelCase)]
|
#[deserr(rename_all = camelCase)]
|
||||||
pub enum MatchingStrategy {
|
pub enum MatchingStrategy {
|
||||||
@ -538,6 +591,16 @@ impl fmt::Debug for SearchResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SimilarResult {
|
||||||
|
pub hits: Vec<SearchHit>,
|
||||||
|
pub id: String,
|
||||||
|
pub processing_time_ms: u128,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub hits_info: HitsInfo,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct SearchResultWithIndex {
|
pub struct SearchResultWithIndex {
|
||||||
@ -719,131 +782,52 @@ pub fn perform_search(
|
|||||||
SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?,
|
SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
|
let SearchQuery {
|
||||||
|
q,
|
||||||
|
vector: _,
|
||||||
|
hybrid: _,
|
||||||
|
// already computed from prepare_search
|
||||||
|
offset: _,
|
||||||
|
limit,
|
||||||
|
page,
|
||||||
|
hits_per_page,
|
||||||
|
attributes_to_retrieve,
|
||||||
|
attributes_to_crop,
|
||||||
|
crop_length,
|
||||||
|
attributes_to_highlight,
|
||||||
|
show_matches_position,
|
||||||
|
show_ranking_score,
|
||||||
|
show_ranking_score_details,
|
||||||
|
filter: _,
|
||||||
|
sort,
|
||||||
|
facets,
|
||||||
|
highlight_pre_tag,
|
||||||
|
highlight_post_tag,
|
||||||
|
crop_marker,
|
||||||
|
matching_strategy: _,
|
||||||
|
attributes_to_search_on: _,
|
||||||
|
} = query;
|
||||||
|
|
||||||
let displayed_ids = index
|
let format = AttributesFormat {
|
||||||
.displayed_fields_ids(&rtxn)?
|
attributes_to_retrieve,
|
||||||
.map(|fields| fields.into_iter().collect::<BTreeSet<_>>())
|
attributes_to_highlight,
|
||||||
.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect());
|
attributes_to_crop,
|
||||||
|
crop_length,
|
||||||
let fids = |attrs: &BTreeSet<String>| {
|
crop_marker,
|
||||||
let mut ids = BTreeSet::new();
|
highlight_pre_tag,
|
||||||
for attr in attrs {
|
highlight_post_tag,
|
||||||
if attr == "*" {
|
show_matches_position,
|
||||||
ids.clone_from(&displayed_ids);
|
sort,
|
||||||
break;
|
show_ranking_score,
|
||||||
}
|
show_ranking_score_details,
|
||||||
|
|
||||||
if let Some(id) = fields_ids_map.id(attr) {
|
|
||||||
ids.insert(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ids
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// The attributes to retrieve are the ones explicitly marked as to retrieve (all by default),
|
let documents =
|
||||||
// but these attributes must be also be present
|
make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?;
|
||||||
// - in the fields_ids_map
|
|
||||||
// - in the displayed attributes
|
|
||||||
let to_retrieve_ids: BTreeSet<_> = query
|
|
||||||
.attributes_to_retrieve
|
|
||||||
.as_ref()
|
|
||||||
.map(fids)
|
|
||||||
.unwrap_or_else(|| displayed_ids.clone())
|
|
||||||
.intersection(&displayed_ids)
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let attr_to_highlight = query.attributes_to_highlight.unwrap_or_default();
|
|
||||||
|
|
||||||
let attr_to_crop = query.attributes_to_crop.unwrap_or_default();
|
|
||||||
|
|
||||||
// Attributes in `formatted_options` correspond to the attributes that will be in `_formatted`
|
|
||||||
// These attributes are:
|
|
||||||
// - the attributes asked to be highlighted or cropped (with `attributesToCrop` or `attributesToHighlight`)
|
|
||||||
// - the attributes asked to be retrieved: these attributes will not be highlighted/cropped
|
|
||||||
// But these attributes must be also present in displayed attributes
|
|
||||||
let formatted_options = compute_formatted_options(
|
|
||||||
&attr_to_highlight,
|
|
||||||
&attr_to_crop,
|
|
||||||
query.crop_length,
|
|
||||||
&to_retrieve_ids,
|
|
||||||
&fields_ids_map,
|
|
||||||
&displayed_ids,
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut tokenizer_builder = TokenizerBuilder::default();
|
|
||||||
tokenizer_builder.create_char_map(true);
|
|
||||||
|
|
||||||
let script_lang_map = index.script_language(&rtxn)?;
|
|
||||||
if !script_lang_map.is_empty() {
|
|
||||||
tokenizer_builder.allow_list(&script_lang_map);
|
|
||||||
}
|
|
||||||
|
|
||||||
let separators = index.allowed_separators(&rtxn)?;
|
|
||||||
let separators: Option<Vec<_>> =
|
|
||||||
separators.as_ref().map(|x| x.iter().map(String::as_str).collect());
|
|
||||||
if let Some(ref separators) = separators {
|
|
||||||
tokenizer_builder.separators(separators);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dictionary = index.dictionary(&rtxn)?;
|
|
||||||
let dictionary: Option<Vec<_>> =
|
|
||||||
dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect());
|
|
||||||
if let Some(ref dictionary) = dictionary {
|
|
||||||
tokenizer_builder.words_dict(dictionary);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build());
|
|
||||||
formatter_builder.crop_marker(query.crop_marker);
|
|
||||||
formatter_builder.highlight_prefix(query.highlight_pre_tag);
|
|
||||||
formatter_builder.highlight_suffix(query.highlight_post_tag);
|
|
||||||
|
|
||||||
let mut documents = Vec::new();
|
|
||||||
let documents_iter = index.documents(&rtxn, documents_ids)?;
|
|
||||||
|
|
||||||
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
|
||||||
// First generate a document with all the displayed fields
|
|
||||||
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?;
|
|
||||||
|
|
||||||
// select the attributes to retrieve
|
|
||||||
let attributes_to_retrieve = to_retrieve_ids
|
|
||||||
.iter()
|
|
||||||
.map(|&fid| fields_ids_map.name(fid).expect("Missing field name"));
|
|
||||||
let mut document =
|
|
||||||
permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve);
|
|
||||||
|
|
||||||
let (matches_position, formatted) = format_fields(
|
|
||||||
&displayed_document,
|
|
||||||
&fields_ids_map,
|
|
||||||
&formatter_builder,
|
|
||||||
&formatted_options,
|
|
||||||
query.show_matches_position,
|
|
||||||
&displayed_ids,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
if let Some(sort) = query.sort.as_ref() {
|
|
||||||
insert_geo_distance(sort, &mut document);
|
|
||||||
}
|
|
||||||
|
|
||||||
let ranking_score =
|
|
||||||
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
|
||||||
let ranking_score_details =
|
|
||||||
query.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
|
|
||||||
|
|
||||||
let hit = SearchHit {
|
|
||||||
document,
|
|
||||||
formatted,
|
|
||||||
matches_position,
|
|
||||||
ranking_score_details,
|
|
||||||
ranking_score,
|
|
||||||
};
|
|
||||||
documents.push(hit);
|
|
||||||
}
|
|
||||||
|
|
||||||
let number_of_hits = min(candidates.len() as usize, max_total_hits);
|
let number_of_hits = min(candidates.len() as usize, max_total_hits);
|
||||||
let hits_info = if is_finite_pagination {
|
let hits_info = if is_finite_pagination {
|
||||||
let hits_per_page = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
let hits_per_page = hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
||||||
// If hit_per_page is 0, then pages can't be computed and so we respond 0.
|
// If hit_per_page is 0, then pages can't be computed and so we respond 0.
|
||||||
let total_pages = (number_of_hits + hits_per_page.saturating_sub(1))
|
let total_pages = (number_of_hits + hits_per_page.saturating_sub(1))
|
||||||
.checked_div(hits_per_page)
|
.checked_div(hits_per_page)
|
||||||
@ -851,15 +835,15 @@ pub fn perform_search(
|
|||||||
|
|
||||||
HitsInfo::Pagination {
|
HitsInfo::Pagination {
|
||||||
hits_per_page,
|
hits_per_page,
|
||||||
page: query.page.unwrap_or(1),
|
page: page.unwrap_or(1),
|
||||||
total_pages,
|
total_pages,
|
||||||
total_hits: number_of_hits,
|
total_hits: number_of_hits,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
HitsInfo::OffsetLimit { limit: query.limit, offset, estimated_total_hits: number_of_hits }
|
HitsInfo::OffsetLimit { limit, offset, estimated_total_hits: number_of_hits }
|
||||||
};
|
};
|
||||||
|
|
||||||
let (facet_distribution, facet_stats) = match query.facets {
|
let (facet_distribution, facet_stats) = match facets {
|
||||||
Some(ref fields) => {
|
Some(ref fields) => {
|
||||||
let mut facet_distribution = index.facets_distribution(&rtxn);
|
let mut facet_distribution = index.facets_distribution(&rtxn);
|
||||||
|
|
||||||
@ -896,7 +880,7 @@ pub fn perform_search(
|
|||||||
let result = SearchResult {
|
let result = SearchResult {
|
||||||
hits: documents,
|
hits: documents,
|
||||||
hits_info,
|
hits_info,
|
||||||
query: query.q.unwrap_or_default(),
|
query: q.unwrap_or_default(),
|
||||||
processing_time_ms: before_search.elapsed().as_millis(),
|
processing_time_ms: before_search.elapsed().as_millis(),
|
||||||
facet_distribution,
|
facet_distribution,
|
||||||
facet_stats,
|
facet_stats,
|
||||||
@ -907,6 +891,130 @@ pub fn perform_search(
|
|||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AttributesFormat {
|
||||||
|
attributes_to_retrieve: Option<BTreeSet<String>>,
|
||||||
|
attributes_to_highlight: Option<HashSet<String>>,
|
||||||
|
attributes_to_crop: Option<Vec<String>>,
|
||||||
|
crop_length: usize,
|
||||||
|
crop_marker: String,
|
||||||
|
highlight_pre_tag: String,
|
||||||
|
highlight_post_tag: String,
|
||||||
|
show_matches_position: bool,
|
||||||
|
sort: Option<Vec<String>>,
|
||||||
|
show_ranking_score: bool,
|
||||||
|
show_ranking_score_details: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_hits(
|
||||||
|
index: &Index,
|
||||||
|
rtxn: &RoTxn<'_>,
|
||||||
|
format: AttributesFormat,
|
||||||
|
matching_words: milli::MatchingWords,
|
||||||
|
documents_ids: Vec<u32>,
|
||||||
|
document_scores: Vec<Vec<ScoreDetails>>,
|
||||||
|
) -> Result<Vec<SearchHit>, MeilisearchHttpError> {
|
||||||
|
let fields_ids_map = index.fields_ids_map(rtxn).unwrap();
|
||||||
|
let displayed_ids = index
|
||||||
|
.displayed_fields_ids(rtxn)?
|
||||||
|
.map(|fields| fields.into_iter().collect::<BTreeSet<_>>())
|
||||||
|
.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect());
|
||||||
|
let fids = |attrs: &BTreeSet<String>| {
|
||||||
|
let mut ids = BTreeSet::new();
|
||||||
|
for attr in attrs {
|
||||||
|
if attr == "*" {
|
||||||
|
ids.clone_from(&displayed_ids);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(id) = fields_ids_map.id(attr) {
|
||||||
|
ids.insert(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ids
|
||||||
|
};
|
||||||
|
let to_retrieve_ids: BTreeSet<_> = format
|
||||||
|
.attributes_to_retrieve
|
||||||
|
.as_ref()
|
||||||
|
.map(fids)
|
||||||
|
.unwrap_or_else(|| displayed_ids.clone())
|
||||||
|
.intersection(&displayed_ids)
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default();
|
||||||
|
let attr_to_crop = format.attributes_to_crop.unwrap_or_default();
|
||||||
|
let formatted_options = compute_formatted_options(
|
||||||
|
&attr_to_highlight,
|
||||||
|
&attr_to_crop,
|
||||||
|
format.crop_length,
|
||||||
|
&to_retrieve_ids,
|
||||||
|
&fields_ids_map,
|
||||||
|
&displayed_ids,
|
||||||
|
);
|
||||||
|
let mut tokenizer_builder = TokenizerBuilder::default();
|
||||||
|
tokenizer_builder.create_char_map(true);
|
||||||
|
let script_lang_map = index.script_language(rtxn)?;
|
||||||
|
if !script_lang_map.is_empty() {
|
||||||
|
tokenizer_builder.allow_list(&script_lang_map);
|
||||||
|
}
|
||||||
|
let separators = index.allowed_separators(rtxn)?;
|
||||||
|
let separators: Option<Vec<_>> =
|
||||||
|
separators.as_ref().map(|x| x.iter().map(String::as_str).collect());
|
||||||
|
if let Some(ref separators) = separators {
|
||||||
|
tokenizer_builder.separators(separators);
|
||||||
|
}
|
||||||
|
let dictionary = index.dictionary(rtxn)?;
|
||||||
|
let dictionary: Option<Vec<_>> =
|
||||||
|
dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect());
|
||||||
|
if let Some(ref dictionary) = dictionary {
|
||||||
|
tokenizer_builder.words_dict(dictionary);
|
||||||
|
}
|
||||||
|
let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build());
|
||||||
|
formatter_builder.crop_marker(format.crop_marker);
|
||||||
|
formatter_builder.highlight_prefix(format.highlight_pre_tag);
|
||||||
|
formatter_builder.highlight_suffix(format.highlight_post_tag);
|
||||||
|
let mut documents = Vec::new();
|
||||||
|
let documents_iter = index.documents(rtxn, documents_ids)?;
|
||||||
|
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
||||||
|
// First generate a document with all the displayed fields
|
||||||
|
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?;
|
||||||
|
|
||||||
|
// select the attributes to retrieve
|
||||||
|
let attributes_to_retrieve = to_retrieve_ids
|
||||||
|
.iter()
|
||||||
|
.map(|&fid| fields_ids_map.name(fid).expect("Missing field name"));
|
||||||
|
let mut document =
|
||||||
|
permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve);
|
||||||
|
|
||||||
|
let (matches_position, formatted) = format_fields(
|
||||||
|
&displayed_document,
|
||||||
|
&fields_ids_map,
|
||||||
|
&formatter_builder,
|
||||||
|
&formatted_options,
|
||||||
|
format.show_matches_position,
|
||||||
|
&displayed_ids,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
if let Some(sort) = format.sort.as_ref() {
|
||||||
|
insert_geo_distance(sort, &mut document);
|
||||||
|
}
|
||||||
|
|
||||||
|
let ranking_score =
|
||||||
|
format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
||||||
|
let ranking_score_details =
|
||||||
|
format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
|
||||||
|
|
||||||
|
let hit = SearchHit {
|
||||||
|
document,
|
||||||
|
formatted,
|
||||||
|
matches_position,
|
||||||
|
ranking_score_details,
|
||||||
|
ranking_score,
|
||||||
|
};
|
||||||
|
documents.push(hit);
|
||||||
|
}
|
||||||
|
Ok(documents)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn perform_facet_search(
|
pub fn perform_facet_search(
|
||||||
index: &Index,
|
index: &Index,
|
||||||
search_query: SearchQuery,
|
search_query: SearchQuery,
|
||||||
@ -941,6 +1049,95 @@ pub fn perform_facet_search(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn perform_similar(
|
||||||
|
index: &Index,
|
||||||
|
query: SimilarQuery,
|
||||||
|
embedder_name: String,
|
||||||
|
embedder: Arc<Embedder>,
|
||||||
|
) -> Result<SimilarResult, ResponseError> {
|
||||||
|
let before_search = Instant::now();
|
||||||
|
let rtxn = index.read_txn()?;
|
||||||
|
|
||||||
|
let SimilarQuery {
|
||||||
|
id,
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
filter: _,
|
||||||
|
embedder: _,
|
||||||
|
attributes_to_retrieve,
|
||||||
|
show_ranking_score,
|
||||||
|
show_ranking_score_details,
|
||||||
|
} = query;
|
||||||
|
|
||||||
|
// using let-else rather than `?` so that the borrow checker identifies we're always returning here,
|
||||||
|
// preventing a use-after-move
|
||||||
|
let Some(internal_id) = index.external_documents_ids().get(&rtxn, &id)? else {
|
||||||
|
return Err(ResponseError::from_msg(
|
||||||
|
MeilisearchHttpError::DocumentNotFound(id.into_inner()).to_string(),
|
||||||
|
Code::NotFoundSimilarId,
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut similar =
|
||||||
|
milli::Similar::new(internal_id, offset, limit, index, &rtxn, embedder_name, embedder);
|
||||||
|
|
||||||
|
if let Some(ref filter) = query.filter {
|
||||||
|
if let Some(facets) = parse_filter(filter)
|
||||||
|
// inject InvalidSimilarFilter code
|
||||||
|
.map_err(|e| ResponseError::from_msg(e.to_string(), Code::InvalidSimilarFilter))?
|
||||||
|
{
|
||||||
|
similar.filter(facets);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let milli::SearchResult {
|
||||||
|
documents_ids,
|
||||||
|
matching_words: _,
|
||||||
|
candidates,
|
||||||
|
document_scores,
|
||||||
|
degraded: _,
|
||||||
|
used_negative_operator: _,
|
||||||
|
} = similar.execute().map_err(|err| match err {
|
||||||
|
milli::Error::UserError(milli::UserError::InvalidFilter(_)) => {
|
||||||
|
ResponseError::from_msg(err.to_string(), Code::InvalidSimilarFilter)
|
||||||
|
}
|
||||||
|
err => err.into(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let format = AttributesFormat {
|
||||||
|
attributes_to_retrieve,
|
||||||
|
attributes_to_highlight: None,
|
||||||
|
attributes_to_crop: None,
|
||||||
|
crop_length: DEFAULT_CROP_LENGTH(),
|
||||||
|
crop_marker: DEFAULT_CROP_MARKER(),
|
||||||
|
highlight_pre_tag: DEFAULT_HIGHLIGHT_PRE_TAG(),
|
||||||
|
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
|
||||||
|
show_matches_position: false,
|
||||||
|
sort: None,
|
||||||
|
show_ranking_score,
|
||||||
|
show_ranking_score_details,
|
||||||
|
};
|
||||||
|
|
||||||
|
let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?;
|
||||||
|
|
||||||
|
let max_total_hits = index
|
||||||
|
.pagination_max_total_hits(&rtxn)
|
||||||
|
.map_err(milli::Error::from)?
|
||||||
|
.map(|x| x as usize)
|
||||||
|
.unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS);
|
||||||
|
|
||||||
|
let number_of_hits = min(candidates.len() as usize, max_total_hits);
|
||||||
|
let hits_info = HitsInfo::OffsetLimit { limit, offset, estimated_total_hits: number_of_hits };
|
||||||
|
|
||||||
|
let result = SimilarResult {
|
||||||
|
hits,
|
||||||
|
hits_info,
|
||||||
|
id: id.into_inner(),
|
||||||
|
processing_time_ms: before_search.elapsed().as_millis(),
|
||||||
|
};
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
fn insert_geo_distance(sorts: &[String], document: &mut Document) {
|
fn insert_geo_distance(sorts: &[String], document: &mut Document) {
|
||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
static ref GEO_REGEX: Regex =
|
static ref GEO_REGEX: Regex =
|
||||||
|
@ -63,6 +63,7 @@ pub use self::heed_codec::{
|
|||||||
};
|
};
|
||||||
pub use self::index::Index;
|
pub use self::index::Index;
|
||||||
pub use self::search::facet::{FacetValueHit, SearchForFacetValues};
|
pub use self::search::facet::{FacetValueHit, SearchForFacetValues};
|
||||||
|
pub use self::search::similar::Similar;
|
||||||
pub use self::search::{
|
pub use self::search::{
|
||||||
FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy,
|
FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy,
|
||||||
Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
Search, SearchResult, SemanticSearch, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
||||||
|
@ -24,6 +24,7 @@ pub mod facet;
|
|||||||
mod fst_utils;
|
mod fst_utils;
|
||||||
pub mod hybrid;
|
pub mod hybrid;
|
||||||
pub mod new;
|
pub mod new;
|
||||||
|
pub mod similar;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SemanticSearch {
|
pub struct SemanticSearch {
|
||||||
|
111
milli/src/search/similar.rs
Normal file
111
milli/src/search/similar.rs
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
|
use roaring::RoaringBitmap;
|
||||||
|
|
||||||
|
use crate::score_details::{self, ScoreDetails};
|
||||||
|
use crate::vector::Embedder;
|
||||||
|
use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult};
|
||||||
|
|
||||||
|
pub struct Similar<'a> {
|
||||||
|
id: DocumentId,
|
||||||
|
// this should be linked to the String in the query
|
||||||
|
filter: Option<Filter<'a>>,
|
||||||
|
offset: usize,
|
||||||
|
limit: usize,
|
||||||
|
rtxn: &'a heed::RoTxn<'a>,
|
||||||
|
index: &'a Index,
|
||||||
|
embedder_name: String,
|
||||||
|
embedder: Arc<Embedder>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Similar<'a> {
|
||||||
|
pub fn new(
|
||||||
|
id: DocumentId,
|
||||||
|
offset: usize,
|
||||||
|
limit: usize,
|
||||||
|
index: &'a Index,
|
||||||
|
rtxn: &'a heed::RoTxn<'a>,
|
||||||
|
embedder_name: String,
|
||||||
|
embedder: Arc<Embedder>,
|
||||||
|
) -> Self {
|
||||||
|
Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self {
|
||||||
|
self.filter = Some(filter);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute(&self) -> Result<SearchResult> {
|
||||||
|
let universe = filtered_universe(self.index, self.rtxn, &self.filter)?;
|
||||||
|
|
||||||
|
let embedder_index =
|
||||||
|
self.index
|
||||||
|
.embedder_category_id
|
||||||
|
.get(self.rtxn, &self.embedder_name)?
|
||||||
|
.ok_or_else(|| crate::UserError::InvalidEmbedder(self.embedder_name.to_owned()))?;
|
||||||
|
|
||||||
|
let readers: std::result::Result<Vec<_>, _> =
|
||||||
|
self.index.arroy_readers(self.rtxn, embedder_index).collect();
|
||||||
|
|
||||||
|
let readers = readers?;
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
for reader in readers.iter() {
|
||||||
|
let nns_by_item = reader.nns_by_item(
|
||||||
|
self.rtxn,
|
||||||
|
self.id,
|
||||||
|
self.limit + self.offset + 1,
|
||||||
|
None,
|
||||||
|
Some(&universe),
|
||||||
|
)?;
|
||||||
|
if let Some(mut nns_by_item) = nns_by_item {
|
||||||
|
results.append(&mut nns_by_item);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
|
||||||
|
|
||||||
|
let mut documents_ids = Vec::with_capacity(self.limit);
|
||||||
|
let mut document_scores = Vec::with_capacity(self.limit);
|
||||||
|
// list of documents we've already seen, so that we don't return the same document multiple times.
|
||||||
|
// initialized to the target document, that we never want to return.
|
||||||
|
let mut documents_seen = RoaringBitmap::new();
|
||||||
|
documents_seen.insert(self.id);
|
||||||
|
|
||||||
|
for (docid, distance) in results
|
||||||
|
.into_iter()
|
||||||
|
// skip documents we've already seen & mark that we saw the current document
|
||||||
|
.filter(|(docid, _)| documents_seen.insert(*docid))
|
||||||
|
.skip(self.offset)
|
||||||
|
// take **after** filter and skip so that we get exactly limit elements if available
|
||||||
|
.take(self.limit)
|
||||||
|
{
|
||||||
|
documents_ids.push(docid);
|
||||||
|
|
||||||
|
let score = 1.0 - distance;
|
||||||
|
let score = self
|
||||||
|
.embedder
|
||||||
|
.distribution()
|
||||||
|
.map(|distribution| distribution.shift(score))
|
||||||
|
.unwrap_or(score);
|
||||||
|
|
||||||
|
let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) });
|
||||||
|
|
||||||
|
document_scores.push(vec![score]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(SearchResult {
|
||||||
|
matching_words: Default::default(),
|
||||||
|
candidates: universe,
|
||||||
|
documents_ids,
|
||||||
|
document_scores,
|
||||||
|
degraded: false,
|
||||||
|
used_negative_operator: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user