mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-22 01:57:41 +08:00
Fix vector parsing
This commit is contained in:
parent
d97af4d8e6
commit
4706a0eb49
@ -167,7 +167,7 @@ fn entry_from_raw_value(
|
||||
value: &RawValue,
|
||||
has_configured_embedder: bool,
|
||||
) -> std::result::Result<VectorEntry<'_>, serde_json::Error> {
|
||||
let value: RawVectors = serde_json::from_str(value.get())?;
|
||||
let value: RawVectors = RawVectors::from_raw_value(value)?;
|
||||
|
||||
Ok(match value {
|
||||
RawVectors::Explicit(raw_explicit_vectors) => VectorEntry {
|
||||
@ -177,7 +177,7 @@ fn entry_from_raw_value(
|
||||
},
|
||||
RawVectors::ImplicitlyUserProvided(value) => VectorEntry {
|
||||
has_configured_embedder,
|
||||
embeddings: Some(Embeddings::FromJsonImplicityUserProvided(value)),
|
||||
embeddings: value.map(Embeddings::FromJsonImplicityUserProvided),
|
||||
regenerate: false,
|
||||
},
|
||||
})
|
||||
|
@ -12,11 +12,96 @@ use crate::{DocumentId, FieldId, InternalError, UserError};
|
||||
|
||||
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum RawVectors<'doc> {
|
||||
Explicit(#[serde(borrow)] RawExplicitVectors<'doc>),
|
||||
ImplicitlyUserProvided(#[serde(borrow)] &'doc RawValue),
|
||||
ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>),
|
||||
}
|
||||
|
||||
impl<'doc> RawVectors<'doc> {
|
||||
pub fn from_raw_value(raw: &'doc RawValue) -> Result<Self, serde_json::Error> {
|
||||
use serde::de::Deserializer as _;
|
||||
Ok(match raw.deserialize_any(RawVectorsVisitor)? {
|
||||
RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None),
|
||||
RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)),
|
||||
RawVectorsVisitorValue::Explicit { regenerate, embeddings } => {
|
||||
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate })
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct RawVectorsVisitor;
|
||||
|
||||
enum RawVectorsVisitorValue<'doc> {
|
||||
ImplicitNone,
|
||||
Implicit,
|
||||
Explicit { regenerate: bool, embeddings: Option<&'doc RawValue> },
|
||||
}
|
||||
|
||||
impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
|
||||
type Value = RawVectorsVisitorValue<'doc>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(formatter, "a map containing at least `regenerate`, or an array of floats`")
|
||||
}
|
||||
|
||||
fn visit_none<E>(self) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(RawVectorsVisitorValue::ImplicitNone)
|
||||
}
|
||||
|
||||
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'doc>,
|
||||
{
|
||||
deserializer.deserialize_any(self)
|
||||
}
|
||||
|
||||
fn visit_unit<E>(self) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(RawVectorsVisitorValue::ImplicitNone)
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'doc>,
|
||||
{
|
||||
// must consume all elements or parsing fails
|
||||
while let Some(_) = seq.next_element::<&RawValue>()? {}
|
||||
Ok(RawVectorsVisitorValue::Implicit)
|
||||
}
|
||||
|
||||
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::MapAccess<'doc>,
|
||||
{
|
||||
use serde::de::Error as _;
|
||||
let mut regenerate = None;
|
||||
let mut embeddings = None;
|
||||
while let Some(s) = map.next_key()? {
|
||||
match s {
|
||||
"regenerate" => {
|
||||
let value: bool = map.next_value()?;
|
||||
regenerate = Some(value);
|
||||
}
|
||||
"embeddings" => {
|
||||
let value: &RawValue = map.next_value()?;
|
||||
embeddings = Some(value);
|
||||
}
|
||||
other => return Err(A::Error::unknown_field(other, &["regenerate", "embeddings"])),
|
||||
}
|
||||
}
|
||||
let Some(regenerate) = regenerate else {
|
||||
return Err(A::Error::missing_field("regenerate"));
|
||||
};
|
||||
Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
@ -86,7 +171,7 @@ impl<'doc> RawVectors<'doc> {
|
||||
}
|
||||
pub fn embeddings(&self) -> Option<&'doc RawValue> {
|
||||
match self {
|
||||
RawVectors::ImplicitlyUserProvided(embeddings) => Some(embeddings),
|
||||
RawVectors::ImplicitlyUserProvided(embeddings) => *embeddings,
|
||||
RawVectors::Explicit(RawExplicitVectors { embeddings, regenerate: _ }) => *embeddings,
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user