mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-26 12:05:05 +08:00
Display the _semanticSimilarity even if the _vectors
field is not displayed
This commit is contained in:
parent
737aec1705
commit
7aa1275337
@ -17,7 +17,7 @@ use meilisearch_types::{milli, Document};
|
|||||||
use milli::tokenizer::TokenizerBuilder;
|
use milli::tokenizer::TokenizerBuilder;
|
||||||
use milli::{
|
use milli::{
|
||||||
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
|
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
|
||||||
SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET,
|
||||||
};
|
};
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
@ -432,7 +432,6 @@ pub fn perform_search(
|
|||||||
formatter_builder.highlight_suffix(query.highlight_post_tag);
|
formatter_builder.highlight_suffix(query.highlight_post_tag);
|
||||||
|
|
||||||
let mut documents = Vec::new();
|
let mut documents = Vec::new();
|
||||||
|
|
||||||
let documents_iter = index.documents(&rtxn, documents_ids)?;
|
let documents_iter = index.documents(&rtxn, documents_ids)?;
|
||||||
|
|
||||||
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
||||||
@ -460,7 +459,9 @@ pub fn perform_search(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(vector) = query.vector.as_ref() {
|
if let Some(vector) = query.vector.as_ref() {
|
||||||
insert_semantic_similarity(&vector, &mut document);
|
if let Some(vectors) = extract_field("_vectors", &fields_ids_map, obkv)? {
|
||||||
|
insert_semantic_similarity(vector, vectors, &mut document);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let ranking_score =
|
let ranking_score =
|
||||||
@ -548,20 +549,18 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_semantic_similarity(query: &[f32], document: &mut Document) {
|
fn insert_semantic_similarity(query: &[f32], vectors: Value, document: &mut Document) {
|
||||||
if let Some(value) = document.get("_vectors") {
|
let vectors =
|
||||||
let vectors: Vec<Vec<f32>> = match serde_json::from_value(value.clone()) {
|
match serde_json::from_value(vectors).map(VectorOrArrayOfVectors::into_array_of_vectors) {
|
||||||
Ok(Either::Left(vector)) => vec![vector],
|
Ok(vectors) => vectors,
|
||||||
Ok(Either::Right(vectors)) => vectors,
|
|
||||||
Err(_) => return,
|
Err(_) => return,
|
||||||
};
|
};
|
||||||
let similarity = vectors
|
let similarity = vectors
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
|
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
|
||||||
.max()
|
.max()
|
||||||
.map(OrderedFloat::into_inner);
|
.map(OrderedFloat::into_inner);
|
||||||
document.insert("_semanticSimilarity".to_string(), json!(similarity));
|
document.insert("_semanticSimilarity".to_string(), json!(similarity));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_formatted_options(
|
fn compute_formatted_options(
|
||||||
@ -691,6 +690,22 @@ fn make_document(
|
|||||||
Ok(document)
|
Ok(document)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract the JSON value under the field name specified
|
||||||
|
/// but doesn't support nested objects.
|
||||||
|
fn extract_field(
|
||||||
|
field_name: &str,
|
||||||
|
field_ids_map: &FieldsIdsMap,
|
||||||
|
obkv: obkv::KvReaderU16,
|
||||||
|
) -> Result<Option<serde_json::Value>, MeilisearchHttpError> {
|
||||||
|
match field_ids_map.id(field_name) {
|
||||||
|
Some(fid) => match obkv.get(fid) {
|
||||||
|
Some(value) => Ok(serde_json::from_slice(value).map(Some)?),
|
||||||
|
None => Ok(None),
|
||||||
|
},
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn format_fields<A: AsRef<[u8]>>(
|
fn format_fields<A: AsRef<[u8]>>(
|
||||||
document: &Document,
|
document: &Document,
|
||||||
field_ids_map: &FieldsIdsMap,
|
field_ids_map: &FieldsIdsMap,
|
||||||
|
@ -286,6 +286,23 @@ pub fn normalize_facet(original: &str) -> String {
|
|||||||
CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase()
|
CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Represents either a vector or an array of multiple vectors.
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub struct VectorOrArrayOfVectors {
|
||||||
|
#[serde(with = "either::serde_untagged")]
|
||||||
|
inner: either::Either<Vec<f32>, Vec<Vec<f32>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VectorOrArrayOfVectors {
|
||||||
|
pub fn into_array_of_vectors(self) -> Vec<Vec<f32>> {
|
||||||
|
match self.inner {
|
||||||
|
either::Either::Left(vector) => vec![vector],
|
||||||
|
either::Either::Right(vectors) => vectors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Normalize a vector by dividing the dimensions by the lenght of it.
|
/// Normalize a vector by dividing the dimensions by the lenght of it.
|
||||||
pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
|
pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
|
||||||
let squared: f32 = vector.iter().map(|x| x * x).sum();
|
let squared: f32 = vector.iter().map(|x| x * x).sum();
|
||||||
|
@ -3,11 +3,10 @@ use std::fs::File;
|
|||||||
use std::io;
|
use std::io;
|
||||||
|
|
||||||
use bytemuck::cast_slice;
|
use bytemuck::cast_slice;
|
||||||
use either::Either;
|
|
||||||
use serde_json::from_slice;
|
use serde_json::from_slice;
|
||||||
|
|
||||||
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
||||||
use crate::{FieldId, InternalError, Result};
|
use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
||||||
|
|
||||||
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
||||||
///
|
///
|
||||||
@ -31,9 +30,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||||||
// first we retrieve the _vectors field
|
// first we retrieve the _vectors field
|
||||||
if let Some(vectors) = obkv.get(vectors_fid) {
|
if let Some(vectors) = obkv.get(vectors_fid) {
|
||||||
// extract the vectors
|
// extract the vectors
|
||||||
let vectors: Either<Vec<Vec<f32>>, Vec<f32>> =
|
// TODO return a user error before unwrapping
|
||||||
from_slice(vectors).map_err(InternalError::SerdeJson).unwrap();
|
let vectors = from_slice(vectors)
|
||||||
let vectors = vectors.map_right(|v| vec![v]).into_inner();
|
.map_err(InternalError::SerdeJson)
|
||||||
|
.map(VectorOrArrayOfVectors::into_array_of_vectors)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
for (i, vector) in vectors.into_iter().enumerate() {
|
for (i, vector) in vectors.into_iter().enumerate() {
|
||||||
match u16::try_from(i) {
|
match u16::try_from(i) {
|
||||||
|
Loading…
Reference in New Issue
Block a user