- manual embedder
- multi embedders OK
- clippy + tests OK
This commit is contained in:
Louis Dureuil 2023-12-12 23:39:01 +01:00
parent 922a640188
commit 12940d79a9
No known key found for this signature in database
13 changed files with 292 additions and 140 deletions

View File

@ -305,6 +305,7 @@ NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENT
PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ; PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ;
TaskNotFound , InvalidRequest , NOT_FOUND ; TaskNotFound , InvalidRequest , NOT_FOUND ;
TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ;
TooManyVectors , InvalidRequest , BAD_REQUEST ;
UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableDocument , Internal , BAD_REQUEST ;
UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ;
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ;
@ -362,7 +363,9 @@ impl ErrorCode for milli::Error {
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType,
UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType,
UserError::TooManyVectors(_, _) => Code::TooManyVectors,
UserError::SortError(_) => Code::InvalidSearchSort, UserError::SortError(_) => Code::InvalidSearchSort,
UserError::InvalidMinTypoWordLenSetting(_, _) => { UserError::InvalidMinTypoWordLenSetting(_, _) => {
Code::InvalidSettingsTypoTolerance Code::InvalidSettingsTypoTolerance

View File

@ -235,38 +235,42 @@ pub async fn embed(
index_scheduler: &IndexScheduler, index_scheduler: &IndexScheduler,
index: &milli::Index, index: &milli::Index,
) -> Result<(), ResponseError> { ) -> Result<(), ResponseError> {
if let Some(VectorQuery::String(prompt)) = query.vector.take() { match query.vector.take() {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?; Some(VectorQuery::String(prompt)) => {
let embedder = index_scheduler.embedders(embedder_configs)?; let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedder = index_scheduler.embedders(embedder_configs)?;
let embedder_name = if let Some(HybridQuery { let embedder_name =
semantic_ratio: _, if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
embedder: Some(embedder), &query.hybrid
}) = &query.hybrid {
{ embedder
embedder } else {
} else { "default"
"default" };
};
let embeddings = embedder let embeddings = embedder
.get(embedder_name) .get(embedder_name)
.ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned()))
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
.0 .0
.embed(vec![prompt]) .embed(vec![prompt])
.await .await
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
.pop() .pop()
.expect("No vector returned from embedding"); .expect("No vector returned from embedding");
if embeddings.iter().nth(1).is_some() { if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query"); warn!("Ignoring embeddings past the first one in long search query");
query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); query.vector =
} else { Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); } else {
query.vector = Some(VectorQuery::Vector(embeddings.into_inner()));
}
} }
Some(vector) => query.vector = Some(vector),
None => {}
}; };
Ok(()) Ok(())
} }

View File

@ -13,7 +13,7 @@ use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn; use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document}; use meilisearch_types::{milli, Document};
@ -562,8 +562,17 @@ pub fn perform_search(
insert_geo_distance(sort, &mut document); insert_geo_distance(sort, &mut document);
} }
/// FIXME: remove this or set to value from the score details let mut semantic_score = None;
let semantic_score = None; for details in &score {
if let ScoreDetails::Vector(score_details::Vector {
target_vector: _,
value_similarity: Some((_matching_vector, similarity)),
}) = details
{
semantic_score = Some(*similarity);
break;
}
}
let ranking_score = let ranking_score =
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
@ -648,8 +657,10 @@ pub fn perform_search(
hits: documents, hits: documents,
hits_info, hits_info,
query: query.q.unwrap_or_default(), query: query.q.unwrap_or_default(),
// FIXME: display input vector vector: match query.vector {
vector: None, Some(VectorQuery::Vector(vector)) => Some(vector),
_ => None,
},
processing_time_ms: before_search.elapsed().as_millis(), processing_time_ms: before_search.elapsed().as_millis(),
facet_distribution, facet_distribution,
facet_stats, facet_stats,

View File

@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"### "###
); );
@ -1896,7 +1908,8 @@ async fn import_dump_v6_containing_experimental_features() {
}, },
"pagination": { "pagination": {
"maxTotalHits": 1000 "maxTotalHits": 1000
} },
"embedders": {}
} }
"###); "###);

View File

@ -876,7 +876,31 @@ async fn experimental_feature_vector_store() {
})) }))
.await; .await;
meili_snap::snapshot!(code, @"200 OK"); meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); // vector search returns all documents that don't have vectors in the last bucket, like all sorts
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Shazam!",
"id": "287947"
},
{
"title": "Captain Marvel",
"id": "299537"
},
{
"title": "Escape Room",
"id": "522681"
},
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428"
},
{
"title": "Gläss",
"id": "450465"
}
]
"###);
} }
#[cfg(feature = "default")] #[cfg(feature = "default")]

View File

