Add vectors field and geo field to document trait

This commit is contained in:
Louis Dureuil 2024-10-21 10:36:27 +02:00
parent 73e29ee155
commit c278024709
No known key found for this signature in database
3 changed files with 158 additions and 58 deletions

View File

@ -1,9 +1,9 @@
use std::collections::BTreeSet;
use heed::RoTxn;
use raw_collections::RawMap;
use serde_json::value::RawValue;
use super::document_change::{Entry, Versions};
use super::{KvReaderFieldId, KvWriterFieldId};
use crate::documents::FieldIdMapper;
use crate::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME;
@ -17,11 +17,26 @@ pub trait Document<'doc> {
/// Iterate over all **top-level** fields of the document, returning their name and raw JSON value.
///
/// - The returned values *may* contain nested fields.
/// - The `_vectors` field is **ignored** by this method, meaning it is **not returned** by this method.
/// - The `_vectors` and `_geo` fields are **ignored** by this method, meaning they are **not returned** by this method.
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'doc str, &'doc RawValue)>>;
/// Returns the unparsed value of the `_vectors` field from the document data.
///
/// This field alone is insufficient to retrieve vectors, as they may be stored in a dedicated location in the database.
/// Use a [`super::vector_document::VectorDocument`] to access the vector.
///
/// This method is meant as a convenience for implementors of [`super::vector_document::VectorDocument`].
fn vectors_field(&self) -> Result<Option<&'doc RawValue>>;
/// Returns the unparsed value of the `_geo` field from the document data.
///
/// This field alone is insufficient to retrieve geo data, as they may be stored in a dedicated location in the database.
/// Use a [`super::geo_document::GeoDocument`] to access the vector.
///
/// This method is meant as a convenience for implementors of [`super::geo_document::GeoDocument`].
fn geo_field(&self) -> Result<Option<&'doc RawValue>>;
}
#[derive(Clone, Copy)]
pub struct DocumentFromDb<'t, Mapper: FieldIdMapper>
where
Mapper: FieldIdMapper,
@ -30,6 +45,14 @@ where
content: &'t KvReaderFieldId,
}
impl<'t, Mapper: FieldIdMapper> Clone for DocumentFromDb<'t, Mapper> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<'t, Mapper: FieldIdMapper> Copy for DocumentFromDb<'t, Mapper> {}
impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> {
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'t str, &'t RawValue)>> {
let mut it = self.content.iter();
@ -53,6 +76,14 @@ impl<'t, Mapper: FieldIdMapper> Document<'t> for DocumentFromDb<'t, Mapper> {
Some(res)
})
}
fn vectors_field(&self) -> Result<Option<&'t RawValue>> {
self.field(RESERVED_VECTORS_FIELD_NAME)
}
fn geo_field(&self) -> Result<Option<&'t RawValue>> {
self.field("_geo")
}
}
impl<'t, Mapper: FieldIdMapper> DocumentFromDb<'t, Mapper> {
@ -66,6 +97,14 @@ impl<'t, Mapper: FieldIdMapper> DocumentFromDb<'t, Mapper> {
reader.map(|reader| Self { fields_ids_map: db_fields_ids_map, content: reader })
})
}
pub fn field(&self, name: &str) -> Result<Option<&'t RawValue>> {
let Some(fid) = self.fields_ids_map.id(name) else {
return Ok(None);
};
let Some(value) = self.content.get(fid) else { return Ok(None) };
Ok(Some(serde_json::from_slice(value).map_err(InternalError::SerdeJson)?))
}
}
#[derive(Clone, Copy)]
@ -81,29 +120,15 @@ impl<'doc> DocumentFromVersions<'doc> {
impl<'doc> Document<'doc> for DocumentFromVersions<'doc> {
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'doc str, &'doc RawValue)>> {
match &self.versions {
Versions::Single(version) => either::Either::Left(version.iter_top_level_fields()),
Versions::Multiple(versions) => {
let mut seen_fields = BTreeSet::new();
let mut it = versions.iter().rev().flat_map(|version| version.iter()).copied();
either::Either::Right(std::iter::from_fn(move || loop {
let (name, value) = it.next()?;
self.versions.iter_top_level_fields().map(Ok)
}
if seen_fields.contains(name) {
continue;
fn vectors_field(&self) -> Result<Option<&'doc RawValue>> {
Ok(self.versions.vectors_field())
}
seen_fields.insert(name);
return Some(Ok((name, value)));
}))
}
}
}
}
// used in document from payload
impl<'doc> Document<'doc> for &'doc [Entry<'doc>] {
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<Entry<'doc>>> {
self.iter().copied().map(|(k, v)| Ok((k, v)))
fn geo_field(&self) -> Result<Option<&'doc RawValue>> {
Ok(self.versions.geo_field())
}
}
@ -164,6 +189,26 @@ impl<'d, 'doc: 'd, 't: 'd, Mapper: FieldIdMapper> Document<'d>
}
})
}
fn vectors_field(&self) -> Result<Option<&'d RawValue>> {
if let Some(vectors) = self.new_doc.vectors_field()? {
return Ok(Some(vectors));
}
let Some(db) = self.db else { return Ok(None) };
db.vectors_field()
}
fn geo_field(&self) -> Result<Option<&'d RawValue>> {
if let Some(geo) = self.new_doc.geo_field()? {
return Ok(Some(geo));
}
let Some(db) = self.db else { return Ok(None) };
db.geo_field()
}
}
impl<'doc, D> Document<'doc> for &D
@ -173,6 +218,14 @@ where
fn iter_top_level_fields(&self) -> impl Iterator<Item = Result<(&'doc str, &'doc RawValue)>> {
D::iter_top_level_fields(self)
}
fn vectors_field(&self) -> Result<Option<&'doc RawValue>> {
D::vectors_field(self)
}
fn geo_field(&self) -> Result<Option<&'doc RawValue>> {
D::geo_field(self)
}
}
/// Turn this document into an obkv, whose fields are indexed by the provided `FieldIdMapper`.
@ -245,3 +298,52 @@ where
writer.finish().unwrap();
Ok(KvReaderFieldId::from_slice(document_buffer))
}
pub type Entry<'doc> = (&'doc str, &'doc RawValue);
#[derive(Clone, Copy)]
pub struct Versions<'doc> {
data: &'doc [Entry<'doc>],
vectors: Option<&'doc RawValue>,
geo: Option<&'doc RawValue>,
}
impl<'doc> Versions<'doc> {
pub fn multiple(
mut versions: impl Iterator<Item = Result<RawMap<'doc>>>,
) -> Result<Option<Self>> {
let Some(data) = versions.next() else { return Ok(None) };
let mut data = data?;
for future_version in versions {
let future_version = future_version?;
for (field, value) in future_version {
data.insert(field, value);
}
}
Ok(Some(Self::single(data)))
}
pub fn single(version: RawMap<'doc>) -> Self {
let vectors_id = version.get_index(RESERVED_VECTORS_FIELD_NAME);
let geo_id = version.get_index("_geo");
let mut data = version.into_vec();
let geo = geo_id.map(|geo_id| data.remove(geo_id).1);
let vectors = vectors_id.map(|vectors_id| data.remove(vectors_id).1);
let data = data.into_bump_slice();
Self { data, geo, vectors }
}
pub fn iter_top_level_fields(&self) -> impl Iterator<Item = Entry<'doc>> {
self.data.iter().copied()
}
pub fn vectors_field(&self) -> Option<&'doc RawValue> {
self.vectors
}
pub fn geo_field(&self) -> Option<&'doc RawValue> {
self.geo
}
}

