mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-29 16:45:30 +08:00
Merge #4740
4740: Make `embeddings` optional and improve error message for `regenerate` r=dureuill a=irevoire # Pull Request ## Related issue Fixes https://github.com/meilisearch/meilisearch/issues/4741 ## What does this PR do? - Make the `embeddings` parameter optional when manually specifying embeddings for an embedder - Adds a lot of tests around malformed `_vectors.embedder` objects - Use `deserr` to deserialize the `_vectors.embedder` field, improving error messages Co-authored-by: Tamo <tamo@meilisearch.com>
This commit is contained in:
commit
f6a00f4a90
@ -398,7 +398,8 @@ impl ErrorCode for milli::Error {
|
|||||||
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
|
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
|
||||||
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
|
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
|
||||||
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
|
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
|
||||||
UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType,
|
UserError::InvalidVectorsMapType { .. }
|
||||||
|
| UserError::InvalidVectorsEmbedderConf { .. } => Code::InvalidVectorsType,
|
||||||
UserError::TooManyVectors(_, _) => Code::TooManyVectors,
|
UserError::TooManyVectors(_, _) => Code::TooManyVectors,
|
||||||
UserError::SortError(_) => Code::InvalidSearchSort,
|
UserError::SortError(_) => Code::InvalidSearchSort,
|
||||||
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
||||||
|
@ -190,6 +190,285 @@ async fn generate_default_user_provided_documents(server: &Server) -> Index {
|
|||||||
index
|
index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn user_provided_embeddings_error() {
|
||||||
|
let server = Server::new().await;
|
||||||
|
let index = generate_default_user_provided_documents(&server).await;
|
||||||
|
|
||||||
|
// First case, we forget to specify the `regenerate`
|
||||||
|
let documents =
|
||||||
|
json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [0, 0, 0] }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 2,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Missing field `regenerate` inside `.manual`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
// Second case, we don't specify anything
|
||||||
|
let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": {}}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 3,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Missing field `regenerate` inside `.manual`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
// Third case, we specify something wrong in place of regenerate
|
||||||
|
let documents =
|
||||||
|
json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "regenerate": "yes please" }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 4,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.regenerate`: expected a boolean, but found a string: `\"yes please\"`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let documents =
|
||||||
|
json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": true }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 5,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings`: expected null or an array, but found a boolean: `true`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let documents =
|
||||||
|
json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [true] }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 6,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[0]`: expected a number or an array, but found a boolean: `true`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let documents =
|
||||||
|
json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [[true]] }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 7,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[0][0]`: expected a number, but found a boolean: `true`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "embeddings": [23, 0.1, -12], "regenerate": true }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task["status"], @r###""succeeded""###);
|
||||||
|
|
||||||
|
let documents =
|
||||||
|
json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "regenerate": false }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task["status"], @r###""succeeded""###);
|
||||||
|
|
||||||
|
let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "regenerate": false, "embeddings": [0.1, [0.2, 0.3]] }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 10,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[1]`: expected a number, but found an array: `[0.2,0.3]`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "regenerate": false, "embeddings": [[0.1, 0.2], 0.3] }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 11,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[1]`: expected an array, but found a number: `0.3`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let documents = json!({"id": 0, "name": "kefir", "_vectors": { "manual": { "regenerate": false, "embeddings": [[0.1, true], 0.3] }}});
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": 12,
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "failed",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 0
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"message": "Bad embedder configuration in the document with id: `\"0\"`. Invalid value type at `.manual.embeddings[0][1]`: expected a number, but found a boolean: `true`",
|
||||||
|
"code": "invalid_vectors_type",
|
||||||
|
"type": "invalid_request",
|
||||||
|
"link": "https://docs.meilisearch.com/errors#invalid_vectors_type"
|
||||||
|
},
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
}
|
||||||
|
|
||||||
#[actix_rt::test]
|
#[actix_rt::test]
|
||||||
async fn clear_documents() {
|
async fn clear_documents() {
|
||||||
let server = Server::new().await;
|
let server = Server::new().await;
|
||||||
|
@ -119,6 +119,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
|||||||
InvalidVectorDimensions { expected: usize, found: usize },
|
InvalidVectorDimensions { expected: usize, found: usize },
|
||||||
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
|
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
|
||||||
InvalidVectorsMapType { document_id: String, value: Value },
|
InvalidVectorsMapType { document_id: String, value: Value },
|
||||||
|
#[error("Bad embedder configuration in the document with id: `{document_id}`. {error}")]
|
||||||
|
InvalidVectorsEmbedderConf { document_id: String, error: deserr::errors::JsonError },
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
InvalidFilter(String),
|
InvalidFilter(String),
|
||||||
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
|
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use std::collections::{BTreeMap, BTreeSet};
|
use std::collections::{BTreeMap, BTreeSet};
|
||||||
|
|
||||||
|
use deserr::{take_cf_content, DeserializeError, Deserr, Sequence};
|
||||||
use obkv::KvReader;
|
use obkv::KvReader;
|
||||||
use serde_json::{from_slice, Value};
|
use serde_json::{from_slice, Value};
|
||||||
|
|
||||||
@ -10,13 +11,44 @@ use crate::{DocumentId, FieldId, InternalError, UserError};
|
|||||||
|
|
||||||
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
#[derive(serde::Serialize, Debug)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum Vectors {
|
pub enum Vectors {
|
||||||
ImplicitlyUserProvided(VectorOrArrayOfVectors),
|
ImplicitlyUserProvided(VectorOrArrayOfVectors),
|
||||||
Explicit(ExplicitVectors),
|
Explicit(ExplicitVectors),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<E: DeserializeError> Deserr<E> for Vectors {
|
||||||
|
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||||
|
value: deserr::Value<V>,
|
||||||
|
location: deserr::ValuePointerRef,
|
||||||
|
) -> Result<Self, E> {
|
||||||
|
match value {
|
||||||
|
deserr::Value::Sequence(_) | deserr::Value::Null => {
|
||||||
|
Ok(Vectors::ImplicitlyUserProvided(VectorOrArrayOfVectors::deserialize_from_value(
|
||||||
|
value, location,
|
||||||
|
)?))
|
||||||
|
}
|
||||||
|
deserr::Value::Map(_) => {
|
||||||
|
Ok(Vectors::Explicit(ExplicitVectors::deserialize_from_value(value, location)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
value => Err(take_cf_content(E::error(
|
||||||
|
None,
|
||||||
|
deserr::ErrorKind::IncorrectValueKind {
|
||||||
|
actual: value,
|
||||||
|
accepted: &[
|
||||||
|
deserr::ValueKind::Sequence,
|
||||||
|
deserr::ValueKind::Map,
|
||||||
|
deserr::ValueKind::Null,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
location,
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Vectors {
|
impl Vectors {
|
||||||
pub fn must_regenerate(&self) -> bool {
|
pub fn must_regenerate(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
@ -37,9 +69,11 @@ impl Vectors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
#[derive(serde::Serialize, Deserr, Debug)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ExplicitVectors {
|
pub struct ExplicitVectors {
|
||||||
|
#[serde(default)]
|
||||||
|
#[deserr(default)]
|
||||||
pub embeddings: Option<VectorOrArrayOfVectors>,
|
pub embeddings: Option<VectorOrArrayOfVectors>,
|
||||||
pub regenerate: bool,
|
pub regenerate: bool,
|
||||||
}
|
}
|
||||||
@ -149,13 +183,20 @@ impl ParsedVectorsDiff {
|
|||||||
|
|
||||||
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
|
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
|
||||||
|
|
||||||
|
impl<E: DeserializeError> Deserr<E> for ParsedVectors {
|
||||||
|
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||||
|
value: deserr::Value<V>,
|
||||||
|
location: deserr::ValuePointerRef,
|
||||||
|
) -> Result<Self, E> {
|
||||||
|
let value = <BTreeMap<String, Vectors>>::deserialize_from_value(value, location)?;
|
||||||
|
Ok(ParsedVectors(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ParsedVectors {
|
impl ParsedVectors {
|
||||||
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
|
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
|
||||||
let Ok(value) = from_slice(value) else {
|
let value: serde_json::Value = from_slice(value).map_err(Error::InternalSerdeJson)?;
|
||||||
let value = from_slice(value).map_err(Error::InternalSerdeJson)?;
|
deserr::deserialize(value).map_err(|error| Error::InvalidEmbedderConf { error })
|
||||||
return Err(Error::InvalidMap(value));
|
|
||||||
};
|
|
||||||
Ok(ParsedVectors(value))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
|
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
|
||||||
@ -165,6 +206,7 @@ impl ParsedVectors {
|
|||||||
|
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
InvalidMap(Value),
|
InvalidMap(Value),
|
||||||
|
InvalidEmbedderConf { error: deserr::errors::JsonError },
|
||||||
InternalSerdeJson(serde_json::Error),
|
InternalSerdeJson(serde_json::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,6 +216,12 @@ impl Error {
|
|||||||
Error::InvalidMap(value) => {
|
Error::InvalidMap(value) => {
|
||||||
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
|
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
|
||||||
}
|
}
|
||||||
|
Error::InvalidEmbedderConf { error } => {
|
||||||
|
crate::Error::UserError(UserError::InvalidVectorsEmbedderConf {
|
||||||
|
document_id,
|
||||||
|
error,
|
||||||
|
})
|
||||||
|
}
|
||||||
Error::InternalSerdeJson(error) => {
|
Error::InternalSerdeJson(error) => {
|
||||||
crate::Error::InternalError(InternalError::SerdeJson(error))
|
crate::Error::InternalError(InternalError::SerdeJson(error))
|
||||||
}
|
}
|
||||||
@ -194,13 +242,84 @@ fn to_vector_map(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Represents either a vector or an array of multiple vectors.
|
/// Represents either a vector or an array of multiple vectors.
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
#[derive(serde::Serialize, Debug)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub struct VectorOrArrayOfVectors {
|
pub struct VectorOrArrayOfVectors {
|
||||||
#[serde(with = "either::serde_untagged_optional")]
|
#[serde(with = "either::serde_untagged_optional")]
|
||||||
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
|
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<E: DeserializeError> Deserr<E> for VectorOrArrayOfVectors {
|
||||||
|
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||||
|
value: deserr::Value<V>,
|
||||||
|
location: deserr::ValuePointerRef,
|
||||||
|
) -> Result<Self, E> {
|
||||||
|
match value {
|
||||||
|
deserr::Value::Null => Ok(VectorOrArrayOfVectors { inner: None }),
|
||||||
|
deserr::Value::Sequence(seq) => {
|
||||||
|
let mut iter = seq.into_iter();
|
||||||
|
match iter.next().map(|v| v.into_value()) {
|
||||||
|
None => {
|
||||||
|
// With the strange way serde serialize the `Either`, we must send the left part
|
||||||
|
// otherwise it'll consider we returned [[]]
|
||||||
|
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(Vec::new())) })
|
||||||
|
}
|
||||||
|
Some(val @ deserr::Value::Sequence(_)) => {
|
||||||
|
let first = Embedding::deserialize_from_value(val, location.push_index(0))?;
|
||||||
|
let mut collect = vec![first];
|
||||||
|
let mut tail = iter
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, v)| {
|
||||||
|
Embedding::deserialize_from_value(
|
||||||
|
v.into_value(),
|
||||||
|
location.push_index(i + 1),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
collect.append(&mut tail);
|
||||||
|
|
||||||
|
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(collect)) })
|
||||||
|
}
|
||||||
|
Some(
|
||||||
|
val @ deserr::Value::Integer(_)
|
||||||
|
| val @ deserr::Value::NegativeInteger(_)
|
||||||
|
| val @ deserr::Value::Float(_),
|
||||||
|
) => {
|
||||||
|
let first = <f32>::deserialize_from_value(val, location.push_index(0))?;
|
||||||
|
let mut embedding = iter
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, v)| {
|
||||||
|
<f32>::deserialize_from_value(
|
||||||
|
v.into_value(),
|
||||||
|
location.push_index(i + 1),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
embedding.insert(0, first);
|
||||||
|
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Right(embedding)) })
|
||||||
|
}
|
||||||
|
Some(value) => Err(take_cf_content(E::error(
|
||||||
|
None,
|
||||||
|
deserr::ErrorKind::IncorrectValueKind {
|
||||||
|
actual: value,
|
||||||
|
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Float],
|
||||||
|
},
|
||||||
|
location.push_index(0),
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
value => Err(take_cf_content(E::error(
|
||||||
|
None,
|
||||||
|
deserr::ErrorKind::IncorrectValueKind {
|
||||||
|
actual: value,
|
||||||
|
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Null],
|
||||||
|
},
|
||||||
|
location,
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl VectorOrArrayOfVectors {
|
impl VectorOrArrayOfVectors {
|
||||||
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
|
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
|
||||||
match self.inner? {
|
match self.inner? {
|
||||||
@ -234,15 +353,19 @@ impl From<Vec<Embedding>> for VectorOrArrayOfVectors {
|
|||||||
mod test {
|
mod test {
|
||||||
use super::VectorOrArrayOfVectors;
|
use super::VectorOrArrayOfVectors;
|
||||||
|
|
||||||
|
fn embedding_from_str(s: &str) -> Result<VectorOrArrayOfVectors, deserr::errors::JsonError> {
|
||||||
|
let value: serde_json::Value = serde_json::from_str(s).unwrap();
|
||||||
|
deserr::deserialize(value)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn array_of_vectors() {
|
fn array_of_vectors() {
|
||||||
let null: VectorOrArrayOfVectors = serde_json::from_str("null").unwrap();
|
let null = embedding_from_str("null").unwrap();
|
||||||
let empty: VectorOrArrayOfVectors = serde_json::from_str("[]").unwrap();
|
let empty = embedding_from_str("[]").unwrap();
|
||||||
let one: VectorOrArrayOfVectors = serde_json::from_str("[0.1]").unwrap();
|
let one = embedding_from_str("[0.1]").unwrap();
|
||||||
let two: VectorOrArrayOfVectors = serde_json::from_str("[0.1, 0.2]").unwrap();
|
let two = embedding_from_str("[0.1, 0.2]").unwrap();
|
||||||
let one_vec: VectorOrArrayOfVectors = serde_json::from_str("[[0.1, 0.2]]").unwrap();
|
let one_vec = embedding_from_str("[[0.1, 0.2]]").unwrap();
|
||||||
let two_vecs: VectorOrArrayOfVectors =
|
let two_vecs = embedding_from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
|
||||||
serde_json::from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
|
|
||||||
|
|
||||||
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
|
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
|
||||||
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");
|
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");
|
||||||
|
Loading…
Reference in New Issue
Block a user