@ -54,7 +54,7 @@ async fn get_settings() {
let (response, code) = index.settings().await; let (response, code) = index.settings().await;
assert_eq!(code, 200); assert_eq!(code, 200);
let settings = response.as_object().unwrap(); let settings = response.as_object().unwrap();
assert_eq!(settings.keys().len(), 15); assert_eq!(settings.keys().len(), 16);
assert_eq!(settings["displayedAttributes"], json!(["*"])); assert_eq!(settings["displayedAttributes"], json!(["*"]));
assert_eq!(settings["searchableAttributes"], json!(["*"])); assert_eq!(settings["searchableAttributes"], json!(["*"]));
assert_eq!(settings["filterableAttributes"], json!([])); assert_eq!(settings["filterableAttributes"], json!([]));
@ -83,6 +83,7 @@ async fn get_settings() {
"maxTotalHits": 1000, "maxTotalHits": 1000,
}) })
); );
assert_eq!(settings["embedders"], json!({}));
} }
#[actix_rt::test] #[actix_rt::test]

View File

@ -114,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
InvalidGeoField(#[from] GeoError), InvalidGeoField(#[from] GeoError),
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
InvalidVectorDimensions { expected: usize, found: usize }, InvalidVectorDimensions { expected: usize, found: usize },
#[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
InvalidVectorsType { document_id: Value, value: Value }, InvalidVectorsType { document_id: Value, value: Value, subfield: String },
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
InvalidVectorsMapType { document_id: Value, value: Value },
#[error("{0}")] #[error("{0}")]
InvalidFilter(String), InvalidFilter(String),
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
@ -196,6 +198,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
TooManyEmbedders(usize), TooManyEmbedders(usize),
#[error("Cannot find embedder with name {0}.")] #[error("Cannot find embedder with name {0}.")]
InvalidEmbedder(String), InvalidEmbedder(String),
#[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")]
TooManyVectors(String, usize),
} }
impl From<crate::vector::Error> for Error { impl From<crate::vector::Error> for Error {

View File

@ -73,6 +73,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
indexer: GrenadParameters, indexer: GrenadParameters,
field_id_map: &FieldsIdsMap, field_id_map: &FieldsIdsMap,
prompt: &Prompt, prompt: &Prompt,
embedder_name: &str,
) -> Result<ExtractedVectorPoints> { ) -> Result<ExtractedVectorPoints> {
puffin::profile_function!(); puffin::profile_function!();
@ -115,89 +116,87 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
// lazily get it when needed // lazily get it when needed
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) { let vectors_field = vectors_fid
let vectors_obkv = KvReaderDelAdd::new(value); .and_then(|vectors_fid| obkv.get(vectors_fid))
match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { .map(KvReaderDelAdd::new)
(Some(old), Some(new)) => { .map(|obkv| to_vector_maps(obkv, document_id))
// no autogeneration .transpose()?;
let del_vectors = extract_vectors(old, document_id)?;
let add_vectors = extract_vectors(new, document_id)?;
VectorStateDelta::ManualDelta( let (del_map, add_map) = vectors_field.unzip();
del_vectors.unwrap_or_default(), let del_map = del_map.flatten();
add_vectors.unwrap_or_default(), let add_map = add_map.flatten();
)
}
(None, Some(new)) => {
// was possibly autogenerated, remove all vectors for that document
let add_vectors = extract_vectors(new, document_id)?;
VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
} let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
(Some(_old), None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
}
(None, None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept { let delta = match (del_value, add_value) {
// Don't give up if the old prompt was failing (Some(old), Some(new)) => {
let old_prompt = // no autogeneration
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); let del_vectors = extract_vectors(old, document_id, embedder_name)?;
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if old_prompt != new_prompt {
log::trace!( if add_vectors.len() > u8::MAX.into() {
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
); document_id().to_string(),
VectorStateDelta::NowGenerated(new_prompt) add_vectors.len(),
} else { )));
log::trace!("⏭️ Prompt unmodified, skipping"); }
VectorStateDelta::NoChange
} VectorStateDelta::ManualDelta(del_vectors, add_vectors)
} else { }
VectorStateDelta::NowRemoved (Some(_old), None) => {
} // Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
} }
} }
} else { (None, Some(new)) => {
// Do we keep this document? // was possibly autogenerated, remove all vectors for that document
let document_is_kept = obkv let add_vectors = extract_vectors(new, document_id, embedder_name)?;
.iter() if add_vectors.len() > u8::MAX.into() {
.map(|(_, deladd)| KvReaderDelAdd::new(deladd)) return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
.any(|deladd| deladd.get(DelAdd::Addition).is_some()); document_id().to_string(),
add_vectors.len(),
if document_is_kept { )));
// Don't give up if the old prompt was failing }
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); VectorStateDelta::WasGeneratedNowManual(add_vectors)
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; }
if old_prompt != new_prompt { (None, None) => {
log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"); // Do we keep this document?
VectorStateDelta::NowGenerated(new_prompt) let document_is_kept = obkv
} else { .iter()
log::trace!("⏭️ Prompt unmodified, skipping"); .map(|(_, deladd)| KvReaderDelAdd::new(deladd))
VectorStateDelta::NoChange .any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
} }
} else {
VectorStateDelta::NowRemoved
} }
}; };
@ -221,6 +220,34 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
}) })
} }
fn to_vector_maps(
obkv: KvReaderDelAdd,
document_id: impl Fn() -> Value,
) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> {
let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?;
let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?;
Ok((del, add))
}
fn to_vector_map(
obkv: KvReaderDelAdd,
side: DelAdd,
document_id: &impl Fn() -> Value,
) -> Result<Option<serde_json::Map<String, Value>>> {
Ok(if let Some(value) = obkv.get(side) {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
document_id: document_id(),
value,
}));
};
Some(value)
} else {
None
})
}
/// Computes the diff between both Del and Add numbers and /// Computes the diff between both Del and Add numbers and
/// only inserts the parts that differ in the sorter. /// only inserts the parts that differ in the sorter.
fn push_vectors_diff( fn push_vectors_diff(
@ -286,12 +313,20 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
} }
/// Extracts the vectors from a JSON value. /// Extracts the vectors from a JSON value.
fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> { fn extract_vectors(
match from_slice(value) { value: Value,
Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), document_id: impl Fn() -> Value,
name: &str,
) -> Result<Vec<Vec<f32>>> {
// FIXME: ugly clone of the vectors here
match serde_json::from_value(value.clone()) {
Ok(vectors) => {
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
}
Err(_) => Err(UserError::InvalidVectorsType { Err(_) => Err(UserError::InvalidVectorsType {
document_id: document_id(), document_id: document_id(),
value: from_slice(value).map_err(InternalError::SerdeJson)?, value,
subfield: name.to_owned(),
} }
.into()), .into()),
} }

View File

@ -298,6 +298,7 @@ fn send_original_documents_data(
indexer, indexer,
&field_id_map, &field_id_map,
&prompt, &prompt,
&name,
); );
match result { match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {

View File

@ -514,16 +514,18 @@ where
// We write the primary key field id into the main database // We write the primary key field id into the main database
self.index.put_primary_key(self.wtxn, &primary_key)?; self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?;
let mut rng = rand::rngs::StdRng::from_entropy();
for (embedder_name, dimension) in dimension { for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn; let wtxn = &mut *self.wtxn;
let vector_arroy = self.index.vector_arroy; let vector_arroy = self.index.vector_arroy;
/// FIXME: unwrap
let embedder_index = let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
pool.install(|| { pool.install(|| {
let writer_index = (embedder_index as u16) << 8; let writer_index = (embedder_index as u16) << 8;
let mut rng = rand::rngs::StdRng::from_entropy();
for k in 0..=u8::MAX { for k in 0..=u8::MAX {
let writer = arroy::Writer::prepare( let writer = arroy::Writer::prepare(
wtxn, wtxn,

View File

@ -22,7 +22,9 @@ use crate::index::db_name::DOCUMENTS;
use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd};
use crate::update::facet::FacetsUpdate; use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; use crate::{
lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError,
};
pub(crate) enum TypedChunk { pub(crate) enum TypedChunk {
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>), FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
@ -363,8 +365,9 @@ pub(crate) fn write_typed_chunk_into_index(
expected_dimension, expected_dimension,
embedder_name, embedder_name,
} => { } => {
/// FIXME: unwrap let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
let writer_index = (embedder_index as u16) << 8; let writer_index = (embedder_index as u16) << 8;
// FIXME: allow customizing distance // FIXME: allow customizing distance
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX) let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
@ -404,7 +407,20 @@ pub(crate) fn write_typed_chunk_into_index(
// code error if we somehow got the wrong dimension // code error if we somehow got the wrong dimension
.unwrap(); .unwrap();
/// FIXME: detect overflow if embeddings.embedding_count() > u8::MAX.into() {
let external_docid = if let Ok(Some(Ok(index))) = index
.external_id_of(wtxn, std::iter::once(docid))
.map(|it| it.into_iter().next())
{
index
} else {
format!("internal docid={docid}")
};
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
external_docid,
embeddings.embedding_count(),
)));
}
for (embedding, writer) in embeddings.iter().zip(&writers) { for (embedding, writer) in embeddings.iter().zip(&writers) {
writer.add_item(wtxn, docid, embedding)?; writer.add_item(wtxn, docid, embedding)?;
} }
@ -455,7 +471,7 @@ pub(crate) fn write_typed_chunk_into_index(
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
let vector = pod_collect_to_vec(value); let vector = pod_collect_to_vec(value);
/// FIXME: detect overflow // overflow was detected during vector extraction.
for writer in &writers { for writer in &writers {
if !writer.contains_item(wtxn, docid)? { if !writer.contains_item(wtxn, docid)? {
writer.add_item(wtxn, docid, &vector)?; writer.add_item(wtxn, docid, &vector)?;

View File

@ -0,0 +1,34 @@
use super::error::EmbedError;
use super::Embeddings;
#[derive(Debug, Clone, Copy)]
pub struct Embedder {
dimensions: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub dimensions: usize,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions }
}
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let Some(text) = texts.pop() else { return Ok(Default::default()) };
Err(EmbedError::embed_on_manual_embedder(text))
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
}

View File

@ -31,6 +31,10 @@ impl<F> Embeddings<F> {
Ok(this) Ok(this)
} }
pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension
}
pub fn dimension(&self) -> usize { pub fn dimension(&self) -> usize {
self.dimension self.dimension
} }