From 922a640188bd4b4930bf18b6b6ce9b8a73927d28 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 21:19:48 +0100 Subject: [PATCH] WIP multi embedders fixed template bugs --- index-scheduler/src/lib.rs | 1 - meilisearch-types/src/error.rs | 5 ++ .../src/analytics/segment_analytics.rs | 33 ++++++- .../src/routes/indexes/facet_search.rs | 10 ++- meilisearch/src/routes/indexes/search.rs | 51 ++++++++--- meilisearch/src/search.rs | 16 ++++ milli/src/error.rs | 19 +++- milli/src/prompt/mod.rs | 78 +++++++++++++++- milli/src/prompt/template_checker.rs | 45 +++++++--- milli/src/search/new/mod.rs | 2 +- .../extract/extract_vector_points.rs | 89 +++++++------------ .../src/update/index_documents/extract/mod.rs | 47 +++++----- milli/src/update/index_documents/mod.rs | 20 +++-- .../src/update/index_documents/typed_chunk.rs | 21 +++-- milli/src/update/settings.rs | 18 +++- milli/src/vector/error.rs | 15 ++++ milli/src/vector/hf.rs | 21 +++-- milli/src/vector/mod.rs | 24 ++++- milli/src/vector/openai.rs | 23 ++++- milli/src/vector/settings.rs | 58 ++++++++---- 20 files changed, 438 insertions(+), 158 deletions(-) diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index d01b0a17d..65d257ea0 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -1361,7 +1361,6 @@ impl IndexScheduler { let embedder = Arc::new( Embedder::new(embedder_options.clone()) .map_err(meilisearch_types::milli::vector::Error::from) - .map_err(meilisearch_types::milli::UserError::from) .map_err(meilisearch_types::milli::Error::from)?, ); { diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index b1cc7cf82..5df1ae106 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; +InvalidEmbedder , InvalidRequest , BAD_REQUEST ; +InvalidHybridQuery , InvalidRequest , BAD_REQUEST ; InvalidIndexLimit , InvalidRequest , BAD_REQUEST ; InvalidIndexOffset , InvalidRequest , BAD_REQUEST ; InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ; @@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; +InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -340,6 +343,7 @@ impl ErrorCode for milli::Error { } UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::MultiplePrimaryKeyCandidatesFound { .. } => { @@ -363,6 +367,7 @@ impl ErrorCode for milli::Error { UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance } + UserError::InvalidEmbedder(_) => Code::InvalidEmbedder, UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, } } diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index d5f08936d..67770d87c 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats}; use crate::search::{ FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, - DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO, }; use crate::Opt; @@ -586,6 +586,11 @@ pub struct SearchAggregator { // vector // The maximum number of floats in a vector request max_vector_size: usize, + // Whether the semantic ratio passed to a hybrid search equals the default ratio. + semantic_ratio: bool, + // Whether a non-default embedder was specified + embedder: bool, + hybrid: bool, // every time a search is done, we increment the counter linked to the used settings matching_strategy: HashMap, @@ -639,6 +644,7 @@ impl SearchAggregator { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, } = query; let mut ret = Self::default(); @@ -712,6 +718,12 @@ impl SearchAggregator { ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score_details = *show_ranking_score_details; + if let Some(hybrid) = hybrid { + ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO(); + ret.embedder = hybrid.embedder.is_some(); + ret.hybrid = true; + } + ret } @@ -765,6 +777,9 @@ impl SearchAggregator { facets_total_number_of_facets, show_ranking_score, show_ranking_score_details, + semantic_ratio, + embedder, + hybrid, } = other; if self.timestamp.is_none() { @@ -810,6 +825,9 @@ impl SearchAggregator { // vector self.max_vector_size = self.max_vector_size.max(max_vector_size); + self.semantic_ratio |= semantic_ratio; + self.hybrid |= hybrid; + self.embedder |= embedder; // pagination self.max_limit = self.max_limit.max(max_limit); @@ -878,6 +896,9 @@ impl SearchAggregator { facets_total_number_of_facets, show_ranking_score, show_ranking_score_details, + semantic_ratio, + embedder, + hybrid, } = self; if total_received == 0 { @@ -917,6 +938,11 @@ impl SearchAggregator { "vector": { "max_vector_size": max_vector_size, }, + "hybrid": { + "enabled": hybrid, + "semantic_ratio": semantic_ratio, + "embedder": embedder, + }, "pagination": { "max_limit": max_limit, "max_offset": max_offset, @@ -1012,6 +1038,7 @@ impl MultiSearchAggregator { crop_marker: _, matching_strategy: _, attributes_to_search_on: _, + hybrid: _, } = query; index_uid.as_str() @@ -1158,6 +1185,7 @@ impl FacetSearchAggregator { filter, matching_strategy, attributes_to_search_on, + hybrid, } = query; let mut ret = Self::default(); @@ -1171,7 +1199,8 @@ impl FacetSearchAggregator { || vector.is_some() || filter.is_some() || *matching_strategy != MatchingStrategy::default() - || attributes_to_search_on.is_some(); + || attributes_to_search_on.is_some() + || hybrid.is_some(); ret } diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 72440711c..59c0e7353 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -14,9 +14,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator}; use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::search::{ - add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, - DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, + DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -37,6 +37,8 @@ pub struct FacetSearchQuery { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option>, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default, error = DeserrJsonError)] pub filter: Option, #[deserr(default, error = DeserrJsonError, default)] @@ -96,6 +98,7 @@ impl From for SearchQuery { filter, matching_strategy, attributes_to_search_on, + hybrid, } = value; SearchQuery { @@ -120,6 +123,7 @@ impl From for SearchQuery { matching_strategy, vector: vector.map(VectorQuery::Vector), attributes_to_search_on, + hybrid, } } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index e63a95e60..ec4825661 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::VectorQuery; +use meilisearch_types::milli::{self, VectorQuery}; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -17,9 +17,9 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, - DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, + DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -75,6 +75,10 @@ pub struct SearchQueryGet { matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrQueryParamError)] pub attributes_to_search_on: Option>, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_embedder: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_semantic_ratio: Option, } impl From for SearchQuery { @@ -87,6 +91,18 @@ impl From for SearchQuery { None => None, }; + let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) { + (None, None) => None, + (None, Some(semantic_ratio)) => Some(HybridQuery { semantic_ratio, embedder: None }), + (Some(embedder), None) => Some(HybridQuery { + semantic_ratio: DEFAULT_SEMANTIC_RATIO(), + embedder: Some(embedder), + }), + (Some(embedder), Some(semantic_ratio)) => { + Some(HybridQuery { semantic_ratio, embedder: Some(embedder) }) + } + }; + Self { q: other.q, vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), @@ -109,6 +125,7 @@ impl From for SearchQuery { crop_marker: other.crop_marker, matching_strategy: other.matching_strategy, attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), + hybrid, } } } @@ -159,6 +176,9 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); + + embed(&mut query, index_scheduler.get_ref(), &index).await?; + let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; if let Ok(ref search_result) = search_result { @@ -213,22 +233,31 @@ pub async fn search_with_post( pub async fn embed( query: &mut SearchQuery, index_scheduler: &IndexScheduler, - index: &meilisearch_types::milli::Index, + index: &milli::Index, ) -> Result<(), ResponseError> { if let Some(VectorQuery::String(prompt)) = query.vector.take() { let embedder_configs = index.embedding_configs(&index.read_txn()?)?; let embedder = index_scheduler.embedders(embedder_configs)?; - /// FIXME: add error if no embedder, remove unwrap, support multiple embedders + let embedder_name = if let Some(HybridQuery { + semantic_ratio: _, + embedder: Some(embedder), + }) = &query.hybrid + { + embedder + } else { + "default" + }; + let embeddings = embedder - .get("default") - .unwrap() + .get(embedder_name) + .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) + .map_err(milli::Error::from)? .0 .embed(vec![prompt]) .await - .map_err(meilisearch_types::milli::vector::Error::from) - .map_err(meilisearch_types::milli::UserError::from) - .map_err(meilisearch_types::milli::Error::from)? + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? .pop() .expect("No vector returned from embedding"); diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 9136157f9..c1e667570 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -36,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10; pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); +pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5; #[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] @@ -44,6 +45,8 @@ pub struct SearchQuery { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -84,6 +87,15 @@ pub struct SearchQuery { pub attributes_to_search_on: Option>, } +#[derive(Debug, Clone, Default, PartialEq, Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct HybridQuery { + #[deserr(default, error = DeserrJsonError, default = DEFAULT_SEMANTIC_RATIO())] + pub semantic_ratio: f32, + #[deserr(default, error = DeserrJsonError, default)] + pub embedder: Option, +} + impl SearchQuery { pub fn is_finite_pagination(&self) -> bool { self.page.or(self.hits_per_page).is_some() @@ -103,6 +115,8 @@ pub struct SearchQueryWithIndex { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -168,6 +182,7 @@ impl SearchQueryWithIndex { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, } = self; ( index_uid, @@ -193,6 +208,7 @@ impl SearchQueryWithIndex { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, diff --git a/milli/src/error.rs b/milli/src/error.rs index 3d07590b0..95a0aba6d 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -63,6 +63,8 @@ pub enum InternalError { InvalidMatchingWords, #[error(transparent)] ArroyError(#[from] arroy::Error), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), } #[derive(Error, Debug)] @@ -188,8 +190,23 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), #[error(transparent)] InvalidPrompt(#[from] crate::prompt::error::NewPromptError), - #[error("Invalid prompt in for embeddings with name '{0}': {1}")] + #[error("Invalid prompt in for embeddings with name '{0}': {1}.")] InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), + #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] + TooManyEmbedders(usize), + #[error("Cannot find embedder with name {0}.")] + InvalidEmbedder(String), +} + +impl From for Error { + fn from(value: crate::vector::Error) -> Self { + match value.fault() { + FaultSource::User => Error::UserError(value.into()), + FaultSource::Runtime => Error::InternalError(value.into()), + FaultSource::Bug => Error::InternalError(value.into()), + FaultSource::Undecided => Error::InternalError(value.into()), + } + } } impl From for Error { diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs index 351a51bb1..67ef8b4f6 100644 --- a/milli/src/prompt/mod.rs +++ b/milli/src/prompt/mod.rs @@ -110,7 +110,6 @@ impl Prompt { }; // render template with special object that's OK with `doc.*` and `fields.*` - /// FIXME: doesn't work for nested objects e.g. `doc.a.b` this.template .render(&template_checker::TemplateChecker) .map_err(NewPromptError::invalid_fields_in_template)?; @@ -142,3 +141,80 @@ pub enum PromptFallbackStrategy { #[default] Error, } + +#[cfg(test)] +mod test { + use super::Prompt; + use crate::error::FaultSource; + use crate::prompt::error::{NewPromptError, NewPromptErrorKind}; + + #[test] + fn default_template() { + // does not panic + Prompt::default(); + } + + #[test] + fn empty_template() { + Prompt::new("".into(), None, None).unwrap(); + } + + #[test] + fn template_ok() { + Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None, None).unwrap(); + } + + #[test] + fn template_syntax() { + assert!(matches!( + Prompt::new("{{doc.title: {{doc.overview}}".into(), None, None), + Err(NewPromptError { + kind: NewPromptErrorKind::CannotParseTemplate(_), + fault: FaultSource::User + }) + )); + } + + #[test] + fn template_missing_doc() { + assert!(matches!( + Prompt::new("{{title}}: {{overview}}".into(), None, None), + Err(NewPromptError { + kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), + fault: FaultSource::User + }) + )); + } + + #[test] + fn template_nested_doc() { + Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None, None).unwrap(); + } + + #[test] + fn template_fields() { + Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None, None).unwrap(); + } + + #[test] + fn template_fields_ok() { + Prompt::new( + "{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(), + None, + None, + ) + .unwrap(); + } + + #[test] + fn template_fields_invalid() { + assert!(matches!( + // intentionally garbled field + Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None, None), + Err(NewPromptError { + kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), + fault: FaultSource::User + }) + )); + } +} diff --git a/milli/src/prompt/template_checker.rs b/milli/src/prompt/template_checker.rs index 641a9ed64..4cda4a70d 100644 --- a/milli/src/prompt/template_checker.rs +++ b/milli/src/prompt/template_checker.rs @@ -1,7 +1,7 @@ use liquid::model::{ ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, }; -use liquid::{ObjectView, ValueView}; +use liquid::{Object, ObjectView, ValueView}; #[derive(Debug)] pub struct TemplateChecker; @@ -31,11 +31,11 @@ impl ObjectView for DummyField { } fn values<'k>(&'k self) -> Box + 'k> { - Box::new(std::iter::empty()) + Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter()) } fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { - Box::new(std::iter::empty()) + Box::new(self.keys().zip(self.values())) } fn contains_key(&self, index: &str) -> bool { @@ -69,7 +69,12 @@ impl ValueView for DummyField { } fn query_state(&self, state: State) -> bool { - DUMMY_VALUE.query_state(state) + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } } fn to_kstr(&self) -> KStringCow<'_> { @@ -77,7 +82,10 @@ impl ValueView for DummyField { } fn to_value(&self) -> LiquidValue { - LiquidValue::Nil + let mut this = Object::new(); + this.insert("name".into(), LiquidValue::Nil); + this.insert("value".into(), LiquidValue::Nil); + LiquidValue::Object(this) } fn as_object(&self) -> Option<&dyn ObjectView> { @@ -103,7 +111,12 @@ impl ValueView for DummyFields { } fn query_state(&self, state: State) -> bool { - DUMMY_VALUE.query_state(state) + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } } fn to_kstr(&self) -> KStringCow<'_> { @@ -111,7 +124,7 @@ impl ValueView for DummyFields { } fn to_value(&self) -> LiquidValue { - LiquidValue::Nil + LiquidValue::Array(vec![DummyField.to_value()]) } fn as_array(&self) -> Option<&dyn ArrayView> { @@ -125,15 +138,15 @@ impl ArrayView for DummyFields { } fn size(&self) -> i64 { - i64::MAX + u16::MAX as i64 } fn values<'k>(&'k self) -> Box + 'k> { - Box::new(std::iter::empty()) + Box::new(std::iter::once(DummyField.as_value())) } - fn contains_key(&self, _index: i64) -> bool { - true + fn contains_key(&self, index: i64) -> bool { + index < self.size() } fn get(&self, _index: i64) -> Option<&dyn ValueView> { @@ -167,7 +180,8 @@ impl ObjectView for DummyDoc { } fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { - Some(DUMMY_VALUE.as_view()) + // Recursively sends itself + Some(self) } } @@ -189,7 +203,12 @@ impl ValueView for DummyDoc { } fn query_state(&self, state: State) -> bool { - DUMMY_VALUE.query_state(state) + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } } fn to_kstr(&self) -> KStringCow<'_> { diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index bba6cf119..bc7f6fb08 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -516,7 +516,7 @@ pub fn execute_vector_search( ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; - /// FIXME: input universe = universe & documents_with_vectors + // FIXME: input universe = universe & documents_with_vectors // for now if we're computing embeddings for ALL documents, we can assume that this is just universe let ranking_rules = get_ranking_rules_for_vector( ctx, diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index d8d6c933c..6edde98fb 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -71,8 +71,8 @@ impl VectorStateDelta { pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - field_id_map: FieldsIdsMap, - prompt: Option<&Prompt>, + field_id_map: &FieldsIdsMap, + prompt: &Prompt, ) -> Result { puffin::profile_function!(); @@ -142,14 +142,11 @@ pub fn extract_vector_points( .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { // becomes autogenerated - match prompt { - Some(prompt) => VectorStateDelta::NowGenerated(prompt.render( - obkv, - DelAdd::Addition, - &field_id_map, - )?), - None => VectorStateDelta::NowRemoved, - } + VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + field_id_map, + )?) } else { VectorStateDelta::NowRemoved } @@ -162,26 +159,18 @@ pub fn extract_vector_points( .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { - match prompt { - Some(prompt) => { - // Don't give up if the old prompt was failing - let old_prompt = prompt - .render(obkv, DelAdd::Deletion, &field_id_map) - .unwrap_or_default(); - let new_prompt = - prompt.render(obkv, DelAdd::Addition, &field_id_map)?; - if old_prompt != new_prompt { - log::trace!( - "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" - ); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange - } - } - // We no longer have a prompt, so we need to remove any existing vector - None => VectorStateDelta::NowRemoved, + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange } } else { VectorStateDelta::NowRemoved @@ -196,24 +185,16 @@ pub fn extract_vector_points( .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { - match prompt { - Some(prompt) => { - // Don't give up if the old prompt was failing - let old_prompt = prompt - .render(obkv, DelAdd::Deletion, &field_id_map) - .unwrap_or_default(); - let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; - if old_prompt != new_prompt { - log::trace!( - "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" - ); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange - } - } - None => VectorStateDelta::NowRemoved, + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange } } else { VectorStateDelta::NowRemoved @@ -322,7 +303,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, -) -> Result<(grenad::Reader>, Option)> { +) -> Result>> { let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism @@ -341,8 +322,6 @@ pub fn extract_embeddings( let mut chunks_ids = Vec::with_capacity(n_chunks); let mut cursor = prompt_reader.into_cursor()?; - let mut expected_dimension = None; - while let Some((key, value)) = cursor.move_on_next()? { let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); // SAFETY: precondition, the grenad value was saved from a string @@ -367,7 +346,6 @@ pub fn extract_embeddings( .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), ) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -376,7 +354,6 @@ pub fn extract_embeddings( .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) { state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; - expected_dimension = Some(embeddings.dimension()); } chunks_ids.clear(); } @@ -387,7 +364,6 @@ pub fn extract_embeddings( let chunked_embeds = rt .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids .iter() @@ -395,7 +371,6 @@ pub fn extract_embeddings( .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) { state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; - expected_dimension = Some(embeddings.dimension()); } } @@ -403,14 +378,12 @@ pub fn extract_embeddings( let embeds = rt .block_on(embedder.embed(std::mem::take(&mut current_chunk))) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?; for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; - expected_dimension = Some(embeddings.dimension()); } } - Ok((writer_into_reader(state_writer)?, expected_dimension)) + writer_into_reader(state_writer) } diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 69530a507..4831cc69d 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -292,43 +292,42 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let (embedder, prompt) = embedders.get("default").cloned().unzip(); - let result = - extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); - match result { - Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { - /// FIXME: support multiple embedders - let results = embedder.and_then(|embedder| { - match extract_embeddings(prompts, indexer, embedder.clone()) { + for (name, (embedder, prompt)) in embedders { + let result = extract_vector_points( + documents_chunk_cloned.clone(), + indexer, + &field_id_map, + &prompt, + ); + match result { + Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { + let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { Ok(results) => Some(results), Err(error) => { let _ = lmdb_writer_sx_cloned.send(Err(error)); None } - } - }); - let (embeddings, expected_dimension) = results.unzip(); - let expected_dimension = expected_dimension.flatten(); - if !(remove_vectors.is_empty() - && manual_vectors.is_empty() - && embeddings.as_ref().map_or(true, |e| e.is_empty())) - { - /// FIXME FIXME FIXME - if expected_dimension.is_some() { + }; + + if !(remove_vectors.is_empty() + && manual_vectors.is_empty() + && embeddings.as_ref().map_or(true, |e| e.is_empty())) + { let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { remove_vectors, embeddings, - /// FIXME: compute an expected dimension from the manual vectors if any - expected_dimension: expected_dimension.unwrap(), + expected_dimension: embedder.dimensions(), manual_vectors, + embedder_name: name, })); } } + + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); + } } - Err(error) => { - let _ = lmdb_writer_sx_cloned.send(Err(error)); - } - }; + } }); // TODO: create a custom internal error diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 472c77111..c3c39b90f 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -435,7 +435,7 @@ where let mut word_docids = None; let mut exact_word_docids = None; - let mut dimension = None; + let mut dimension = HashMap::new(); for result in lmdb_writer_rx { if (self.should_abort)() { @@ -471,13 +471,15 @@ where remove_vectors, embeddings, manual_vectors, + embedder_name, } => { - dimension = Some(expected_dimension); + dimension.insert(embedder_name.clone(), expected_dimension); TypedChunk::VectorPoints { remove_vectors, embeddings, expected_dimension, manual_vectors, + embedder_name, } } otherwise => otherwise, @@ -513,14 +515,22 @@ where self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; - if let Some(dimension) = dimension { + for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; let vector_arroy = self.index.vector_arroy; + /// FIXME: unwrap + let embedder_index = + self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); pool.install(|| { - /// FIXME: do for each embedder + let writer_index = (embedder_index as u16) << 8; let mut rng = rand::rngs::StdRng::from_entropy(); for k in 0..=u8::MAX { - let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?; + let writer = arroy::Writer::prepare( + wtxn, + vector_arroy, + writer_index | (k as u16), + dimension, + )?; if writer.is_empty(wtxn)? { break; } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index da99ed685..dde2124ed 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -47,6 +47,7 @@ pub(crate) enum TypedChunk { embeddings: Option>>, expected_dimension: usize, manual_vectors: grenad::Reader>, + embedder_name: String, }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), } @@ -100,8 +101,8 @@ impl TypedChunk { TypedChunk::GeoPoints(grenad) => { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } - TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { - format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension) + TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => { + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -360,12 +361,20 @@ pub(crate) fn write_typed_chunk_into_index( manual_vectors, embeddings, expected_dimension, + embedder_name, } => { - /// FIXME: allow customizing distance + /// FIXME: unwrap + let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); + let writer_index = (embedder_index as u16) << 8; + // FIXME: allow customizing distance let writers: std::result::Result, _> = (0..=u8::MAX) .map(|k| { - /// FIXME: allow customizing index and then do index << 8 + k - arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension) + arroy::Writer::prepare( + wtxn, + index.vector_arroy, + writer_index | (k as u16), + expected_dimension, + ) }) .collect(); let writers = writers?; @@ -456,7 +465,7 @@ pub(crate) fn write_typed_chunk_into_index( } } - log::debug!("There are 🤷‍♀️ entries in the arroy so far"); + log::debug!("Finished vector chunk for {}", embedder_name); } TypedChunk::ScriptLanguageDocids(sl_map) => { for (key, (deletion, addition)) in sl_map { diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index b8355be51..1149dbce5 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -431,7 +431,6 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let embedder = Arc::new( Embedder::new(embedder_options.clone()) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?, ); Ok((name, (embedder, prompt))) @@ -976,6 +975,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Setting::NotSet => Some((name, EmbeddingSettings::default().into())), }) .collect(); + + self.index.embedder_category_id.clear(self.wtxn)?; + for (index, (embedder_name, _)) in new_configs.iter().enumerate() { + self.index.embedder_category_id.put_with_flags( + self.wtxn, + heed::PutFlags::APPEND, + embedder_name, + &index + .try_into() + .map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?, + )?; + } + if new_configs.is_empty() { self.index.delete_embedding_configs(self.wtxn)?; } else { @@ -1062,7 +1074,7 @@ fn validate_prompt( match new { Setting::Set(EmbeddingSettings { embedder_options, - prompt: + document_template: Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), }) => { // validate @@ -1072,7 +1084,7 @@ fn validate_prompt( Ok(Setting::Set(EmbeddingSettings { embedder_options, - prompt: Setting::Set(PromptSettings { + document_template: Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback, diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 1ae7a4678..c5cce622d 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -65,6 +65,8 @@ pub enum EmbedErrorKind { OpenAiTooManyTokens(OpenAiError), #[error("received unhandled HTTP status code {0} from OpenAI")] OpenAiUnhandledStatusCode(u16), + #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] + ManualEmbed(String), } impl EmbedError { @@ -111,6 +113,10 @@ impl EmbedError { pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } } + + pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { + Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } + } } #[derive(Debug, thiserror::Error)] @@ -170,6 +176,13 @@ impl NewEmbedderError { Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } } + pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + Self { + kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), + fault: FaultSource::Runtime, + } + } + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } } @@ -219,6 +232,8 @@ pub enum NewEmbedderErrorKind { NewApiFail(ApiError), #[error("fetching file from HG_HUB failed: {0}")] ApiGet(ApiError), + #[error("could not determine model dimensions: test embedding failed with {0}")] + CouldNotDetermineDimension(EmbedError), #[error("loading model failed: {0}")] LoadModel(candle_core::Error), // openai diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 81cdd4b34..07185d25c 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -62,6 +62,7 @@ pub struct Embedder { model: BertModel, tokenizer: Tokenizer, options: EmbedderOptions, + dimensions: usize, } impl std::fmt::Debug for Embedder { @@ -126,10 +127,17 @@ impl Embedder { tokenizer.with_padding(Some(pp)); } - Ok(Self { model, tokenizer, options }) + let mut this = Self { model, tokenizer, options, dimensions: 0 }; + + let embeddings = this + .embed(vec!["test".into()]) + .map_err(NewEmbedderError::hf_could_not_determine_dimension)?; + this.dimensions = embeddings.first().unwrap().dimension(); + + Ok(this) } - pub async fn embed( + pub fn embed( &self, mut texts: Vec, ) -> std::result::Result>, EmbedError> { @@ -170,12 +178,11 @@ impl Embedder { Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) } - pub async fn embed_chunks( + pub fn embed_chunks( &self, text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { - futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) - .await + text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() } pub fn chunk_count_hint(&self) -> usize { @@ -185,6 +192,10 @@ impl Embedder { pub fn prompt_count_in_chunk_hint(&self) -> usize { std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) } + + pub fn dimensions(&self) -> usize { + self.dimensions + } } fn normalize_l2(v: &Tensor) -> Result { diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 91640b8fb..7185e56b1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -3,6 +3,7 @@ use crate::prompt::PromptData; pub mod error; pub mod hf; +pub mod manual; pub mod openai; pub mod settings; @@ -67,6 +68,7 @@ impl Embeddings { pub enum Embedder { HuggingFace(hf::Embedder), OpenAi(openai::Embedder), + UserProvided(manual::Embedder), } #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] @@ -80,6 +82,7 @@ pub struct EmbeddingConfig { pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), OpenAi(openai::EmbedderOptions), + UserProvided(manual::EmbedderOptions), } impl Default for EmbedderOptions { @@ -93,7 +96,7 @@ impl EmbedderOptions { Self::HuggingFace(hf::EmbedderOptions::new()) } - pub fn openai(api_key: String) -> Self { + pub fn openai(api_key: Option) -> Self { Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) } } @@ -103,6 +106,9 @@ impl Embedder { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + EmbedderOptions::UserProvided(options) => { + Self::UserProvided(manual::Embedder::new(options)) + } }) } @@ -111,8 +117,9 @@ impl Embedder { texts: Vec, ) -> std::result::Result>, EmbedError> { match self { - Embedder::HuggingFace(embedder) => embedder.embed(texts).await, + Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::OpenAi(embedder) => embedder.embed(texts).await, + Embedder::UserProvided(embedder) => embedder.embed(texts), } } @@ -121,8 +128,9 @@ impl Embedder { text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { match self { - Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } @@ -130,6 +138,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), + Embedder::UserProvided(_) => 1, } } @@ -137,6 +146,15 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::UserProvided(_) => 1, + } + } + + pub fn dimensions(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.dimensions(), + Embedder::OpenAi(embedder) => embedder.dimensions(), + Embedder::UserProvided(embedder) => embedder.dimensions(), } } } diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 670dc8526..bab62f5e4 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -15,7 +15,7 @@ pub struct Embedder { #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { - pub api_key: String, + pub api_key: Option, pub embedding_model: EmbeddingModel, } @@ -68,11 +68,11 @@ impl EmbeddingModel { pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; impl EmbedderOptions { - pub fn with_default_model(api_key: String) -> Self { + pub fn with_default_model(api_key: Option) -> Self { Self { api_key, embedding_model: Default::default() } } - pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self { + pub fn with_embedding_model(api_key: Option, embedding_model: EmbeddingModel) -> Self { Self { api_key, embedding_model } } } @@ -80,9 +80,14 @@ impl EmbedderOptions { impl Embedder { pub fn new(options: EmbedderOptions) -> Result { let mut headers = reqwest::header::HeaderMap::new(); + let mut inferred_api_key = Default::default(); + let api_key = options.api_key.as_ref().unwrap_or_else(|| { + inferred_api_key = infer_api_key(); + &inferred_api_key + }); headers.insert( reqwest::header::AUTHORIZATION, - reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key)) + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) .map_err(NewEmbedderError::openai_invalid_api_key_format)?, ); headers.insert( @@ -315,6 +320,10 @@ impl Embedder { pub fn prompt_count_in_chunk_hint(&self) -> usize { 10 } + + pub fn dimensions(&self) -> usize { + self.options.embedding_model.dimensions() + } } // retrying in case of failure @@ -414,3 +423,9 @@ struct OpenAiEmbedding { // object: String, // index: usize, } + +fn infer_api_key() -> String { + std::env::var("MEILI_OPENAI_API_KEY") + .or_else(|_| std::env::var("OPENAI_API_KEY")) + .unwrap_or_default() +} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 2c0cf7924..f90c3cc71 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -15,14 +15,14 @@ pub struct EmbeddingSettings { pub embedder_options: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] - pub prompt: Setting, + pub document_template: Setting, } impl EmbeddingSettings { pub fn apply(&mut self, new: Self) { - let EmbeddingSettings { embedder_options, prompt } = new; + let EmbeddingSettings { embedder_options, document_template: prompt } = new; self.embedder_options.apply(embedder_options); - self.prompt.apply(prompt); + self.document_template.apply(prompt); } } @@ -30,7 +30,7 @@ impl From for EmbeddingSettings { fn from(value: EmbeddingConfig) -> Self { Self { embedder_options: Setting::Set(value.embedder_options.into()), - prompt: Setting::Set(value.prompt.into()), + document_template: Setting::Set(value.prompt.into()), } } } @@ -38,7 +38,7 @@ impl From for EmbeddingSettings { impl From for EmbeddingConfig { fn from(value: EmbeddingSettings) -> Self { let mut this = Self::default(); - let EmbeddingSettings { embedder_options, prompt } = value; + let EmbeddingSettings { embedder_options, document_template: prompt } = value; if let Some(embedder_options) = embedder_options.set() { this.embedder_options = embedder_options.into(); } @@ -105,6 +105,7 @@ impl From for PromptData { pub enum EmbedderSettings { HuggingFace(Setting), OpenAi(Setting), + UserProvided(UserProvidedSettings), } impl Deserr for EmbedderSettings @@ -145,11 +146,17 @@ where location.push_key(&k), )?, ))), + "userProvided" => Ok(EmbedderSettings::UserProvided( + UserProvidedSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + )), other => Err(deserr::take_cf_content(E::error::( None, deserr::ErrorKind::UnknownKey { key: other, - accepted: &["huggingFace", "openAi"], + accepted: &["huggingFace", "openAi", "userProvided"], }, location, ))), @@ -182,6 +189,9 @@ impl From for EmbedderSettings { crate::vector::EmbedderOptions::OpenAi(openai) => { Self::OpenAi(Setting::Set(openai.into())) } + crate::vector::EmbedderOptions::UserProvided(user_provided) => { + Self::UserProvided(user_provided.into()) + } } } } @@ -192,9 +202,12 @@ impl From for crate::vector::EmbedderOptions { EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), - EmbedderSettings::OpenAi(_setting) => Self::OpenAi( - crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()), - ), + EmbedderSettings::OpenAi(_setting) => { + Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None)) + } + EmbedderSettings::UserProvided(user_provided) => { + Self::UserProvided(user_provided.into()) + } } } } @@ -286,7 +299,7 @@ impl OpenAiEmbedderSettings { impl From for OpenAiEmbedderSettings { fn from(value: crate::vector::openai::EmbedderOptions) -> Self { Self { - api_key: Setting::Set(value.api_key), + api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset), embedding_model: Setting::Set(value.embedding_model), } } @@ -295,14 +308,25 @@ impl From for OpenAiEmbedderSettings { impl From for crate::vector::openai::EmbedderOptions { fn from(value: OpenAiEmbedderSettings) -> Self { let OpenAiEmbedderSettings { api_key, embedding_model } = value; - Self { - api_key: api_key.set().unwrap_or_else(infer_api_key), - embedding_model: embedding_model.set().unwrap_or_default(), - } + Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() } } } -fn infer_api_key() -> String { - /// FIXME: get key from instance options? - std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default() +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct UserProvidedSettings { + pub dimensions: usize, +} + +impl From for crate::vector::manual::EmbedderOptions { + fn from(value: UserProvidedSettings) -> Self { + Self { dimensions: value.dimensions } + } +} + +impl From for UserProvidedSettings { + fn from(value: crate::vector::manual::EmbedderOptions) -> Self { + Self { dimensions: value.dimensions } + } }