View File

@ -1,5 +1,4 @@
use heed::RoTxn;
use serde_json::value::RawValue;
use super::document::{DocumentFromDb, DocumentFromVersions, MergedDocument};
use crate::documents::FieldIdMapper;
@ -138,11 +137,3 @@ impl<'doc> Update<'doc> {
}
}
}
pub type Entry<'doc> = (&'doc str, &'doc RawValue);
#[derive(Clone, Copy)]
pub enum Versions<'doc> {
Single(&'doc [Entry<'doc>]),
Multiple(&'doc [&'doc [Entry<'doc>]]),
}

View File

@ -2,7 +2,6 @@ use bumpalo::collections::CollectIn;
use bumpalo::Bump;
use heed::RoTxn;
use memmap2::Mmap;
use rayon::iter::IntoParallelIterator;
use rayon::slice::ParallelSlice;
use serde_json::value::RawValue;
use IndexDocumentsMethod as Idm;
@ -10,8 +9,7 @@ use IndexDocumentsMethod as Idm;
use super::super::document_change::DocumentChange;
use super::document_changes::{DocumentChangeContext, DocumentChanges, MostlySend};
use crate::documents::PrimaryKey;
use crate::update::new::document::DocumentFromVersions;
use crate::update::new::document_change::Versions;
use crate::update::new::document::{DocumentFromVersions, Versions};
use crate::update::new::{Deletion, Insertion, Update};
use crate::update::{AvailableIds, IndexDocumentsMethod};
use crate::{DocumentId, Error, FieldsIdsMap, Index, Result, UserError};
@ -291,8 +289,7 @@ impl MergeChanges for MergeDocumentForReplacement {
let document = raw_collections::RawMap::from_raw_value(document, doc_alloc)
.map_err(UserError::SerdeJson)?;
let document = document.into_bump_slice();
let document = DocumentFromVersions::new(Versions::Single(document));
let document = DocumentFromVersions::new(Versions::single(document));
if is_new {
Ok(Some(DocumentChange::Insertion(Insertion::create(
@ -365,9 +362,22 @@ impl MergeChanges for MergeDocumentForUpdates {
};
}
let mut versions = bumpalo::collections::Vec::with_capacity_in(operations.len(), doc_alloc);
let versions = match operations {
[single] => {
let DocumentOffset { content } = match single {
InnerDocOp::Addition(offset) => offset,
InnerDocOp::Deletion => {
unreachable!("Deletion in document operations")
}
};
let document = serde_json::from_slice(content).unwrap();
let document = raw_collections::RawMap::from_raw_value(document, doc_alloc)
.map_err(UserError::SerdeJson)?;
for operation in operations {
Some(Versions::single(document))
}
operations => {
let versions = operations.iter().map(|operation| {
let DocumentOffset { content } = match operation {
InnerDocOp::Addition(offset) => offset,
InnerDocOp::Deletion => {
@ -378,17 +388,14 @@ impl MergeChanges for MergeDocumentForUpdates {
let document = serde_json::from_slice(content).unwrap();
let document = raw_collections::RawMap::from_raw_value(document, doc_alloc)
.map_err(UserError::SerdeJson)?;
let document = document.into_bump_slice();
versions.push(document);
Ok(document)
});
Versions::multiple(versions)?
}
let versions = versions.into_bump_slice();
let versions = match versions {
[single] => Versions::Single(single),
versions => Versions::Multiple(versions),
};
let Some(versions) = versions else { return Ok(None) };
let document = DocumentFromVersions::new(versions);
if is_new {